Skip to main content

6.5.2 注意力机制

本节定位

RNN 是一步步传信息。注意力让一个 token 直接看其他 token,并判断哪些更重要。这是 Transformer 背后的核心转变。

学习目标

  • 解释为什么注意力有助于长距离依赖。
  • 通过检索类比理解 Query、Key、Value。
  • 手算 scaled dot-product attention。
  • 使用 causal mask 防止偷看未来。
  • 读懂 PyTorch 中 nn.MultiheadAttention 的 shape。

先看 Q/K/V

Self-Attention QKV 结构图

注意力是一种加权检索:

Q 提问 -> K 匹配 -> softmax 变成权重 -> V 提供内容 -> 加权求和

检索类比:

注意力 QKV 图书馆检索类比图

角色直觉在注意力中
Query Q我现在想找什么?当前 token 的问题
Key K每个条目匹配什么?用来打分的索引
Value V应该返回什么内容?实际被混合的信息

一句话:

Q 和 K 打分,然后用得到的权重混合 V。

为什么需要注意力

旧式序列模型中,远处信息要么沿很多个循环步骤传递,要么被压进一个固定向量。注意力缩短了路径:

当前 token -> 直接给所有 token 打分 -> 选择有用上下文

它带来三个实践优势:

  • 直接建立长距离连接;
  • 比一步步 RNN 更容易并行训练;
  • 得到可观察的 token-to-token 混合权重矩阵。

实验 1:手算注意力

为了教学,令 Q = K = V = X

import numpy as np

X = np.array(
[
[1.0, 0.0],
[0.0, 1.0],
[1.0, 1.0],
]
)

Q = K = V = X

scores = Q @ K.T
scaled_scores = scores / np.sqrt(K.shape[1])


def softmax(row):
e = np.exp(row - row.max())
return e / e.sum()


weights = np.apply_along_axis(softmax, 1, scaled_scores)
output = weights @ V

print("attention_lab")
print("scores")
print(np.round(scores, 3))
print("weights")
print(np.round(weights, 3))
print("output")
print(np.round(output, 3))

预期输出:

attention_lab
scores
[[1. 0. 1.]
[0. 1. 1.]
[1. 1. 2.]]
weights
[[0.401 0.198 0.401]
[0.198 0.401 0.401]
[0.248 0.248 0.503]]
output
[[0.802 0.599]
[0.599 0.802]
[0.752 0.752]]

读这三步:

步骤代码含义
打分Q @ K.T每个 token 和每个 token 有多匹配
归一化softmax(...)把分数变成和为 1 的权重
混合weights @ V按权重组合 token 内容

为什么要除以 sqrt(d_k)

Transformer 里的公式是:

Attention(Q, K, V) = softmax(QK^T / sqrt(d_k))V

当向量维度很大时,点积也容易变大。大分数会让 softmax 过于尖锐,某个 token 几乎拿走全部权重。除以 sqrt(d_k) 可以给分数降温,让训练更稳定。

Self-Attention

Self-attention 指 QKV 都来自同一个序列。每个 token 都能看同一个序列里的每个 token。

例子:

"Alex gave Sam the notebook because he trusted him."

要理解 “he” 和 “him”,当前 token 需要看其他 token。Self-attention 给了这种直接路径。

实验 2:Causal Mask

生成任务不能看未来 token。causal mask 只让下三角可见。

Causal Mask 防止偷看未来图

import numpy as np

scores = np.array(
[
[2.0, 1.0, 0.5],
[1.2, 2.1, 0.7],
[0.8, 1.3, 2.2],
]
)

mask = np.tril(np.ones_like(scores))
masked_scores = np.where(mask == 1, scores, -1e9)


def softmax(row):
e = np.exp(row - row.max())
return e / e.sum()


weights = np.apply_along_axis(softmax, 1, masked_scores)

print("mask_lab")
print(np.round(weights, 3))

预期输出:

mask_lab
[[1. 0. 0. ]
[0.289 0.711 0. ]
[0.149 0.246 0.605]]

读法:

  • 位置 1 只能看自己;
  • 位置 2 能看位置 1 和 2;
  • 位置 3 能看位置 1、2、3。

未来答案不可见。

Multi-Head Attention

一个 attention head 可能只学到一种关系。multi-head attention 让模型并行查看多个关系空间。

不同 head 可能关注:

  • 附近位置模式;
  • 主语 / 宾语关系;
  • 重复词;
  • 长距离引用。

多个 head 的结果会拼接,再投影回一个表示。

实验 3:PyTorch MultiheadAttention

import torch
from torch import nn

torch.manual_seed(42)

attention = nn.MultiheadAttention(embed_dim=8, num_heads=2, batch_first=True)
tokens = torch.randn(1, 4, 8)
output, weights = attention(tokens, tokens, tokens)

print("mha_lab")
print("tokens:", tuple(tokens.shape))
print("output:", tuple(output.shape))
print("weights:", tuple(weights.shape))
print("row0_sum:", round(float(weights[0, 0].sum().detach()), 4))

预期输出:

mha_lab
tokens: (1, 4, 8)
output: (1, 4, 8)
weights: (1, 4, 4)
row0_sum: 1.0

shape 读法:

TensorShape含义
tokens[1, 4, 8]batch 1,4 个 token,embedding size 8
output[1, 4, 8]每个 token 得到新的上下文表示
weights[1, 4, 4]每个 query token 对 4 个 key token 分配权重

Attention 权重不是完整解释

Attention 权重很有用,但不要过度解读。

它能说明:

在这一层 / 这个 head 中,这个 query 从那些 key 位置混合了更多 value

它不能自动证明:

模型最终决策就是因为那个 token

把 attention 权重当作调试和观察工具,而不是完整因果解释。

常见错误

错误修复
把 Q/K/V 当神秘变量读成 问题 / 索引 / 内容
忘记 shape 含义追踪 [batch, seq_len, embed_dim] 和 attention [batch, query, key]
生成任务不用 mask用 causal mask 隐藏未来 token
在错误维度上 softmax应该在 key 位置上归一化
把 attention 当推理魔法记住:打分 -> softmax -> 加权求和

练习

  1. 把实验 1 的第三个 token 改成 [2.0, 0.0],weights 怎么变?
  2. 把 mask 实验扩展成 4 x 4 矩阵。
  3. 把实验 3 的 num_heads2 改成 1,哪些 shape 不变?
  4. 解释为什么 attention 比普通 RNN 更容易建模远距离 token 交互。
  5. 描述一个 attention 权重有帮助但不是完整解释的场景。

小结

  • Attention 让 token 直接选择相关上下文。
  • Q/K/V 把打分和内容检索分开。
  • Scaled dot-product attention 是打分、softmax、加权求和。
  • Causal mask 防止生成任务偷看未来。
  • Multi-head attention 从多个子空间查看关系。