news 2026/6/26 21:16:57

《Nano-vLLM 源码解读》第 22 篇 · 张量并行(二)代码实现

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
《Nano-vLLM 源码解读》第 22 篇 · 张量并行(二)代码实现

nano-vllm 用千行代码拆解 vLLM 核心,是读懂大模型推理最快的捷径。

1. 介绍

上一篇讲清了张量并行的数学:一个线性层只有列切、行切两种拆法,行切之后要all_reduce求和,attention 按 head 切,RMSNorm 复制不切。

本篇实现张量并行。

真实代码靠多进程跑,每个进程占用一张卡。为了单机单进程也能把 tp=2 的切分跑出来,本篇把卡数tp_size、卡号tp_rank当构造参数显式传进去。真实代码从dist.get_world_size()dist.get_rank()取。

importtorchfromtorchimportnnimporttorch.nn.functionalasFdefdivide(numerator,denominator):assertnumerator%denominator==0returnnumerator//denominator

2. 总览

把模型切到多卡,代码上就是改Linear的 weight_loader 和 forward。

L16 里的LinearBase只管一张权重表加一个加载钩子;TP 版多存了三个数,整套切分都靠它们驱动:

  • tp_size:总卡数(真实代码dist.get_world_size())。
  • tp_rank:本卡编号(真实代码dist.get_rank())。
  • tp_dim:这一层沿权重的哪一维切——列切是 0、行切是 1、不切是None

为什么存成「转置」的[out, in]Linear的前向是x @ weight.T:weight 每一行就是「一个输出单元」的全部输入权重,按输出排。于是列切(按输出维切)落在第 0 维、切的是整行;行切(按输入维切)落在第 1 维、切的是列。上一篇按数学习惯把权重写成[in, out](输出在列),同一刀到代码里就从切列转成了切行——名字没变,维度转了 90°。

classLinearBase(nn.Module):def__init__(self,input_size,output_size,tp_size,tp_rank,bias=False,tp_dim=None):super().__init__()self.tp_dim=tp_dim self.tp_size=tp_size# 真实代码:dist.get_world_size()self.tp_rank=tp_rank# 真实代码:dist.get_rank()# 权重已是本卡那一块的形状(子类把 output/input 缩好后传进来)self.weight=nn.Parameter(torch.empty(output_size,input_size))self.weight.weight_loader=self.weight_loaderifbias:self.bias=nn.Parameter(torch.empty(output_size))self.bias.weight_loader=self.weight_loaderelse:self.register_parameter("bias",None)defforward(self,x):raiseNotImplementedError

3. ColumnParallelLinear:列切

列切按输出维度切——把权重的行(输出那一维)分给各卡,每卡算输出的一段,输入完整。

__init__:把output_size除以卡数、tp_dim设成 0。权重直接建成[out/tp, in]

weight_loader把磁盘整份切出本卡那块,三步(见上图):

  1. shard_size = param.size(0):本卡要几行——就是已缩好的out/tp(图里4/2 = 2)。
  2. start = tp_rank * shard_size:本卡从第几行起。rank0 从0、rank1 从2
  3. loaded_weight.narrow(0, start, shard_size):从第start行起、取shard_size行,copy_进 param。

两张卡读的是同一份磁盘权重,只是start不同,各取互不重叠的一段行。

forward:一句F.linear搞定。输入完整,输出是[N, out/tp]的一段——输出分片、matmul 阶段零通信。完整输出散在各卡上。

