Backward 总结

完成 cs231n assignment 2 中 bachnorm_backward 函数花费了不少时间, 稍微总结下.

计算 Backward 主要分为两种方法 1). On paper 2). Computation graph

本质是一样的, 都是利用 Chain rule. 但我个人更偏爱 Computation graph, 对于复杂的函数会更清晰, 尤其是当我们在处理矩阵时.

下面我们以计算方差 var = np.var(x, axis=0) 为例来说明两种方法 (x.shape(n, d)).

On paper

首先我们要明确一点, 因为 varx 都是多维的, 所以我们最终要求的并不是 , 而是 , out 是在 var 基础上得到的一个标量, 假设 out = np.sum(var).

我们有:

根据 Chain rule:

下面我们推导最关键的 :

最后就是看怎么推广到矩阵. 这个步骤极容易出错, 这也是为什么我偏爱用 Computation graph 的原因!

Computation graph

首先我们把计算过程变成图, 图的节点为操作数或操作符. 尽量把计算过程分解成比较简单的运算.

computation graph

然后就是按照从后往前的顺序一步步计算. 因为都是简单操作, 所以我们可以尝试根据 shape 来思考.

比如 x, w, doutshape 分别为 (N, D), (D, M)(N, M), 因为 dxshape 应该和 x 一样为 (N, D), 所以我们就可以得出 dx = dout.dot(w.T). (假设这里 out = x.dot(w))

下面我们就一个个来计算:

dsquare = dvar / x.sahpe[0]
ddiff = 2 * (x - mean) * dsquare
dmean = -np.sum(ddiff, axis=0)
dx1 = ddiff
dx2 = dmean / x.shape[0]
dx = dx1 + dx2

其实 Computation graph 就是在每个节点上应用 On paper, 只是每个节点都是简单操作, 所以不容易出错也好理解.

Backpropagation, Intuitions