没有废话,直接上图。

看着很复杂,其实很简单。首先明确这些符号的意义,最底部的表示t时刻的输入,
表示t-1时刻的隐藏状态,
表示t时刻的隐藏状态。
、
、
、
分别表示遗忘门、输入门、更新门和输出门的权重矩阵。
和
表示t-1时刻的细胞(Cell)状态和t时刻的细胞状态。什么叫细胞状态?GeminiPro给的解释是这样的:

这样说可能还是不好理解,你可以把细胞状态理解为一条传送带,遗忘门、输入门负责在传送带上拿走一些退回的快递(旧事物),放入一些新发的快递(新事物),所以t-1时刻到t时刻的传送带状态发生了更新。再回顾一下两个激活函数:sigmoid和tanh
表示的是sigmoid激活函数,返回的值在(0,1)范围内,因此可以作为闸门控制信息流出的比例。

表示的是tanh激活函数,返回的值在(-1,1)范围内,因此可以控制信息的增减方向。

完成了对这些概念的理解后,下面我直接开始用一个例子来说明LSTM的计算过程。假设输入维度d=2,隐藏状态维度h=3,序列长度T=1。
x_1 = [0.5, 0.8] 当前时刻输入
h_0 = [0.1, 0.2, 0.3] 上一时刻隐藏状态
C_0 = [0.0, 0.0, 0.0] 上一时刻Cell状态(初始为零)
正向传播:
1.拼接输入向量:[ ,
]=[0.1, 0.2, 0.3, 0.5, 0.8],得到了一个5维的拼接向量。
2.计算四个门的原始值(记作net)
计算方法也很简单,就是分别用四个门各自的权重矩阵乘上拼接输入向量的转置,然后加上偏置。
该例子的权重矩阵大小为3*5,[ ,
]的转置矩阵大小为5*1,偏置值是0,因此可以得到四个3*1大小的矩阵(长度为3的列向量),对应四个门的原始值(net)。


[ ,
]的转置就是一个列向量,用权重矩阵的每一行乘以这个列向量就得到了原始值,结果也是一个列向量,长度为3。

3.通过激活函数得到四个门的值
对原始值通过激活函数和转置操作就得到了四个门的结果值(4个1*3的行向量)
后面我不特意标注行向量和列向量,本质是一样的,怎么方便怎么写。
f_1 = σ(net_f) = σ([0.74, 0.57, 0.30]) = [0.677, 0.639, 0.574]
i_1 = σ(net_i) = σ([0.19, 0.38, 0.57]) = [0.547, 0.594, 0.639]
u_1 = tanh(net_u) = tanh([1.04, 1.23, 1.42]) = [0.862, 0.831, 0.890]
o_1 = σ(net_o) = σ([0.76, 0.95, 1.14]) = [0.681, 0.721, 0.758]
遗忘门:每个维度保留的比例(67.7%,63.9%,57.4%)
输入门:每个维度写入的比例(54.7%,59.4%,63.9%)
更新门:待写入的新记忆内容 (-1~1),0.862,0.831,0.890不是某个维度保留或写入的比例,而是反映了新记忆内容的更新程度。
输出门:每个维度输出的比例(68.1%,72.1%,75.8%)
4.更新细胞状态
C1就等于f1和C0进行Hadamard积,也就是逐元素相乘,(0.677*0,0.639*0,0.574*0),加上i1和u1逐元素相乘,(0.547*0.862,0.594*0.831,0.639*0.890),二者相加就是更新后的细胞状态,即C1。

5.计算隐藏状态
新的细胞状态计算出来了就可以计算下一个隐藏状态了。
把C1放到tanh激活函数中计算出结果,然后与o1逐元素相乘就得到了h1的值。

如此我们可以继续计算后面的门控值、细胞状态和隐藏层状态。
反向求导:

还是使用链式法则,L对o1求偏导就等于L对h1求偏导,乘以h1对o1求偏导。L对求偏导就在前面的基础上多乘以一个o1对
求偏导。








以上就是整个LSTM的计算过程,希望能对你有所帮助。
1万+

被折叠的 条评论
为什么被折叠?



