news 2026/4/23 16:07:40

【LLaVA-NeXT】LLaVATrainer说明

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
【LLaVA-NeXT】LLaVATrainer说明

LLaVATrainer

classllava.train.llava_trainer.LLaVATrainer(Trainer)

用于训练 LLaVA (Large Language and Vision Assistant) 多模态模型的训练器类,继承自transformers.Trainer

该类在标准 Transformer Trainer 基础上扩展了以下功能:

  • 支持MeZO (Memory-efficient Zeroth-Order Optimization)零阶优化训练模式
  • 提供多种基于长度和模态的数据采样策略
  • 支持DeepSpeedFSDP分布式训练
  • 提供针对多模态适配器 (MM Adapter) 的特定检查点保存功能

参数

该类接受所有transformers.Trainer支持的关键字参数,同时支持以下额外参数(通过args传入):

参数类型默认值描述
trainer_modestr"regular"训练模式。可选"regular"(常规反向传播训练)或"zo"(MeZO 零阶优化训练)。
zo_epsfloat1e-3MeZO 超参数 epsilon,控制参数扰动的幅度。
zo_num_directionsint1MeZO 优化中使用的随机方向数量。
group_by_lengthboolFalse是否按序列长度分组采样。
group_by_modality_lengthboolFalse是否按模态长度分组采样。
group_by_modality_length_autoboolFalse是否使用自动模态长度分组采样。
group_by_varlenboolFalse是否使用可变长度分组采样。
mm_projector_lrfloat,optionalNone多模态投影层的独立学习率。
mm_vision_tower_lrfloat,optionalNone视觉编码器的独立学习率。

属性

属性类型描述
trainer_modestr当前训练模式("regular""zo")。
zo_epsfloatMeZO epsilon 超参数。
zo_num_directionsintMeZO 随机方向数量。
trainable_paramsList[Tuple[str, Parameter]]可训练参数列表,包含参数名称和参数本身。
mezo_update_historyList[Dict]MeZO 更新历史记录,用于检查点恢复。

方法

zo_perturb_parameters

zo_perturb_parameters(scaling_factor:float=1.0)->None

使用随机向量z zz扰动模型参数。

参数:

  • scaling_factor(float) – 扰动的缩放因子。正值表示正向扰动,负值表示反向扰动。

示例:

# 正向扰动trainer.zo_perturb_parameters(scaling_factor=1.0)# 反向扰动(恢复原始参数后再扰动)trainer.zo_perturb_parameters(scaling_factor=-2.0)

zo_forward

zo_forward(model:nn.Module,inputs:Dict)->torch.Tensor

在推理模式下计算前向传播损失。

参数:

  • model(nn.Module) – 需要计算损失的模型。
  • inputs(Dict) – 输入批次数据。

返回:

  • torch.Tensor– 计算得到的损失值(已 detach)。

zo_step

zo_step(model:nn.Module,inputs:Dict)->torch.Tensor

使用 MeZO 算法执行单步梯度估计。通过正向和反向扰动的损失差来近似梯度。

参数:

  • model(nn.Module) – 模型实例。
  • inputs(Dict) – 输入批次数据。

返回:

  • torch.Tensor– 归一化后的损失值。

注意事项:

该方法在gradient_accumulation_steps期间累积多个方向的梯度估计,在zo_update中统一应用。


zo_update

zo_update(learning_rate:float)->None

根据累积的梯度估计更新模型参数。

参数:

  • learning_rate(float) – 当前学习率。

注意事项:

  • 该方法自动处理 weight decay
  • biaslayer_normlayernorm参数不会应用 weight decay
  • 调用后会清空累积的梯度估计

save_model

save_model(output_dir:Optional[str]=None,_internal_call:bool=False)

保存模型检查点。当使用 MeZO 模式时,会额外保存轻量级的 MeZO 状态检查点。

参数:

  • output_dir(str,optional) – 保存路径。默认使用args.output_dir
  • _internal_call(bool) – 是否为内部调用。

_save_checkpoint

_save_checkpoint(model,trial,metrics=None)->None

保存训练检查点。该方法重写了父类的检查点保存逻辑,以支持仅保存多模态适配器 (MM Adapter) 权重的场景。

参数:

  • model– 需要保存的模型实例。
  • trial– 超参数搜索试验对象(用于确定输出目录)。
  • metrics(Dict,optional) – 评估指标字典。

行为说明:

当满足以下任一条件时,仅保存适配器权重:

  • args.tune_mm_mlp_adapter=True
  • args.mm_tunable_parts仅包含"mm_mlp_adapter""mm_vision_resampler"

在这种情况下,会保存:

  • 模型配置文件 (config.json)
  • 适配器权重文件 (mm_projector.bin)

