nano-vllm 用千行代码拆解 vLLM 核心,是读懂大模型推理最快的捷径。
1. 介绍
上一篇讲清了张量并行的数学:一个线性层只有列切、行切两种拆法,行切之后要all_reduce求和,attention 按 head 切,RMSNorm 复制不切。
本篇实现张量并行。
真实代码靠多进程跑,每个进程占用一张卡。为了单机单进程也能把 tp=2 的切分跑出来,本篇把卡数tp_size、卡号tp_rank当构造参数显式传进去。真实代码从dist.get_world_size()、dist.get_rank()取。
importtorchfromtorchimportnnimporttorch.nn.functionalasFdefdivide(numerator,denominator):assertnumerator%denominator==0returnnumerator//denominator2. 总览
把模型切到多卡,代码上就是改Linear的 weight_loader 和 forward。
L16 里的LinearBase只管一张权重表加一个加载钩子;TP 版多存了三个数,整套切分都靠它们驱动:
tp_size:总卡数(真实代码dist.get_world_size())。tp_rank:本卡编号(真实代码dist.get_rank())。tp_dim:这一层沿权重的哪一维切——列切是 0、行切是 1、不切是None。
为什么存成「转置」的[out, in]?Linear的前向是x @ weight.T:weight 每一行就是「一个输出单元」的全部输入权重,按输出排。于是列切(按输出维切)落在第 0 维、切的是整行;行切(按输入维切)落在第 1 维、切的是列。上一篇按数学习惯把权重写成[in, out](输出在列),同一刀到代码里就从切列转成了切行——名字没变,维度转了 90°。
classLinearBase(nn.Module):def__init__(self,input_size,output_size,tp_size,tp_rank,bias=False,tp_dim=None):super().__init__()self.tp_dim=tp_dim self.tp_size=tp_size# 真实代码:dist.get_world_size()self.tp_rank=tp_rank# 真实代码:dist.get_rank()# 权重已是本卡那一块的形状(子类把 output/input 缩好后传进来)self.weight=nn.Parameter(torch.empty(output_size,input_size))self.weight.weight_loader=self.weight_loaderifbias:self.bias=nn.Parameter(torch.empty(output_size))self.bias.weight_loader=self.weight_loaderelse:self.register_parameter("bias",None)defforward(self,x):raiseNotImplementedError3. ColumnParallelLinear:列切
列切按输出维度切——把权重的行(输出那一维)分给各卡,每卡算输出的一段,输入完整。
__init__:把output_size除以卡数、tp_dim设成 0。权重直接建成[out/tp, in]。
weight_loader把磁盘整份切出本卡那块,三步(见上图):
shard_size = param.size(0):本卡要几行——就是已缩好的out/tp(图里4/2 = 2)。start = tp_rank * shard_size:本卡从第几行起。rank0 从0、rank1 从2。loaded_weight.narrow(0, start, shard_size):从第start行起、取shard_size行,copy_进 param。
两张卡读的是同一份磁盘权重,只是start不同,各取互不重叠的一段行。
forward:一句F.linear搞定。输入完整,输出是[N, out/tp]的一段——输出分片、matmul 阶段零通信。完整输出散在各卡上。
classColumnParallelLinear(LinearBase):def__init__(self,input_size,output_size,tp_size,tp_rank,bias=False):super().__init__(input_size,divide(output_size,tp_size),tp_size,tp_rank,bias,tp_dim=0)defweight_loader(self,param,loaded_weight):shard_size=param.size(self.tp_dim)# out/tpstart=self.tp_rank*shard_size loaded_weight=loaded_weight.narrow(self.tp_dim,start,shard_size)param.data.copy_(loaded_weight)defforward(self,x):returnF.linear(x,self.weight,self.bias)# 磁盘整份 weight[out=4, in=3],第 r 行全填常数 r,便于辨认是哪一行W_full=torch.arange(4.).reshape(4,1).repeat(1,3)print("磁盘整份 W(行号):",W_full[:,0].tolist())# [0., 1., 2., 3.]# tp=2:每卡 weight 已缩成 [out/tp, in] = [2, 3]c0=ColumnParallelLinear(3,4,tp_size=2,tp_rank=0)c1=ColumnParallelLinear(3,4,tp_size=2,tp_rank=1)print("每卡 weight 形状:",tuple(c0.weight.shape)," tp_dim =",c0.tp_dim)c0.weight_loader(c0.weight,W_full)c1.weight_loader(c1.weight,W_full)print("rank0 拿到行:",c0.weight[:,0].tolist())# [0., 1.]print("rank1 拿到行:",c1.weight[:,0].tolist())# [2., 3.]x=torch.randn(5,3)# 输入完整print("rank0 输出分片:",tuple(c0(x).shape))# (5, 2)磁盘整份 W(行号): [0.0, 1.0, 2.0, 3.0] 每卡 weight 形状: (2, 3) tp_dim = 0 rank0 拿到行: [0.0, 1.0] rank1 拿到行: [2.0, 3.0] rank0 输出分片: (5, 2)4. RowParallelLinear:行切
行切按输入维度切——把权重的列(输入那一维)分给各卡,输入也跟着切开,每卡算一个部分和。比列切多两件事:forward末尾要通信,bias 只在一张卡上加。
__init__:把input_size除以卡数、tp_dim设成 1。权重建成[out, in/tp]。
weight_loader:二维权重和列切同样三步,只是改切dim 1——shard_size = in/tp、start = tp_rank * shard_size(rank0 从 0、rank1 从 2)、narrow(1, start, shard_size)取本卡那in/tp列。多一个特例:bias 是一维(长度out、输出维没切),两卡各整份拷。
forward:F.linear各卡算出一个部分和,all_reduce跨卡求和才是完整输出。bias 这里只让 rank0 加——它每卡都存了完整一份,若每卡都加,all_reduce求和后会被加上tp次;只 rank0 加,求和后恰好算一次。
classRowParallelLinear(LinearBase):def__init__(self,input_size,output_size,tp_size,tp_rank,bias=False):super().__init__(divide(input_size,tp_size),output_size,tp_size,tp_rank,bias,tp_dim=1)defweight_loader(self,param,loaded_weight):ifparam.data.ndim==1:# bias:输出维没切,整份拷param.data.copy_(loaded_weight)returnshard_size=param.size(self.tp_dim)# in/tpstart=self.tp_rank*shard_size loaded_weight=loaded_weight.narrow(self.tp_dim,start,shard_size)param.data.copy_(loaded_weight)defforward(self,x):y=F.linear(x,self.weight,self.biasifself.tp_rank==0elseNone)# 真实代码此处 if self.tp_size > 1: dist.all_reduce(y)# 简单起见,手动计算 y0+y1 模拟 all_reducereturny# 磁盘整份:weight[out=2, in=4]、bias[2]W_full=torch.tensor([[1.,2,3,4],[5.,6,7,8]])b_full=torch.tensor([10.,20.])r0=RowParallelLinear(4,2,tp_size=2,tp_rank=0,bias=True)r1=RowParallelLinear(4,2,tp_size=2,tp_rank=1,bias=True)forrin(r0,r1):r.weight_loader(r.weight,W_full)# 沿 dim1 切列r.weight_loader(r.bias,b_full)# 一维:整份拷print("每卡 weight 形状:",tuple(r0.weight.shape))# (2, 2) = [out, in/tp]print("rank0 列段:",r0.weight.tolist())# 前 2 列print("rank1 列段:",r1.weight.tolist())# 后 2 列print("bias 两卡各一整份:",r0.bias.tolist(),r1.bias.tolist())# forward:输入 x 也按输入维切两段x=torch.tensor([[1.,1,1,1]])y0=r0(x[:,:2])# rank0:含 biasy1=r1(x[:,2:])# rank1:不含 biasprint("rank0 部分和(含bias):",y0.tolist())# [[13., 31.]]print("rank1 部分和(无bias):",y1.tolist())# [[7., 15.]]print("all_reduce 求和:",(y0+y1).tolist())# [[20., 46.]]print("单卡对照:",F.linear(x,W_full,b_full).tolist())# [[20., 46.]]每卡 weight 形状: (2, 2) rank0 列段: [[1.0, 2.0], [5.0, 6.0]] rank1 列段: [[3.0, 4.0], [7.0, 8.0]] bias 两卡各一整份: [10.0, 20.0] [10.0, 20.0] rank0 部分和(含bias): [[13.0, 31.0]] rank1 部分和(无bias): [[7.0, 15.0]] all_reduce 求和: [[20.0, 46.0]] 单卡对照: [[20.0, 46.0]]5. 合并类的双重切
qkv_proj、gate_up_proj是合并投影:把 q/k/v(或 gate/up)几路拼成一张大权重表、一次矩阵乘算完,省去分开调用的开销。它们都继承ColumnParallelLinear,只重写weight_loader。
列切、行切每装一份权重只切一刀;合并类要切两刀,因为它需要处理「合并」和「并行」两件事。磁盘上 q/k/v 分三份存,内存里却挤进同一张表(合并);这张表又按输出维列切到各卡(并行)。二者各要一刀,彼此正交,叠起来就是双重切:
- 第一刀,切哪种投影:磁盘上 q/k/v 分开存,每装一路得先定位它落在合并表的哪一段。合并表已按卡数缩成
[各路之和 / tp, in],所以段偏移、段长也都跟着//tp。这刀单卡就有——chunk那刀此时是空操作。 - 第二刀,切哪张卡:合并表按输出维列切到各卡,本卡只要每一路里属于自己的那部分。磁盘那一路是整份,
chunk(tp)切成tp块、取第tp_rank块——本卡那一条。
落到qkv_proj。__init__先把头数按卡数分:num_heads、num_kv_heads各除以tp。weight_loader按 q/k/v 三段定位偏移,再chunk取本卡那一条。
段偏移切的是「q 还是 k 还是 v」,chunk切的是 head——两刀各管一个维度、互不干扰。所以 tp=2 时,rank0 拿到 q 头的前一半 + kv 头的前一半、rank1 拿后一半。
classMergedColumnParallelLinear(ColumnParallelLinear):# 合并投影(如 gate_up = gate + up):几路拼成一张列切表。# output_sizes = 各路输出宽度(如 [gate宽, up宽]),加载时按它定位每路的段。def__init__(self,input_size,output_sizes,tp_size,tp_rank,bias=False):self.output_sizes=output_sizes# 合并表总宽 = 各路之和;父类(列切)再把总宽 //tp 缩成本卡那块super().__init__(input_size,sum(output_sizes),tp_size,tp_rank,bias)# 一次只装一路。loaded_shard_id = 这路的序号(0=gate、1=up)defweight_loader(self,param,loaded_weight,loaded_shard_id):# 第一刀(合并):定位这路在本卡合并表里的段。# 偏移 = 前面各路宽度之和,段长 = 本路宽度,都 //tp(表已按卡缩小)。shard_offset=sum(self.output_sizes[:loaded_shard_id])//self.tp_size shard_size=self.output_sizes[loaded_shard_id]//self.tp_size# param_data 是 param 那一段的视图,改它就是改 paramparam_data=param.data.narrow(self.tp_dim,shard_offset,shard_size)# 第二刀(并行):磁盘这路是整份,沿输出维切 tp 块、取本卡那块loaded_weight=loaded_weight.chunk(self.tp_size,self.tp_dim)[self.tp_rank]param_data.copy_(loaded_weight)# 本卡块写进那一段classQKVParallelLinear(ColumnParallelLinear):# 合并 q/k/v 三路。GQA 下 kv 头比 q 头少、三路宽度不等,所以单列一类。def__init__(self,hidden_size,head_size,total_num_heads,total_num_kv_heads,tp_size,tp_rank,bias=False):self.head_size=head_size self.num_heads=divide(total_num_heads,tp_size)# 每卡 q 头self.num_kv_heads=divide(total_num_kv_heads,tp_size)# 每卡 kv 头# 合并表总宽 = q + k + v = (q头 + 2×kv头) × head_sizeoutput_size=(total_num_heads+2*total_num_kv_heads)*head_sizesuper().__init__(hidden_size,output_size,tp_size,tp_rank,bias)# loaded_shard_id = "q"/"k"/"v"。两刀同 Merged,只是段偏移按 q/k/v 排布算。defweight_loader(self,param,loaded_weight,loaded_shard_id):assertloaded_shard_idin["q","k","v"]# 第一刀:q、k、v 三段在合并表里依次排,算本路的段长、段偏移(行单位):# q 段 = q头×head_size,偏移 0# k 段 = kv头×head_size,偏移跳过 q# v 段 = kv头×head_size,偏移跳过 q、kifloaded_shard_id=="q":shard_size=self.num_heads*self.head_size shard_offset=0elifloaded_shard_id=="k":shard_size=self.num_kv_heads*self.head_size shard_offset=self.num_heads*self.head_sizeelse:# vshard_size=self.num_kv_heads*self.head_size shard_offset=(self.num_heads+self.num_kv_heads)*self.head_size param_data=param.data.narrow(self.tp_dim,shard_offset,shard_size)# 第二刀(并行):磁盘这路是整份,沿输出维切 tp 块、取本卡那块loaded_weight=loaded_weight.chunk(self.tp_size,self.tp_dim)[self.tp_rank]param_data.copy_(loaded_weight)# 本卡块写进那一段# 合成磁盘整份:q 4 头、k/v 各 2 头,head_size=2,hidden=3# 第 h 个头的两行都填一个头号常数(q:0~3、k:10~11、v:20~21)便于辨认defby_head(num_heads,base,head_size=2,hidden=3):rows=[[float(base+h)]*hiddenforhinrange(num_heads)for_inrange(head_size)]returntorch.tensor(rows)q_full,k_full,v_full=by_head(4,0),by_head(2,10),by_head(2,20)qkv0=QKVParallelLinear(3,2,total_num_heads=4,total_num_kv_heads=2,tp_size=2,tp_rank=0)qkv1=QKVParallelLinear(3,2,total_num_heads=4,total_num_kv_heads=2,tp_size=2,tp_rank=1)forqkvin(qkv0,qkv1):qkv.weight_loader(qkv.weight,q_full,"q")qkv.weight_loader(qkv.weight,k_full,"k")qkv.weight_loader(qkv.weight,v_full,"v")# 每卡合并 param[8,3]:q 段[0:4]、k 段[4:6]、v 段[6:8](读第 0 列的头号)a,b=qkv0.weight[:,0],qkv1.weight[:,0]print("rank0 q头",a[0:4].tolist()," k头",a[4:6].tolist()," v头",a[6:8].tolist())print("rank1 q头",b[0:4].tolist()," k头",b[4:6].tolist()," v头",b[6:8].tolist())# gate/up:两段等大,双重切同理(intermediate=4、hidden=3)m0=MergedColumnParallelLinear(3,[4,4],tp_size=2,tp_rank=0)m0.weight_loader(m0.weight,torch.full((4,3),1.),0)# gate → 段[0:2]m0.weight_loader(m0.weight,torch.full((4,3),2.),1)# up → 段[2:4]print("gate/up rank0 gate段",m0.weight[:2,0].tolist()," up段",m0.weight[2:4,0].tolist())rank0 q头 [0.0, 0.0, 1.0, 1.0] k头 [10.0, 10.0] v头 [20.0, 20.0] rank1 q头 [2.0, 2.0, 3.0, 3.0] k头 [11.0, 11.0] v头 [21.0, 21.0] gate/up rank0 gate段 [1.0, 1.0] up段 [2.0, 2.0]6. 小结
Linear家族的 TP 改造,就三件事:__init__按卡数把权重缩小(列切缩输出维、行切缩输入维,第一刀切在构造时);weight_loader沿tp_dim从磁盘整份切出本卡那一片(列切切行、行切切列);行切层forward末尾all_reduce求和,bias 只让 rank0 加。
合并类把两刀叠起来:段偏移定位「切哪种投影」、chunk取「切哪张卡」。qkv_proj的这两刀正交在 q/k/v 与 head 两个维度上。
Linear之外,还有按词表切的embed、lm_head和按 head 切的 KV cache,下一篇接着看。