news 2026/4/30 14:53:29

从零手写 FlashAttention(PyTorch实现 + 原理推导)

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
从零手写 FlashAttention(PyTorch实现 + 原理推导)

本文基于一个最小 PyTorch 示例,手写实现 FlashAttention
的核心计算流程,并详细解释其数值稳定性和分块计算原理。


1. 标准 Attention 回顾

标准 Attention 的计算公式:

Attention(Q,K,V)=softmax(QKT)V Attention(Q,K,V) = softmax(QK^T)VAttention(Q,K,V)=softmax(QKT)V

importtorch query=torch.randn(1,12,10)key=torch.randn(1,12,10)value=torch.randn(1,12,10)logits=torch.einsum('bqd,bkd->bqk',query,key)probs=torch.nn.functional.softmax(logits,dim=-1)softmax_output=torch.einsum('bqk,bkd->bqd',probs,value)

2. FlashAttention 核心思想

FlashAttention 的核心目标:

避免显式存储整个 attention matrix(QK^T)

关键手段:

  • 分块计算(block-wise)
  • 在线 Softmax(online softmax)

3. 数值稳定 Softmax

softmax(xj)=exj−m∑kexk−m,m=max(x) softmax(x_j) = \frac{e^{x_j - m}}{\sum_k e^{x_k - m}}, \quad m = max(x)softmax(xj)=kexkmexjm,m=max(x)


4. 核心递推

mi=max(mi−1,mij) m_i = max(m_{i-1}, m_{ij})mi=max(mi1,mij)

li=li−1emi−1−mi+∑exij−mi l_i = l_{i-1} e^{m_{i-1} - m_i} + \sum e^{x_{ij} - m_i}li=li1emi1mi+exijmi

oi=oi−1emi−1−mi+∑(exij−miVj) o_i = o_{i-1} e^{m_{i-1} - m_i} + \sum (e^{x_{ij} - m_i} V_j)oi=oi1emi1mi+(exijmiVj)


🔍 关键细节深入理解

很多人在理解这里时容易卡住:为什么需要对历史的oi−1o_{i-1}oi1
rescale?

我们一步一步拆解:

1️⃣oi−1o_{i-1}oi1并不是"最终正确的值"

在第i−1i-1i1次循环时:

  • 我们用的是局部最大值mi−1m_{i-1}mi1
  • 所以 softmax 实际是:

exi−1∑exi−1=exi−1−mi−1∑exi−1−mi−1 \frac{e^{x_{i-1}}}{\sum e^{x_{i-1}}} = \frac{e^{x_{i-1} - m_{i-1}}}{\sum e^{x_{i-1} - m_{i-1}}}exi1exi1=exi1mi1exi1mi1

👉 注意:这里的归一化是基于局部 block 的尺度


2️⃣ 当进入第iii个 block 时发生了什么?

我们得到了新的最大值:

mi=max(mi−1,mij) m_i = max(m_{i-1}, m_{ij})mi=max(mi1,mij)

👉 这个mim_imi更接近全局最大值


3️⃣ 问题的本质

此时出现一个不一致:

项目 使用的 max


oi−1o_{i-1}oi1mi−1m_{i-1}mi1
当前 blockmim_imi

👉 如果直接相加,会导致:

不同尺度的指数项被混合(数值错误)


4️⃣ 解决方法:统一尺度(rescale)

我们需要把旧的oi−1o_{i-1}oi1从:

ex−mi−1 e^{x - m_{i-1}}exmi1

转换到:

ex−mi e^{x - m_i}exmi

变换方式:

ex−mi−1=ex−mi⋅emi−mi−1 e^{x - m_{i-1}} = e^{x - m_i} \cdot e^{m_i - m_{i-1}}exmi1=exmiemimi1

👉 因此:

oi−1→oi−1⋅emi−1−mi o_{i-1} \rightarrow o_{i-1} \cdot e^{m_{i-1} - m_i}oi1oi1emi1mi


5️⃣ 对应代码

o_i = o_i_1 * torch.exp(m_i_1 - m_i)[…, None] + torch.einsum(‘bqk,bkd->bqd’, exp_term, v_i)

含义是:

  • 第一项:旧结果 rescale 到新尺度
  • 第二项:当前 block 的贡献

6️⃣ 一个直观理解

可以把整个过程理解为:

我们在不断"修正历史",让所有累积值都统一到"当前最稳定的坐标系(最大值)"下

随着循环进行:

  • mim_imi会逐步逼近全局最大值
  • 所有历史贡献都会被重新缩放到这个统一尺度

