Press "Enter" to skip to content

循环神经网络中的反向传播和梯度消失问题(第二部分)

LSTM中如何减轻这个问题

https://unsplash.com/photos/B22I8wnon34

在本系列的第一部分中,我们介绍了RNN模型中的反向传播,并通过公式和数值展示了RNN中的梯度消失问题。在本文中,我们将解释如何使用LSTM部分地解决梯度消失问题,即使它并没有完全消失,在非常长的序列中仍然存在这个问题。

动机

正如我们在本系列的第一部分中所见,普通的RNN将时间信息存储在隐藏状态中,当新的信息被添加时,即在序列中处理一个新的标记时,隐藏状态会被更新。由于隐藏状态在每一步都会被更新,旧的信息会被覆盖,网络会忘记过去所见的内容。为了避免这种情况,我们需要一个单独的内存和一个机制来决定在给定新信息时要写入什么内容,要从过去删除哪些内容,这些内容在未来将不再有用,以及要传递给下一个状态的内容。LSTM恰好做到了这一点 – 它添加了一个存储长期信息的记忆单元,并具有一个门控机制,用于决定从过去中遗忘什么,从当前输入中添加什么,以及向前传递什么。

前向传播

Figure by author (0)

让我们看看在LSTM模型中如何执行时间上的前向传播。给定一系列N个标记,并假设我们从上一个单元获得了一个记忆单元c(t-1)和一个隐藏状态h(t-1),在时间步t,我们计算门来决定如何处理新的传入信息。首先,让我们计算激活值:

循环神经网络中的反向传播和梯度消失问题(第二部分) 四海 第3张

Figure by author (1)

请记住,所有的权重都是在时间步上共享的。然后,将激活矩阵分割成4个维度为H的矩阵,并对前3个应用sigmoid激活函数,对最后一个应用tanh函数,我们计算出门控值:

Figure by author (2)
Figure by author (3)

请注意,所有的门控值都是输入和上一个隐藏状态的函数。

最后,我们计算当前的记忆单元状态c(t)和隐藏状态h(t),它们将传递到下一步。

Figure by author (4)

计算得到的门控值具有以下功能:

  • 门控f:决定从以前的记忆单元c(t-1)中遗忘哪些信息。请注意,由于我们进行逐元素乘法(记住c(t-1)和h(t-1)是向量),并且f包含介于0和1之间的值(由于sigmoid激活函数),当f的值等于0或接近0时,它将取消或减少c(t-1)中的信息,并且当f的值等于1或接近1时,它将保留所有或几乎所有的信息。
  • 门控g:可以解释为与先前的记忆单元c(t-1)相结合的记忆单元更新向量。与其他门不同,激活a(g)应用了tanh函数,它的输出值介于-1和1之间。这样做是为了允许细胞记忆状态既增加又减少,因为如果我们使用sigmoid激活函数,记忆单元的元素永远不会减少。
  • 门控i:决定从记忆单元更新向量(门控g)中写入哪些信息到先前的记忆单元c(t-1)。
  • 门控o:决定将哪些信息包含在新的隐藏状态h(t)中

然后,这些门被组合在一起,如图4所示,来计算新的记忆单元c(t)和隐藏状态h(t)。然后,这些新的单元和隐藏状态被传递给下一个LSTM单元,再次重复相同的过程。所有这个过程可以在下面的图表中说明:

循环神经网络中的反向传播和梯度消失问题(第二部分) 四海 第8张

来源 http://colah.github.io/posts/2015-08-Understanding-LSTMs/ (5)

之后,对于每个隐藏状态,我们计算输出和损失:

图表作者 (6)

在代码中:

