Introduction

In this multi-part series, we look inside LSTM forward pass. If you haven’t already read it I suggest run through the previous parts (part-1) before you come back here. Once you are back, in this article, we explore the meaning, math and the implementation of an LSTM cell. I do this using the first principles approach for which I have pure python implementation Deep-Breathe of most complex Deep Learning models.

What this article is not about?

  • This article will not talk about the conceptual model of LSTM, on which there is some great existing material here and here in the order of difficulty.
  • This is not about the differences between vanilla RNN and LSTMs, on which there is an awesome post by Andrej if a somewhat difficult post.
  • This is not about how LSTMs mitigate the vanishing gradient problem, on which there is a little mathy but awesome posts here and here in the order of difficulty

What this article is about?

  • Using the first principles we picturize/visualize the forward pass of an LSTM cell.
  • Then we associate code with the picture to make our understanding more concrete

Context

This is the same example and the context is the same as described in part-1. The focus however, this time is on LSTMCell.

LSTM Cell

While there are many variations to the LSTM cell. In the current article, we are looking at LayerNormBasicLSTMCell.

  • Weights and biases of the LSTM cell.
  • Shapes color-coded.
Image: figure-1: Batch=1, State=5(internal states(C and H)), input_size=10(assuming vocab size of 10 translates to one-hot of size 10) .
figure-1: Batch=1, State=5(internal states(C and H)), input_size=10(assuming vocab size of 10 translates to one-hot of size 10) .
  • Weights, biases of the LSTM cell. Also shown is the cell state (c-state) over and above the h-state of a vanilla RNN.
  • ht-1 may be initialized by default with zeros.
Image: figure-2: Weights and biases for all the gates f(forget), i(input), c(candidatecellstate), o(output). The Shapes of these <strong>weights</strong> are <strong>5X15</strong> individually and stacked vertically as shown they measure <strong>20X15</strong>. The Shapes of the <strong>biases</strong> are <strong>5X1</strong> individually and stacked up <strong>20X1</strong>.
figure-2: Weights and biases for all the gates f(forget), i(input), c(candidatecellstate), o(output). The Shapes of these weights are 5X15 individually and stacked vertically as shown they measure 20X15. The Shapes of the biases are 5X1 individually and stacked up 20X1.
  • Weights and biases in the code.
  • Often the code accepts constants as weights for comparison with Tensorflow
Image: figure-3: <strong>4*hidden_size(5),input_size(10)+hidden_size(5)+1(biases)</strong>. Making it <strong>20X16</strong>, biases allocated with weights is a matter of convenience and performance.
figure-3: 4*hidden_size(5),input_size(10)+hidden_size(5)+1(biases). Making it 20X16, biases allocated with weights is a matter of convenience and performance.
  • Equations of LSTM and their interplay with weights and biases.
  • ht-1 and ct-1 may be initialized by default with zeros(Language Model) or some non-zero state for conditional language model.
Image: figure-4: Equations of LSTM and their interplay with weights and biases.
figure-4: Equations of LSTM and their interplay with weights and biases.
  • GEMM(General Matrix Multiplication) in code
  • h and x concatenation is only a performance convenience.
Image: figure-5: GEMM and running them though various gates.
figure-5: GEMM and running them though various gates.
  • The mathematical reason why the vanishing gradient is mitigated.
  • There are many theories though. Look into this in detail with back-propagation.
Image: figure-6: New cell state is a function of old cell state and new candidate state.
figure-6: New cell state is a function of old cell state and new candidate state.
  • New cell state(c-state) in code
  • This is used in 2 ways as illustrated in figure-8.
Image: figure-7: The previous steps already decided what to do, we just need to actually do it.
figure-7: The previous steps already decided what to do, we just need to actually do it.
  • c-statet is used to calculate the “externally visible” h-statet
  • Sometimes referred to as ct, ht, used as the states at the next time step.
Image: figure-8: Conceptually it is the same cell transitioning into the next state. <strong>In this case c<sub>t</sub>.</strong>
figure-8: Conceptually it is the same cell transitioning into the next state. In this case ct.
  • New h-state in code
  • The transition to the next time step is complete with ht.
Image: figure-9: Conceptually it is the same cell transitioning into the next state. <strong>In this case h<sub>t</sub>.</strong>
figure-9: Conceptually it is the same cell transitioning into the next state. In this case ht.
  • ht multiplied by the “Wy” and “by” added to it.
  • Traditionally we call this value “logits” or “preds”.
Image: figure-10: Wy shape is 10X5 multiplied by h<sub>t</sub> which is 5X1 and by added to it which is 10X1 as shown above.
figure-10: Wy shape is 10X5 multiplied by ht which is 5X1 and by added to it which is 10X1 as shown above.
  • Usually, this is done in on the client-side code because only the clients know the model and when the loss calculation can be done.
  • Tensorflow version of the code.
Image: figure-11: DEEP-Breathe version of the <a href='https://github.com/slowbreathing/Deep-Breathe/blob/f9585bde9cbb61e71f67ccd936aa22a155c36709/org/mk/training/dl/LSTMMainGraph.py#L98-L109'>code</a>.
figure-11: DEEP-Breathe version of the code.
  • Usually done internally by “cross_entropy_loss” function.
  • Shown here just so that you get an idea.
Image: figure-12: Traditionally called yhat.
figure-12: Traditionally called yhat.
  • Does a combination of softmax(figure-12) and loss calculation(figure-13).
  • Softmax and a cross_entropy_loss to jog your memory.
Image: figure-13: Similar to logistic regression and its cost function.
figure-13: Similar to logistic regression and its cost function.
  • Summary with weights.
Image: figure-15: <strong>Summary with weights.</strong>
figure-15: Summary with weights.
  • Summary as flow diagram.
Image: figure-14: <strong>Summary as simple flow diagram.</strong>
figure-14: Summary as simple flow diagram.
  • A Sequence of 3 shown below, but that can depend on the use-case.
  • Tensorflow version of the code and DEEP-Breathe version of the code.
Image: figure-16: The code below is merely shown to give you an idea of the sequence. In practice the 'dynamic_rnn' function does iteration part, the preds are multiplied 'Wy' added to 'By' by client code and then handed over to 'cross_entropy_loss' function.
figure-16: The code below is merely shown to give you an idea of the sequence. In practice the 'dynamic_rnn' function does iteration part, the preds are multiplied 'Wy' added to 'By' by client code and then handed over to 'cross_entropy_loss' function.

Summary

In summary, then, that was the walk through of LSTM’s forward pass. As a study in contrast, if building a Language model that predicts the next word in the sequence, the training would be similar but we’ll calculate loss at every step. The label would be the ‘X’ just one time-step advanced. However, Let’s not get ahead of ourselves, before we get to a language model, let’s look at the backward pass(Back Propagation) for LSTM in the next post.