保存的权重包括:

  • mm_projector相关参数
  • vision_resampler相关参数
  • 如果use_im_start_end=True,还包括embed_tokensembed_in

其他情况下,调用父类Trainer._save_checkpoint()进行完整模型保存。

注意事项:

  • 该方法支持DeepSpeed ZeRO-3模式,会正确收集分布在多个 GPU 上的参数
  • 仅在主进程(local_rank == 0local_rank == -1)上执行实际的保存操作

示例:

# 仅微调 MM Adapter 时的配置training_args.tune_mm_mlp_adapter=True# 或者通过 mm_tunable_parts 指定training_args.mm_tunable_parts="mm_mlp_adapter"# 训练过程中的检查点将只包含适配器权重# 保存路径示例: output_dir/checkpoint-1000/mm_projector.bin

create_optimizer

create_optimizer()->torch.optim.Optimizer

创建优化器。支持为不同模块设置独立学习率(如mm_projectorvision_tower)。

返回:

  • torch.optim.Optimizer– 配置好的优化器实例。

注意事项:

在 MeZO 模式下,会创建一个虚拟优化器(dummy optimizer),实际参数更新由zo_update方法执行。


get_train_dataloader

get_train_dataloader()->DataLoader

创建并返回训练数据加载器。

返回:

  • torch.utils.data.DataLoader– 训练数据加载器。

示例

基本使用

fromllava.train.llava_trainerimportLLaVATrainerfromtransformersimportTrainingArguments# 配置训练参数training_args=TrainingArguments(output_dir="./output",per_device_train_batch_size=4,gradient_accumulation_steps=8,learning_rate=2e-5,num_train_epochs=1,group_by_modality_length=True,# 启用模态长度分组)# 创建训练器trainer=LLaVATrainer(model=model,tokenizer=tokenizer,args=training_args,train_dataset=train_dataset,data_collator=data_collator,)# 开始训练trainer.train()# 保存模型trainer.save_model("./final_model")

使用 MeZO 训练模式

fromllava.train.llava_trainerimportLLaVATrainer# 配置 MeZO 相关参数training_args.trainer_mode="zo"training_args.zo_eps=1e-3training_args.zo_num_directions=1# 创建训练器trainer=LLaVATrainer(model=model,tokenizer=tokenizer,args=training_args,train_dataset=train_dataset,)# MeZO 模式训练trainer.train()

设置模块独立学习率

# 为多模态投影层和视觉编码器设置独立学习率training_args.mm_projector_lr=1e-4training_args.mm_vision_tower_lr=2e-6trainer=LLaVATrainer(model=model,args=training_args,...)

参见

  • transformers.Trainer– 基类文档
  • LengthGroupedSampler– 长度分组采样器
  • LLaVADPOTrainer– 用于 DPO (Direct Preference Optimization) 训练的变体
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/23 11:35:15

计算机Java毕设实战-基于springboot的医药药品管理系统【完整源码+LW+部署说明+演示视频,全bao一条龙等】

博主介绍:✌️码农一枚 ,专注于大学生项目实战开发、讲解和毕业🚢文撰写修改等。全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围:&am…

作者头像 李华
网站建设 2026/4/23 13:19:44

实用丨维普AIGC降AI工具推荐 + 操作顺序

维普AIGC检测高?6款工具帮你降到合格线 TL;DR:维普AIGC检测算法和知网不同,很多知网能过的工具在维普可能过不了。实测对维普效果最好的是嘎嘎降AI(67%→9%),其次是比话降AI(60%→12%&#xff0…

作者头像 李华
网站建设 2026/4/23 14:43:26

React 高阶组件

作为一名前端工程师,日常开发中我们总会遇到组件逻辑复用的需求。在 React Hooks 出现之前,高阶组件(Higher-Order Component,简称 HOC)是实现这一需求的核心方案之一;即便在 Hooks 普及的当下,…

作者头像 李华
网站建设 2026/4/23 14:39:57

【课程设计/毕业设计】基于SpringBoot与Web的数学库组卷系统设计与实现基于springboot的小学数学错题管理及推荐系统设计与实现【附源码、数据库、万字文档】

博主介绍:✌️码农一枚 ,专注于大学生项目实战开发、讲解和毕业🚢文撰写修改等。全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围:&am…

作者头像 李华
网站建设 2026/4/18 12:40:34

Java计算机毕设之基于Springboot在线错题本管理系统springboot的小学数学错题管理及推荐系统设计与实现(完整前后端代码+说明文档+LW,调试定制等)

博主介绍:✌️码农一枚 ,专注于大学生项目实战开发、讲解和毕业🚢文撰写修改等。全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战 ✌️技术范围:&am…

作者头像 李华