def softmax(x, axis=2):    p = np.exp(x - np.max(x, axis=axis,keepdims=True))    return p / np.sum(p, axis=axis, keepdims=True)def lstm_step_forward(x, prev_h, prev_c, Wx, Wh, b):    next_h, next_c, cache = None, None, None        h = x @ Wx + prev_h @ Wh + b    assert h.shape[-1] % 4 == 0    ai, af, ao, ag = np.array_split(h, 4, axis=-1)    i = sigmoid(ai)    f = sigmoid(af)    o = sigmoid(ao)    g = np.tanh(ag)    next_c = f * prev_c + i * g     next_h = o * np.tanh(next_c)        cache = (x, next_h, prev_h, prev_c, Wx, Wh, h, np.tanh(next_c), i, f, o ,g)    return next_h, next_c, cachenp.random.seed(232)# N - 批量大小# D - 嵌入维度# V - 词汇大小# H - 隐藏维度# T - 时间步N, D, T, H, V = 2, 5, 3, 4, 4x  = np.random.randn(N, T, D)h0 = np.random.randn(N, H)Wx = np.random.randn(D, H)Wh = np.random.randn(H, H)Wy = np.random.randn(H, V)b  = np.random.randn(H)y = np.random.randint(V, size=(N, T))mask = np.ones((N, T))all_cache = []h = np.zeros((N, T, H))    next_c = np.zeros((N, H))    for t in range(T):    xt = x[:, t , :]    if t == 0:        next_h, next_c, cache_s = lstm_step_forward(xt, h0, next_c, Wx, Wh, b)        all_cache.append(cache_s)    else:        next_h, next_c, cache_s = lstm_step_forward(xt, next_h, next_c, Wx, Wh, b)           all_cache.append(cache_s)    h[:, t, :] = next_h ft = h @ Wyout = softmax(ft)

反向传播

来源 https://www.iitg.ac.in/cseweb/osint/neural/slides/L8.pdf (7)

与普通RNN相比,反向传播的公式要复杂一些。在本教程中,我们将推导出相对于Wx的梯度,并展示LSTM如何处理消失梯度问题。相对于其他参数的导数可以类似地推导出来,作为读者的练习留给读者。然而,代码中包含了相对于所有梯度的导数,您可以根据代码检查结果。损失相对于隐藏状态的导数仍然与RNN相同,因为在这里没有变化,损失只接受隐藏状态作为输入:

作者提供的图片(8)

现在让我们找出其他单个组件的导数:

作者提供的图片(9)

请注意,为了方便起见,我们将dct/dat和dht/dat分开,无论我们在dht/dct dct/dat中有dht/dct dct/dat,我们都直接将其写为dht/dat。另外,因为我们将以矩阵形式进行反向传播,我们将以以下方式连接门的导数:

作者提供的图片(10)

dht/dat中的求和来自于我们有两个方向(参见图7)——一个方向进入前一个单元,另一个方向进入隐藏状态。根据梯度流的相同逻辑,dct/dc(t-1)的导数如下:

作者提供的图片(11)

现在,让我们求关于Wx的总梯度。这由与本系列第1部分中描述的关于Wx的单个损失的总和给出:

作者提供的图片(12)

关注单个损失,例如dL3/dWx,在我们从L3传播到Wx时,Wx出现在所有的时间步骤组件中,因此,我们需要将所有这些组件相加以获得L3关于Wx的完整梯度。稍微滥用数学符号,我们正在做这样的事情(记住Wx3 = Wx2 = Wx1):

作者提供的图片(13)

第一个组件如下。另外,我们用dht/dat替换了dht/dct dct/dat,因此我们直接使用了该导数

作者提供的图片(14)

为了简洁起见,我将跳过dL3/dWx2,并直接进入第三个组件。我们有:

作者提供的图片(15)

与之前一样,让我们用dht/dat替换dht/dct dct/dat,以便我们直接使用该导数:

作者提供的图片(16)

将它们相加,我们得到dL3/dWx的导数。要获得dWx关于总损失的导数,我们需要将dL3/dWx、dL2/dWx和dL1/dWx相加。

作者提供的图片(17)

代码如下:

