本文基于一个最小 PyTorch 示例,手写实现 FlashAttention
的核心计算流程,并详细解释其数值稳定性和分块计算原理。
1. 标准 Attention 回顾
标准 Attention 的计算公式:
Attention(Q,K,V)=softmax(QKT)V Attention(Q,K,V) = softmax(QK^T)VAttention(Q,K,V)=softmax(QKT)V
importtorch query=torch.randn(1,12,10)key=torch.randn(1,12,10)value=torch.randn(1,12,10)logits=torch.einsum('bqd,bkd->bqk',query,key)probs=torch.nn.functional.softmax(logits,dim=-1)softmax_output=torch.einsum('bqk,bkd->bqd',probs,value)2. FlashAttention 核心思想
FlashAttention 的核心目标:
避免显式存储整个 attention matrix(QK^T)
关键手段:
- 分块计算(block-wise)
- 在线 Softmax(online softmax)
3. 数值稳定 Softmax
softmax(xj)=exj−m∑kexk−m,m=max(x) softmax(x_j) = \frac{e^{x_j - m}}{\sum_k e^{x_k - m}}, \quad m = max(x)softmax(xj)=∑kexk−mexj−m,m=max(x)
4. 核心递推
mi=max(mi−1,mij) m_i = max(m_{i-1}, m_{ij})mi=max(mi−1,mij)
li=li−1emi−1−mi+∑exij−mi l_i = l_{i-1} e^{m_{i-1} - m_i} + \sum e^{x_{ij} - m_i}li=li−1emi−1−mi+∑exij−mi
oi=oi−1emi−1−mi+∑(exij−miVj) o_i = o_{i-1} e^{m_{i-1} - m_i} + \sum (e^{x_{ij} - m_i} V_j)oi=oi−1emi−1−mi+∑(exij−miVj)
🔍 关键细节深入理解
很多人在理解这里时容易卡住:为什么需要对历史的oi−1o_{i-1}oi−1做
rescale?
我们一步一步拆解:
1️⃣oi−1o_{i-1}oi−1并不是"最终正确的值"
在第i−1i-1i−1次循环时:
- 我们用的是局部最大值mi−1m_{i-1}mi−1
- 所以 softmax 实际是:
exi−1∑exi−1=exi−1−mi−1∑exi−1−mi−1 \frac{e^{x_{i-1}}}{\sum e^{x_{i-1}}} = \frac{e^{x_{i-1} - m_{i-1}}}{\sum e^{x_{i-1} - m_{i-1}}}∑exi−1exi−1=∑exi−1−mi−1exi−1−mi−1
👉 注意:这里的归一化是基于局部 block 的尺度
2️⃣ 当进入第iii个 block 时发生了什么?
我们得到了新的最大值:
mi=max(mi−1,mij) m_i = max(m_{i-1}, m_{ij})mi=max(mi−1,mij)
👉 这个mim_imi更接近全局最大值
3️⃣ 问题的本质
此时出现一个不一致:
项目 使用的 max
oi−1o_{i-1}oi−1mi−1m_{i-1}mi−1
当前 blockmim_imi
👉 如果直接相加,会导致:
不同尺度的指数项被混合(数值错误)
4️⃣ 解决方法:统一尺度(rescale)
我们需要把旧的oi−1o_{i-1}oi−1从:
ex−mi−1 e^{x - m_{i-1}}ex−mi−1
转换到:
ex−mi e^{x - m_i}ex−mi
变换方式:
ex−mi−1=ex−mi⋅emi−mi−1 e^{x - m_{i-1}} = e^{x - m_i} \cdot e^{m_i - m_{i-1}}ex−mi−1=ex−mi⋅emi−mi−1
👉 因此:
oi−1→oi−1⋅emi−1−mi o_{i-1} \rightarrow o_{i-1} \cdot e^{m_{i-1} - m_i}oi−1→oi−1⋅emi−1−mi
5️⃣ 对应代码
o_i = o_i_1 * torch.exp(m_i_1 - m_i)[…, None] + torch.einsum(‘bqk,bkd->bqd’, exp_term, v_i)
含义是:
- 第一项:旧结果 rescale 到新尺度
- 第二项:当前 block 的贡献
6️⃣ 一个直观理解
可以把整个过程理解为:
我们在不断"修正历史",让所有累积值都统一到"当前最稳定的坐标系(最大值)"下
随着循环进行:
- mim_imi会逐步逼近全局最大值
- 所有历史贡献都会被重新缩放到这个统一尺度
7️⃣ 最终结果
当所有 block 处理完:
- mim_imi= 全局最大值
- oi/lio_i / l_ioi/li= 完整 softmax 结果
5. PyTorch实现
flash_softmax_outputs=[]q_chunks=4q_chunk_size=query.shape[1]//q_chunks k_chunks=3k_chunk_size=key.shape[1]//k_chunksforiinrange(q_chunks):q_i=query[:,i*q_chunk_size:(i+1)*q_chunk_size]m_i_1=torch.full((q_i.shape[0],q_i.shape[1]),-float('inf'))l_i_1=torch.zeros_like(m_i_1)o_i_1=torch.zeros((q_i.shape[0],q_i.shape[1],value.shape[-1]))forjinrange(k_chunks):k_i=key[:,j*k_chunk_size:(j+1)*k_chunk_size]# (B, K_block, D)v_i=value[:,j*k_chunk_size:(j+1)*k_chunk_size]# (B, K_block, Dv)logits_i=torch.einsum('nqd,nkd->nqk',q_i,k_i)# (B, Q_block, K_block)# ---- 更新 m ----m_ij=torch.max(logits_i,dim=-1)[0]# (B, Q_block)m_i=torch.maximum(m_i_1,m_ij)# 计算Softmax分子e^(x_i - m_i)exp_term=torch.exp(logits_i-m_i[...,None])# (B, Q_block, K_block)# 更新Softmax分母# rescale * 旧的softmax分母 + 新的softmax分母l_i=l_i_1*torch.exp(m_i_1-m_i)+exp_term.sum(dim=-1)# ---- 更新 O(关键!)----# rescale * 旧的logit * v + 新的logit * vo_i=o_i_1*torch.exp(m_i_1-m_i)[...,None]+torch.einsum('nqk,nkd->nqd',exp_term,v_i)# ---- 状态更新 ----m_i_1=m_i l_i_1=l_i o_i_1=o_i# ---- 最后除以Softmax分母----output=o_i/l_i[...,None]flash_softmax_outputs.append(output)flash_softmax_outputs=torch.cat(flash_softmax_outputs,dim=1)6. 正确性验证
torch.allclose(softmax_output,flash_softmax_outputs)7. 总结
FlashAttention 本质:
- 分块计算
- 在线 softmax
- 动态重标定(rescale)
复杂度从 O(N^2) 降到 O(N)