首页 / 模块 3 · Python、NumPy、PyTorch 与实验工程 / 第 5 课(共 10 课)

nn.Module 与模型构建

从零到前沿 ML 自学课程 · 阶段0:数学与工具基础 · 能力点:nn.Module——把张量运算组织成可保存/迁移的模型(注册机制是关键)

上一课你用一个手写的标量自动微分引擎(micrograd)看清了反向传播:每个张量都记着自己怎么来的,.backward() 顺着这张图把梯度一路累加回去。但真实模型有成百上千个参数,你不可能手动 w1, b1, w2, b2 = ... 一个个拎着它们走——优化器要知道"该更新哪些张量",保存模型要知道"该存哪些张量",换设备(GPU)要知道"该搬哪些张量"。这一课讲的 nn.Module 就是这套"参数账本"的标准答案:它替你登记、收集、保存所有参数。把这一课啃透,下一课的训练循环里 optimizer.step() 才不会是黑魔法。

读完这一课,你将能够

  • 解释 nn.Module 为什么能"自动找到"所有参数——也就是 __setattr__ 拦截注册机制。
  • 手写一个迷你 Module 基类,实现 parameters() 递归收集和 state_dict() 雏形。
  • 指出用普通 Python list 装子模块的后果,并改用 nn.ModuleList 修复。
  • 区分 model(x)model.forward(x),说出 train()/eval() 分别影响什么。
  • state_dict() 保存与加载一个模型的权重。

nn.Module 的心智模型:__init__ 注册,forward 算账

一个 nn.Module 子类只有两件核心职责,分工清晰:

一个关键习惯:调用 model(x) 而不是 model.forward(x)。两者算出的结果一样,但 model(x) 会触发 nn.Module__call__,它在调你的 forward 前后还做了别的事——比如执行注册的钩子(hook)。直接调 forward 会绕过这些机制;很多调试工具、可视化、量化都靠钩子工作,绕过它们的 bug 极难排查。把"模型当函数调"当成肌肉记忆。

模型天然是树形结构:根 Module 下挂子 Module,子 Module 下还能挂子 Module,叶子上挂着参数。下图是一个两层网络的样子——注意右边那个用普通 list 装着的子模块是"游离"的。

nn.Module 是一棵树 根 Module 挂子模块,子模块挂参数;用普通 list 装的子模块游离在注册账本之外。 nn.Module 是一棵树:根挂子模块,子模块挂参数 右侧虚线灰色 = 用普通 Python list 装的子模块,游离在账本之外(parameters() / .to() / state_dict() 都看不见它) parameters() 只递归走实线分支(fc1/fc2 共 4 个参数) MLP (根 Module) self.extra = [Linear(...)] ← 普通 list fc1: Linear 子模块 fc2: Linear 子模块 extra[0]: Linear 未注册 weight (4, 3) bias (4,) weight (1, 4) bias (1,) weight (2, 1) bias (2,) ↑ 账本边界 已注册子树(实线) 未注册(游离在外) 图例 实线箭头 = 已注册(被 parameters() / .to() / state_dict() 遍历) 灰色虚线箭头 = 未注册(参数存在却收不到) 圆角矩形 = Module 椭圆 = Parameter 叶子
nn.Module 是一棵树:根模块下挂子模块,子模块挂参数。用普通 list 装的子模块(右侧虚线灰色)游离在账本之外——parameters()、.to()、state_dict() 都看不见它。

注册机制:手写一个迷你 Module 讲透"参数怎么被找到的"

核心问题:你在 __init__ 里写 self.fc1 = Linear(3, 4),为什么之后 parameters() 就能把 fc1 里的 weight、bias 都吐出来?魔法藏在 __setattr__:每当你给 self 的某个属性赋值,Python 都会调用 __setattr__nn.Module 重写了它——赋值时偷看一眼,如果你赋的是 Parameter 或子 Module,就顺手记进一本内部账本。下面这个迷你框架只用纯 Python + NumPy,把这套机制完整复刻出来。其中 state_dict() 返回的就是一本"参数名字符串 → 张量"的字典(保存模型时存的正是它,下面"初始化与保存"节会再用到),并故意埋一个"普通 list 装子模块"的坑:

import numpy as np

class Parameter:
    def __init__(self, data):
        self.data = np.asarray(data, dtype=float)

