Introduction

In this multi-part series, we look inside LSTM forward pass. If you haven’t already read it I suggest run through the previous parts (part-1,part-2) before you come back here. Once you are back, in this article, we explore LSTM’s Backward Propagation. This would usually involve lots of math. I am not a trained mathematician, however, I have decent intuition. But just intuition can only get you so far. So I used my programming skills to validate pure theoretical results often cited by a pure mathematician. I do this using the first principles approach for which I have pure python implementation Deep-Breathe of most complex Deep Learning models.

What this article is not about?

  • This article will not talk about the conceptual model of LSTM, on which there is some great existing material here and here in the order of difficulty.
  • This is not about the differences between vanilla RNN and LSTMs, on which there is an awesome post by Andrej if a somewhat difficult post.
  • This is not about how LSTMs mitigate the vanishing gradient problem, on which there is a little mathy but awesome posts here and here in the order of difficulty.

What this article is about?

  • Using the first principles we picturize/visualize the Backward Propagation.
  • Then we use a heady mixture of intuition, Logic, and programming to prove mathematical results.

Context

This is the same example and the context is the same as described in part-1. The focus, however, this time is on LSTMCell.

LSTM Cell Backward Propagation

While there are many variations to the LSTM cell. In the current article, we are looking at LayerNormBasicLSTMCell. The next section is divided into 3 parts.

  1. LSTM Cell Forward Propagation(Summary)
  2. LSTM Cell Backward Propagation(Summary)
  3. LSTM Cell Backward Propagation(Detailed)

1. LSTM Cell Forward Propagation(Summary)

  • Forward propagation:Summary with weights.
Image: figure-1: <strong>Forward propagation:Summary with weights.</strong>
figure-1: Forward propagation:Summary with weights.
  • Forward propagation:Summary as flow diagram.
Image: figure-2: <strong>Forward propagation:Summary as simple flow diagram.</strong>
figure-2: Forward propagation:Summary as simple flow diagram.
  • Forward propagation: The complete picture.
Image: figure-3: Forward propagation: The complete picture.
figure-3: Forward propagation: The complete picture.

2. LSTM Cell Backward Propagation(Summary)

  • Backward Propagation through time or BPTT is shown here in 2 steps.
  • Step-1 is depicted in Figure-4 where it backward propagates through the FeedForward network calculating Wy and By
Image: figure-4: <strong>Step-1:</strong><strong>Wy</strong> and <strong>By</strong> first.
figure-4: Step-1:Wy and By first.
  • Step-2 is depicted in Figure-5, Figure-6 and Figure-7 where it backward propagates through the LSTMCell. This is time step-3 or the last one. BPTT is carried out in reverse order.
Image: figure-5: <strong>Step-2:</strong><strong>3rd</strong> time step <strong>W<sub>f</sub>,W<sub>i</sub>,W<sub>c</sub>,W<sub>o</sub></strong> and <strong>B<sub>f</sub>,B<sub>i</sub>,B<sub>c</sub>,B<sub>o</sub></strong>
figure-5: Step-2:3rd time step Wf,Wi,Wc,Wo and Bf,Bi,Bc,Bo
  • Step-2 LSTMCell BPTT in a simple reverse flow diagram.

Image: figure-6: <strong>Step-2:</strong>LSTMCell BPTT in a simple reverse flow diagram.
figure-6: Step-2:LSTMCell BPTT in a simple reverse flow diagram.
  • Step-2 Repeat for remaining time steps in the sequence.
Image: figure-7: <strong>Step-2:</strong> Repeat for remaining time steps in the sequence
figure-7: Step-2: Repeat for remaining time steps in the sequence

3. LSTM Cell Backward Propagation(Detailed)

The gradient calculation through the complete network is divided into 2 parts

  1. Feed-Forward Network layer(Step-1)
  2. Backward propagation through LSTMCell(Step-2)

Feed-Forward Network layer(Step-1)
  • We have already done this derivation for “softmax” and “cross_entropy_loss”.
  • The final derived form in eq-3. Which in turn depends on eq-1 and eq-2. We are calling it “pred” instead of “logits” in figure-5.
Image: figure-8: <strong>The proof, intuition and program implementation</strong> of the derivation can be found at <strong><a href='/articles/2019-05/softmax-and-its-gradient'>softmax</a></strong> and <strong><a href='/articles/2019-05/softmax-and-cross-entropy'>cross_entropy_loss</a></strong>.
figure-8: The proof, intuition and program implementation of the derivation can be found at softmax and cross_entropy_loss.
  • “DBy” is straight forward, once the gradients for pred is clear in the previous step.
  • This is just gradient calculation, the weights aren’t adjusted just yet. They will be adjusted together in the end as a matter of convenience by an optimizer such as GradientDescentOptimizer.
  • The Complete code of listing can be found at code-1.
Image: figure-9: <strong>DBy</strong>
figure-9: DBy
  • “DWy” is straight forward too, once the gradients for pred is clear in the previous step(figure-5).
  • The Complete code of listing can be found at code-2.
