一、小白先懂:自注意力是怎么“打分”的?
自注意力的核心,是给每个词(Token)计算和其他词的匹配度分数,步骤就3步:
- 生成3个向量:给每个词生成 Query(查询向量,好比“我要找什么”)、Key(键向量,好比“我有什么”)、Value(值向量,好比“我要输出的内容”)。
- 算匹配分数:用 Query 和 Key 的点积,计算两个词的相似度,公式就是:
分数=Q×KT\text{分数} = Q \times K^T分数=Q×KT
(KTK^TKT是 Key 的转置,简单理解就是为了让矩阵能相乘) - 分数归一化:用 softmax 把分数变成 0~1 之间的权重,权重越高,说明两个词关联越强,最后用权重乘以 Value 得到结果。
二、问题来了:直接算分数会“出bug”
如果不除以dk\sqrt{d_k}dk,会遇到两个大麻烦,我们用比喻讲明白:
1. 第一个麻烦:dkd_kdk越大,分数值越“离谱”
dkd_kdk是 Query 和 Key 的维度(比如 BERT 里dk=64d_k=64dk=64)。
我们可以把 Q 和 K 的每个维度值,想象成独立的随机数(均值0,方差1)。
两个向量的点积,就是把对应维度的值相乘再相加:
Q⋅K=q1k1+q2k2+...+qdkkdkQ \cdot K = q_1k_1 + q_2k_2 + ... + q_{d_k}k_{d_k}Q⋅K=q1k1+q2k2+...+qdkkdk
根据统计学知识:独立随机变量相加,方差会累加。
- 单个qikiq_ik_iqiki的方差是1×1=11 \times 1 = 11×1=1
- dkd_kdk个加起来,总方差就是dkd_kdk,标准差就是dk\sqrt{d_k}dk
举个例子:
- 当dk=64d_k=64dk=64时,点积的标准差是 8 → 分数值可能会跑到±几十
- 当dk=1024d_k=1024dk=1024时,标准差是 32 → 分数值可能会跑到±几百
2. 第二个麻烦:分数太大会让 softmax “罢工”
softmax 函数的特点是:输入值越大,输出越极端。
比如有两个分数:10和5,经过 softmax 后:
softmax([10,5])=[e10e10+e5,e5e10+e5]≈[0.993,0.007]softmax([10,5]) = [\frac{e^{10}}{e^{10}+e^5}, \frac{e^5}{e^{10}+e^5}] ≈ [0.993, 0.007]softmax([10,5])=[e10+e5e10,e10+e5e5]≈[0.993,0.007]
几乎所有权重都集中在最大的分数上,其他分数的权重接近 0。
更要命的是:极端的输入会导致 softmax 的梯度消失。
softmax 的梯度和输入值的大小成反比,输入值越大,梯度越接近 0 → 模型训练时学不到东西,相当于“罢工”了。
三、除以dk\sqrt{d_k}dk的“魔法”:给分数“降温”
这个操作的核心目的,就是把点积的方差归一化到 1,让分数值的范围变得合理。
因为点积的方差是dkd_kdk,标准差是dk\sqrt{d_k}dk,所以除以dk\sqrt{d_k}dk后:
归一化分数=Q×KTdk\text{归一化分数} = \frac{Q \times K^T}{\sqrt{d_k}}归一化分数=dkQ×KT
- 新的方差 =dkdk=1\frac{d_k}{d_k} = 1dkdk=1
- 新的标准差 = 1 → 分数值会稳定在±几的范围
还是上面的例子,分数10和5除以 8(dk=64d_k=64dk=64)后变成1.25和0.625,再经过 softmax:
softmax([1.25,0.625])≈[0.65,0.35]softmax([1.25, 0.625]) ≈ [0.65, 0.35]softmax([1.25,0.625])≈[0.65,0.35]
权重分布更合理,既突出了重要的词,又不会完全忽略其他词,同时梯度也能正常传递,模型就能顺利学习了。
四、总结
小白一句话总结
除以dk\sqrt{d_k}dk是为了防止点积分数太大,避免 softmax 输出极端值导致梯度消失,让模型能正常训练。
技术一句话总结
自注意力中Q⋅KTQ \cdot K^TQ⋅KT的方差与dkd_kdk成正比,除以dk\sqrt{d_k}dk可将方差归一化到 1,保证 softmax 输出的权重分布合理且梯度稳定。
完整自注意力公式
Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q,K,V) = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right) VAttention(Q,K,V)=softmax(dkQKT)V