Back-propagation explained

反向传导算法的本质是用链式法则求复合函数的导数,其数学基础就是复合函数的链式法则求导: 对函数 。 神经网络就是一个复合函数,网络的每一层就是一个函数,整个网络就是这么多函数的复合。 链式法则其实很简洁,只要求各函数可导,并没有对函数的具体形式有任何限制。 之前的很多介绍反向传导算法的文章里单独把“激活函数”、“损失函数”挑出来作为特例,使得整个数学表达显得特别复杂。

因此,一个简洁的推导应该将所有的层都抽象成成一个函数,而不去区分卷积层、全连接层和激活函数的特殊情况。 只要这个函数数值可导,那么在反向传导中就遵循相同的数学规则。 在神经网络中,这样一个函数又被称作“operator”。

在这篇文章里,我们试图根据链式法则,对网络所有类型的层(operator)用统一的形式解释清楚链式法则和反向传导算法。 相比之前的一些解释,这一套解释数学上非常清晰,涉及的数学符号非常少: 在前向传导时,每层只有输入 和输出 ;在反向传导时,各层只有梯度 。 并且所有的操作(卷积,全连接,激活函数,损失函数)都有相同的数学形式。

前向传导 Forward Pass

这部分没什么好讲的,为了引出数学符号的定义简单介绍一下。 如果对这部分比较熟悉,建议直接跳到 反向传导 部分。

对于一个神经网络, 表示“第 层的操作符的第 个输入”, 表示“第 层的操作符的第 个输出”。 对于一个操作而言,只有输入 和输出 ,请忘记之前那一套输入,输出,激活值等复杂的数学符号。 再次强调,别问,问就是只有输入和输出。 在一次前向传导中,多个操作序贯进行:前一层的输出是下一层的输入:

一个简单的线性操作(卷积) 和非线性操作 (sigmoid函数)。 左: ; 右:

一个简单的三层网络,每一个箭头的连接都对应一个权重(为了简便,每层的权重只用一个数学符号标注)。 最后一层接受两个输入(第 3 层的输出 和标注 ),并输出损失(loss)

上图是一个简单的 4 层神经网络,输入数据依次通过“卷积” “激活” “卷积”,最后输入损失层计算损失

反向传导 Back Propagation

假设神经网络第 层的前向传播表示为 , 为了不失一般性,这里的输入输出都用大写符号,表示是一个向量: 这里我们对 没有任何假设,也就是说 可以表示任意一个层(卷积,全连接,激活函数等), 唯一的要求是 必须是数值可导的 (不一定是符号可导的,比如 ReLU 激活函数在0点不可导,但只需要数值上倒数有定义即可)。

把网络的一层当做一个黑盒子,前向传导的时候输入 ,输出 ;反向传导的时候输入 ,输出 。 其中 是一个 雅克比矩阵 (jacobian matrix):

我们把 当成一个黑盒子,那么在反向传导的时候,层接后层 ( 层) 传过来的梯度 ,并计算向前层(层)传递的梯度 。 实际上 是损失 对 第 层的输出 (也是 层的输入 )的导数。 根据链式法则,我们可以根据以下规则计算

因为

这里我们不像其他的各种文章中分别把激活函数、损失函数当做特殊条件。在反向传导中,梯度传播的规则就只有 Eq. 这一种规则。 其中导函数 是一个 雅克比矩阵(Jacob), 其第 行第 列的元素 表示第 个输出对第 个输入的导数: 。 具体地讲:

  • 如果 一个激活函数,那么 就是一个对角阵,对角元素为每个输出对输入的导数:,其他元素为零:
  • 如果 是全连接或者卷积层,那么 对应第 个输出和第 个输入之间连接的权重。

下面我分别用激活函数、卷积层、全连接层和损失函数作为示例,说明这几种 operator 的反向传播规则看似各不一样, 其实都遵循 Eq. 的规则。

激活函数的反向传导

激活函数的前向传导和反向传导。在反向传导时,输入后一层的梯度,输出层的梯度。 如果 ,那么当输入 进入 sigmoid 函数的饱和区时,将趋近于0,导致前面的层得不到足够的梯度,出现“梯度消失现象”。

由于激活函数都是单输入单输出的,我们可以用数值符号(而非向量)描述激活函数的反向传导过程。 假设激活函数是 ,对于任意的输入,前向传导的输出都是。 在反向传导时,假设 层传回的梯度是 ,根据 Eq. 层的梯度为

Sigmoid 和 ReLU 反向传导,以及为什么 ReLU 可以进行 inplace 前向传导

这里我们列举化两个比较典型的激活函数:Sigmoid 和 ReLU。 首先写出这两个激活函数的函数式: 和导数:

结合 Eq.,在 backward 的时候向前传播的梯度分别是:

观察 Eq. 和 Eq.,不难发现 ReLU 和 Sigmoid 的区别: 计算 Sigmoid 的导数值必须要知道输入 ,而 ReLU 函数的导数值可以通过函数输出 来可以确定。 也就是说,对 ReLU 而言,输入 对于计算反传的梯度值是不必要的。 既然输入 在反向传播中没有用到,那么就可以被覆盖掉(在 forward 中可以进行 inplace 操作)。