classColumnParallelLinear(LinearBase):def__init__(self,input_size,output_size,tp_size,tp_rank,bias=False):super().__init__(input_size,divide(output_size,tp_size),tp_size,tp_rank,bias,tp_dim=0)defweight_loader(self,param,loaded_weight):shard_size=param.size(self.tp_dim)# out/tpstart=self.tp_rank*shard_size loaded_weight=loaded_weight.narrow(self.tp_dim,start,shard_size)param.data.copy_(loaded_weight)defforward(self,x):returnF.linear(x,self.weight,self.bias)# 磁盘整份 weight[out=4, in=3],第 r 行全填常数 r,便于辨认是哪一行W_full=torch.arange(4.).reshape(4,1).repeat(1,3)print("磁盘整份 W(行号):",W_full[:,0].tolist())# [0., 1., 2., 3.]# tp=2:每卡 weight 已缩成 [out/tp, in] = [2, 3]c0=ColumnParallelLinear(3,4,tp_size=2,tp_rank=0)c1=ColumnParallelLinear(3,4,tp_size=2,tp_rank=1)print("每卡 weight 形状:",tuple(c0.weight.shape)," tp_dim =",c0.tp_dim)c0.weight_loader(c0.weight,W_full)c1.weight_loader(c1.weight,W_full)print("rank0 拿到行:",c0.weight[:,0].tolist())# [0., 1.]print("rank1 拿到行:",c1.weight[:,0].tolist())# [2., 3.]x=torch.randn(5,3)# 输入完整print("rank0 输出分片:",tuple(c0(x).shape))# (5, 2)
磁盘整份 W(行号): [0.0, 1.0, 2.0, 3.0] 每卡 weight 形状: (2, 3) tp_dim = 0 rank0 拿到行: [0.0, 1.0] rank1 拿到行: [2.0, 3.0] rank0 输出分片: (5, 2)

4. RowParallelLinear:行切

行切按输入维度切——把权重的列(输入那一维)分给各卡,输入也跟着切开,每卡算一个部分和。比列切多两件事:forward末尾要通信,bias 只在一张卡上加。

__init__:把input_size除以卡数、tp_dim设成 1。权重建成[out, in/tp]

weight_loader:二维权重和列切同样三步,只是改切dim 1——shard_size = in/tpstart = tp_rank * shard_size(rank0 从 0、rank1 从 2)、narrow(1, start, shard_size)取本卡那in/tp列。多一个特例:bias 是一维(长度out、输出维没切),两卡各整份拷。

forwardF.linear各卡算出一个部分和,all_reduce跨卡求和才是完整输出。bias 这里只让 rank0 加——它每卡都存了完整一份,若每卡都加,all_reduce求和后会被加上tp次;只 rank0 加,求和后恰好算一次。

classRowParallelLinear(LinearBase):def__init__(self,input_size,output_size,tp_size,tp_rank,bias=False):super().__init__(divide(input_size,tp_size),output_size,tp_size,tp_rank,bias,tp_dim=1)defweight_loader(self,param,loaded_weight):ifparam.data.ndim==1:# bias:输出维没切,整份拷param.data.copy_(loaded_weight)returnshard_size=param.size(self.tp_dim)# in/tpstart=self.tp_rank*shard_size loaded_weight=loaded_weight.narrow(self.tp_dim,start,shard_size)param.data.copy_(loaded_weight)defforward(self,x):y=F.linear(x,self.weight,self.biasifself.tp_rank==0elseNone)# 真实代码此处 if self.tp_size > 1: dist.all_reduce(y)# 简单起见,手动计算 y0+y1 模拟 all_reducereturny# 磁盘整份:weight[out=2, in=4]、bias[2]W_full=torch.tensor([[1.,2,3,4],[5.,6,7,8]])b_full=torch.tensor([10.,20.])r0=RowParallelLinear(4,2,tp_size=2,tp_rank=0,bias=True)r1=RowParallelLinear(4,2,tp_size=2,tp_rank=1,bias=True)forrin(r0,r1):r.weight_loader(r.weight,W_full)# 沿 dim1 切列r.weight_loader(r.bias,b_full)# 一维:整份拷print("每卡 weight 形状:",tuple(r0.weight.shape))# (2, 2) = [out, in/tp]print("rank0 列段:",r0.weight.tolist())# 前 2 列print("rank1 列段:",r1.weight.tolist())# 后 2 列print("bias 两卡各一整份:",r0.bias.tolist(),r1.bias.tolist())# forward:输入 x 也按输入维切两段x=torch.tensor([[1.,1,1,1]])y0=r0(x[:,:2])# rank0:含 biasy1=r1(x[:,2:])# rank1:不含 biasprint("rank0 部分和(含bias):",y0.tolist())# [[13., 31.]]print("rank1 部分和(无bias):",y1.tolist())# [[7., 15.]]print("all_reduce 求和:",(y0+y1).tolist())# [[20., 46.]]print("单卡对照:",F.linear(x,W_full,b_full).tolist())# [[20., 46.]]
每卡 weight 形状: (2, 2) rank0 列段: [[1.0, 2.0], [5.0, 6.0]] rank1 列段: [[3.0, 4.0], [7.0, 8.0]] bias 两卡各一整份: [10.0, 20.0] [10.0, 20.0] rank0 部分和(含bias): [[13.0, 31.0]] rank1 部分和(无bias): [[7.0, 15.0]] all_reduce 求和: [[20.0, 46.0]] 单卡对照: [[20.0, 46.0]]