class Module:
    def __init__(self):
        # 用 object.__setattr__ 绕过下面拦截,先把两本账本建好
        object.__setattr__(self, "_params", {})
        object.__setattr__(self, "_modules", {})

    def __setattr__(self, name, value):       # 赋值时拦截
        if isinstance(value, Parameter):
            self._params[name] = value        # 是参数 -> 记进参数账本
        elif isinstance(value, Module):
            self._modules[name] = value        # 是子模块 -> 记进子模块账本
        object.__setattr__(self, name, value)  # 照常真正赋值

    def parameters(self):                      # 递归收集所有参数
        for p in self._params.values():
            yield p
        for m in self._modules.values():
            yield from m.parameters()

    def state_dict(self, prefix=""):           # 参数名字符串 -> 张量
        sd = {}
        for name, p in self._params.items():
            sd[prefix + name] = p.data
        for name, m in self._modules.items():
            sd.update(m.state_dict(prefix + name + "."))
        return sd

    # 基类不实现 forward,留给子类定义;__call__ 调的就是子类的 forward
    def __call__(self, x):                     # 让 model(x) 等价于 forward
        return self.forward(x)

class Linear(Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.weight = Parameter(np.random.randn(d_out, d_in) * 0.1)
        self.bias = Parameter(np.zeros(d_out))
    def forward(self, x):
        return x @ self.weight.data.T + self.bias.data

class MLP(Module):
    def __init__(self):
        super().__init__()
        self.fc1 = Linear(3, 4)
        self.fc2 = Linear(4, 1)
        self.extra = [Linear(1, 1)]            # 经典坑:list 装子模块!
    def forward(self, x):
        h = np.maximum(self.fc1(x), 0)         # ReLU
        return self.fc2(h)

np.random.seed(0)
net = MLP()
x = np.random.randn(2, 3)
y = net(x)
print("输出形状:", y.shape)
print("注册到的参数个数:", sum(1 for _ in net.parameters()))
print("state_dict 的键:", list(net.state_dict().keys()))
print("list 里的 Linear 被注册了吗:", any(
    p is net.extra[0].weight for p in net.parameters()))

运行输出:

输出形状: (2, 1)
注册到的参数个数: 4
state_dict 的键: ['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias']
list 里的 Linear 被注册了吗: False

读懂这段输出,本课就过了一大半:

经典坑:普通 list 装子模块会"消失"

真实 nn.Module__setattr__ 和上面那个迷你版一模一样:只认 nn.Parameternn.Module,不认普通 list/dict。所以下面这段在 PyTorch 里看似无害的代码,藏着一个会让你 debug 一下午的 bug:

import torch.nn as nn

class BadNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = [nn.Linear(4, 4) for _ in range(3)]   # 错!普通 list
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

net = BadNet()
print(len(list(net.parameters())))   # 0 —— 一个参数都没有!

后果是连锁的,全都源于"没被注册进账本":

修复只需把容器换成会注册的版本——nn.ModuleList(按下标用)或 nn.ModuleDict(按名字用)。它们内部对每个元素都走了一遍注册:

class GoodNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleList([nn.Linear(4, 4) for _ in range(3)])
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

print(len(list(GoodNet().parameters())))   # 6 = 3 层 × (weight + bias)

记一条规则:凡是要装"会学习的零件"的容器,必须用 nn.ModuleList/nn.ModuleDict,绝不用裸 list/dict。如果是装多个张量参数,对应的是 nn.ParameterList/nn.ParameterDict

nn.Parameter:什么样的张量算"可学习的"

在迷你框架里,我们用一个 Parameter 类来标记"这是要学的张量"。真实 PyTorch 里 nn.ParameterTensor 的子类,被赋给 Module 属性时自动登记进 parameters()、自动 requires_grad=True、会被优化器更新、会进 state_dict。普通 torch.Tensor 赋给属性则不会被当成参数。

那"需要随模型搬设备、要存档,但又不该被优化器更新"的张量怎么办(比如 BatchNorm 的滑动均值)?用 register_buffer 注册成缓冲区(buffer):进 state_dict、随 .to() 搬,但不进 parameters()、不被梯度更新。

import torch, torch.nn as nn

class Affine(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.scale = nn.Parameter(torch.ones(d))     # 学习
        self.register_buffer("running_mean", torch.zeros(d))  # 存档不学习
    def forward(self, x):
        return (x - self.running_mean) * self.scale

train() / eval():一个开关,两种行为

有些层在训练和推理时行为不同,最典型的是 Dropout(训练时随机丢弃一部分激活以防过拟合,推理时全部保留)和 BatchNorm(训练时用当前 batch 的统计量,推理时用累计的滑动统计量)。nn.Module 用一个布尔标志 self.training 控制,model.train() 把整棵树设为 Truemodel.eval() 设为 False。下面用 NumPy 手写 Dropout,看清这个开关到底改变了什么:

import numpy as np

def dropout(x, p, training):
    if not training:
        return x                                  # 推理:原样通过
    mask = (np.random.rand(*x.shape) > p) / (1 - p)  # 训练:随机丢 + 放大补偿
    return x * mask

np.random.seed(0)
x = np.ones((1, 10000))
train_out = dropout(x, p=0.5, training=True)
eval_out  = dropout(x, p=0.5, training=False)
print("training=True  输出均值:", round(float(train_out.mean()), 3))
print("training=False 输出均值:", round(float(eval_out.mean()), 3))
print("两次 training=True 的输出相同吗:",
      np.allclose(dropout(x, 0.5, True), dropout(x, 0.5, True)))
training=True  输出均值: 0.987
training=False 输出均值: 1.0
两次 training=True 的输出相同吗: False

注意 training=True 时输出均值仍接近 1(因为存活的激活被除以 1-p 放大补偿了),但每次都不一样——它是随机的。这就是为什么验证/推理前必须 model.eval():忘了切,模型在验证时还在随机丢神经元,每次预测都抖动,验证 loss 虚高且不可复现;BatchNorm 还会用单个 batch 的统计量算出错误结果。这个坑下一课写训练循环时会再踩一遍——记住验证段的标准开头是 model.eval(),并配 with torch.no_grad():

初始化与保存:nn.init 和 state_dict

参数的初始值不是小事。如果权重太大,前向传播时激活值会层层放大直至爆炸;太小则层层衰减到几乎为零,梯度也跟着消失。下面把初始化标准差从"随便取 1"换成"按维度缩放"(Xavier/He 的思想:让每层输出方差大致保持不变),跑 10 层就能看到天壤之别:

import numpy as np
np.random.seed(0)
d = 100
naive  = np.random.randn(d, d) * 1.0                  # 标准差 1,太大
scaled = np.random.randn(d, d) * np.sqrt(2.0 / d)     # He 初始化(配 ReLU)
h_n = h_s = np.random.randn(1, d)
for _ in range(10):
    h_n = np.maximum(h_n @ naive,  0)                 # 10 层 ReLU
    h_s = np.maximum(h_s @ scaled, 0)
print("朴素初始化 10层后 激活均方根:", round(float(np.sqrt((h_n**2).mean())), 2))
print("缩放初始化 10层后 激活均方根:", round(float(np.sqrt((h_s**2).mean())), 2))
朴素初始化 10层后 激活均方根: 99664490.48
缩放初始化 10层后 激活均方根: 1.24

朴素初始化的激活在 10 层后已飙到约 1e8(数值溢出的前兆,再深下去就会爆掉),而缩放后的稳定在 1.24——也就是 O(1) 的合理量级(不是无限缩小,而是"既没爆炸也没消失")。PyTorch 的 nn.init 提供了现成的初始化函数;nn.Linear 等层也内置了合理的默认初始化,多数时候你不必手动调,但要知道"初始化坏了"是训练不动的常见原因之一。

import torch.nn as nn

layer = nn.Linear(100, 100)
nn.init.kaiming_normal_(layer.weight, nonlinearity="relu")  # He 初始化
nn.init.zeros_(layer.bias)

训练好的模型靠 state_dict(参数名字符串 → 张量 的字典,就是迷你框架里那本账本)保存和加载。注意保存的是 state_dict 而非整个模型对象——后者依赖类定义和文件路径,脆弱且不安全。下图展示这张映射表的样子:

键(字符串) 值(张量 Tensor) state_dict OrderedDict: 参数名 → 张量 'fc1.weight' (4, 3) 'fc1.bias' (4,) 'fc2.weight' (1, 4) 'fc2.bias' (1,) 'fc1' (父模块名) + '.' + 'weight' (参数名) 点号 = 父子层级的分隔符 torch.save(model.state_dict(), ...) 存的就是这张表
state_dict 是一张「参数名字符串 → 张量」的映射表。键名由父模块名用点号拼接子参数名得来,因此能唯一定位树中任意一个张量。
import torch

# 保存:只存权重字典
torch.save(model.state_dict(), "model.pt")

# 加载:先构建结构相同的模型,再灌入权重
model = MyModel()
model.load_state_dict(torch.load("model.pt"))
model.eval()      # 推理前别忘了切换模式

调一调,观察现象

微任务 1:把 list 换成 ModuleList(在迷你框架里模拟)。预期现象:用普通 list 装的那个对象 parameters() 数到 0 个,用 ModuleList 装的数到 2 个(= 1 层 ×(weight + bias))。为什么:list 不被 __setattr__ 识别;只要让容器里每个子模块都注册进 _modules 账本,parameters() 就能递归收到它们。(把这个 0→2 的修复推广到上面三层的 MLP,注册数就会从 4 变成 6。)

import numpy as np

class Parameter:
    def __init__(self, data): self.data = np.asarray(data, float)

class Module:
    def __init__(self):
        object.__setattr__(self, "_params", {})
        object.__setattr__(self, "_modules", {})
    def __setattr__(self, name, value):
        if isinstance(value, Parameter): self._params[name] = value
        elif isinstance(value, Module):  self._modules[name] = value
        object.__setattr__(self, name, value)
    def parameters(self):
        for p in self._params.values(): yield p
        for m in self._modules.values(): yield from m.parameters()

class ModuleList(Module):                       # 关键:注册每个子模块
    def __init__(self, mods):
        super().__init__()
        for i, m in enumerate(mods):
            setattr(self, str(i), m)            # 触发 __setattr__ 注册

class Linear(Module):
    def __init__(self, a, b):
        super().__init__()
        self.weight = Parameter(np.zeros((b, a)))
        self.bias = Parameter(np.zeros(b))

bad = Module(); bad.layers = [Linear(2,2)]                  # list
good = Module(); good.layers = ModuleList([Linear(2,2)])    # ModuleList
print("list 装:", sum(1 for _ in bad.parameters()), "个参数")
print("ModuleList 装:", sum(1 for _ in good.parameters()), "个参数")
list 装: 0 个参数
ModuleList 装: 2 个参数

微任务 2:把 Dropout 概率 p 从 0.5 调到 0.9。预期现象:training=True 时大量激活被置零,但输出均值仍接近 1。为什么:存活概率只有 1-p=0.1,但每个存活的激活被除以 0.1 放大了 10 倍,期望上正好补偿,所以均值不变、方差暴涨。

import numpy as np
np.random.seed(1)
def dropout(x, p, training):
    if not training: return x
    return x * (np.random.rand(*x.shape) > p) / (1 - p)
x = np.ones((1, 10000))
for p in [0.1, 0.5, 0.9]:
    out = dropout(x, p, training=True)
    print(f"p={p}: 被置零比例≈{round(float((out==0).mean()),2)}, 输出均值≈{round(float(out.mean()),3)}")
p=0.1: 被置零比例≈0.1, 输出均值≈1.0
p=0.5: 被置零比例≈0.49, 输出均值≈1.02
p=0.9: 被置零比例≈0.9, 输出均值≈1.03

微任务 3:把初始化标准差从 He 缩放改回 1.0,层数从 10 调到 30。预期现象:朴素初始化 30 层后激活均方根飙到约 1e24 量级(一个极大但仍有限的数,指数级膨胀;再往深堆几十层就会溢出成 inf),而 He 缩放仍稳定在约 1.15。为什么:每层把激活放大约固定倍数,层数越多指数累积越猛,足够深就会超出浮点数表示范围。

import numpy as np
np.random.seed(0)
d = 100
for scale, tag in [(1.0, "朴素"), (np.sqrt(2.0/d), "He缩放")]:
    W = np.random.randn(d, d) * scale
    h = np.random.randn(1, d)
    for _ in range(30):
        h = np.maximum(h @ W, 0)
    print(f"{tag}初始化 30层后 激活均方根:", round(float(np.sqrt((h**2).mean())), 2))
朴素初始化 30层后 激活均方根: 2.2656832467872677e+24
He缩放初始化 30层后 激活均方根: 1.15

动手练习

掌握自检