权重初始化
深度网络训练成败的一个关键因素是权重初始化。不好的初始化会导致梯度消失或梯度爆炸,让训练完全失败。好消息是:PyTorch 默认已经帮你选了合适的初始化。
学习目标
- 理解为什么不能全零初始化
- 理解 Xavier / Glorot 初始化
- 理解 He / Kaiming 初始化
- 观察初始化对训练的影响
先建立一张地图
初始化这节最容易让新人觉得“像额外细节”,但它其实直接关系到模型能不能顺利开始学。
这节真正想解决的是:
- 为什么权重不能乱设
- 为什么不同激活函数要搭不同初始化
- 第一次写网络时,什么时候可以放心用 PyTorch 默认值
这节和前面几节是怎么接上的
如果你把前面几节串起来看,会发现这一节其实在回答一个很自然的问题:
- 神经元会前向传播
- 反向传播会把梯度传回来
- 优化器会更新参数
但这一切都有个前提:
- 网络一开始的信号和梯度不能太离谱
所以初始化其实是在回答:
模型训练开始前,第一步棋要怎么摆,后面整盘棋才不容易崩。
一、为什么初始化很重要?
1.1 全零初始化的问题
如果所有权重都是 0,那所有神经元计算结果完全一样,梯度也一样,永远不会分化——等于只有一个神经元。
1.2 随机初始化也有坑
- 太大:激活值饱和 → 梯度消失(Sigmoid/Tanh)或梯度爆炸
- 太小:信号逐层衰减 → 梯度也衰减 → 训练极慢
1.2.1 一个更适合新人的直觉:先别让每层“太安静”或“太激 动”
可以先把初始化想成给每层一个起跑姿势:
- 太小:像刚开始就没力气,信号一层层传着传着就没了
- 太大:像一上来就用力过猛,输出和梯度都可能失控
所以好初始化的目标非常朴素:
- 让前向信号别迅速衰减或爆炸
- 让反向梯度也还能稳定传回来
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
# 观察不同初始化下的激活值分布
torch.manual_seed(42)
def observe_activations(init_fn, title, activation=nn.Tanh()):
"""观察 10 层网络中每层的激活值分布"""
layers = []
for i in range(10):
linear = nn.Linear(256, 256, bias=False)
init_fn(linear.weight)
layers.append(linear)
layers.append(activation)
model = nn.Sequential(*layers)
# 记录每层输出
x = torch.randn(200, 256)
activations = []
for i in range(0, len(layers), 2):
x = layers[i](x) # Linear
x = layers[i+1](x) # Activation
activations.append(x.detach().numpy().flatten())
fig, axes = plt.subplots(2, 5, figsize=(15, 5))
for i, (ax, act) in enumerate(zip(axes.ravel(), activations)):
ax.hist(act, bins=50, color='steelblue', alpha=0.7)
ax.set_title(f'Layer {i+1}')
ax.set_xlim(-1.5, 1.5)
plt.suptitle(title, fontsize=13)
plt.tight_layout()
plt.show()
# 太小的初始化
observe_activations(
lambda w: nn.init.normal_(w, 0, 0.01),
'太小初始化 (std=0.01) + Tanh → 信号衰减'
)
# 太大的初始化
observe_activations(
lambda w: nn.init.normal_(w, 0, 1.0),
'太大初始化 (std=1.0) + Tanh → 饱和'
)
二、Xavier / Glorot 初始化
2.1 核心思想
让每一层的输入和输出的方差保持一致,避免信号逐层放大或衰减。
权重从 N(0, 2/(fan_in + fan_out)) 中采样
fan_in = 输入维度, fan_out = 输出维度
2.1.1 Xavier 最值得先记的,不是公式而是什么?
最值得先记的是它的目标:
- 尽量让每层输入输出的尺度不要差太多
公式只是实现这个目标的一种方式。
所以第一次学时,先稳住这个直觉,比死记分母形式更重要。
2.2 适用:Sigmoid / Tanh
observe_activations(
lambda w: nn.init.xavier_normal_(w),
'Xavier 初始化 + Tanh → 信号稳定'
)
三、He / Kaiming 初始化
3.1 核心思想
Xavier 假设激活函数是线性的。但 ReLU 会把一半神经元置为 0,所以需要更大的方差来补偿。
权重从 N(0, 2/fan_in) 中采样
3.1.1 He 初始化为什么会比 Xavier 更适合 ReLU?
因为 ReLU 会把一部分信号直接截成 0。
如果还沿用更保守的初始化,信号就更容易一路衰减。
所以 He 初始化可以先朴素理解成:
- 为了适应 ReLU 的“截断特性”,把起始方差稍微放大一点
3.2 适用:ReLU 及其变体
observe_activations(
lambda w: nn.init.kaiming_normal_(w, mode='fan_in', nonlinearity='relu'),
'He 初始化 + ReLU → 信号稳定',
activation=nn.ReLU()
)
四、选择指南
| 激活函数 | 推荐初始化 | PyTorch 函数 |
|---|---|---|
| Sigmoid / Tanh | Xavier | nn.init.xavier_normal_ |
| ReLU / Leaky ReLU | He (Kaiming) | nn.init.kaiming_normal_ |
| GELU / Swish | He | nn.init.kaiming_normal_ |
PyTorch 默认行为
# PyTorch 的 nn.Linear 默认使用 Kaiming Uniform
linear = nn.Linear(256, 128)
print(f"默认初始化范围: [{linear.weight.min():.4f}, {linear.weight.max():.4f}]")
# 手动指定初始化
nn.init.kaiming_normal_(linear.weight, mode='fan_in', nonlinearity='relu')
nn.init.zeros_(linear.bias)
PyTorch 的 nn.Linear 默认使用 Kaiming Uniform 初始化,nn.Conv2d 也是。大多数情况下你不需要手动初始化——但理解原理能帮你诊断训练异常。
4.1 新人第一次做项目时到底要不要手动初始化?
大多数时候:
- 不需要一开始就自己写初始化
- 先用 PyTorch 默认值通常就够了
更值得手动初始化的情况通常是:
- 你在做更深的网络实验
- 你怀疑训练很不稳定
- 你想系统比较不同初始化策略
所以这节最重要的不是“今天就手写很多初始化代码”,而是先知道:
- 默认值为什么通常可用
- 什么时候该怀疑初始化有问题
4.2 一个更稳的默认判断顺序
如果你刚开始做项目,可以先按这个顺序判断:
- 先用 PyTorch 默认初 始化
- 如果训练明显不稳,再看学习率和优化器
- 还不对,再去怀疑初始化和激活函数搭配
这样会比“一遇到问题就先改初始化”更稳,因为初始化虽然重要,但不一定总是第一嫌疑人。
五、初始化对训练的影响
# 对比不同初始化的训练效果
from sklearn.datasets import make_moons
X, y = make_moons(500, noise=0.2, random_state=42)
X_t = torch.FloatTensor(X)
y_t = torch.LongTensor(y)
init_methods = {
'全零': lambda w: nn.init.zeros_(w),
'N(0, 0.01)': lambda w: nn.init.normal_(w, 0, 0.01),
'N(0, 1.0)': lambda w: nn.init.normal_(w, 0, 1.0),
'Xavier': lambda w: nn.init.xavier_normal_(w),
'He (Kaiming)': lambda w: nn.init.kaiming_normal_(w),
}
plt.figure(figsize=(10, 5))
for name, init_fn in init_methods.items():
model = nn.Sequential(
nn.Linear(2, 64), nn.ReLU(),
nn.Linear(64, 64), nn.ReLU(),
nn.Linear(64, 2),
)
# 初始化
for m in model:
if isinstance(m, nn.Linear):
init_fn(m.weight)
nn.init.zeros_(m.bias)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
losses = []
for epoch in range(200):
loss = criterion(model(X_t), y_t)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
plt.plot(losses, label=name, linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('不同初始化方法的训练曲线')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
5.1 如果训练一开始就不对劲,初始化是该怀疑的一项
典型信号包括:
- loss 一开始就非常大
- 很多层输出几乎全 0 或极度饱和
- 梯度很快消失或爆炸
当然,初始化不是唯一原因,但它常常是值得优先排查的一层。
小结
| 初始化 | 原理 | 适用 |
|---|---|---|
| 全零 | 所有神经元相同 | ❌ 永远不要用 |
| 小随机 | 信号衰减 | ❌ 深层网络不适合 |
| 大随机 | 梯度爆炸/饱 和 | ❌ 不适合 |
| Xavier | 保持输入输出方差 | Sigmoid / Tanh |
| He (Kaiming) | ReLU 补偿 | ReLU 系列(最常用) |
这节最该带走什么
- 初始化不是装饰,而是在决定网络一开始能不能健康传播信号
- Xavier 更偏向 Sigmoid / Tanh,He 更偏向 ReLU 系列
- 第一次做项目时先用 PyTorch 默认值完全没问题,但要知道它背后的原理
如果再压成一句话,那就是:
初始化决定的是“训练能不能好好起跑”,而不是模型最后一定能跑多远。
动手练习
练习 1:深层网络对比
创建一个 20 层的 MLP(ReLU 激活),分别用全零、Xavier、He 初始化,观察前向传播后各层激活值的分布(打印均值和标准差)。
练习 2:训练深层 MNIST
用 10 层 MLP 训练 MNIST,对比 He 初始化和默认初始化的训练速度和最终准确率。