5. 合并类的双重切

qkv_projgate_up_proj是合并投影:把 q/k/v(或 gate/up)几路拼成一张大权重表、一次矩阵乘算完,省去分开调用的开销。它们都继承ColumnParallelLinear,只重写weight_loader

列切、行切每装一份权重只切一刀;合并类要切两刀,因为它需要处理「合并」和「并行」两件事。磁盘上 q/k/v 分三份存,内存里却挤进同一张表(合并);这张表又按输出维列切到各卡(并行)。二者各要一刀,彼此正交,叠起来就是双重切:

  • 第一刀,切哪种投影:磁盘上 q/k/v 分开存,每装一路得先定位它落在合并表的哪一段。合并表已按卡数缩成[各路之和 / tp, in],所以段偏移、段长也都跟着//tp。这刀单卡就有——chunk那刀此时是空操作。
  • 第二刀,切哪张卡:合并表按输出维列切到各卡,本卡只要每一路里属于自己的那部分。磁盘那一路是整份,chunk(tp)切成tp块、取第tp_rank块——本卡那一条。

落到qkv_proj__init__先把头数按卡数分:num_headsnum_kv_heads各除以tpweight_loader按 q/k/v 三段定位偏移,再chunk取本卡那一条。

段偏移切的是「q 还是 k 还是 v」,chunk切的是 head——两刀各管一个维度、互不干扰。所以 tp=2 时,rank0 拿到 q 头的前一半 + kv 头的前一半、rank1 拿后一半。

