Processing math: 100%

NOTES ON STATISTICS, PROBABILITY and MATHEMATICS


Backpropagation in LONG SHORT TERM MEMORY (LSTM):



Adapted from this online post:


The activation at is not a gate, but rather an affine transformation followed by a tanh function. It proposes the new cell state:

at=˜C=tanh([WaUa][xtht1]+ba)

In the example from the linked tutorial (including the toy numbers that follow), and at the step t1,

at1=tanh([WaUa][xt1ht2]+ba)=tanh([0.450.250.15][120]+0.2)=0.82

W_a = c(0.45, 0.25)           # These are the given weights in the example in the linked post for the activation step
update_prior_output_a = 0.15  # This is the weight in the activation given to the output (h) in the prior layer - also given
bias_a = 0.2                  # The given bias in the activation.
x_t_minus_1 = c(1,2)          # The given input values at t - 1.
prior_output = 0              # Since it is the first layer, there is no prior output
(a_t_minus_1 = tanh(c(W_a, update_prior_output_a)%*%c(x_t_minus_1, prior_output) + bias_a ))
##           [,1]
## [1,] 0.8177541

Input gate:

it=σ([WiUi][xt1ht2]+bi) Therefore for the first layer (t1),

it1=σ([WiUi][xt1ht2]+bi)=σ([0.950.80.8][120]+0.65)=0.96

W_i = c(0.95, 0.8)            # These are the given weights in the example in the linked post for the input step
update_prior_output_i = 0.8   # This is the weight in the input given to the output (h) in the prior layer - also given
bias_i = 0.65                 # The given bias in the input gate
(i_t_minus_1 = 1/(1 + exp(-(c(W_i, update_prior_output_i)%*%c(x_t_minus_1, prior_output) + bias_i))))
##           [,1]
## [1,] 0.9608343

###Forget gate:

ft=σ([WfUf][xt1ht2]+bf)

Therefore for the first layer (t1),

ft1=σ([WfUf][xt1ht2]+bf)=σ([0.70.450.1][120]+0.15)=0.85

W_f = c(0.7, 0.45)           # These are the given weights in the example in the linked post for the forget gate step
update_prior_output_f = 0.1  # This is the weight in the forget gate given to the output (h) in the prior layer - also given
bias_f = 0.15                # The given bias in the forget gate
(f_t_minus_1 = 1/(1 + exp(-(c(W_f, update_prior_output_f)%*%c(x_t_minus_1, prior_output) + bias_f))))
##           [,1]
## [1,] 0.8519528

###Output gate:

ot=σ([WoUo][xt1ht2]+b0)

Therefore for the first layer (t1),

ot1=σ([WoUo][xt1ht2]+bo)=σ([0.60.40.25][120]+0.1)=0.82

W_o = c(0.6, 0.4)             # These are the given weights in the example in the linked post for the output gate step
update_prior_output_o = 0.25  # This is the weight in the output gate given to the output (h) in the prior layer - also given
bias_o = 0.1                  # The given bias in the output gate
(o_t_minus_1 = 1/(1 + exp(-(c(W_o, update_prior_output_o)%*%c(x_t_minus_1, prior_output) + bias_o))))
##           [,1]
## [1,] 0.8175745

After calculating the values of the gates (and the activation of the input and prior output) we can calculate:

####Cell state (Ct):

Ct=atit+ftCt1 For the t1 the calculation is

Ct1=0.82×0.96+0.85×0=0.79

C_t_minus_2 = 0
(C_t_minus_1 = a_t_minus_1 * i_t_minus_1 + f_t_minus_1 * C_t_minus_2)
##           [,1]
## [1,] 0.7857261

####Ouput (ht):

ht=tanh(Ct)ot For the t1 layer the calculation is:

ht1=tanh(0.79)0.82=0.53

(h_t_minus_1 = tanh(C_t_minus_1) * o_t_minus_1)
##           [,1]
## [1,] 0.5363134

Now the calculations can be repeated for the layer t with the following inputs:

xt=[0.50.30.53] where 0.53 is the output of the prior layer.

The matrix of weights is the same as in the prior step! The results are embedded in the figure above. In R code:

x_t = c(.5,3)
(a_t = tanh(c(W_a, update_prior_output_a)%*%c(x_t, h_t_minus_1) + bias_a ))
##          [,1]
## [1,] 0.849804
(i_t = 1/(1 + exp(-(c(W_i, update_prior_output_i)%*%c(x_t, h_t_minus_1) + bias_i))))
##          [,1]
## [1,] 0.981184
(f_t = 1/(1 + exp(-(c(W_f, update_prior_output_f)%*%c(x_t, h_t_minus_1) + bias_f))))
##          [,1]
## [1,] 0.870302
(o_t = 1/(1 + exp(-(c(W_o, update_prior_output_o)%*%c(x_t, h_t_minus_1) + bias_o))))
##           [,1]
## [1,] 0.8499333
(C_t = a_t * i_t + f_t * C_t_minus_1)
##          [,1]
## [1,] 1.517633
(h_t = tanh(C_t) * o_t)
##           [,1]
## [1,] 0.7719811

