5 Types of LSTM Recurrent Neural Networks and What to Do With Them
The Primordial Soup of Vanilla RNNs and Reservoir Computing
Using past experience for improved future performance is a cornerstone of deep learning and of machine learning in general. One definition of machine learning lays out the importance of improving with experience explicitly:
A computer program is said to learn from experience E with respect to some class of tasks T and performance measure P, if its performance at tasks in T, as measured by P, improves with experience E.
-Tom Mitchell. Machine Learning
In neural networks, performance improvement with experience is encoded as a very long term memory in the model parameters, the weights. After learning from a training set of annotated examples, a neural network is more likely to make the right decision when shown additional examples that are similar but previously unseen. This is the essence of supervised deep learning on data with a clear one to one matching, e.g. a set of images that map to one class per image (cat, dog, hotdog, etc.).
There are many instances where data naturally forms sequences and in those cases, order and content are equally important. Other examples of sequence data include video, music, DNA sequences, and many others. When learning from sequence data, short term memory becomes useful for processing a series of related data with ordered context. For this, machine learning researchers have long turned to the recurrent neural network, or RNN.
A standard RNN is essentially a feed-forward neural network unrolled in time. This arrangement can be simply attained by introducing weighted connections between one or more hidden states of the network and the same hidden states from the last time point, providing some short term memory. The challenge is that this short-term memory is fundamentally limited in the same way that training very deep networks is difficult, making the memory of vanilla RNNs very short indeed.
Learning by back-propagation through many hidden layers is prone to the vanishing gradient problem. Without going into too much detail, the operation typically entails repeatedly multiplying an error signal by a series of values (the activation function gradients) less than 1.0, attenuating the signal at each layer. Back-propagating through time has the same problem, fundamentally limiting the ability to learn from relatively long-term dependencies.
The Development of New Activation Functions in Deep Learning Neural Networks
For deep learning with feed-forward neural networks, the challenge of vanishing gradients led to the popularity of new activation functions (like ReLUs) and new architectures (like ResNet and DenseNet). For RNNs, one early solution was to skip training the recurrent layers altogether, instead initializing them in such a way that they perform a chaotic non-linear transformation of the input data into higher dimensional representations.
Recurrent feedback and parameter initialization is chosen such that the system is very nearly unstable, and a simple linear layer is added to the output. Learning is limited to that last linear layer, and in this way it’s possible to get reasonably OK performance on many tasks while avoiding dealing with the vanishing gradient problem by ignoring it completely. This sub-field of computer science is called reservoir computing, and it even works (to some degree) using a bucket of water as a dynamic reservoir performing complex computations.
Reservoir-Type RNNs Are Insufficient for a Few Reasons
With quasi-stable dynamic reservoirs, the effect of any given input can persist for a very long time. However, reservoir-type RNNs are still insufficient for several reasons: 1) the dynamic reservoir must be very near unstable for long-term dependencies to persist, so continued stimuli could cause the output to blow up over time and 2) there’s still no direct learning on the lower/earlier parts of the network. Sepp Hochreiter’s 1991 diploma thesis (pdf in German) described the fundamental problem of vanishing gradients in deep neural networks, paving the way for the invention of Long Short-Term Memory (LSTM) recurrent neural networks by Sepp Hochreiter and Jürgen Schmidhuber in 1997.
LSTMs can learn long-term dependencies that “normal” RNNs fundamentally can’t. The key insight behind this ability is a persistent module called the cell-state that comprises a common thread through time, perturbed only by a few linear operations at each time step. Due to the fact that the cell state connection to previous cell states is interrupted only by the linear operations of multiplication and addition, LSTMs and their variants can remember short-term memories (i.e. activity belonging to the same “episode”) for a very long time.
A number of modifications to the original LSTM architecture have been suggested over the years, but it may come as a surprise that the classic variant continues to achieve state of the art results on a variety of cutting-edge tasks over 20 years later. That being said, what are some LSTM variants and what are they good for?
1. LSTM Classic
The classic LSTM architecture is characterized by a persistent linear cell state surrounded by non-linear layers feeding input and parsing output from it. Concretely the cell state works in concert with 4 gating layers, these are often called the forget, (2x) input, and output gates.
The forget gate chooses what values of the old cell state to get rid of, based on the current input data. The two input gates (often denoted i and j) work together to decide what to add to the cell state depending on the input. i and j typically have different activation functions, which we intuitively expect to be used to suggest a scaling vector and candidate values to add to the cell state.
Finally, the output gate determines what parts of the cell state should be passed on to the output. Note that in the case of classic LSTMs, the output h consists of hidden layer activations (these can be subjected to further layers for classification, for example) and the input consists of the previous hidden state output and any new data x provided at the current time step.
The original LSTM immediately improved upon the state of the art on a set of synthetic experiments with long time lags between relevant pieces of data. Fast forward to today, and we still see the classic LSTM forming a core element of state-of-the-art reinforcement learning breakthroughs like the Dota 2 playing team OpenAI Five.
Examining the policy architecture in more detail (pdf), you can see that while each agent employs a number of dense ReLU layers for feature extraction and final decision classification, a 1024-unit LSTM forms the core representation of each agent’s experience of the game. A similar arrangement was used by OpenAI to train a Shadow robotic hand from scratch to manipulate a colored cube to achieve arbitrary rotations.
2. Peephole Connections
The classic LSTM overcomes the problem of gradients vanishing in a recurrent neural network unrolled in time by connecting all time points via a persistent cell state (often called a “constant error carousel” in early papers describing LSTMs). However, the gating layers that determine what to forget, what to add, and even what to take from the cell state as output don’t take into account the contents of the cell itself.
Intuitively, it makes sense that an agent or model would want to know the memories it already has in place before replacing them with new. Enter LSTM peephole connections. This modification (shown in dark purple in the figure above) simple concatenates the cell state contents to the gating layer inputs. In particular, this configuration was shown to offer an improved ability to count and time distances between rare events when this variant was originally introduced. Providing some cell-state connections to the layers in an LSTM remains a common practice, although specific variants differ in exactly which layers are provided access.
3. Gated Recurrent Unit
Diagrammatically, a Gated Recurrent Unit (GRU) looks more complicated than a classical LSTM. In fact, it’s a bit simpler, and due to its relative simplicity trains a little faster than the traditional LSTM. GRUs combine the gating functions of the input gate j and the forget gate f into a single update gate z.
Practically that means that cell state positions earmarked for forgetting will be matched by entry points for new data. Another key difference of the GRU is that the cell state and hidden output h have been combined into a single hidden state layer, while the unit also contains an intermediate, internal hidden state.
Gated Recurrent Units (GRUs) have been used for the basis for demonstrating exotic concepts like Neural GPUs as well as a simpler model for sequence to sequence learning in general, such as machine translation. GRUs are a capable LSTM variant and they have been fairly popular since their inception. While they can learn quickly on tasks like music or text generation, they have been described as ultimately less powerful than classic LSTMs due to their limitations in counting.
4. Multiplicative LSTM (2017)
Multiplicative LSTMs (mLSTMs) were introduced by Krause et al 2016. Since then, this complicated variant has been the centerpiece of a number of high-profile, state of the art achievements in natural language processing. Perhaps the most well-known of these is OpenAI’s unsupervised sentiment neuron.
Researchers on the project that by pre-training a big mLSTM model on unsupervised text prediction it became much more capable and could perform at a high level on a battery of NLP tasks with minimal fine-tuning. A number of interesting features in the text (such as sentiment) were emergently mapped to specific neurons.
Remarkably, the same phenomenon of interpretable classification neurons emerging from unsupervised learning has been reported in end-to-end protein sequences learning. On next-residue prediction tasks of protein sequences, multiplicative LSTM models apparently learn internal representations corresponding to fundamental secondary structural motifs like alpha helices and beta sheets. Protein sequence and structure is an area ripe for major breakthroughs from unsupervised and semi-supervised sequence learning models.
Although the amount of sequence data has been increasing exponentially for the last few years, available protein structure data increases at a much more leisurely pace. Therefore, the next big AI upset in the protein folding field will probably involve some degree of unsupervised learning on pure sequences, and may even eclipse Deepmind’s upset at the CASP13 protein folding challenge.
5. LSTMs With Attention
Finally, we arrive at what is probably the most transformative innovation in sequence models in recent memory*. Attention in machine learning refers to a model’s ability to focus on specific elements in data, in our case the hidden state outputs of LSTMs. Wu et al. at Google used an architecture consisting of an attention network sandwiched between encoding and decoding LSTM layers to achieve state of the art Neural Machine Translation.
This likely continues to power Google Translate to this day. OpenAI’s demonstration of tool use in a hide-and-seek reinforcement learning environment is a recent example of the capability of LSTMs with attention on a complex, unstructured task.
The significant successes of LSTMs with attention in natural language processing foreshadowed the decline of LSTMs in the best language models. With increasingly powerful computational resources available for NLP research, state-of-the-art models now routinely make use of a memory-hungry architectural style known as the transformer.
Transformers do away with LSTMs in favor of feed-forward encoder/decoders with attention. Attention transformers obviate the need for cell-state memory by picking and choosing from an entire sequence fragment at once, using attention to focus on the most important parts. BERT, ELMO, GPT-2 and other major language models all follow this approach.
On the other hand, state-of-the-art NLP models incur a significant economic and environmental impact to train from scratch, requiring resources available mainly to research labs associated with wealthy tech companies. The massive energy requirements for these big transformer models makes transfer learning all the more important, but it also leaves plenty of room for LSTM-based sequence-to-sequence models to make meaningful contributions on tasks sufficiently different than those the big language transformers are trained for.
What’s the Best LSTM for Your Next Project?
In this article, we’ve discussed a number of LSTM variants, all with their own pros and cons. We’ve covered a lot of ground, but in fact, we’ve only scratched the surface of both what is possible and what has been tried. The good news is that a well-put-together and bug-free LSTM will probably perform just as well as any of the more esoteric variants for most sequence-to-sequence learning tasks, and is definitely still capable of state-of-the-art performance in challenging reinforcement learning environments as we’ve discussed above.
Several articles have compared LSTM variants and their performance on a variety of typical tasks. In general, the 1997 original performs about as well as the newer variants, and paying attention to details like bias initialization is more important than the exact architecture used. Josefowicz et al. analyzed the performance of more than 10,000 different LSTM permutations, some from the literature but most generated as LSTM “mutants,” and found that some of the mutations did perform better than both the classic LSTM and the GRU variants on some, but not all, of the tasks studied.
Importantly, they found that by initializing the forget gate with a large bias term they saw significantly improved performance of the LSTM. This bias is important to avoid handicapping the forget gate with a vanishing gradient, a naive initialization with small random values and sigmoid activation leads to forget gate values centered around 0.5 that will rapidly attenuate the ability to learn long-term dependencies.
Understanding LSTM Is Crucial for Good Performance in Your Project
Ultimately, the best LSTM for your project will be the one that is best optimized and bug-free, so understanding how it works in detail is important. Architectures like the GRU offer good performance and simplified architecture, while variants like multiplicative LSTMs are generating intriguing results in unsupervised sequence-to-sequence tasks.
On the other hand, relatively complicated variants like mLSTM intrinsically introduce increased complexity, aka “bug-attracting surfaces.” To get started, it may be a good idea to build a simplified variant and work with a few toy problems to fully understand the interplay of cell states and gating layers.
Additionally, if your project has plenty of other complexity to consider (e.g. in a complex reinforcement learning problem) a simpler variant makes more sense to start with. With experience building and optimizing relatively simple LSTM variants and deploying these on reduced versions of your primary problem, building complex models with multiple LSTM layers and attention mechanisms becomes possible.
Finally, if your goals are more than simply didactic and your problem is well-framed by previously developed and trained models, “don’t be a hero”. It’s more efficient to take a pre-trained model and fine-tune to your needs. The savings apply both in terms of economic and environmental cost as well as developer time, and performance can be as good or better thanks to the massive data and computational resources that major AI labs bring to bear to train a state-of-the-art NLP model.
* Both puns, unfortunately, were fully intended.