classMergedColumnParallelLinear(ColumnParallelLinear):# 合并投影(如 gate_up = gate + up):几路拼成一张列切表。# output_sizes = 各路输出宽度(如 [gate宽, up宽]),加载时按它定位每路的段。def__init__(self,input_size,output_sizes,tp_size,tp_rank,bias=False):self.output_sizes=output_sizes# 合并表总宽 = 各路之和;父类(列切)再把总宽 //tp 缩成本卡那块super().__init__(input_size,sum(output_sizes),tp_size,tp_rank,bias)# 一次只装一路。loaded_shard_id = 这路的序号(0=gate、1=up)defweight_loader(self,param,loaded_weight,loaded_shard_id):# 第一刀(合并):定位这路在本卡合并表里的段。# 偏移 = 前面各路宽度之和,段长 = 本路宽度,都 //tp(表已按卡缩小)。shard_offset=sum(self.output_sizes[:loaded_shard_id])//self.tp_size shard_size=self.output_sizes[loaded_shard_id]//self.tp_size# param_data 是 param 那一段的视图,改它就是改 paramparam_data=param.data.narrow(self.tp_dim,shard_offset,shard_size)# 第二刀(并行):磁盘这路是整份,沿输出维切 tp 块、取本卡那块loaded_weight=loaded_weight.chunk(self.tp_size,self.tp_dim)[self.tp_rank]param_data.copy_(loaded_weight)# 本卡块写进那一段classQKVParallelLinear(ColumnParallelLinear):# 合并 q/k/v 三路。GQA 下 kv 头比 q 头少、三路宽度不等,所以单列一类。def__init__(self,hidden_size,head_size,total_num_heads,total_num_kv_heads,tp_size,tp_rank,bias=False):self.head_size=head_size self.num_heads=divide(total_num_heads,tp_size)# 每卡 q 头self.num_kv_heads=divide(total_num_kv_heads,tp_size)# 每卡 kv 头# 合并表总宽 = q + k + v = (q头 + 2×kv头) × head_sizeoutput_size=(total_num_heads+2*total_num_kv_heads)*head_sizesuper().__init__(hidden_size,output_size,tp_size,tp_rank,bias)# loaded_shard_id = "q"/"k"/"v"。两刀同 Merged,只是段偏移按 q/k/v 排布算。defweight_loader(self,param,loaded_weight,loaded_shard_id):assertloaded_shard_idin["q","k","v"]# 第一刀:q、k、v 三段在合并表里依次排,算本路的段长、段偏移(行单位):# q 段 = q头×head_size,偏移 0# k 段 = kv头×head_size,偏移跳过 q# v 段 = kv头×head_size,偏移跳过 q、kifloaded_shard_id=="q":shard_size=self.num_heads*self.head_size shard_offset=0elifloaded_shard_id=="k":shard_size=self.num_kv_heads*self.head_size shard_offset=self.num_heads*self.head_sizeelse:# vshard_size=self.num_kv_heads*self.head_size shard_offset=(self.num_heads+self.num_kv_heads)*self.head_size param_data=param.data.narrow(self.tp_dim,shard_offset,shard_size)# 第二刀(并行):磁盘这路是整份,沿输出维切 tp 块、取本卡那块loaded_weight=loaded_weight.chunk(self.tp_size,self.tp_dim)[self.tp_rank]param_data.copy_(loaded_weight)# 本卡块写进那一段# 合成磁盘整份:q 4 头、k/v 各 2 头,head_size=2,hidden=3# 第 h 个头的两行都填一个头号常数(q:0~3、k:10~11、v:20~21)便于辨认defby_head(num_heads,base,head_size=2,hidden=3):rows=[[float(base+h)]*hiddenforhinrange(num_heads)for_inrange(head_size)]returntorch.tensor(rows)q_full,k_full,v_full=by_head(4,0),by_head(2,10),by_head(2,20)qkv0=QKVParallelLinear(3,2,total_num_heads=4,total_num_kv_heads=2,tp_size=2,tp_rank=0)qkv1=QKVParallelLinear(3,2,total_num_heads=4,total_num_kv_heads=2,tp_size=2,tp_rank=1)forqkvin(qkv0,qkv1):qkv.weight_loader(qkv.weight,q_full,"q")qkv.weight_loader(qkv.weight,k_full,"k")qkv.weight_loader(qkv.weight,v_full,"v")# 每卡合并 param[8,3]:q 段[0:4]、k 段[4:6]、v 段[6:8](读第 0 列的头号)a,b=qkv0.weight[:,0],qkv1.weight[:,0]print("rank0 q头",a[0:4].tolist()," k头",a[4:6].tolist()," v头",a[6:8].tolist())print("rank1 q头",b[0:4].tolist()," k头",b[4:6].tolist()," v头",b[6:8].tolist())# gate/up:两段等大,双重切同理(intermediate=4、hidden=3)m0=MergedColumnParallelLinear(3,[4,4],tp_size=2,tp_rank=0)m0.weight_loader(m0.weight,torch.full((4,3),1.),0)# gate → 段[0:2]m0.weight_loader(m0.weight,torch.full((4,3),2.),1)# up → 段[2:4]print("gate/up rank0 gate段",m0.weight[:2,0].tolist()," up段",m0.weight[2:4,0].tolist())
rank0 q头 [0.0, 0.0, 1.0, 1.0] k头 [10.0, 10.0] v头 [20.0, 20.0] rank1 q头 [2.0, 2.0, 3.0, 3.0] k头 [11.0, 11.0] v头 [21.0, 21.0] gate/up rank0 gate段 [1.0, 1.0] up段 [2.0, 2.0]

