My takeaway series follow a Q&A format to explain AI concepts at three levels:
Anyone with general knowledge can understand them.
For anyone who wants to dive into the code implementation details of the concept.
For anyone who wants to understand the mathematics behind the technique.
Long Short-Term Memory (LSTM) is a type of advanced recurrent neural network (RNN) architecture. It is designed to better capture long-range dependencies in sequential data.
The LSTM is basically an enhanced version of the vanilla RNN, with a more complex internal structure to manage information flow:


The \(\mathbf{h}_t\) is equivalent to the hidden state in vanilla RNN, which connects to the output \(y_t\).
Unlike the vanilla RNN unit that directly updates the hidden state \(\mathbf{h}_t\) from the previous hidden state \(\mathbf{h}_{t-1}\) and the current input \(\mathbf{x}_t\), the LSTM unit has a more complex mechanism to update the hidden state.
- The hidden state \(\mathbf{h}_t\) is now gated from an intermediate state \(\mathbf{c}_{t}\):
\[\mathbf{h}_t = \mathbf{o}_{t} \odot \sigma_h (\mathbf{c}_t) \tag{1}\]
- \(\mathbf{c}_t\) (called the cell state) is an intermediate state variable that the hidden state \(\mathbf{h}_t\) is gated from. Now it serves as a long-term memory that can carry information across many time steps, because it is updated using the previous cell state \(\mathbf{c}_{t-1}\) and a new candidate cell state \(\tilde{\mathbf{c}}_t\):
\[\mathbf{c}_t = \mathbf{f}_t \odot \mathbf{c}_{t-1} + \mathbf{i}_t \odot \tilde{\mathbf{c}}_t \tag{2}\]
- \(\mathbf{o}_{t}\) (called the output gate) is a gate variable that controls how much of the cell state \(\mathbf{c}_t\) is exposed to the hidden state \(\mathbf{h}_t\). It is computed similarly to the hidden state in vanilla RNN:
\[\mathbf{o}_{t} = \sigma_g(\mathbf{W}_o [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_o) \tag{3}\]
- The candidate cell state \(\tilde{\mathbf{c}}_t\), and the weights in the formula \(\mathbf{f}_t\) (called forget gate), \(\mathbf{i}_t\) (called input gate) are computed similarly to the hidden state in vanilla RNN:
\[\tilde{\mathbf{c}}_t = \sigma_c(\mathbf{W}_c [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_c) \tag{4}\]
\[\mathbf{f}_t = \sigma_g(\mathbf{W}_f [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_f) \tag{5}\]
\[\mathbf{i}_t = \sigma_g(\mathbf{W}_i [\mathbf{h}_{t-1}, \mathbf{x}_t] + \mathbf{b}_i) \tag{6}\]
Anyway, LSTM is a complex manual design by researchers to solve the problems of vanilla RNN. It is not easy to remember all the details.
LSTM has both long-term and short-term memory capabilities. Long-term memory is represented by the cell state \(\mathbf{c}_t\), which can carry information across many time steps (see Equation 2). Short-term memory is represented by all other variables computed similarly to the hidden state vanilla RNN, including \(\mathbf{o}_t\) (see Equation 3), \(\mathbf{f}_t\) (see Equation 5), and \(\mathbf{i}_t\) (see Equation 6). These variables are updated following the same logic, which captures short-term dependencies in the sequence.
Vanilla RNN suffers from the vanishing / exploding gradient problem (see this article), which makes the model hard to train when the sequence is long. That is because the hidden state \(\mathbf{h}_t\) is updated from the previous hidden state \(\mathbf{h}_{t-1}\) through a non-linear gate function. The vanishing / exploding gradient problem is raised by the non-linear gate function during backpropagation, see this article.
However, in LSTM, the hidden state \(\mathbf{h}_t\) is computed from the cell state \(\mathbf{c}_t\), which is updated from the previous cell state \(\mathbf{c}_{t-1}\) through a linear interaction (see Equation 2). This linear interaction allows gradients to flow more easily across many time steps during backpropagation, alleviating the vanishing / exploding gradient problem. Specifically, \(\frac{\partial \mathbf{c}_t}{\partial \mathbf{c}_{t-1}} = \mathbf{f}_t\), where \(\mathbf{f}_t\) is the forget gate (see Equation 5). This is naturally learnt without being forcefully squeezed by function like the derivative of sigmoid or tanh function (see this article).
Secondly, the vanilla RNN can only capture short-range dependencies (see this article). LSTM, as the name suggests, can capture both long-range and short-range dependencies in the sequence (see the above question), better than vanilla RNN.