首页 / 模块 1 · 线性代数与微积分 / 第 8 课(共 10 课)

链式法则与计算图

从零到前沿 ML 自学课程 · 阶段0:数学与工具基础 · 能力点:反向传播为何成立、为何高效(反向模式 = 一串 VJP)

读完这一课,你将能够

  • 用"沿路径相乘、跨路径相加"复述多元链式法则,并说出扇出节点的梯度为什么要累加而非覆盖。
  • 给 \(f=\sigma(wx+b)\) 这类小表达式画出计算图、标注每条边的局部导数,手工跑一次前向 + 一次反向,得到各输入的伴随 \(\bar w,\bar x,\bar b\)。
  • 用"上游梯度 × 局部梯度"这一条规则,逐节点回传并核对它等价于多元链式法则。
  • 说清正向模式与反向模式的差别,并解释为什么"标量损失 + 海量参数"必须选反向模式(一趟反向,代价≈一次前向的常数倍)。
  • 用 numpy 手写前向+反向,并用中心差分做梯度检验,判断"反向传播算对了没有"。

上一课我们用雅可比 Jacobian海森 Hessian 描述了多元函数局部如何弯曲;这一课要回答一个更具体、也更"机械"的问题:当一个值经过层层运算最终变成损失 \(L\) 时,每个输入的微小改动对 \(L\) 的影响是怎样一步步传回来的。答案是链式法则 chain rule,而把它图形化、自动化的工具叫计算图 computational graph。这正是反向传播 backpropagation 的全部秘密。

说明一个记号约定:本课为了把例子讲简单,主例的最终输出记作 \(f\)(一个 sigmoid 输出);在真实训练里,图的最终输出是标量损失 \(L\)。二者在计算图里地位完全相同——都是"我们要对它求梯度的那个终点节点"。所以你看到 \(f\) 和 \(L\) 时,请把它们理解成同一个角色(本例恰好 \(L=f\))。第六节起讨论训练时会改用 \(L\)。

一、一元链式法则:先把"嵌套"拆开

如果 \(y\) 依赖 \(u\),\(u\) 又依赖 \(x\),即 \(y=g(u),\ u=h(x)\),那么

\[ \frac{dy}{dx}=\frac{dy}{du}\cdot\frac{du}{dx}. \]

直觉很物理:\(x\) 动一点 \(\Delta x\),它先把 \(u\) 推动 \(\frac{du}{dx}\,\Delta x\),这个 \(\Delta u\) 再把 \(y\) 推动 \(\frac{dy}{du}\,\Delta u\)。两级"放大率"相乘,就是总放大率。链条更长时继续连乘:

\[ \frac{dy}{dx}=\frac{dy}{du_1}\cdot\frac{du_1}{du_2}\cdots\frac{du_{k}}{dx}. \]

把导数想成放大率(敏感度):\(dy/du\) 是"在 \(u\) 这一站,输入扰动被放大多少倍"。链式法则就是"逐站放大率连乘"。这一句话,撑起了整个深度学习的梯度计算。

二、多元链式法则:对所有路径求和

当中间变量不止一个,\(x\) 通往输出可能有多条路径。例如 \(z\) 依赖 \(u\) 和 \(v\),而 \(u,v\) 都依赖 \(x\):

\[ \frac{\partial z}{\partial x} =\frac{\partial z}{\partial u}\frac{\partial u}{\partial x} +\frac{\partial z}{\partial v}\frac{\partial v}{\partial x}. \]

一般地,把每条"从 \(x\) 到 \(z\)"的路径上各边的偏导相乘,再把所有路径相加

\[ \boxed{\ \frac{\partial z}{\partial x}=\sum_{\text{paths }P:\,x\rightsquigarrow z}\ \prod_{(i\to j)\in P}\frac{\partial j}{\partial i}\ } \]

这里 \(\dfrac{\partial j}{\partial i}\) 表示沿边 \(i\to j\) 的局部偏导——即节点 \(j\) 对它的直接输入 \(i\) 求导(\(i,j\) 是节点名,不是下标)。