实际上 Sigmoid 函数的输入 也可以通过输出 推断出来,因此输入 也可以被覆盖掉。 只是这样做增加了额外的计算量,相比而言 ReLU 函数 inplace 操作的话额外的代价几乎为0。 只要函数 是单调的,输如就可以用输出推断出来。

卷积层/全连接层的反向传导

卷积层和全连接层的反向传导法则:后一层的梯度 沿着连接向前传递: 等于 乘以对应的连接权重。 例如:\delta^{l+1}_1\cdot W^l_2 + \delta^{l+1}_2 \cdot W^l_3$。

在全连接层和卷积层中,前层的梯度等于后层梯度沿连接的权重回传。 也即在上图中, 等于 乘以对应的连接权重。 如果将卷积/全连接层看做函数,其实这个反传规则也是符合 Eq. 的。因为连接的权重本身可以视作该函数输出对输入的导数:

结合 Eq.,我们可以写出 卷积层/全连接层 反向传播的梯度。 以 为例(请对照上图): 其中 就等于 之间连接的权重 。 因此卷积层/全连接层的反向传播规则也遵循 Eq.。 Eq 写成向量形式为: 形式上与 Eq. 一致。

池化层层的反向传导

相比卷积/全连接层,池化层的反向传导就更简单了,因为池化操作本身可以看作输入 与一个 mask 之间的 点乘 (dot product)。 以 max-pooling 为例,如下图,矩阵 的 max pooling 等价于 与 mask 的点积,除了最大值处,其它位置的权重都为0。

Max-pooling 等价于输入 与 mask 的点积 (dot product,图中用 ⨀ 表示)。

因此,对于 max-pooling 而言,其反向传导规则为:

损失函数的反向传导

平方损失函数 的反向传导。损失函数直接产生梯度 , 不用接受后层传回的梯度。

损失函数是神经网络的最后一层,因此不用接收后层传回来的梯度。损失函数在反向传导时直接计算损失 对输入 的偏导数作为梯度,并反传到前面各层。 以平方损失 为例,损失函数产生的梯度为损失对函数输入 的导数:。 注意由于 label 并不需要求导,因此损失函数反向传播时只对 这一路径返回梯度。

损失函数对各层参数 的导数

经过前面的梯度反传,我们得到了最后的损失 对各层中间输入/输出特征的导数:。 但是我们最终要求的是损失函数 对网络参数 的导数,并用该梯度更新网络参数。

卷积层中对权重的导数:

我们以一个简单的卷积层为例,可以用很简单的推导得出损失函数对权重的导数

如上图所示,经过反向传导,损失函数 层输出 (也即 层输入 ) 的导数是 : 表示卷积运算。结合 Eq. 和 Eq., 很容易根据链式法则推出损失函数 对卷积权重 的导数 :

因为

Eq. 写成向量形式就是:

卷积层中对偏置 的导数:

考虑卷积中的偏置项: 很容易得到

至此,我们得到了损失函数对权重 以及偏置 的导数。接下来只需要用梯度下降法更新网络参数即可。

用 PyTorch 验证全连接层的反向传导规则:

下面我们在 pytorch 中用一个 的全连接层验证上面的推导。 输入向量 经过全连接层 linear 的变换得到 输出 ,最后损失 。 其中全连接层的权重 初始化为1:。 我们打印 、梯度 以及损失对全连接参数的导数 , 代码如下:

import torch.nn as nn
linear = nn.Linear(4, 2)
linear.weight.data.fill_(1)
linear.bias.data.fill_(0)

x = torch.ones(1, 4)
x.requires_grad = True

y = linear(x)
y.retain_grad()
loss = (y**2).sum()
loss.backward()

print("y = ", y)
print("y_grad = ", y.grad)
print("w_grad = ", linear.weight.grad)

由于 ,因此 。 运行结果如下:

y =  tensor([[4., 4.]], grad_fn=)
y_grad =  tensor([[8., 8.]])
w_grad =  tensor([[8., 8., 8., 8.],
                [8., 8., 8., 8.]])

##总结 Conclusion remarks

总的来说,反向传导算法从将数据 输入到网络,到得到损失对网络参数的导数,大致有以下几个步骤:

  1. 前向传导,得到各层的输出
  2. 从损失函数开始由后向前,根据 Eq. 进行梯度反传,得到损失函数对各层输出的导数。 再次强调没有什么激活函数损失函数,所有的层都是一个 operator,唯一的准则就是 Eq.
  3. 根据 Eq. 和 Eq. 计算损失函数对参数的导数。

补充阅读

全文完,感谢阅读。如果你有任何问题,或者发现有任何表述、数学表达的错误,请在下方留言。

title   = {Back-propagation Explained},
author  = {Kai Zhao},
year    = 2016,
note    = {\url{http://kaizhao.net/posts/bp}}
}