自动求导
学习目标
- 理解梯度到底是什么
- 掌握
requires_grad=True的作用 - 明白
loss.backward()做了什么 - 理解梯度累计、清零和
torch.no_grad()
一、为什么要有自动求导?
训练模型的核心目标只有一句话:
让模型参数朝“损失更小”的方向移动。
问题是:怎么知道该往哪个方向移动?
答案就是梯度(gradient)。
你可以把梯度想成“山坡的坡度”:
- 梯度大,说明这里很陡
- 梯度方向告诉你损失增长最快的方向
- 我们想要让损失下降,所以要沿着负梯度方向更新参数
如果每次都手工推导梯度,会非常痛苦。
PyTorch 的 autograd 就像一个自动记账员:
- 你只管写“怎么算出 loss”
- 它会帮你把梯度链路记录下来
- 你调用
backward(),它就自动把梯度算出来
二、一个最小例子
import torch
# 一个需要学习的参数
w = torch.tensor(2.0, requires_grad=True)
# 定义一个简单函数:loss = (w * 3 - 10)^2
loss = (w * 3 - 10) ** 2
print("loss:", loss.item())
# 自动求导
loss.backward()
print("w 的梯度:", w.grad.item())
这里发生了什么?
PyTorch 记录了这条计算链:
w -> w*3 -> w*3-10 -> (w*3-10)^2
当你执行:
loss.backward()
它会沿着这条链,按链式法则把梯度一路传回来,最后得到:
w.grad
这就是“当前 w 再往前走一点点,loss 会怎么变”的信息。
三、从梯度到参数更新
有了梯度,就能做一次最简单的梯度下降:
import torch
w = torch.tensor(2.0, requires_grad=True)
lr = 0.1
for step in range(5):
loss = (w * 3 - 10) ** 2
loss.backward()
with torch.no_grad():
w -= lr * w.grad
print(f"step={step}, w={w.item():.4f}, loss={loss.item():.4f}")
w.grad.zero_()
这一段每步在干什么?
| 代码 | 作用 |
|---|---|
loss = ... | 计算当前损失 |
loss.backward() | 求当前损失对 w 的梯度 |
w -= lr * w.grad | 用梯度更新参数 |
w.grad.zero_() | 把旧梯度清掉,准备下轮计算 |
四、为什么要清零梯度?
这是 PyTorch 初学者最容易踩坑的点之一。
PyTorch 默认会累计梯度,而不是自动覆盖。
看下面的例子:
import torch
x = torch.tensor(3.0, requires_grad=True)
y1 = x ** 2
y1.backward()
print("第一次 backward 后的梯度:", x.grad.item())
y2 = 2 * x
y2.backward()
print("第二次 backward 后的梯度:", x.grad.item())
你会发现第二次梯度不是新的结果,而是“第一次 + 第二次”的和。
这就是为什么训练循环里通常都会写:
optimizer.zero_grad()
或者 :
tensor.grad.zero_()
五、requires_grad=True 到底控制了什么?
只有被标记为 requires_grad=True 的张量,PyTorch 才会为它追踪梯度。
import torch
a = torch.tensor(2.0, requires_grad=True)
b = torch.tensor(3.0, requires_grad=False)
y = a * b + 1
y.backward()
print("a.grad:", a.grad.item())
print("b.grad:", b.grad)
输出里你会看到:
a.grad有值b.grad是None
这很符合直觉:
如果某个值不是“需要学习的参数”,就没必要对它求梯度。
六、torch.no_grad() 是干什么的?
训练时我们要记录梯度。
但推理、评估、参数手动更新时,我们往往不需要梯度。
这时就可以用:
with torch.no_grad():
...
它的作用是:
- 关闭梯度追踪
- 节省内存
- 加快推理
import torch
w = torch.tensor(5.0, requires_grad=True)
with torch.no_grad():
y = w * 2
print("y.requires_grad:", y.requires_grad)
七、把它放回“模型训练”的语境里
真实训练时,我们通常不是只更新一个数字 w,而是更新一整组参数。
比如一个线性模型:
y = wx + b
这里的 w 和 b 都是参数,都要学习。
训练时发生的事其实还是一样:
- 用当前参数做预测
- 计算预测和真实值之间的损失
- 自动求出每个参数的梯度
- 用优化器按梯度方向更新参数
所以自动求导不是“额外功能”,而是深度学习训练的发动机。