"沿路径相乘、跨路径相加"——这就是多元链式法则的全部内容。为什么相加?因为 \(x\) 同时通过多条途径影响 \(z\),这些影响是线性叠加的(一阶近似下,总扰动等于各路径扰动之和)。更具体地:\(x\) 变 \(\Delta x\),它通过 \(u\) 给 \(z\) 带来 \(\frac{\partial z}{\partial u}\frac{\partial u}{\partial x}\Delta x\),通过 \(v\) 带来 \(\frac{\partial z}{\partial v}\frac{\partial v}{\partial x}\Delta x\);两份扰动同时发生、彼此独立,所以 \(\Delta z\) 是两者之和——这正是全微分 \(dz=\frac{\partial z}{\partial u}\,du+\frac{\partial z}{\partial v}\,dv\) 的来源。

阅读指引: 这里先给出最一般的"多路径"公式,只需有个印象即可。第三~五节会用最简单的单路径主例 \(f=\sigma(wx+b)\) 把它落地,扇出(多路径)留到第五节再细讲——所以即使现在觉得一般式有点抽象也不用担心,后面是从简单到复杂逐步展开的。

要点

  • 一条路径内部:偏导连乘(链式)。
  • 多条路径之间:贡献相加(叠加)。
  • 路径数量会随网络深度爆炸,所以我们绝不"枚举路径",而是用计算图把求和自动地、不重复地组织起来——这正是反向传播。

三、计算图:节点是运算,边是依赖

任何一个由初等运算搭起来的表达式,都能画成一张有向无环图 DAG(directed acyclic graph)

"无环 acyclic" = 没有循环依赖,所以所有节点能排成一个先后顺序(拓扑序):前向按序计算、反向逆序回传,每个量只算一次。正因为无环,整套"一趟前向 + 一趟反向"才成立。

以本课主例 \(f=\sigma(wx+b)\)(\(\sigma\) 是 sigmoid 函数,\(\sigma(t)=\frac{1}{1+e^{-t}}\))为例,拆成最小运算:

\[ u=w\,x,\qquad z=u+b,\qquad f=\sigma(z). \]
表达式 f=σ(wx+b) 的计算图,从左到右的前向数据流输入wxbu = w·x×(乘)z = u + b+(加)f = σ(z)σ(sigmoid)输出箭头方向 = 前向数据流(从输入到输出求值)
图1:表达式 f=σ(wx+b) 的计算图。输入节点 w、x、b(左侧),运算节点 u=w·x、z=u+b、f=σ(z)(向右),箭头表示数据依赖(前向求值方向)。每个节点旁标注运算类型;这是后续前向/反向标注的骨架。

前向传播 forward pass = 从输入沿边求值,把每个节点的数值算出来并缓存(反向时要用)。给定 \(w=2,\ x=-1,\ b=0.5\):

\[ u=2\cdot(-1)=-2,\quad z=-2+0.5=-1.5,\quad f=\sigma(-1.5)=0.1824. \]

四、反向传播:从输出沿边回传梯度

我们想要的是输出对每个输入的梯度。反向传播的做法是给每个节点配一个"梯度信号"——通常记 \(\bar v \equiv \dfrac{\partial f}{\partial v}\)(读作"\(v\) 的伴随 adjoint"),表示输出对该节点的敏感度。(再次提醒:这里的输出 \(f\) 就是上文说的"终点节点";训练时把它换成损失 \(L\),记号同理。)规则只有两条:

  1. 起点:输出节点对自己的梯度是 1,即 \(\bar f=\dfrac{\partial f}{\partial f}=1\)。
  2. 沿边回传:若有边 \(i\to j\)(即 \(j\) 用到了 \(i\)),则 \(i\) 从这条边收到的贡献是 \[ \underbrace{\bar j}_{\text{上游梯度}}\times\underbrace{\frac{\partial j}{\partial i}}_{\text{局部梯度 local gradient}}. \] 一个节点把它所有出边收到的贡献相加,就得到自己的 \(\bar i\)(这正是"多路径求和"在图上的体现)。

每条边只需知道一个局部导数(只看这一步运算)。反向时把"传到下游的梯度"和"这条边的局部梯度"相乘——上游梯度 × 局部梯度——梯度就像水流一样从输出端逆流而上。复杂全在前向,反向只是机械地乘和加。

给本例标注每条边的局部导数(用前向缓存的值):

