Shortcuts

lightrft.trainer.srm_trainer_vl

Trainer utilities for scalar reward models (vision-language capable).

This module provides a trainer that supports pairwise preference training with different loss functions, including Bradley-Terry (BT) loss and Human Preference Score (HPS) Loss. It integrates with the project’s Strategy abstraction for distributed training, gradient accumulation, and checkpointing, while optionally logging to Weights & Biases or TensorBoard.

class lightrft.trainer.srm_trainer_vl.SRMTrainerVL(model: torch.nn.Module, strategy, optim: torch.optim.Optimizer, train_dataloader, scheduler, tokenizer, eval_dataloader=None, max_epochs: int = 2, loss: str = 'sigmoid', margin: float = 0.1)[source]

Bases: object

Trainer for scalar vision-language reward models.

Parameters:
  • model (torch.nn.Module) – The model to be trained; expected to return a dict of head scores for each head type when called with token ids and optional image or video features.

  • strategy (Strategy) – Training strategy that manages distributed operations, gradient accumulation, checkpointing, logging helpers, and args.

  • optim (torch.optim.Optimizer) – Optimizer used for parameter updates.

  • train_dataloader (torch.utils.data.DataLoader) – Dataloader providing pairwise (A/B/Equal) batches.

  • scheduler (torch.optim.lr_scheduler._LRScheduler) – Learning rate scheduler.

  • tokenizer (Callable) – Tokenizer used to obtain pad token id for sequence padding when concatenating inputs.

  • eval_dataloader (Optional[torch.utils.data.DataLoader]) – Optional dataloader used for evaluation.

  • max_epochs (int) – Maximum number of training epochs.

  • loss (str) – Loss function to use. Choices are ‘sigmoid’ (PairWiseLoss), ‘logexp’ (LogExpLoss), and ‘hps’ (HPSLoss).

  • margin (float) – Margin value for BT loss (only used if loss is BT).

concatenated_forward(model, input0_ids, input0_mask, input1_ids, input1_mask, input0_img_pixels, input0_img_grid_thws, input1_img_pixels, input1_img_grid_thws, input0_video_pixels, input0_video_grid_thws, input1_video_pixels, input1_video_grid_thws) Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]][source]

Run the model once on concatenated chosen/rejected inputs.

Concatenation reduces the number of forward passes.

Parameters:
  • model (nn.Module) – Model to evaluate.

  • input0_ids (torch.Tensor) – Token ids for chosen sequences.

  • input0_mask (torch.Tensor) – Attention mask for chosen sequences.

  • input1_ids (torch.Tensor) – Token ids for rejected sequences.

  • input1_mask (torch.Tensor) – Attention mask for rejected sequences.

  • input0_img_pixels (Optional[torch.Tensor]) – Optional image features for chosen samples.

  • input0_img_grid_thws (Optional[torch.Tensor]) – Optional image grid meta for chosen.

  • input1_img_pixels (Optional[torch.Tensor]) – Optional image features for rejected samples.

  • input1_img_grid_thws (Optional[torch.Tensor]) – Optional image grid meta for rejected.

  • input0_video_pixels (Optional[torch.Tensor]) – Optional video features for chosen samples.

  • input0_video_grid_thws (Optional[torch.Tensor]) – Optional video grid meta for chosen.

  • input1_video_pixels (Optional[torch.Tensor]) – Optional video features for rejected samples.

  • input1_video_grid_thws (Optional[torch.Tensor]) – Optional video grid meta for rejected.

Returns:

Tuple of dicts (scores0, scores1) separating the outputs for chosen and rejected samples.

Return type:

Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]

concatenated_inputs(input0_ids, input0_mask, input1_ids, input1_mask) Tuple[torch.Tensor, torch.Tensor][source]

Concatenate two inputs into a single batch.

Parameters:
  • input0_ids (torch.Tensor) – Token ids for the first sequences, shape (N, Lc).

  • input0_mask (torch.Tensor) – Attention mask for chosen sequences, shape (N, Lc).

  • input1_ids (torch.Tensor) – Token ids for the second sequences, shape (N, Lr).

  • input1_mask (torch.Tensor) – Attention mask for the second sequences, shape (N, Lr).

Returns:

Tuple (input_ids, att_masks) where inputs are padded to a common max length across input0 and input1, then concatenated along the batch dimension to shape (2N, Lmax).

Return type:

Tuple[torch.Tensor, torch.Tensor]

evaluate(args, eval_dataloader, steps=0) None[source]

Evaluate the model on the provided dataloader and write a JSONL of scores to the save path indicated by strategy.args.save_path. Also calculates and logs accuracy metrics to Weights & Biases or TensorBoard.

Parameters:
  • args (Any) – present for API compatibility with callers.

  • eval_dataloader (torch.utils.data.DataLoader) – Dataloader for evaluation samples.

  • steps (int) – Global step id for naming the output file.

Returns:

None

Return type:

NoneType

The output file name format is eval_scores_{steps}.jsonl.

fit(args, consumed_samples=0, num_update_steps_per_epoch=None) None[source]

Train the model for max_epochs using the provided dataloaders.

Parameters:
  • args (Any) – Training arguments (typically strategy.args) including logging/eval/save intervals and batch sizing.

  • consumed_samples (int) – Number of samples already consumed (for resuming training), defaults to 0.

  • num_update_steps_per_epoch (Optional[int]) – Number of optimizer steps per epoch. If None, it’s inferred externally and passed in by caller.

Returns:

None

Return type:

NoneType

Notes:
  • Supports HPS Scale training via scale_for_train flag in args.

  • Logs to Weights & Biases or TensorBoard when configured.

save_logs_and_checkpoints(args, global_step, step_bar, logs_dict={}, client_states={}) None[source]

Log metrics and optionally run evaluation and checkpointing.

Parameters:
  • args (Any) – Training arguments providing logging/eval/save intervals and checkpoint configurations.

  • global_step (int) – Current global optimization step.

  • step_bar (tqdm.tqdm) – Progress bar for step-level updates.

  • logs_dict (Dict[str, float]) – Dictionary of metrics to log (will be reduced across ranks via strategy.all_reduce before display/logging).

  • client_states (Dict[str, Any]) – Extra state to persist with checkpoints (e.g., consumed samples).

Returns:

None

Return type:

NoneType