Backpropagation:

We start with the loss function, and to make it simple, J=(htyt)22 with derivative ddhtJ=htyt:

Δt=0.771971.25=0.47803

y_t = 1.25
(delta_t = h_t - y_t)
##            [,1]
## [1,] -0.4780189

Since there are no additional layers, there is no error from layers on top to add up to this value. If it weren’t the last layer, we’d need to add it: it is as though we were imputing “blame” to the output of a layer for the bad “karma” it has contributed upstream.

ΔTt=Δt+Δoutt

delta_out_t = 0
(delta_total_t = delta_t + delta_out_t)
##            [,1]
## [1,] -0.4780189

So in the layer t1 the error will be:

ΔTt1=Δt1+Δoutt1=0.036310.01828=0.01803


Backpropagating this error to the cell state, Ct,

ht=tanh(Ct)ot

we can get the cost in relation to the cell state:

CtJ=Δt(1tanh2(Ct))ot+Ct+1ft+1

(delta_C_t = delta_total_t * (1 - tanh(C_t)^2) * o_t + 0)
##             [,1]
## [1,] -0.07110771

Or the activation:

atJ=CtJit(1a2t)

(delta_a_t = delta_C_t * i_t * (1 - a_t^2))
##             [,1]
## [1,] -0.01938435

the input gate:

itJ=CtJatit(1it)

(delta_i_t = delta_C_t * a_t * i_t *(1 - i_t))
##              [,1]
## [1,] -0.001115614

remembering that the derivative of the logistic function is σ(x)[1σ(x)].

the forget gate:

ftJ=CtJCt1ft(1ft)

(delta_f_t = delta_C_t * C_t_minus_1 * f_t * (1 - f_t))
##              [,1]
## [1,] -0.006306542

the output gate:

otJ=ΔTttanh(Ct)ot(1ot)

(delta_o_t = delta_total_t * tanh(C_t) * o_t * (1 - o_t))
##             [,1]
## [1,] -0.05537783

Bundling these gate partial derivatives of the loss function:

Δgatet=[atJitJftJotJ]=[delta_a_tdelta_i_tdelta_f_tdelta_o_t]=[0.0190.00110.00630.055]

(delta_gate_t = c(delta_a_t, delta_i_t, delta_f_t, delta_o_t))
## [1] -0.019384348 -0.001115614 -0.006306542 -0.055377831

The “karma” passed back to t1 is

U = c(update_prior_output_a, update_prior_output_i, update_prior_output_f, update_prior_output_o)
(delta_out_t_minus_1 = U %*% delta_gate_t)
##             [,1]
## [1,] -0.01827526

And the delta at t1

y_t_minus_1 = .5
(delta_t_minus_1 = h_t_minus_1 - y_t_minus_1)
##           [,1]
## [1,] 0.0363134

for a total delta at t1

(delta_total_t_minus_1 = delta_t_minus_1 + delta_out_t_minus_1)
##            [,1]
## [1,] 0.01803814
(delta_C_t_minus_1 = delta_total_t_minus_1 * (1 - tanh(C_t_minus_1)^2) * o_t_minus_1 + delta_C_t * f_t)
##             [,1]
## [1,] -0.05348368
(delta_a_t_minus_1 = delta_C_t_minus_1 * i_t_minus_1 * (1 - a_t_minus_1^2))
##             [,1]
## [1,] -0.01702404
(delta_i_t_minus_1 = delta_C_t_minus_1 * a_t_minus_1 * i_t_minus_1 *(1 - i_t_minus_1))
##              [,1]
## [1,] -0.001645882
(delta_f_t_minus_1 = delta_C_t_minus_1 * C_t_minus_2 * f_t_minus_1 * (1 - f_t_minus_1))
##      [,1]
## [1,]    0
(delta_o_t_minus_1 = delta_total_t_minus_1 * tanh(C_t_minus_1) * o_t_minus_1 * (1 - o_t_minus_1))
##             [,1]
## [1,] 0.001764802

We can compile these results into

Δgatet1=[delta_a_t_minus_1delta_i_t_minus_1delta_f_t_minus_1elta_o_t_minus_1]=[0.0170.001600.0017]

We will be then used to calcuate:

ΔW=Tt=1Δgatetxt=[delta_a_t_minus_1delta_i_t_minus_1delta_f_t_minus_1elta_o_t_minus_1][x_t_minus_1]+[delta_a_tdelta_i_tdelta_f_tdelta_o_t][x_t]=[0.0170.001600.0017][12]+[0.0190.00110.00630.055][0.53]

Delta_t_minus_1 = c(delta_a_t_minus_1, delta_i_t_minus_1,
                    delta_f_t_minus_1, delta_o_t_minus_1)
Delta_t         = c(delta_a_t, delta_i_t,
                    delta_f_t, delta_o_t)
(Delta_W = outer(Delta_t_minus_1, x_t_minus_1,"*") + 
           outer(Delta_t, x_t, "*"))