def lstm_forward(x, h0, Wx, Wh, b, next_c=None):    h, cache = None, None       cache = []    N, T, _ = x.shape    H = h0.shape[-1]    h = np.zeros((N, T, H))    if next_c is None:        next_c = np.zeros((N, H))    for t in range(x.shape[1]):        xt = x[:, t , :]        if t == 0:            next_h, next_c, cache_s = lstm_step_forward(xt, h0, next_c, Wx, Wh, b)            cache.append(cache_s)        else:            next_h, next_c, cache_s = lstm_step_forward(xt, next_h, next_c, Wx, Wh, b)               cache.append(cache_s)            h[:, t, :] = next_h     return h, cachedef dc_da(h, prev_c, next_c_t, i, f, o, g):    dgrad_c = np.zeros((h.shape[0], 4 * h.shape[1]))    dgrad_h = np.zeros((h.shape[0], 4 * h.shape[1]))    # assert dgrad.shape[1] % 4 == 0    H = dgrad.shape[1] // 4        # 从next_h和next_c两个流中计算对于ai、af、ao和ag的梯度    dnextc_dai = (i * (1-i)) * g    dnextc_daf = (f * (1-f)) * prev_c    dnextc_dao = 0    dnextc_dag = (1 - g**2) * i        dh_dc = o * (1 - next_c_t**2)    dnexth_dai = dh_dc * dnextc_dai    dnexth_daf = dh_dc * dnextc_daf    dnexth_dao = (o * (1-o) * next_c_t)    dnexth_dag = dh_dc * dnextc_dag    # 将它们合并到一个矩阵中,以便方便地计算下游梯度     dgrad_c[:, 0:H] = dnextc_dai     dgrad_c[:, H:2*H] = dnextc_daf     dgrad_c[:, 2*H:3*H] = dnextc_dao     dgrad_c[:, 3*H:4*H] = dnextc_dag         dgrad_h[:, 0:H] =  dnexth_dai    dgrad_h[:, H:2*H] = dnexth_daf    dgrad_h[:, 2*H:3*H] = dnexth_dao    dgrad_h[:, 3*H:4*H] = dnexth_dag    return dgrad_c, dgrad_hnp.random.seed(1)N, D, T, H = 1, 3, 3, 1x = np.random.randn(N, T, D)h0 = np.random.randn(N, H)Wx = np.random.randn(D, 4 * H)Wh = np.random.randn(H, 4 * H)b = np.random.randn(4 * H)out, cache = lstm_forward(x, h0, Wx, Wh, b)# 为了简单起见,我们定义dout而不是推导它们dout = np.random.randn(*out.shape)    # dL3/dWvxdnext_c2 = np.zeros((h0.shape))dnext_h2 = dout[:, -1, :](x2, next_h, prev_h, prev_c, Wx, Wh, next_h, next_c_t2, i2, f2, o2 ,g2) = cache[2]dgrad_c2, dgrad_h2 = dc_da(h0, cache[2][3], cache[2][-5], cache[2][-4],  cache[2][-3], cache[2][-2], cache[2][-1]) dL3_dWx2 = x2.T @ (dgrad_h2 * dnext_h2 + dgrad_c2 * dnext_c2)print(dL3_dWx2)dnext_c1 = dnext_c2 * f2 + dnext_h2 * o2 * (1 - next_c_t2**2) * f2dnext_h1 = (dnext_h2 * dgrad_h2 +  dnext_c2 * dgrad_c2) @ Wh.T(x1, next_h, prev_h, prev_c, Wx, Wh, next_h, next_c_t1, i1, f1, o1 ,g1) = cache[1]dgrad_c1, dgrad_h1 = dc_da(h0, cache[1][3], cache[1][-5], cache[1][-4],  cache[1][-3], cache[1][-2], cache[1][-1])    dL3_dWx1 = x1.T @ (dnext_c1 * dgrad_c1 + dnext_h1 * dgrad_h1)print(dL3_dWx1)dnext_c0 = dnext_c1 * f1 + dnext_h1 * o1 * (1 - next_c_t1**2) * f1dnext_h0 = (dnext_h1 * dgrad_h1 + dnext_c1 * dgrad_c1) @ Wh.T(x0, next_h, prev_h, prev_c, Wx, Wh, next_h, next_c_t0, i0, f0, o0 ,g0) = cache[0]dgrad_c0, dgrad_h0 = dc_da(h0, cache[0][3], cache[0][-5], cache[0][-4],  cache[0][-3], cache[0][-2], cache[0][-1])    dL3_dWx0 = x0.T @ (dnext_c0 * dgrad_c0 + dnext_h0 * dgrad_h0)print(dL3_dWx0)

