lightrft.trainer.srm_trainer_vl¶
Trainer utilities for scalar reward models (vision-language capable).
This module provides a trainer that supports pairwise preference training with different loss functions, including Bradley-Terry (BT) loss and Human Preference Score (HPS) Loss. It integrates with the project’s Strategy abstraction for distributed training, gradient accumulation, and checkpointing, while optionally logging to Weights & Biases or TensorBoard.
- class lightrft.trainer.srm_trainer_vl.SRMTrainerVL(model: torch.nn.Module, strategy, optim: torch.optim.Optimizer, train_dataloader, scheduler, tokenizer, eval_dataloader=None, max_epochs: int = 2, loss: str = 'sigmoid', margin: float = 0.1)[source]¶
Bases:
objectTrainer for scalar vision-language reward models.
- Parameters:
model (torch.nn.Module) – The model to be trained; expected to return a dict of head scores for each head type when called with token ids and optional image or video features.
strategy (Strategy) – Training strategy that manages distributed operations, gradient accumulation, checkpointing, logging helpers, and args.
optim (torch.optim.Optimizer) – Optimizer used for parameter updates.
train_dataloader (torch.utils.data.DataLoader) – Dataloader providing pairwise (A/B/Equal) batches.
scheduler (torch.optim.lr_scheduler._LRScheduler) – Learning rate scheduler.
tokenizer (Callable) – Tokenizer used to obtain pad token id for sequence padding when concatenating inputs.
eval_dataloader (Optional[torch.utils.data.DataLoader]) – Optional dataloader used for evaluation.
max_epochs (int) – Maximum number of training epochs.
loss (str) – Loss function to use. Choices are ‘sigmoid’ (PairWiseLoss), ‘logexp’ (LogExpLoss), and ‘hps’ (HPSLoss).
margin (float) – Margin value for BT loss (only used if
lossis BT).
- concatenated_forward(model, input0_ids, input0_mask, input1_ids, input1_mask, input0_img_pixels, input0_img_grid_thws, input1_img_pixels, input1_img_grid_thws, input0_video_pixels, input0_video_grid_thws, input1_video_pixels, input1_video_grid_thws) Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]][source]¶
Run the model once on concatenated chosen/rejected inputs.
Concatenation reduces the number of forward passes.
- Parameters:
model (nn.Module) – Model to evaluate.
input0_ids (torch.Tensor) – Token ids for chosen sequences.
input0_mask (torch.Tensor) – Attention mask for chosen sequences.
input1_ids (torch.Tensor) – Token ids for rejected sequences.
input1_mask (torch.Tensor) – Attention mask for rejected sequences.
input0_img_pixels (Optional[torch.Tensor]) – Optional image features for chosen samples.
input0_img_grid_thws (Optional[torch.Tensor]) – Optional image grid meta for chosen.
input1_img_pixels (Optional[torch.Tensor]) – Optional image features for rejected samples.
input1_img_grid_thws (Optional[torch.Tensor]) – Optional image grid meta for rejected.
input0_video_pixels (Optional[torch.Tensor]) – Optional video features for chosen samples.
input0_video_grid_thws (Optional[torch.Tensor]) – Optional video grid meta for chosen.
input1_video_pixels (Optional[torch.Tensor]) – Optional video features for rejected samples.
input1_video_grid_thws (Optional[torch.Tensor]) – Optional video grid meta for rejected.
- Returns:
Tuple of dicts
(scores0, scores1)separating the outputs for chosen and rejected samples.- Return type:
Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]
- concatenated_inputs(input0_ids, input0_mask, input1_ids, input1_mask) Tuple[torch.Tensor, torch.Tensor][source]¶
Concatenate two inputs into a single batch.
- Parameters:
input0_ids (torch.Tensor) – Token ids for the first sequences, shape
(N, Lc).input0_mask (torch.Tensor) – Attention mask for chosen sequences, shape
(N, Lc).input1_ids (torch.Tensor) – Token ids for the second sequences, shape
(N, Lr).input1_mask (torch.Tensor) – Attention mask for the second sequences, shape
(N, Lr).
- Returns:
Tuple
(input_ids, att_masks)where inputs are padded to a common max length across input0 and input1, then concatenated along the batch dimension to shape(2N, Lmax).- Return type:
Tuple[torch.Tensor, torch.Tensor]
- evaluate(args, eval_dataloader, steps=0) None[source]¶
Evaluate the model on the provided dataloader and write a JSONL of scores to the save path indicated by
strategy.args.save_path. Also calculates and logs accuracy metrics to Weights & Biases or TensorBoard.- Parameters:
args (Any) – present for API compatibility with callers.
eval_dataloader (torch.utils.data.DataLoader) – Dataloader for evaluation samples.
steps (int) – Global step id for naming the output file.
- Returns:
None
- Return type:
NoneType
The output file name format is
eval_scores_{steps}.jsonl.
- fit(args, consumed_samples=0, num_update_steps_per_epoch=None) None[source]¶
Train the model for
max_epochsusing the provided dataloaders.- Parameters:
args (Any) – Training arguments (typically
strategy.args) including logging/eval/save intervals and batch sizing.consumed_samples (int) – Number of samples already consumed (for resuming training), defaults to
0.num_update_steps_per_epoch (Optional[int]) – Number of optimizer steps per epoch. If
None, it’s inferred externally and passed in by caller.
- Returns:
None
- Return type:
NoneType
- Notes:
Supports HPS Scale training via
scale_for_trainflag in args.Logs to Weights & Biases or TensorBoard when configured.
- save_logs_and_checkpoints(args, global_step, step_bar, logs_dict={}, client_states={}) None[source]¶
Log metrics and optionally run evaluation and checkpointing.
- Parameters:
args (Any) – Training arguments providing logging/eval/save intervals and checkpoint configurations.
global_step (int) – Current global optimization step.
step_bar (tqdm.tqdm) – Progress bar for step-level updates.
logs_dict (Dict[str, float]) – Dictionary of metrics to log (will be reduced across ranks via
strategy.all_reducebefore display/logging).client_states (Dict[str, Any]) – Extra state to persist with checkpoints (e.g., consumed samples).
- Returns:
None
- Return type:
NoneType