6. 小结

Linear家族的 TP 改造,就三件事:__init__按卡数把权重缩小(列切缩输出维、行切缩输入维,第一刀切在构造时);weight_loader沿tp_dim从磁盘整份切出本卡那一片(列切切行、行切切列);行切层forward末尾all_reduce求和,bias 只让 rank0 加。

合并类把两刀叠起来:段偏移定位「切哪种投影」、chunk取「切哪张卡」。qkv_proj的这两刀正交在 q/k/v 与 head 两个维度上。

Linear之外,还有按词表切的embedlm_head和按 head 切的 KV cache,下一篇接着看。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/26 21:16:52

SpringBoot + Redis 实现北极星日淘商品热点缓存优化(实战含源码)

摘要:北极星日淘平台日均承载数万件日系小众商品检索、下单、合箱业务,原生数据库直查模式下,热门限定商品、绝版孤品的高频访问会导致MySQL查询压力激增,接口响应延迟飙升。本文基于北极星日淘真实业务场景,采用Sprin…

作者头像 李华
网站建设 2026/6/26 21:14:11

大模型推理服务部署:从模型加载到弹性扩缩容的工程实践

大模型推理服务部署:从模型加载到弹性扩缩容的工程实践一、大模型推理部署的三大工程瓶颈:显存、延迟与冷启动 将大语言模型从实验环境推向生产服务,需要跨越三道工程瓶颈。第一道是显存瓶颈:一个 7B 参数模型在 FP16 精度下需要约…

作者头像 李华
网站建设 2026/6/26 21:11:56

2026年上半年软考信息系统项目管理师论文真题及答案解析(第二批)

请结合你所叙述的应用AI技术的信息系统项目,围绕以下要点论述你对 AI 时代的安全管理与风险管理的认识 (1)结合我国近期AI安全相关的政策法规角度,给出 AI 安全管理与风险管理需要重点关注的内容; (2)请根据你所描述的项目,按照…

作者头像 李华
网站建设 2026/6/26 21:11:56

ArchivePasswordTestTool:3步快速找回加密压缩包密码的完整指南

ArchivePasswordTestTool:3步快速找回加密压缩包密码的完整指南 【免费下载链接】ArchivePasswordTestTool 利用7zip测试压缩包的功能 对加密压缩包进行自动化测试密码 项目地址: https://gitcode.com/gh_mirrors/ar/ArchivePasswordTestTool 你是否曾经因为…

作者头像 李华
网站建设 2026/6/26 21:07:58

电影售票系统-springboot + vue

本项目为前几天收费帮学妹做的一个项目,在工作环境中基本使用不到,但是很多学校把这个当作编程入门的项目来做,故分享出本项目供初学者参考。 一、项目描述 基于springboot vue的电影售票系统 前台登录网址: http://localhost:8082/ 后台登…

作者头像 李华
网站建设 2026/6/26 21:07:22

CUA:让大模型操控电脑的开放框架——从原理到 Python 实战

不是只有 Claude Code 和 Hermes 才能操控桌面。CUA 提供了一套独立的开放基础设施,pip install 就能让你的 Python 程序拥有"看屏幕、点鼠标、敲键盘"的能力。 前言 2024 年 10 月 Anthropic 发布 Computer Use 的时候,"AI 操控电脑&quo…

作者头像 李华