lightrft.trainer.grm_trainer_vl¶
Trainer for generative reward models (vision-language capable).
This module contains a trainer that optimizes a generative reward model using
next-token prediction loss (GPTLMLoss). It integrates with the Strategy
abstraction for distributed training, gradient accumulation, checkpointing, and
logging via Weights & Biases or TensorBoard.
- class lightrft.trainer.grm_trainer_vl.GRMTrainerVL(model: torch.nn.Module, strategy, optim: torch.optim.Optimizer, train_dataloader, scheduler, tokenizer, eval_dataloader=None, max_epochs: int = 2, loss: str = 'GPTLMLoss')[source]¶
Bases:
objectTrainer for generative reward models.
- Parameters:
model (torch.nn.Module) – The model to be trained. Expected to return logits for next-token prediction when called with token ids and optional image or video features.
strategy (Strategy) – The training strategy to apply, handling distributed setup, gradient accumulation, logging and checkpointing.
optim (torch.optim.Optimizer) – Optimizer to use during training.
train_dataloader (torch.utils.data.DataLoader) – Dataloader for the training dataset.
scheduler (torch.optim.lr_scheduler._LRScheduler) – Learning rate scheduler for dynamic adjustments during training.
tokenizer (Callable) – Tokenizer for input data (used when padding/processing sequences as needed by some helpers).
eval_dataloader (Optional[torch.utils.data.DataLoader]) – Dataloader for the evaluation dataset.
max_epochs (int) – Maximum number of training epochs.
loss (str) – The loss function selector.
- evaluate(args, eval_dataloader, steps: int = 0) None[source]¶
Evaluate the model on the provided dataloader by generating text responses and saving them to a JSON file. This method handles distributed gathering of generated text, extracted assistant responses, and extra metadata across all processes, with the rank 0 process writing the final results to disk. Evaluation results and metrics are also logged to Weights & Biases or TensorBoard if configured.
- Parameters:
args (Any) – Training arguments containing generation configurations.
eval_dataloader (torch.utils.data.DataLoader) – Dataloader for evaluation samples.
steps (int) – Global step id for logging and naming the output file.
- Returns:
None
- Return type:
NoneType
- fit(args, consumed_samples=0, num_update_steps_per_epoch=None) None[source]¶
Main training loop for generative reward model.
- Parameters:
args (argparse.Namespace) – Training arguments containing hyperparameters and configurations.
consumed_samples (int) – Number of samples already consumed (for resuming).
num_update_steps_per_epoch (Optional[int]) – Number of update steps per epoch. Used for eval scheduling, determining epoch boundaries, and resuming training from checkpoints.
- save_logs_and_checkpoints(args, global_step: int, step_bar, logs_dict: Dict[str, float] = {}, client_states: Dict[str, Any] = {}) None[source]¶
Log metrics and optionally run evaluation and checkpointing.
- Parameters:
args (Any) – Training arguments providing logging/eval/save intervals and checkpoint configurations.
global_step (int) – Current global optimization step.
step_bar (tqdm.tqdm) – Progress bar for step-level updates.
logs_dict (Dict[str, float]) – Dictionary of metrics to log.
client_states (Dict[str, Any]) – Extra state to persist with checkpoints (e.g., consumed samples).
- Returns:
None
- Return type:
NoneType