局部导数数值
\(z\to f\)\(\sigma'(z)=f(1-f)\)\(0.1824(1-0.1824)=0.1491\)
\(u\to z\)\(\partial z/\partial u=1\)\(1\)
\(b\to z\)\(\partial z/\partial b=1\)\(1\)
\(w\to u\)\(\partial u/\partial w=x\)\(-1\)
\(x\to u\)\(\partial u/\partial x=w\)\(2\)
前向求值(绿色向右)与反向回传(橙色向左)双向标注的计算图图2:前向值(绿,向右)与反向梯度(橙,向左)wxbu ×z +f σw=2x=-1b=0.5u=-2.0z=-1.5f=0.1824x=-1w=211σ′=0.1491f̄=1z̄=0.1491ū=0.1491b̄=0.1491w̄=-0.1491x̄=0.2983前向求值(向右)反向回传(向左)
图2:同一张图上的前向值(绿色,向右)与反向梯度(橙色,向左)双向标注。绿色给出 u=-2、z=-1.5、f=0.1824;橙色给出每条边的局部导数与回传的伴随 f̄=1、z̄=0.1491、ū=0.1491、w̄=-0.1491、x̄=0.2983、b̄=0.1491。直观展示『上游梯度×局部梯度』如何逆流。

例题:对 \(f=\sigma(wx+b)\) 手工跑一次前向 + 一次反向

输入 \(w=2,\ x=-1,\ b=0.5\)。前向已得 \(u=-2,\ z=-1.5,\ f=0.1824\)。现在反向,从 \(\bar f=1\) 开始:

  1. \(\bar z=\bar f\cdot\sigma'(z)=1\times 0.1491=0.1491.\)
  2. \(\bar u=\bar z\cdot\dfrac{\partial z}{\partial u}=0.1491\times 1=0.1491;\qquad \bar b=\bar z\cdot\dfrac{\partial z}{\partial b}=0.1491\times 1=0.1491.\)
  3. \(\bar w=\bar u\cdot\dfrac{\partial u}{\partial w}=0.1491\times x=0.1491\times(-1)=-0.1491.\)
  4. \(\bar x=\bar u\cdot\dfrac{\partial u}{\partial x}=0.1491\times w=0.1491\times 2=0.2983.\)

所以最终梯度(保留 4 位):

\[ \frac{\partial f}{\partial w}=-0.1491,\qquad \frac{\partial f}{\partial x}=0.2983,\qquad \frac{\partial f}{\partial b}=0.1491. \]

核对:用闭式 \(\partial f/\partial w=\sigma'(z)\,x\),\(\partial f/\partial x=\sigma'(z)\,w\),\(\partial f/\partial b=\sigma'(z)\),数值完全一致。后面的代码还会用有限差分再验一遍。

五、扇出节点:被多处使用,梯度要相加

如果一个变量被多个下游用到,它在图上有多条出边,这叫扇出 fan-out。规则第 2 条说得很清楚:把所有出边收到的梯度相加。这其实就是"多路径求和"最常见的来源。

看一个最小例子 \(g=x^2+xy\),这里 \(x\) 同时进了 \(p=x^2\) 和 \(q=xy\) 两个节点(扇出度为 2,指 \(x\) 有两条出边 \(x\to p\) 和 \(x\to q\)):

\[ \frac{\partial g}{\partial x} =\underbrace{\frac{\partial g}{\partial p}\frac{\partial p}{\partial x}}_{\text{路径 }x\to p\to g} +\underbrace{\frac{\partial g}{\partial q}\frac{\partial q}{\partial x}}_{\text{路径 }x\to q\to g} =1\cdot 2x+1\cdot y=2x+y. \]

在 \(x=3,\ y=2\) 处:\(\partial g/\partial x=2\cdot3+2=8\),\(\partial g/\partial y=x=3\)。

扇出节点梯度相加示意图:x 的梯度由两条路径累加 x̄ = 2x + y = 8 扇出节点梯度相加:x̄ = 2x + y ∂p/∂x = 2x ∂q/∂x = y x + p p = x² q q = x·y y g g = p + q 扇出 fan-out:累加 x̄ = 2x + y = 2·3 + 2 = 6 + 2 = 8
图3:扇出节点梯度相加示意。变量 x 同时进入 p=x² 和 q=x·y 两个节点(两条出边即扇出),反向时 x 的梯度由两条路径回传的梯度累加而成:x̄ = 2x + y。在 x=3,y=2 处为 6+2=8。

易错

扇出节点的梯度是累加而不是覆盖。手写反向传播时最常见的 bug 就是:第二条出边的梯度把第一条冲掉了(写成赋值 grad = ... 而非累加 grad += ...)。把每个节点的梯度初始化为 0、所有入流一律 +=,可一劳永逸地避免。这也解释了为什么框架里梯度默认累加、需要你每步手动 zero_grad() 清零。

六、前向模式 vs 反向模式:为什么深度学习选反向

自动微分 automatic differentiation(autodiff)有两种走法,区别在于"沿图传播的是什么":

设函数 \(f:\mathbb{R}^n\to\mathbb{R}^m\),雅可比是 \(m\times n\) 矩阵。要拼出整个雅可比:

要点:深度学习为什么用反向模式

训练时损失 \(L\) 是一个标量(\(m=1\)),而参数有几百万上亿个(\(n\) 极大)。反向模式只需一趟反向(= 一次前向 + 一次反向回传)就拿到 \(L\) 对全部参数的梯度,总工作量只是一次前向的常数倍(约 2~3 倍),与 \(n\) 无关。正向模式则要跑 \(n\) 趟,完全不可行。一句话:多输入、单输出(标量损失)→ 反向模式

代价由"输入维 vs 输出维"决定,与"网络多深"无关:很深但单标量输出,反向模式依旧便宜。反过来,若要算一个少输入、多输出函数的全部导数(\(n\) 小 \(m\) 大),正向模式反而更划算。

向量-雅可比积 VJP 视角

反向传播的每一步,本质都是一个向量-雅可比积 vector-Jacobian product(VJP)。设某节点运算 \(y=\phi(x)\)(\(x\in\mathbb{R}^n,\ y\in\mathbb{R}^m\)),上游传来的梯度(伴随)是行向量 \(\bar y\)(\(1\times m\),因为它是标量输出对向量 \(y\) 的偏导)。这一步把它变成 \(\bar x\):

\[ \bar x=\bar y\,J_\phi,\qquad J_\phi=\frac{\partial y}{\partial x}\ (m\times n). \]

(维度核对:行向量 \(1\times m\) 乘 \(m\times n\) 矩阵 = 行向量 \(1\times n\),正是 \(\bar x\)。)

这里先不纠结行/列向量与转置约定——下一课"矩阵求导"会统一声明布局(分母布局 denominator layout)并把上式精确化。本节只需记住直觉:每步反向 = 拿上游梯度去乘这一步的局部雅可比,且永不显式构造那个大矩阵 \(J_\phi\)。

关键在于:我们从不显式构造大矩阵 \(J_\phi\),只实现"给定 \(\bar y\) 直接算出 \(\bar y\,J_\phi\)"的函数。对加法、乘法、sigmoid 这些运算,VJP 都有简单闭式(比如 sigmoid 的 VJP 就是 \(\bar y\odot f\odot(1-f)\))。整个反向传播 = 把这些 VJP 沿图从输出到输入串起来。

ML 和 ML 的联系

PyTorch / JAX / TensorFlow 的 autograd 引擎做的就是这件事:前向时偷偷记录计算图并缓存中间值,loss.backward() 时从标量损失 \(L\) 出发,按 VJP 规则沿图回传,给每个参数张量填上 .grad。"算梯度只比前向贵一个常数倍"正是反向模式的承诺,也是大模型能用梯度下降训练的根本原因。你这一课手算的伴随 \(\bar w,\bar b\)(在本例中是 \(\partial f/\partial w,\ \partial f/\partial b\);训练里就是 \(\partial L/\partial w,\ \partial L/\partial b\)),就是框架里 w.grad, b.grad 的值。

七、detach / stop-gradient:截断梯度流

有时我们希望某个量在前向照常参与计算,但反向时不让梯度穿过它。这就是 detach()(PyTorch)/ stop_gradient()(JAX/TF)。在计算图里,它相当于把那条边的局部梯度强制设成 0:值照传,梯度到此为止。

易错

detach 只切梯度,不改数值。前向结果和不 detach 时一模一样,区别只在反向:被 detach 的子图拿不到梯度。如果你发现某些参数 .grad 始终是 0 或 None,第一嫌疑就是某处不小心 detach 了(或用了 with torch.no_grad())。注意区分二者:torch.no_grad() 是临时关掉整段计算的建图(连前向都不记录图),而 detach 只切断单个张量这一条边——两者都会让下游拿不到梯度,但作用范围不同。

八、代码:numpy 手写计算图 + 有限差分校验

下面用纯 numpy 实现 \(f=\sigma(wx+b)\) 的前向与反向,并用中心差分 central difference \(\frac{F(\theta+h)-F(\theta-h)}{2h}\) 作为"数值真值"对照,打印相对误差。

import numpy as np

def sigmoid(t):
    return 1.0 / (1.0 + np.exp(-t))

def forward(w, x, b):
    # 前向:逐节点求值并缓存
    u = w * x          # 节点 u
    z = u + b          # 节点 z
    f = sigmoid(z)     # 节点 f(输出)
    cache = (w, x, b, u, z, f)
    return f, cache

def backward(cache):
    w, x, b, u, z, f = cache
    fbar = 1.0                       # d f / d f = 1
    zbar = fbar * (f * (1.0 - f))    # 经过 sigmoid:局部导数 f(1-f)
    ubar = zbar * 1.0                # z = u + b
    bbar = zbar * 1.0                # b 的出边
    wbar = ubar * x                  # u = w*x,对 w 局部导数是 x
    xbar = ubar * w                  # 对 x 局部导数是 w
    return dict(w=wbar, x=xbar, b=bbar)

# 取一组具体输入
w, x, b = 2.0, -1.0, 0.5
f, cache = forward(w, x, b)
grads = backward(cache)
print("forward: u=%.4f z=%.4f f=%.6f" % (cache[3], cache[4], f))
print("analytic grads:", {k: round(v, 6) for k, v in grads.items()})

# 有限差分校验(中心差分)
def F_only(w, x, b):
    return sigmoid(w * x + b)

h = 1e-6
num = {}
num["w"] = (F_only(w+h, x, b) - F_only(w-h, x, b)) / (2*h)
num["x"] = (F_only(w, x+h, b) - F_only(w, x-h, b)) / (2*h)
num["b"] = (F_only(w, x, b+h) - F_only(w, x, b-h)) / (2*h)

print("numeric  grads:", {k: round(v, 6) for k, v in num.items()})
for k in ["w", "x", "b"]:
    a, n = grads[k], num[k]
    rel = abs(a - n) / (abs(n) + 1e-12)
    print("relative error %s: %.2e" % (k, rel))

运行后会看到解析梯度 \(\{w:-0.1491,\ x:0.2983,\ b:0.1491\}\) 与数值梯度几乎相同,三个相对误差都在 \(10^{-8}\) 量级——这就是"反向传播算对了"的标准自检方法(gradient check 梯度检验)。

调一调,观察现象

下面三个微改都基于本课已验证的主例,改一个数、跑几秒,就能把"梯度怎么流"看成可复现的现象。

任务 1:改 sigmoid 的输入 z,看局部梯度如何"饱和衰减"

改什么:把进入 sigmoid 的 \(z\) 从 0 一路调大到 10。预期现象:局部梯度 \(\sigma'(z)=f(1-f)\) 从 0.25 单调塌向 0——z=1.5 时约 0.149,z=3 时约 0.045,z=6 时已不足 0.003(约 0.00247),z=10 时只剩约 0.000045。为什么:反向时每条经过 sigmoid 的边都要乘这个数,一旦神经元饱和(\(f\to 0\) 或 1),梯度就被乘没了——这就是"梯度消失"的最小现场。

import numpy as np
def sigmoid(t): return 1.0/(1.0+np.exp(-t))
for z in [0.0, 1.5, 3.0, 6.0, 10.0]:
    f = sigmoid(z)
    print("z=%5.1f  f=%.6f  sigma'(z)=f(1-f)=%.6f" % (z, f, f*(1-f)))

任务 2:改扇出节点 x 的两条出边,看梯度是累加还是覆盖

改什么:在 \(g=x^2+xy\) 里,先把 x 的梯度按本课规则用 += 从两条出边累加,再故意改成只取第二条(赋值覆盖)。预期现象:累加版始终给出正确的 \(\partial g/\partial x=2x+y\)(在 \(x=3,y=2\) 得 8,在 \(x=1,y=5\) 得 7,在 \(x=-2,y=4\) 得 0);覆盖版只剩第二条出边的 \(y\)(分别是 2、5、4),凭空丢了第一条 \(2x\)(如 \(x=3\) 处丢了 6)。为什么:扇出节点同时影响多个下游,各路贡献必须叠加;写成赋值会让后一条出边把前一条冲掉——这正是手写反向最常见的 bug。

import numpy as np
for (x, y) in [(3.0, 2.0), (1.0, 5.0), (-2.0, 4.0)]:
    pbar, qbar = 1.0, 1.0          # g = p + q, p=x^2, q=x*y
    accumulate = pbar*(2*x) + qbar*y   # 累加(正确)
    overwrite  = qbar*y                # 只取第二条出边(错误)
    print("x=%.1f y=%.1f  累加=%.1f (2x+y=%.1f)  覆盖=%.1f" %
          (x, y, accumulate, 2*x+y, overwrite))

任务 3:对 x 做 detach,看"值照传、梯度截断"

改什么:在主例 \(f=\sigma(wx+b)\) 的反向里,把回到 x 那条边的局部梯度强制设为 0(即 detach x),其余不动。预期现象:\(\partial f/\partial x\) 从 0.2983 变成 0,而 \(\partial f/\partial w=-0.1491\) 和前向输出 \(f=0.182426\) 完全不变。为什么:detach 只把那条边的梯度截断,数值照常前向参与,所以其他参数的梯度毫发无损——这就是 stop-gradient 的本质。

import numpy as np
def sigmoid(t): return 1.0/(1.0+np.exp(-t))
w, x, b = 2.0, -1.0, 0.5
u = w*x; z = u + b; f = sigmoid(z)
zbar = 1.0*(f*(1-f))               # 反向到 z
print("前向 f=%.6f(detach 不改前向)" % f)
print("正常: df/dw=%.4f  df/dx=%.4f" % (zbar*x, zbar*w))
print("detach x: df/dw=%.4f  df/dx=%.4f" % (zbar*x, zbar*0.0))

动手练习

  1. 线性回归损失。 设 \(L=(wx-y)^2\),给定 \(w=1.5,\ x=2,\ y=1\)。先画计算图(节点:\(p=wx,\ r=p-y,\ L=r^2\)),手算 \(\partial L/\partial w,\ \partial L/\partial x\),再用下面骨架的反向与有限差分对照。
    import numpy as np
    w, x, y = 1.5, 2.0, 1.0
    def forward(w, x, y):
        p = w * x
        r = p - y
        L = r * r
        return L, (w, x, y, p, r)
    def backward(c):
        w, x, y, p, r = c
        Lbar = 1.0
        rbar = Lbar * (2 * r)     # dL/dr = 2r
        pbar = rbar * 1.0         # r = p - y
        wbar = pbar * x           # TODO: 确认局部导数
        xbar = pbar * w
        return dict(w=wbar, x=xbar)
    L, c = forward(w, x, y)
    print("L=", L, "grads=", backward(c))
    # TODO: 加中心差分校验 w 和 x
    
  2. 扇出累加。 实现 \(g=x^2+xy\) 的反向,要求 \(x\) 的梯度用 += 从两条出边累加(节点 \(p=x^2,\ q=xy\))。在 \(x=3,y=2\) 验证得到 \(\partial g/\partial x=8,\ \partial g/\partial y=3\);故意把累加改成赋值,观察 \(x\) 的梯度变成多少、错在哪。
  3. detach 实验。 在第 1 题里把 \(x\) "detach":反向时令来自 \(x\) 那条边的梯度为 0(即 xbar = 0.0),其余不变。验证 \(\partial L/\partial w\) 不受影响而 \(\partial L/\partial x=0\),体会"值照传、梯度截断"。
  4. 正向 vs 反向的代价。 对 \(f=\sigma(wx+b)\),用分别对 \(w,x,b\) 各做一次单边扰动来"模拟正向模式"(3 趟),数一数你做了几次 sigmoid 求值;再对比一次反向传播只算了 1 次 sigmoid。写一句话总结:输入越多,反向模式省得越多。
  5. (挑战)两层链。 设 \(a=\sigma(w_1 x),\ f=\sigma(w_2 a)\),求 \(\partial f/\partial w_1,\ \partial f/\partial w_2,\ \partial f/\partial x\),并用有限差分校验。建议节点拆分: \(t_1=w_1\cdot x,\ a=\sigma(t_1)\);\(t_2=w_2\cdot a,\ f=\sigma(t_2)\)。注意 \(t_2\) 经 \(a\) 同时依赖 \(w_1\) 和 \(x\),所以回到 \(a\) 的梯度要继续往 \(t_1\) 这一路回传——确认梯度路径没漏。

掌握自检

下一课,我们把这套"节点—边—局部梯度"的机制升级成矩阵形式(矩阵求导):当节点不再是标量而是向量、矩阵时,每个运算的 VJP 怎么写成干净的矩阵公式,从而一次性给整层网络求出梯度。届时会正式声明分母布局 denominator layout(\(\partial L/\partial W\) 与 \(W\) 同形)这一深度学习惯例,并把本课 VJP 的行/列向量与转置约定彻底敲定。