Shortcuts

lightrft.trainer.grm_trainer_vl

Trainer for generative reward models (vision-language capable).

This module contains a trainer that optimizes a generative reward model using next-token prediction loss (GPTLMLoss). It integrates with the Strategy abstraction for distributed training, gradient accumulation, checkpointing, and logging via Weights & Biases or TensorBoard.

class lightrft.trainer.grm_trainer_vl.GRMTrainerVL(model: torch.nn.Module, strategy, optim: torch.optim.Optimizer, train_dataloader, scheduler, tokenizer, eval_dataloader=None, max_epochs: int = 2, loss: str = 'GPTLMLoss')[source]

Bases: object

Trainer for generative reward models.

Parameters:
  • model (torch.nn.Module) – The model to be trained. Expected to return logits for next-token prediction when called with token ids and optional image or video features.

  • strategy (Strategy) – The training strategy to apply, handling distributed setup, gradient accumulation, logging and checkpointing.

  • optim (torch.optim.Optimizer) – Optimizer to use during training.

  • train_dataloader (torch.utils.data.DataLoader) – Dataloader for the training dataset.

  • scheduler (torch.optim.lr_scheduler._LRScheduler) – Learning rate scheduler for dynamic adjustments during training.

  • tokenizer (Callable) – Tokenizer for input data (used when padding/processing sequences as needed by some helpers).

  • eval_dataloader (Optional[torch.utils.data.DataLoader]) – Dataloader for the evaluation dataset.

  • max_epochs (int) – Maximum number of training epochs.

  • loss (str) – The loss function selector.

evaluate(args, eval_dataloader, steps: int = 0) None[source]

Evaluate the model on the provided dataloader by generating text responses and saving them to a JSON file. This method handles distributed gathering of generated text, extracted assistant responses, and extra metadata across all processes, with the rank 0 process writing the final results to disk. Evaluation results and metrics are also logged to Weights & Biases or TensorBoard if configured.

Parameters:
  • args (Any) – Training arguments containing generation configurations.

  • eval_dataloader (torch.utils.data.DataLoader) – Dataloader for evaluation samples.

  • steps (int) – Global step id for logging and naming the output file.

Returns:

None

Return type:

NoneType

fit(args, consumed_samples=0, num_update_steps_per_epoch=None) None[source]

Main training loop for generative reward model.

Parameters:
  • args (argparse.Namespace) – Training arguments containing hyperparameters and configurations.

  • consumed_samples (int) – Number of samples already consumed (for resuming).

  • num_update_steps_per_epoch (Optional[int]) – Number of update steps per epoch. Used for eval scheduling, determining epoch boundaries, and resuming training from checkpoints.

save_logs_and_checkpoints(args, global_step: int, step_bar, logs_dict: Dict[str, float] = {}, client_states: Dict[str, Any] = {}) None[source]

Log metrics and optionally run evaluation and checkpointing.

Parameters:
  • args (Any) – Training arguments providing logging/eval/save intervals and checkpoint configurations.

  • global_step (int) – Current global optimization step.

  • step_bar (tqdm.tqdm) – Progress bar for step-level updates.

  • logs_dict (Dict[str, float]) – Dictionary of metrics to log.

  • client_states (Dict[str, Any]) – Extra state to persist with checkpoints (e.g., consumed samples).

Returns:

None

Return type:

NoneType