7️⃣ 最终结果

当所有 block 处理完:

  • mim_imi= 全局最大值
  • oi/lio_i / l_ioi/li= 完整 softmax 结果

5. PyTorch实现

flash_softmax_outputs=[]q_chunks=4q_chunk_size=query.shape[1]//q_chunks k_chunks=3k_chunk_size=key.shape[1]//k_chunksforiinrange(q_chunks):q_i=query[:,i*q_chunk_size:(i+1)*q_chunk_size]m_i_1=torch.full((q_i.shape[0],q_i.shape[1]),-float('inf'))l_i_1=torch.zeros_like(m_i_1)o_i_1=torch.zeros((q_i.shape[0],q_i.shape[1],value.shape[-1]))forjinrange(k_chunks):k_i=key[:,j*k_chunk_size:(j+1)*k_chunk_size]# (B, K_block, D)v_i=value[:,j*k_chunk_size:(j+1)*k_chunk_size]# (B, K_block, Dv)logits_i=torch.einsum('nqd,nkd->nqk',q_i,k_i)# (B, Q_block, K_block)# ---- 更新 m ----m_ij=torch.max(logits_i,dim=-1)[0]# (B, Q_block)m_i=torch.maximum(m_i_1,m_ij)# 计算Softmax分子e^(x_i - m_i)exp_term=torch.exp(logits_i-m_i[...,None])# (B, Q_block, K_block)# 更新Softmax分母# rescale * 旧的softmax分母 + 新的softmax分母l_i=l_i_1*torch.exp(m_i_1-m_i)+exp_term.sum(dim=-1)# ---- 更新 O(关键!)----# rescale * 旧的logit * v + 新的logit * vo_i=o_i_1*torch.exp(m_i_1-m_i)[...,None]+torch.einsum('nqk,nkd->nqd',exp_term,v_i)# ---- 状态更新 ----m_i_1=m_i l_i_1=l_i o_i_1=o_i# ---- 最后除以Softmax分母----output=o_i/l_i[...,None]flash_softmax_outputs.append(output)flash_softmax_outputs=torch.cat(flash_softmax_outputs,dim=1)

6. 正确性验证

torch.allclose(softmax_output,flash_softmax_outputs)

7. 总结

FlashAttention 本质:

  • 分块计算
  • 在线 softmax
  • 动态重标定(rescale)

复杂度从 O(N^2) 降到 O(N)

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/30 14:49:59

Linux进程资源泄漏自动清理:agent-reaper守护进程的设计与实践

1. 项目概述:一个守护进程的“清道夫”在开发和运维的日常里,我们经常会遇到一种让人头疼的情况:某个后台进程(Agent)因为各种原因卡死、僵死或者异常退出,但它留下的“烂摊子”却还在系统里。这些“烂摊子…

作者头像 李华
网站建设 2026/4/30 14:49:55

如何快速配置键盘映射:终极游戏操作优化指南

如何快速配置键盘映射:终极游戏操作优化指南 【免费下载链接】socd Key remapper for epic gamers 项目地址: https://gitcode.com/gh_mirrors/so/socd Hitboxer是一款专为游戏玩家设计的键盘重映射工具,它能够智能解决游戏中同时按下相反方向键时…

作者头像 李华
网站建设 2026/4/30 14:48:26

SQL必会的常用函数(三)文本函数

SQL文本函数详解一、基础查询函数1. LENGTH / LEN - 获取字符串长度-- MySQL SELECT LENGTH(Hello World); -- 返回 11-- SQL Server SELECT LEN(Hello World); -- 返回 112. CONCAT - 字符串拼接-- 标准语法(所有数据库通用) SELECT CONCAT(Hello, …

作者头像 李华
网站建设 2026/4/30 14:47:50

用particles.js创造动态粒子效果的3种实用场景指南

用particles.js创造动态粒子效果的3种实用场景指南 【免费下载链接】particles.js A lightweight JavaScript library for creating particles 项目地址: https://gitcode.com/gh_mirrors/pa/particles.js 你是否曾惊叹于那些科技感十足的网站背景,无数光点如…

作者头像 李华
网站建设 2026/4/30 14:44:43

告别臃肿模拟器:如何在Windows上轻松安装APK文件

告别臃肿模拟器:如何在Windows上轻松安装APK文件 【免费下载链接】APK-Installer An Android Application Installer for Windows 项目地址: https://gitcode.com/GitHub_Trending/ap/APK-Installer 你是否曾经想要在Windows电脑上运行安卓应用,却…

作者头像 李华