nano-vllm 用千行代码拆解 vLLM 核心是读懂大模型推理最快的捷径。1. 介绍上一篇讲清了张量并行的数学一个线性层只有列切、行切两种拆法行切之后要all_reduce求和attention 按 head 切RMSNorm 复制不切。本篇实现张量并行。真实代码靠多进程跑每个进程占用一张卡。为了单机单进程也能把 tp2 的切分跑出来本篇把卡数tp_size、卡号tp_rank当构造参数显式传进去。真实代码从dist.get_world_size()、dist.get_rank()取。importtorchfromtorchimportnnimporttorch.nn.functionalasFdefdivide(numerator,denominator):assertnumerator%denominator0returnnumerator//denominator2. 总览把模型切到多卡代码上就是改Linear的 weight_loader 和 forward。L16 里的LinearBase只管一张权重表加一个加载钩子TP 版多存了三个数整套切分都靠它们驱动tp_size总卡数真实代码dist.get_world_size()。tp_rank本卡编号真实代码dist.get_rank()。tp_dim这一层沿权重的哪一维切——列切是 0、行切是 1、不切是None。为什么存成「转置」的[out, in]Linear的前向是x weight.Tweight 每一行就是「一个输出单元」的全部输入权重按输出排。于是列切按输出维切落在第 0 维、切的是整行行切按输入维切落在第 1 维、切的是列。上一篇按数学习惯把权重写成[in, out]输出在列同一刀到代码里就从切列转成了切行——名字没变维度转了 90°。classLinearBase(nn.Module):def__init__(self,input_size,output_size,tp_size,tp_rank,biasFalse,tp_dimNone):super().__init__()self.tp_dimtp_dim self.tp_sizetp_size# 真实代码dist.get_world_size()self.tp_ranktp_rank# 真实代码dist.get_rank()# 权重已是本卡那一块的形状子类把 output/input 缩好后传进来self.weightnn.Parameter(torch.empty(output_size,input_size))self.weight.weight_loaderself.weight_loaderifbias:self.biasnn.Parameter(torch.empty(output_size))self.bias.weight_loaderself.weight_loaderelse:self.register_parameter(bias,None)defforward(self,x):raiseNotImplementedError3. ColumnParallelLinear列切列切按输出维度切——把权重的行输出那一维分给各卡每卡算输出的一段输入完整。__init__把output_size除以卡数、tp_dim设成 0。权重直接建成[out/tp, in]。weight_loader把磁盘整份切出本卡那块三步见上图shard_size param.size(0)本卡要几行——就是已缩好的out/tp图里4/2 2。start tp_rank * shard_size本卡从第几行起。rank0 从0、rank1 从2。loaded_weight.narrow(0, start, shard_size)从第start行起、取shard_size行copy_进 param。两张卡读的是同一份磁盘权重只是start不同各取互不重叠的一段行。forward一句F.linear搞定。输入完整输出是[N, out/tp]的一段——输出分片、matmul 阶段零通信。完整输出散在各卡上。classColumnParallelLinear(LinearBase):def__init__(self,input_size,output_size,tp_size,tp_rank,biasFalse):super().__init__(input_size,divide(output_size,tp_size),tp_size,tp_rank,bias,tp_dim0)defweight_loader(self,param,loaded_weight):shard_sizeparam.size(self.tp_dim)# out/tpstartself.tp_rank*shard_size loaded_weightloaded_weight.narrow(self.tp_dim,start,shard_size)param.data.copy_(loaded_weight)defforward(self,x):returnF.linear(x,self.weight,self.bias)# 磁盘整份 weight[out4, in3]第 r 行全填常数 r便于辨认是哪一行W_fulltorch.arange(4.).reshape(4,1).repeat(1,3)print(磁盘整份 W行号:,W_full[:,0].tolist())# [0., 1., 2., 3.]# tp2每卡 weight 已缩成 [out/tp, in] [2, 3]c0ColumnParallelLinear(3,4,tp_size2,tp_rank0)c1ColumnParallelLinear(3,4,tp_size2,tp_rank1)print(每卡 weight 形状:,tuple(c0.weight.shape), tp_dim ,c0.tp_dim)c0.weight_loader(c0.weight,W_full)c1.weight_loader(c1.weight,W_full)print(rank0 拿到行:,c0.weight[:,0].tolist())# [0., 1.]print(rank1 拿到行:,c1.weight[:,0].tolist())# [2., 3.]xtorch.randn(5,3)# 输入完整print(rank0 输出分片:,tuple(c0(x).shape))# (5, 2)磁盘整份 W行号: [0.0, 1.0, 2.0, 3.0] 每卡 weight 形状: (2, 3) tp_dim 0 rank0 拿到行: [0.0, 1.0] rank1 拿到行: [2.0, 3.0] rank0 输出分片: (5, 2)4. RowParallelLinear行切行切按输入维度切——把权重的列输入那一维分给各卡输入也跟着切开每卡算一个部分和。比列切多两件事forward末尾要通信bias 只在一张卡上加。__init__把input_size除以卡数、tp_dim设成 1。权重建成[out, in/tp]。weight_loader二维权重和列切同样三步只是改切dim 1——shard_size in/tp、start tp_rank * shard_sizerank0 从 0、rank1 从 2、narrow(1, start, shard_size)取本卡那in/tp列。多一个特例bias 是一维长度out、输出维没切两卡各整份拷。forwardF.linear各卡算出一个部分和all_reduce跨卡求和才是完整输出。bias 这里只让 rank0 加——它每卡都存了完整一份若每卡都加all_reduce求和后会被加上tp次只 rank0 加求和后恰好算一次。classRowParallelLinear(LinearBase):def__init__(self,input_size,output_size,tp_size,tp_rank,biasFalse):super().__init__(divide(input_size,tp_size),output_size,tp_size,tp_rank,bias,tp_dim1)defweight_loader(self,param,loaded_weight):ifparam.data.ndim1:# bias输出维没切整份拷param.data.copy_(loaded_weight)returnshard_sizeparam.size(self.tp_dim)# in/tpstartself.tp_rank*shard_size loaded_weightloaded_weight.narrow(self.tp_dim,start,shard_size)param.data.copy_(loaded_weight)defforward(self,x):yF.linear(x,self.weight,self.biasifself.tp_rank0elseNone)# 真实代码此处 if self.tp_size 1: dist.all_reduce(y)# 简单起见手动计算 y0y1 模拟 all_reducereturny# 磁盘整份weight[out2, in4]、bias[2]W_fulltorch.tensor([[1.,2,3,4],[5.,6,7,8]])b_fulltorch.tensor([10.,20.])r0RowParallelLinear(4,2,tp_size2,tp_rank0,biasTrue)r1RowParallelLinear(4,2,tp_size2,tp_rank1,biasTrue)forrin(r0,r1):r.weight_loader(r.weight,W_full)# 沿 dim1 切列r.weight_loader(r.bias,b_full)# 一维整份拷print(每卡 weight 形状:,tuple(r0.weight.shape))# (2, 2) [out, in/tp]print(rank0 列段:,r0.weight.tolist())# 前 2 列print(rank1 列段:,r1.weight.tolist())# 后 2 列print(bias 两卡各一整份:,r0.bias.tolist(),r1.bias.tolist())# forward输入 x 也按输入维切两段xtorch.tensor([[1.,1,1,1]])y0r0(x[:,:2])# rank0含 biasy1r1(x[:,2:])# rank1不含 biasprint(rank0 部分和(含bias):,y0.tolist())# [[13., 31.]]print(rank1 部分和(无bias):,y1.tolist())# [[7., 15.]]print(all_reduce 求和:,(y0y1).tolist())# [[20., 46.]]print(单卡对照:,F.linear(x,W_full,b_full).tolist())# [[20., 46.]]每卡 weight 形状: (2, 2) rank0 列段: [[1.0, 2.0], [5.0, 6.0]] rank1 列段: [[3.0, 4.0], [7.0, 8.0]] bias 两卡各一整份: [10.0, 20.0] [10.0, 20.0] rank0 部分和(含bias): [[13.0, 31.0]] rank1 部分和(无bias): [[7.0, 15.0]] all_reduce 求和: [[20.0, 46.0]] 单卡对照: [[20.0, 46.0]]5. 合并类的双重切qkv_proj、gate_up_proj是合并投影把 q/k/v或 gate/up几路拼成一张大权重表、一次矩阵乘算完省去分开调用的开销。它们都继承ColumnParallelLinear只重写weight_loader。列切、行切每装一份权重只切一刀合并类要切两刀因为它需要处理「合并」和「并行」两件事。磁盘上 q/k/v 分三份存内存里却挤进同一张表合并这张表又按输出维列切到各卡并行。二者各要一刀彼此正交叠起来就是双重切第一刀切哪种投影磁盘上 q/k/v 分开存每装一路得先定位它落在合并表的哪一段。合并表已按卡数缩成[各路之和 / tp, in]所以段偏移、段长也都跟着//tp。这刀单卡就有——chunk那刀此时是空操作。第二刀切哪张卡合并表按输出维列切到各卡本卡只要每一路里属于自己的那部分。磁盘那一路是整份chunk(tp)切成tp块、取第tp_rank块——本卡那一条。落到qkv_proj。__init__先把头数按卡数分num_heads、num_kv_heads各除以tp。weight_loader按 q/k/v 三段定位偏移再chunk取本卡那一条。段偏移切的是「q 还是 k 还是 v」chunk切的是 head——两刀各管一个维度、互不干扰。所以 tp2 时rank0 拿到 q 头的前一半 kv 头的前一半、rank1 拿后一半。classMergedColumnParallelLinear(ColumnParallelLinear):# 合并投影如 gate_up gate up几路拼成一张列切表。# output_sizes 各路输出宽度如 [gate宽, up宽]加载时按它定位每路的段。def__init__(self,input_size,output_sizes,tp_size,tp_rank,biasFalse):self.output_sizesoutput_sizes# 合并表总宽 各路之和父类列切再把总宽 //tp 缩成本卡那块super().__init__(input_size,sum(output_sizes),tp_size,tp_rank,bias)# 一次只装一路。loaded_shard_id 这路的序号0gate、1updefweight_loader(self,param,loaded_weight,loaded_shard_id):# 第一刀合并定位这路在本卡合并表里的段。# 偏移 前面各路宽度之和段长 本路宽度都 //tp表已按卡缩小。shard_offsetsum(self.output_sizes[:loaded_shard_id])//self.tp_size shard_sizeself.output_sizes[loaded_shard_id]//self.tp_size# param_data 是 param 那一段的视图改它就是改 paramparam_dataparam.data.narrow(self.tp_dim,shard_offset,shard_size)# 第二刀并行磁盘这路是整份沿输出维切 tp 块、取本卡那块loaded_weightloaded_weight.chunk(self.tp_size,self.tp_dim)[self.tp_rank]param_data.copy_(loaded_weight)# 本卡块写进那一段classQKVParallelLinear(ColumnParallelLinear):# 合并 q/k/v 三路。GQA 下 kv 头比 q 头少、三路宽度不等所以单列一类。def__init__(self,hidden_size,head_size,total_num_heads,total_num_kv_heads,tp_size,tp_rank,biasFalse):self.head_sizehead_size self.num_headsdivide(total_num_heads,tp_size)# 每卡 q 头self.num_kv_headsdivide(total_num_kv_heads,tp_size)# 每卡 kv 头# 合并表总宽 q k v (q头 2×kv头) × head_sizeoutput_size(total_num_heads2*total_num_kv_heads)*head_sizesuper().__init__(hidden_size,output_size,tp_size,tp_rank,bias)# loaded_shard_id q/k/v。两刀同 Merged只是段偏移按 q/k/v 排布算。defweight_loader(self,param,loaded_weight,loaded_shard_id):assertloaded_shard_idin[q,k,v]# 第一刀q、k、v 三段在合并表里依次排算本路的段长、段偏移行单位# q 段 q头×head_size偏移 0# k 段 kv头×head_size偏移跳过 q# v 段 kv头×head_size偏移跳过 q、kifloaded_shard_idq:shard_sizeself.num_heads*self.head_size shard_offset0elifloaded_shard_idk:shard_sizeself.num_kv_heads*self.head_size shard_offsetself.num_heads*self.head_sizeelse:# vshard_sizeself.num_kv_heads*self.head_size shard_offset(self.num_headsself.num_kv_heads)*self.head_size param_dataparam.data.narrow(self.tp_dim,shard_offset,shard_size)# 第二刀并行磁盘这路是整份沿输出维切 tp 块、取本卡那块loaded_weightloaded_weight.chunk(self.tp_size,self.tp_dim)[self.tp_rank]param_data.copy_(loaded_weight)# 本卡块写进那一段# 合成磁盘整份q 4 头、k/v 各 2 头head_size2hidden3# 第 h 个头的两行都填一个头号常数q:0~3、k:10~11、v:20~21便于辨认defby_head(num_heads,base,head_size2,hidden3):rows[[float(baseh)]*hiddenforhinrange(num_heads)for_inrange(head_size)]returntorch.tensor(rows)q_full,k_full,v_fullby_head(4,0),by_head(2,10),by_head(2,20)qkv0QKVParallelLinear(3,2,total_num_heads4,total_num_kv_heads2,tp_size2,tp_rank0)qkv1QKVParallelLinear(3,2,total_num_heads4,total_num_kv_heads2,tp_size2,tp_rank1)forqkvin(qkv0,qkv1):qkv.weight_loader(qkv.weight,q_full,q)qkv.weight_loader(qkv.weight,k_full,k)qkv.weight_loader(qkv.weight,v_full,v)# 每卡合并 param[8,3]q 段[0:4]、k 段[4:6]、v 段[6:8]读第 0 列的头号a,bqkv0.weight[:,0],qkv1.weight[:,0]print(rank0 q头,a[0:4].tolist(), k头,a[4:6].tolist(), v头,a[6:8].tolist())print(rank1 q头,b[0:4].tolist(), k头,b[4:6].tolist(), v头,b[6:8].tolist())# gate/up两段等大双重切同理intermediate4、hidden3m0MergedColumnParallelLinear(3,[4,4],tp_size2,tp_rank0)m0.weight_loader(m0.weight,torch.full((4,3),1.),0)# gate → 段[0:2]m0.weight_loader(m0.weight,torch.full((4,3),2.),1)# up → 段[2:4]print(gate/up rank0 gate段,m0.weight[:2,0].tolist(), up段,m0.weight[2:4,0].tolist())rank0 q头 [0.0, 0.0, 1.0, 1.0] k头 [10.0, 10.0] v头 [20.0, 20.0] rank1 q头 [2.0, 2.0, 3.0, 3.0] k头 [11.0, 11.0] v头 [21.0, 21.0] gate/up rank0 gate段 [1.0, 1.0] up段 [2.0, 2.0]6. 小结Linear家族的 TP 改造就三件事__init__按卡数把权重缩小列切缩输出维、行切缩输入维第一刀切在构造时weight_loader沿tp_dim从磁盘整份切出本卡那一片列切切行、行切切列行切层forward末尾all_reduce求和bias 只让 rank0 加。合并类把两刀叠起来段偏移定位「切哪种投影」、chunk取「切哪张卡」。qkv_proj的这两刀正交在 q/k/v 与 head 两个维度上。Linear之外还有按词表切的embed、lm_head和按 head 切的 KV cache下一篇接着看。