输出:

[[-0.02349287  0.00135057 -0.11156069 -0.05284914] [ 0.01024921 -0.00058921  0.04867045  0.02305643] [-0.00429567  0.00024695 -0.02039889 -0.00966347]][[-9.83990139e-03  6.78775168e-05 -1.10660923e-03  4.20773125e-04] [ 7.93641636e-03 -5.47469140e-05  8.92540613e-04 -3.39376441e-04] [-2.11067811e-02  1.45598602e-04 -2.37369846e-03  9.02566589e-04]][[-1.95768961e-05  0.00000000e+00  2.77411349e-05 -9.76467796e-03] [ 7.37299593e-06  0.00000000e+00 -1.04477887e-05  3.67754574e-03] [ 6.36561888e-06  0.00000000e+00 -9.02030083e-06  3.17508036e-03]]

losses_dWx = {i : {x_comp : 0 for x_comp in range(i)} for i in range(T)}dWx = np.zeros((D, 4 * H))dWh = np.zeros((H, 4 * H))db = np.zeros((4 * H, ))for idx in range(T-1, -1, -1):    print(f"Loss {idx + 1}")    dnext_c = np.zeros((h0.shape))    dnext_h =  dout[:, idx, :]    for j in range(idx, -1, -1):        (x, next_h, prev_h, prev_c, Wx, Wh, next_h, next_c_t, i, f, o ,g) = cache[j]                dgrad_c, dgrad_h = dc_da(h0, prev_c, next_c_t, i, f, o, g)                 dgrad = dnext_c * dgrad_c + dnext_h * dgrad_h        losses_dWx[idx][j] = x.T @ dgrad                dnext_c = dnext_c * f + dnext_h * o * (1 - next_c_t**2) * f        dnext_h = (dnext_h * dgrad_h +  dnext_c * dgrad_c) @ Wh.T        dnext_h = dgrad @ Wh.T          # 累积每个损失的dWx和其他参数的梯度        dWx += x.T @ dgrad        dWh += prev_h.T @ dgrad        db += dgrad.sum(0)        print(f"分量 {j} - ", np.linalg.norm(losses_dWx[idx][j]))

LSTM中的梯度消失

与RNN的第一部分一样,让我们看看每个分量的损失L3的梯度:

损失3分量0 - 0.010906688399113558分量1 - 0.02478099846737857分量2 - 0.13901933055672275

从上面可以看出,最接近L3的X3更新量最大,而X1和X2对Wx1的更新贡献较小。然而,对于RNN来说,这种差异要大得多。实际上,通过隐藏状态的梯度仍会因为与RNN相同的原因而受到梯度消失的影响 —— Wh项 (dat/dh(t-1)) 仍然出现在反向传播中,例如这里的 dL3/dW(x-1):

Figure by author from Figure 15 (18)

然而,通过仍然是输入和隐藏状态的函数的单元流动的梯度没有 Wh 项,而是 sigmoid 项(见 Figure 3 中遗忘门 ft 的公式):

作者的图15 (18)

回想一下,dct/dc(t-1) = ft。因此,如果遗忘门很高,即接近1,那么消失梯度的速率比普通的RNN要慢得多,但除非所有的遗忘门都恰好为1,否则它仍然会发生,而这在实践中是不会发生的。

结论

本文的主要观点是通过反向传播推导出,LSTM在实践中仍然存在消失梯度的问题,但速率要比普通的RNN低得多,这要归功于细胞状态,它使得梯度在遗忘门速率而不是Wx速率下衰减。如果您发现任何错误,请在评论中告诉我。

参考资料

  • https://web.stanford.edu/class/cs224n/slides/cs224n-2021-lecture06-fancy-rnn.pdf
  • http://cs231n.stanford.edu/assignments.html
Leave a Reply

Your email address will not be published. Required fields are marked *