Vision-Language-Action:LMDrive大语言模型(LLM)核心推理单元
LMDrive 大语言模型LLM核心推理单元1. 整体架构概览LMDrive 使用LLaVA-7B作为核心语言模型基于 Meta LLaMA 架构。LLM 在系统中扮演特征提取器角色而非直接生成文本。数据流文本指令 Q-Former特征 → LLM输入嵌入 → LlamaDecoderLayer × 32 → 隐藏状态 → Waypoint预测器 终点分类器2. 核心算法组件2.1 LlamaRMSNorm均方根归一化算法原理与传统 LayerNorm 不同RMSNorm 不减去均值仅对标准差进行归一化公式y x / sqrt(E[x²] ε) * weight优点计算更快、训练更稳定避免均值偏移问题代码实现modeling_llama.py:95-109classLlamaRMSNorm(nn.Module):def__init__(self,hidden_size,eps1e-6):super().__init__()self.weightnn.Parameter(torch.ones(hidden_size))self.variance_epsilonepsdefforward(self,hidden_states):input_dtypehidden_states.dtype hidden_stateshidden_states.to(torch.float32)variancehidden_states.pow(2).mean(-1,keepdimTrue)hidden_stateshidden_states*torch.rsqrt(varianceself.variance_epsilon)returnself.weight*hidden_states.to(input_dtype)举例说明输入[1.0, 2.0, 3.0]hidden_dim3方差计算(149)/3 14/3 ≈ 4.67归一化[1/2.16, 2/2.16, 3/2.16] ≈ [0.46, 0.93, 1.39]乘以可学习权重[0.46*w1, 0.93*w2, 1.39*w3]2.2 LlamaRotaryEmbeddingRoPE 旋转位置编码算法原理通过旋转矩阵对 Query/Key 进行位置编码实现相对位置信息公式q q × cos(θ) rotate_half(q) × sin(θ)支持动态 NTK 缩放扩展序列长度上限代码实现modeling_llama.py:115-148classLlamaRotaryEmbedding(nn.Module):def__init__(self,dim,max_position_embeddings2048,base10000):super().__init__()self.dimdim inv_freq1.0/(base**(torch.arange(0,dim,2).float()/dim))self.register_buffer(inv_freq,inv_freq)defforward(self,x,seq_lenNone):ttorch.arange(seq_len,dtypeself.inv_freq.dtype)freqstorch.einsum(i,j-ij,t,self.inv_freq)embtorch.cat((freqs,freqs),dim-1)returnemb.cos(),emb.sin()defrotate_half(x):x1x[...,:x.shape[-1]//2]x2x[...,x.shape[-1]//2:]returntorch.cat((-x2,x1),dim-1)举例说明假设 head_dim4位置 pos2inv_freq [1/(10000^(0/4)), 1/(10000^(2/4))] [1, 0.01]freqs [2×1, 2×0.01] [2, 0.02]cos [cos(2), cos(0.02)] ≈ [-0.416, 1.0]sin [sin(2), sin(0.02)] ≈ [0.909, 0.02]对 Q/K 向量[q0, q1, q2, q3]q0 q0 × cos(2) - q2 × sin(2)q1 q1 × cos(0.02) - q3 × sin(0.02)2.3 LlamaAttention因果多头自注意力算法原理标准多头注意力机制配合因果掩码防止未来信息泄露支持 KV 缓存加速推理past_key_value可选 GQAGrouped Query Attention通过repeat_kv实现代码实现modeling_llama.py:258-412classLlamaAttention(nn.Module):def__init__(self,config):super().__init__()self.num_headsconfig.num_attention_heads self.head_dimconfig.hidden_size//self.num_heads self.q_projnn.Linear(config.hidden_size,self.num_heads*self.head_dim)self.k_projnn.Linear(config.hidden_size,self.num_key_value_heads*self.head_dim)self.v_projnn.Linear(config.hidden_size,self.num_key_value_heads*self.head_dim)self.o_projnn.Linear(self.num_heads*self.head_dim,config.hidden_size)self._init_rope()defforward(self,hidden_states,attention_maskNone,position_idsNone,past_key_valueNone,use_cacheFalse):bsz,q_len,_hidden_states.size()query_statesself.q_proj(hidden_states)# [B, L, n_heads * head_dim]key_statesself.k_proj(hidden_states)# [B, L, n_kv_heads * head_dim]value_statesself.v_proj(hidden_states)# [B, L, n_kv_heads * head_dim]query_statesquery_states.view(bsz,q_len,self.num_heads,self.head_dim).transpose(1,2)key_stateskey_states.view(bsz,q_len,self.num_key_value_heads,self.head_dim).transpose(1,2)value_statesvalue_states.view(bsz,q_len,self.num_key_value_heads,self.head_dim).transpose(1,2)# 应用 RoPEcos,sinself.rotary_emb(value_states,seq_lenkv_seq_len)query_states,key_statesapply_rotary_pos_emb(query_states,key_states,cos,sin,position_ids)# KV缓存拼接ifpast_key_valueisnotNone:key_statestorch.cat([past_key_value[0],key_states],dim2)value_statestorch.cat([past_key_value[1],value_states],dim2)# GQA: 扩展 KV 到所有头key_statesrepeat_kv(key_states,self.num_key_value_groups)value_statesrepeat_kv(value_states,self.num_key_value_groups)# 注意力计算attn_weightstorch.matmul(query_states,key_states.transpose(2,3))/math.sqrt(self.head_dim)ifattention_maskisnotNone:attn_weightsattn_weightsattention_mask attn_weightsnn.functional.softmax(attn_weights,dim-1)attn_outputtorch.matmul(attn_weights,value_states)attn_outputattn_output.transpose(1,2).reshape(bsz,q_len,self.hidden_size)attn_outputself.o_proj(attn_output)returnattn_output,attn_weights,(key_states,value_states)ifuse_cacheelseNone关键机制机制作用因果掩码防止当前 token 关注未来 tokenRoPE注入相对位置信息KV缓存推理时复用历史 KV减少重复计算GQA通过共享 KV 头减少内存占用2.4 LlamaMLPSwiGLU 前馈网络算法原理使用 SwiGLUSwish-Gated Linear Unit激活函数结构down_proj(SiLU(gate_proj(x)) × up_proj(x))相比 ReLUSwiGLU 提供更平滑的梯度流动代码实现modeling_llama.py:212-243classLlamaMLP(nn.Module):def__init__(self,config):super().__init__()self.gate_projnn.Linear(config.hidden_size,config.intermediate_size,biasFalse)self.up_projnn.Linear(config.hidden_size,config.intermediate_size,biasFalse)self.down_projnn.Linear(config.intermediate_size,config.hidden_size,biasFalse)self.act_fnACT2FN[config.hidden_act]# siludefforward(self,x):returnself.down_proj(self.act_fn(self.gate_proj(x))*self.up_proj(x))SwiGLU 公式y W3 × (SiLU(W1 × x) ⊙ W2 × x)其中⊙表示逐元素相乘SiLU(x) x × σ(x)。2.5 LlamaDecoderLayer解码器层算法原理采用 Pre-Norm 架构归一化在注意力/MLP 之前残差连接模式x Attention(LN(x))和x MLP(LN(x))代码实现modeling_llama.py:597-663classLlamaDecoderLayer(nn.Module):def__init__(self,config):super().__init__()self.self_attnLlamaAttention(config)ifnot_flash_attn_2_enabledelseLlamaFlashAttention2(config)self.mlpLlamaMLP(config)self.input_layernormLlamaRMSNorm(config.hidden_size)self.post_attention_layernormLlamaRMSNorm(config.hidden_size)defforward(self,hidden_states,attention_maskNone,position_idsNone,past_key_valueNone,use_cacheFalse):residualhidden_states# 自注意力子层hidden_statesself.input_layernorm(hidden_states)hidden_states,_,present_key_valueself.self_attn(hidden_stateshidden_states,attention_maskattention_mask,position_idsposition_ids,past_key_valuepast_key_value,use_cacheuse_cache,)hidden_statesresidualhidden_states# MLP 子层residualhidden_states hidden_statesself.post_attention_layernorm(hidden_states)hidden_statesself.mlp(hidden_states)hidden_statesresidualhidden_statesreturn(hidden_states,present_key_value)ifuse_cacheelse(hidden_states,)前向传播流程输入 x → LN1(x) → Attention(LN1(x)) → x Attention(...) → LN2(xAttention) → MLP(LN2(...)) → x Attention(...) MLP(...) → 输出3. LMDrive 特定适配3.1 LoRA 微调原理在 Transformer 的q_proj和v_proj层插入低秩适配器冻结主模型参数仅训练适配器。配置drive.py:145-155loraconfigLoraConfig(r16,# 秩lora_alpha32,# 缩放因子target_modules[q_proj,v_proj],lora_dropout0.05,biasnone,task_typeCAUSAL_LM)self.llm_modelget_peft_model(self.llm_model,loraconfig)优点参数量大幅减少仅约 0.1% 的参数参与训练避免灾难性遗忘训练速度快、显存占用低3.2 Waypoints 预测器原理将 LLM 最后一层隐藏状态映射为 5 个未来轨迹点10 维每点 x,y 坐标。实现drive.py:125-129self.waypoints_predictornn.Sequential(nn.Linear(self.llm_model.config.hidden_size,self.llm_model.config.hidden_size),nn.ReLU(),nn.Linear(self.llm_model.config.hidden_size,10)# 5 waypoints × (x,y))推理流程LLM隐藏状态 [B, L, 4096] → waypoints_predictor → [B, L, 10] → 提取有效帧位置 → [N, 10]GRU Decoder 变体原理采用自回归方式迭代预测 5 个轨迹点每步将前一预测点作为输入反馈给 GRUCell。实现drive.py:116-123self.waypoints_fcnn.Sequential(nn.Linear(self.llm_model.config.hidden_size,self.llm_model.config.hidden_size),nn.ReLU(),nn.Linear(self.llm_model.config.hidden_size,64))self.waypoints_predictornn.GRUCell(input_size2,hidden_size64)# 输入: (x,y)坐标self.waypoints_outputnn.Linear(64,2)推理流程drive.py:469-494waypoints_featureself.waypoints_fc(hidden_states.reshape(-1,4096))# [B*L, 64]xtorch.zeros(size(bs*n_tokens,2))# 初始位置output_wp[]for_inrange(5):# 迭代预测5个waypointwaypoints_featureself.waypoints_predictor(x,waypoints_feature)# GRUCell更新dxself.waypoints_output(waypoints_feature)# 预测位移xdxx# 累加位移output_wp.append(x)predicted_waypointstorch.cat(output_wp,dim1)# [B*L, 10]对比方法特点适用场景MLP单次前馈并行计算训练速度快简单场景GRU自回归迭代考虑时序依赖复杂轨迹需要平滑性3.3 End 预测器原理分类器判断当前帧是否为轨迹终点。实现drive.py:130-134self.end_predictornn.Sequential(nn.Linear(self.llm_model.config.hidden_size,self.llm_model.config.hidden_size),nn.ReLU(),nn.Linear(self.llm_model.config.hidden_size,2)# [continue, end])损失计算CrossEntropyLoss(predicted_end_prob, gt_end_flags)3.4 联合损失函数总损失drive.py:514self.waypoints_losstorch.nn.L1Loss()# L1损失轨迹点回归self.end_losstorch.nn.CrossEntropyLoss()# 交叉熵终点分类losswaypoints_lossend_loss*0.2# 加权联合损失损失权重设计Waypoints 损失L1Loss主要优化轨迹点精度权重为 1.0End 损失CrossEntropyLoss辅助优化终点判断权重为 0.2原因自动驾驶任务中轨迹点精度是首要目标终点判断是次要目标但对安全至关重要因此给予较小权重防止过度拟合。3.5 LLM 作为特征提取器关键设计LMDrive 中LlamaForCausalLM.forward返回hidden_states第1065行而非 logits。原因自动驾驶任务需要回归连续轨迹点而非生成文本直接使用隐藏状态更灵活可通过 MLP 头进行任意下游任务避免 LM Head 的冗余计算代码证据modeling_llama.py:1065returnhidden_states# 直接返回隐藏状态而非 logits4. 输入数据构造4.1 文本-图像特征拼接拼接策略drive.py:169-224llm_inputs[input_embeds[i][:text_len],# 文本前缀image_embeds[i].view(t*n,-1),# Q-Former精炼的视觉特征 (t帧 × n个query)input_embeds[i][text_len:]# 文本后缀(padding)]Attention Mask文本部分[1, 1, ..., 1]有效图像帧[1, 1, ..., 1]无效图像帧padding[0, 0, ..., 0]5. 完整推理流程┌─────────────────────────────────────────────────────────────────────────┐ │ LMDrive LLM 推理流程 │ ├─────────────────────────────────────────────────────────────────────────┤ │ │ │ 1. 视觉编码器提取特征 │ │ [多视角图像 LiDAR] → Memfuser → [B*t, 65, 768] │ │ │ │ 2. Q-Former 精炼特征 │ │ [B*t, 65, 768] → Q-Former → [B*t, 4, 768] → llm_proj → [B*t, 4, 4096]│ │ │ │ 3. 构造 LLM 输入 │ │ 文本嵌入 [B, L, 4096] 视觉特征 [B, t, 4, 4096] → [B, Lt*4, 4096] │ │ │ │ 4. LLM 前向传播 │ │ [B, Lt*4, 4096] → LlamaModel(32层Decoder) → [B, Lt*4, 4096] │ │ └── 每层: RMSNorm → Attention(RoPE) → Residual → RMSNorm → MLP → Residual│ │ │ │ 5. Waypoint 预测 │ │ hidden_states[有效位置] → waypoints_predictor → [N, 10] │ │ │ │ 6. End 预测 │ │ hidden_states[有效位置] → end_predictor → [N, 2] → argmax → [N] │ │ │ └─────────────────────────────────────────────────────────────────────────┘6. 关键技术参数参数LLaVA-7B说明hidden_size4096隐藏层维度num_hidden_layers32解码器层数num_attention_heads32注意力头数head_dim128每头维度intermediate_size11008MLP中间层维度max_position_embeddings2048最大序列长度vocab_size32000词表大小hidden_actsilu激活函数SwiGLU7. 总结LMDrive 的 LLM 核心推理单元基于LLaMA 架构包含以下关键组件LlamaRMSNormPre-Norm 归一化无均值中心化RoPE旋转位置编码支持动态 NTK 缩放因果自注意力支持 KV 缓存和 GQASwiGLU MLP平滑梯度流动DecoderLayerPre-Norm 残差架构在 LMDrive 中LLM 被用作特征提取器而非文本生成器通过LoRA 微调适配自动驾驶任务Waypoints 预测器生成 5 个未来轨迹点End 预测器判断轨迹终点这种设计充分利用了 LLM 的强大语义理解能力同时避免了文本生成的冗余计算高效适配自动驾驶的回归任务需求。

相关新闻