Image: figure-10: <strong>DWy</strong>
figure-10: DWy
  • “DWy” and “Dby”
Image: figure-11: <strong>DWy</strong> and <strong>Dby</strong>. Only the gradients are calculated, the weights are adjusted once in the end.
figure-11: DWy and Dby. Only the gradients are calculated, the weights are adjusted once in the end.

Backward propagation through LSTMCell(Step-2)

There are several gradients we are calculating, figure-6 has a complete list.

  1. DHt
  2. DOt
  3. DCt
  4. DCprojt
  5. DFt
  6. DIt
  7. DHX
  8. DCt

DHt
  • Probably the most confusing step backward propagation because there is an element of time.
  • There were 2 recursive elements h-state, c-state, similarly there 2 recursive elements while deriving the gradient backwards. We call it dh_next and dc_next, as you will see later both and many other gradients are derived from dHt.
Image: figure-12: <strong>DHt</strong>
figure-12: DHt
Image: figure-13: <strong>DHt</strong>
figure-13: DHt
  • dHt schematically.Shape is the same as h, for this example it is (5X1)
Image: figure-14: <strong>DHt</strong> Shape is the same as h, for this example it is (5X1).
figure-14: DHt Shape is the same as h, for this example it is (5X1).

DOt
  • Shape of DOt is the same as h, for this example it is (5X1)
  • The complete listing is at code-6.
Image: figure-15: <strong>DOt</strong>.
figure-15: DOt.
  • DOt schematically.
Image: figure-16: <strong>DOt</strong> schematically.
figure-16: DOt schematically.
  • DOt, DWo and DBo
  • Complete listing can be found at code-7
Image: figure-17: <strong>DOt</strong>, <strong>DWo</strong> and <strong>DBo</strong> and there shapes
figure-17: DOt, DWo and DBo and there shapes

DCt
  • Shape of DCt is the same as h, for this example it is (5X1)
  • The complete listing is at code-8.
Image: figure-18: <strong>DCt</strong>.
figure-18: DCt.
  • DCt schematically.
Image: figure-19: <strong>DCt</strong> schematically.
figure-19: DCt schematically.

DCprojt
  • Shape of DCprojt is the same as h, for this example it is (5X1)
  • The complete listing is at code-9.
Image: figure-20: <strong>DCprojt</strong>.
figure-20: DCprojt.
  • DCprojt schematically.
Image: figure-21: <strong>DCprojt</strong> schematically.
figure-21: DCprojt schematically.
  • DCprojt
  • Complete listing can be found at code-10
Image: figure-22: <strong>DCprojt</strong>
figure-22: DCprojt

DFt
  • Shape of DFt is the same as h, for this example it is (5X1)
  • The complete listing is at code-11.
Image: figure-23: <strong>DFt</strong>.
figure-23: DFt.
  • DFt schematically.
Image: figure-24: <strong>DFt</strong> schematically.
figure-24: DFt schematically.
  • DFt
  • Complete listing can be found at code-12
Image: figure-25: <strong>DFt</strong>
figure-25: DFt

DIt
  • Shape of DIt is the same as h, for this example it is (5X1)
  • The complete listing is at code-13.
Image: figure-26: <strong>DIt</strong>.
figure-26: DIt.
  • DIt schematically.
Image: figure-27: <strong>DIt</strong> schematically.
figure-27: DIt schematically.
  • DIt
  • Complete listing can be found at code-14
Image: figure-28: <strong>DIt</strong>
figure-28: DIt

DHX
  • Dhx, Dh_next, dxt.
  • Complete listing can be found at code-15
Image: figure-29: <strong>Dhx</strong>, <strong>Dh_next</strong>, <strong>dxt</strong>.
figure-29: Dhx, Dh_next, dxt.
  • Many models like Neural Machine Translators(NMT), Bidirectional Encoder Representation from Transformers(BERT) use word embeddings as their inputs(Xs) which oftentimes need learning and that is where we need dxt
Image: figure-30: <strong>Dhx</strong>, <strong>Dh_next</strong>, <strong>dxt</strong>
figure-30: Dhx, Dh_next, dxt

DCt_recur
  • DCt.
  • Complete listing can be found at code-16
Image: figure-31: <strong>DCt</strong>.
figure-31: DCt.

Summary

There you go, a complete documentation of LSTMCell’s backward propagation. Also when using Deep-Breathe you could enable forward or backward logging like so

  cell = LSTMCell(n_hidden,debug=False,backpassdebug=True)
  out_l = Dense(10,kernel_initializer=init_ops.Constant(out_weights),bias_initializer=init_ops.Constant(out_biases),backpassdebug=True,debug=False)

  Listing-1

We went through “hell” and back in this article. But if you have come so far then the good news is, it can only get easier. When we do use LSTMs in more complex architecture like Neural Machine Translators(NMT), Bidirectional Encoder Representation from Transformers(BERT), or a Differentiable Neural Computers(DNC) it will be a walk in a blooming lovely park. Well alomost!! ;-). In a later article well demystify the multilayer LSTMcells and Bi-directional LSTMCells.