##              [,1]         [,2]
## [1,] -0.026716218 -0.092201132
## [2,] -0.002203689 -0.006638606
## [3,] -0.003153271 -0.018919625
## [4,] -0.025924113 -0.162603889

ΔU=Tt=1Δgatetht1

(Delta_U = outer(Delta_t, h_t_minus_1,"*"))
## , , 1
## 
##               [,1]
## [1,] -0.0103960853
## [2,] -0.0005983188
## [3,] -0.0033822828
## [4,] -0.0296998728

Δb=Tt=1Δgatet+1

(Delta_bias = Delta_t_minus_1 + Delta_t)
## [1] -0.036408392 -0.002761496 -0.006306542 -0.053613029

to proceed with the update:

Wnew=WoldλΔWold If we fix the learning rate at λ=0.1,

(W = matrix(c(W_a, W_i, W_f, W_o), ncol=2, byrow=T))
##      [,1] [,2]
## [1,] 0.45 0.25
## [2,] 0.95 0.80
## [3,] 0.70 0.45
## [4,] 0.60 0.40
(W_new = W - 0.1 * Delta_W)
##           [,1]      [,2]
## [1,] 0.4526716 0.2592201
## [2,] 0.9502204 0.8006639
## [3,] 0.7003153 0.4518920
## [4,] 0.6025924 0.4162604
(U_new = U - 0.1 * Delta_U)
## , , 1
## 
##           [,1]
## [1,] 0.1510396
## [2,] 0.8000598
## [3,] 0.1003382
## [4,] 0.2529700
bias = matrix(c(bias_a, bias_i, bias_f, bias_o), ncol=1)
(bias_new = bias - 0.1 * Delta_bias)
##           [,1]
## [1,] 0.2036408
## [2,] 0.6502761
## [3,] 0.1506307
## [4,] 0.1053613


This video by Brandon Rohrer gave me some intuition into this concept. I still don’t understand this well; however, what follows may be a very rough, intuitive approximation with many loose ends, and with the intention of spurring new, knowledgeable answers or feedback.

Using Google search autocomplete, which allegedly could have LSTM cells in its architecture, and typing the letter xt1= "t" returns these suggestions:

enter image description here

The actual process is, no doubt, much more complex. However, I suppose we can imagine these suggestions as a vector of library words with an attached vote (squashed from 1 to 1 via a tanh function). The weights are the result of news (e.g. the mega-chain of department stores Target has announced some changes in their fashion department), and / or number of recent Google searches (e.g. Trump), applications starting with t (e.g. Google Translate), and local businesses (e.g. TD Bank). It is imaginable that a word like Tzuyu

Chou Tzu-yu (born June 14, 1999), known as Tzuyu, is a Taiwanese singer based in South Korea and a member of the K-pop girl group Twice, under JYP Entertainment.

is likely to have, on a search within the US and with the user’s history erased, a “negative vote”.

The cell state, Ct1, vector before the second letter is typed could be something like

[target,0.89; trump,0.85;; tzuyu,0.89].

And looking for a place to get the story going, let’s say that at this point the cell state coincides with the output, ht1.

The next letter entered on the keyboard happens to be xt= "r", which will, through learned weights and the sigmoid activation, end up giving a value very close to zero to suggestions that are incompatible with the spelling: immediately, the word Target will be gone from the suggestion list, thanks to the forget gate layer. This gate is simply the Hadamard product of the initial Ct1 with the output of the sigmoid function in the step:

enter image description here

The interesting part is that if the user history is cleared up before the search, and no account is logged on at the time of the search, the predictions now are:

enter image description here

whereas if I am logged on as a user and type tr, the predictions are completely different:

enter image description here

because the weights have been adjusted based on a prior search, and Google remembers a prior inquire into the African parasitic disease Trypanosomiasis. Trump is still an option, but voted much lower. This could possibly have been regulated by the matrix of weights for new candidate values, ˜Wc, voting up that, otherwise, implausible prediction:

enter image description here

and updating the cell state. The first step () will be a gate (input gate), where the current input will decide which of these prior memories in ˜Wc can be made available - I suppose the idea could be that if I had typed ty instead of tr, the word trypanosomiasis would never had appeared, even if it were the only search ever performed on my computer. The second step is the elementwise sum () to update the state.

At this point, the cell state has been updated twice (once initially through a sigmoid activation in the forget gate; and a second time through this memory gate just analyzed):

Ct=ftCt1+it˜Ct

The vector of “votes” in the cell state at this time will be further squashed by the last tanh function, and “gated” by the result of the activation of the sigmoid function applied to [xt,ht1] (output gate) to produce the output of the cell:

enter image description here

The operation in the equation above standing for element-wise product, also expressed as

ht=σ(W0[ht1,xt]+b0)tanh(Ct)

Interestingly, the cell state Ct is not squashed, as it moves on to the next cell.


Home Page

NOTE: These are tentative notes on different topics for personal use - expect mistakes and misunderstandings.