Source code for lightrft.trainer.srm_trainer_al
"""
Trainer utilities for scalar reward models (audio-language capable).
This module provides a trainer that supports pairwise preference training with
different loss functions, including Bradley-Terry (BT) loss and Human Preference
Score (HPS) Loss. It integrates with the project's Strategy abstraction for distributed
training, gradient accumulation, and checkpointing, while optionally logging to
Weights & Biases or TensorBoard.
"""
import os
import json
from tqdm import tqdm
from typing import Tuple, Dict
import torch
import torch.nn as nn
from torch.optim import Optimizer
import torch.distributed as dist
from lightrft.models import LogExpLoss, LogSigmoidLoss, HPSLoss, pad_to_length
from lightrft.utils import DistributedSampler, all_gather_and_flatten, all_reduce_dict
[docs]class SRMTrainerAL:
"""
Trainer for scalar audio-language reward models.
:param model: The model to be trained; expected to return a dict of head
scores for each head type when called with token ids and audio features.
:type model: torch.nn.Module
:param strategy: Training strategy that manages distributed operations,
gradient accumulation, checkpointing, logging helpers, and args.
:type strategy: Strategy
:param optim: Optimizer used for parameter updates.
:type optim: torch.optim.Optimizer
:param train_dataloader: Dataloader providing pairwise (A/B/Equal) batches.
:type train_dataloader: torch.utils.data.DataLoader
:param scheduler: Learning rate scheduler.
:type scheduler: torch.optim.lr_scheduler._LRScheduler
:param tokenizer: Tokenizer used to obtain pad token id for sequence
padding when concatenating inputs.
:type tokenizer: Callable
:param eval_dataloader: Optional dataloader used for evaluation.
:type eval_dataloader: Optional[torch.utils.data.DataLoader]
:param max_epochs: Maximum number of training epochs.
:type max_epochs: int
:param loss: Loss function to use. Choices are 'sigmoid' (PairWiseLoss),
'logexp' (LogExpLoss), and 'hps' (HPSLoss).
:type loss: str
:param margin: Margin value for BT loss (only used if ``loss`` is BT).
:type margin: float
"""
def __init__(
self,
model: nn.Module,
strategy,
optim: Optimizer,
train_dataloader,
scheduler,
tokenizer,
eval_dataloader=None,
max_epochs: int = 2,
loss: str = "sigmoid",
margin: float = 0.1,
) -> None:
self.strategy = strategy
self.epochs = max_epochs
self.model = model
self.train_dataloader = train_dataloader
self.eval_dataloader = eval_dataloader
self.scheduler = scheduler
self.optimizer = optim
self.tokenizer = tokenizer
self.margin = margin
self.args = strategy.args
if loss == "sigmoid":
self.loss = "sigmoid"
self.loss_fn = LogSigmoidLoss()
self.strategy.print("LogSigmoid Loss")
elif loss == "logexp":
self.loss = "logexp"
self.loss_fn = LogExpLoss()
self.strategy.print("LogExp Loss")
elif loss == "hps":
self.loss = "hps"
self.loss_fn = HPSLoss()
self.strategy.print("HPS Loss")
else:
raise ValueError(f"invalid loss type: {loss}")
# wandb/tensorboard setting
self._wandb = None
self._tensorboard = None
if self.strategy.args.use_wandb and self.strategy.is_rank_0():
import wandb
self._wandb = wandb
if not wandb.api.api_key:
wandb.login(key=strategy.args.use_wandb)
wandb.init(
entity=strategy.args.wandb_org,
project=strategy.args.wandb_project,
group=strategy.args.wandb_group,
name=strategy.args.wandb_run_name,
config=strategy.args.__dict__,
reinit=True,
)
wandb.define_metric("train/global_step")
wandb.define_metric("train/*", step_metric="train/global_step", step_sync=True)
wandb.define_metric("eval/global_step")
wandb.define_metric("eval/*", step_metric="eval/global_step", step_sync=True)
# Initialize TensorBoard writer if wandb is not available
if self.strategy.args.use_tensorboard and self._wandb is None and self.strategy.is_rank_0():
from torch.utils.tensorboard import SummaryWriter
os.makedirs(self.strategy.args.use_tensorboard, exist_ok=True)
log_dir = os.path.join(self.strategy.args.use_tensorboard, strategy.args.wandb_run_name)
self._tensorboard = SummaryWriter(log_dir=log_dir)
[docs] def fit(self, args, consumed_samples=0, num_update_steps_per_epoch=None) -> None:
"""
Train the model for ``max_epochs`` using the provided dataloaders.
:param args: Training arguments (typically ``strategy.args``) including
logging/eval/save intervals and batch sizing.
:type args: Any
:param consumed_samples: Number of samples already consumed (for
resuming training), defaults to ``0``.
:type consumed_samples: int
:param num_update_steps_per_epoch: Number of optimizer steps per epoch.
If ``None``, it's inferred externally and passed in by caller.
:type num_update_steps_per_epoch: Optional[int]
:returns: None
:rtype: NoneType
Notes:
- Supports HPS Scale training via ``scale_for_train`` flag in args.
- Logs to Weights & Biases or TensorBoard when configured.
"""
# get eval and save steps
if args.eval_steps == -1:
args.eval_steps = num_update_steps_per_epoch # Evaluate once per epoch
if args.save_steps == -1:
args.save_steps = float("inf") # do not save ckpt
# Restore step and start_epoch
step = consumed_samples // args.train_batch_size * self.strategy.accumulated_gradient + 1
start_epoch = consumed_samples // args.train_batch_size // num_update_steps_per_epoch
consumed_samples = consumed_samples % (num_update_steps_per_epoch * args.train_batch_size)
epoch_bar = tqdm(range(start_epoch, self.epochs), desc="Train epoch", disable=not self.strategy.is_rank_0())
head_types = self.model.head_types
loss_mean = {}
acc_mean = {}
acc = {}
total_loss_mean = 0.0
for head_type in head_types:
loss_mean[head_type] = 0
acc_mean[head_type] = 0
acc[head_type] = 0
for epoch in range(start_epoch, self.epochs):
if isinstance(self.train_dataloader.sampler, DistributedSampler):
self.train_dataloader.sampler.set_epoch(
epoch, consumed_samples=0 if epoch > start_epoch else consumed_samples
)
step_bar = tqdm(
range(self.train_dataloader.__len__()),
desc="Train step of epoch %d" % epoch,
disable=not self.strategy.is_rank_0(),
)
self.model.train()
scale_for_train = args.scale_for_train
for data in self.train_dataloader:
(
input0_ids,
input0_mask,
input1_ids,
input1_mask,
input0_input_features,
input0_feature_attention_mask,
input1_input_features,
input1_feature_attention_mask,
extras,
) = data
device = torch.cuda.current_device()
input0_ids = input0_ids.squeeze(1).to(device)
input0_mask = input0_mask.squeeze(1).to(device)
input1_ids = input1_ids.squeeze(1).to(device)
input1_mask = input1_mask.squeeze(1).to(device)
input0_input_features = input0_input_features.to(device)
input0_feature_attention_mask = input0_feature_attention_mask.to(device)
input1_input_features = input1_input_features.to(device)
input1_feature_attention_mask = input1_feature_attention_mask.to(device)
scores0, scores1 = self.concatenated_forward(
self.model,
input0_ids,
input0_mask,
input1_ids,
input1_mask,
input0_input_features,
input0_feature_attention_mask,
input1_input_features,
input1_feature_attention_mask,
)
labels = {}
for head_type in head_types:
labels[head_type] = [e[head_type] if head_type in e else "C" for e in extras]
chosens = {}
rejects = {}
equals = {}
for head_type in head_types:
chosens[head_type] = []
rejects[head_type] = []
equals[head_type] = []
for i in range(len(extras)):
for head_type in head_types:
label = labels[head_type][i]
if label == "A":
chosens[head_type].append(scores0[head_type][i])
rejects[head_type].append(scores1[head_type][i])
elif label == "B":
chosens[head_type].append(scores1[head_type][i])
rejects[head_type].append(scores0[head_type][i])
else:
equals[head_type].append([scores0[head_type][i], scores1[head_type][i]])
for head_type in head_types:
if len(chosens[head_type]) > 0:
chosens[head_type] = torch.stack(chosens[head_type])
rejects[head_type] = torch.stack(rejects[head_type])
equals_loss = {}
for head_type in head_types:
if len(equals[head_type]) > 0:
equals[head_type] = torch.stack([torch.stack(t) for t in equals[head_type]])
equal_loss = torch.abs(equals[head_type][:, 0] - equals[head_type][:, 1]).mean()
equals_loss[head_type] = equal_loss
else:
equals_loss[head_type] = torch.tensor(0.0, device=device)
head_loss = {}
for head_type in head_types:
if len(chosens[head_type]) == 0:
continue
# Compute per head loss based on the selected loss function
head_loss[head_type] = self.loss_fn(chosens[head_type], rejects[head_type], self.margin)
total_loss = sum(head_loss.values()) + 0.01 * sum(equals_loss.values())
self.strategy.backward(total_loss, self.model, self.optimizer)
self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler)
for head_type in head_types:
if len(chosens[head_type]) == 0:
continue
acc[head_type] = (chosens[head_type] > rejects[head_type]).float().mean().item()
acc_mean[head_type] = acc_mean[head_type] * 0.9 + 0.1 * acc[head_type]
loss_mean[head_type] = loss_mean[head_type] * 0.9 + 0.1 * head_loss[head_type].item()
total_loss_mean = total_loss_mean * 0.9 + 0.1 * total_loss.item()
logs_dict = {
"loss": total_loss.item(),
"loss_mean": total_loss_mean,
"lr": self.scheduler.get_last_lr()[0],
}
for head_type in head_types:
if len(chosens[head_type]) > 0:
logs_dict[f"{head_type}_loss"] = head_loss[head_type].item()
logs_dict[f"{head_type}_acc"] = acc[head_type]
logs_dict[f"{head_type}_acc_mean"] = acc_mean[head_type]
logs_dict[f"{head_type}_loss_mean"] = loss_mean[head_type]
logs_dict[f"{head_type}_chosen_reward"] = (
round(chosens[head_type].mean().item() *
0.07, 4) if scale_for_train else chosens[head_type].mean().item()
)
logs_dict[f"{head_type}_reject_reward"] = (
round(rejects[head_type].mean().item() *
0.07, 4) if scale_for_train else rejects[head_type].mean().item()
)
else:
logs_dict[f"{head_type}_loss"] = 0.0
logs_dict[f"{head_type}_acc"] = 0.0
logs_dict[f"{head_type}_acc_mean"] = 0.0
logs_dict[f"{head_type}_loss_mean"] = 0.0
logs_dict[f"{head_type}_chosen_reward"] = 0.0
logs_dict[f"{head_type}_reject_reward"] = 0.0
# step bar
for k in logs_dict.keys():
if k.startswith("preference"):
logs_dict[k] = self.strategy.all_reduce(logs_dict[k], op="max")
else:
logs_dict[k] = self.strategy.all_reduce(logs_dict[k])
step_bar.set_postfix(logs_dict)
step_bar.update()
# logs/checkpoints/evaluation
if step % self.strategy.accumulated_gradient == 0:
global_step = step // self.strategy.accumulated_gradient
client_states = {"consumed_samples": global_step * args.train_batch_size}
self.save_logs_and_checkpoints(args, global_step, step_bar, logs_dict, client_states)
step += 1
epoch_bar.update()
if self._wandb is not None and self.strategy.is_rank_0():
self._wandb.finish()
if self._tensorboard is not None and self.strategy.is_rank_0():
self._tensorboard.close()
# logs/checkpoints/evaluate
[docs] def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, client_states={}) -> None:
"""
Log metrics and optionally run evaluation and checkpointing.
:param args: Training arguments providing logging/eval/save intervals
and checkpoint configurations.
:type args: Any
:param global_step: Current global optimization step.
:type global_step: int
:param step_bar: Progress bar for step-level updates.
:type step_bar: tqdm.tqdm
:param logs_dict: Dictionary of metrics to log (will be reduced across
ranks via ``strategy.all_reduce`` before display/logging).
:type logs_dict: Dict[str, float]
:param client_states: Extra state to persist with checkpoints (e.g.,
consumed samples).
:type client_states: Dict[str, Any]
:returns: None
:rtype: NoneType
"""
if global_step % args.logging_steps == 0:
# wandb
if self._wandb is not None and self.strategy.is_rank_0():
logs = {"train/%s" % k: v for k, v in {**logs_dict, "global_step": global_step}.items()}
self._wandb.log(logs)
# TensorBoard
elif self._tensorboard is not None and self.strategy.is_rank_0():
for k, v in logs_dict.items():
self._tensorboard.add_scalar(f"train/{k}", v, global_step)
# eval
if global_step % args.eval_steps == 0:
# do eval when len(dataloader) > 0, avoid zero division in eval.
if self.eval_dataloader and len(self.eval_dataloader) > 0:
# Pass args first to match evaluate signature (args, dataloader, steps)
self.evaluate(args, self.eval_dataloader, global_step)
# save ckpt
# TODO: save best model on dev, use loss/perplexity on whole dev dataset as metric
if global_step % args.save_steps == 0:
tag = f"global_step{global_step}"
self.strategy.save_ckpt(
self.model, args.ckpt_path, tag, args.max_ckpt_num, args.max_ckpt_mem, client_states
)
[docs] def evaluate(self, args, eval_dataloader, steps=0) -> None:
"""
Evaluate the model on the provided dataloader and write a JSONL of
scores to the save path indicated by ``strategy.args.save_path``.
Also calculates and logs accuracy metrics to Weights & Biases or
TensorBoard.
:param args: present for API compatibility with callers.
:type args: Any
:param eval_dataloader: Dataloader for evaluation samples.
:type eval_dataloader: torch.utils.data.DataLoader
:param steps: Global step id for naming the output file.
:type steps: int
:returns: None
:rtype: NoneType
The output file name format is ``eval_scores_{steps}.jsonl``.
"""
step_bar = tqdm(
range(len(eval_dataloader)),
desc="Eval stage of steps %d" % steps,
disable=not self.strategy.is_rank_0(),
)
self.model.eval()
# Create JSONL file and write header (only on rank 0)
if self.strategy.is_rank_0():
self.strategy.print(f"Start Evaluation at global step {steps}...")
output_file = f"eval_scores_{steps}.jsonl"
output_file = os.path.join(self.strategy.args.save_path, "evals", output_file)
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, "w") as f:
f.write("") # Just create/clear the file
head_types = self.model.head_types
# Metrics accumulators
eval_metrics = {"count": 0}
for head in head_types:
eval_metrics[f"{head}_correct"] = 0.0
eval_metrics[f"{head}_count"] = 0
eval_metrics[f"{head}_chosen_reward"] = 0.0
eval_metrics[f"{head}_reject_reward"] = 0.0
with torch.no_grad():
for data in eval_dataloader:
(
input0_ids,
input0_mask,
input1_ids,
input1_mask,
input0_input_features,
input0_feature_attention_mask,
input1_input_features,
input1_feature_attention_mask,
extras,
) = data
device = torch.cuda.current_device()
input0_ids = input0_ids.squeeze(1).to(device)
input0_mask = input0_mask.squeeze(1).to(device)
input1_ids = input1_ids.squeeze(1).to(device)
input1_mask = input1_mask.squeeze(1).to(device)
input0_input_features = input0_input_features.to(device)
input0_feature_attention_mask = input0_feature_attention_mask.to(device)
input1_input_features = input1_input_features.to(device)
input1_feature_attention_mask = input1_feature_attention_mask.to(device)
scores0, scores1 = self.concatenated_forward(
self.model,
input0_ids,
input0_mask,
input1_ids,
input1_mask,
input0_input_features,
input0_feature_attention_mask,
input1_input_features,
input1_feature_attention_mask,
)
# --- Metric Calculation Start ---
labels = {}
for head_type in head_types:
labels[head_type] = [e[head_type] if head_type in e else "C" for e in extras]
chosens = {}
rejects = {}
for head_type in head_types:
chosens[head_type] = []
rejects[head_type] = []
for i in range(len(extras)):
for head_type in head_types:
label = labels[head_type][i]
if label == "A":
chosens[head_type].append(scores0[head_type][i])
rejects[head_type].append(scores1[head_type][i])
elif label == "B":
chosens[head_type].append(scores1[head_type][i])
rejects[head_type].append(scores0[head_type][i])
# We don't need equals for accuracy/reward calculation
for head_type in head_types:
if len(chosens[head_type]) > 0:
chosens[head_type] = torch.stack(chosens[head_type])
rejects[head_type] = torch.stack(rejects[head_type])
# Update local metrics
batch_size = len(extras)
eval_metrics["count"] += batch_size
for head_type in head_types:
if len(chosens[head_type]) > 0:
count = len(chosens[head_type])
eval_metrics[f"{head_type}_correct"] += ((chosens[head_type]
> rejects[head_type]).float().sum().item())
eval_metrics[f"{head_type}_count"] += count
eval_metrics[f"{head_type}_chosen_reward"] += chosens[head_type].sum().item()
eval_metrics[f"{head_type}_reject_reward"] += rejects[head_type].sum().item()
# --- Metric Calculation End ---
# Gather scores from all GPUs for each head_type
gathered_scores0 = {}
gathered_scores1 = {}
for head_type in scores0.keys():
if head_type in scores0:
# Create tensor list for all_gather
tensor_list0 = [torch.zeros_like(scores0[head_type]) for _ in range(dist.get_world_size())]
tensor_list1 = [torch.zeros_like(scores1[head_type]) for _ in range(dist.get_world_size())]
# Use all_gather instead of all_gather_object for tensors
dist.all_gather(tensor_list0, scores0[head_type])
dist.all_gather(tensor_list1, scores1[head_type])
# Concatenate all tensors along batch dimension
gathered_scores0[head_type] = torch.cat(tensor_list0, dim=0)
gathered_scores1[head_type] = torch.cat(tensor_list1, dim=0)
# Gather extras
all_extras = all_gather_and_flatten(extras)
# write scores to JSONL file immediately (only on rank 0)
if self.strategy.is_rank_0():
with open(output_file, "a") as f:
for i, extras in enumerate(all_extras):
# build per-sample scores dict from gathered_scores
input0_scores = {
head_type: gathered_scores0[head_type][i].item()
for head_type in gathered_scores0
}
input1_scores = {
head_type: gathered_scores1[head_type][i].item()
for head_type in gathered_scores1
}
# build per-sample results dict
results = {
"info": extras,
"scores0": input0_scores,
"scores1": input1_scores,
}
f.write(json.dumps(results) + "\n")
step_bar.update()
# --- Aggregate and Log Metrics ---
reduced_metrics = all_reduce_dict(eval_metrics, op="sum")
logs_dict = {}
for head in head_types:
count = reduced_metrics[f"{head}_count"]
if count > 0:
logs_dict[f"eval/{head}_acc"] = reduced_metrics[f"{head}_correct"] / count
chosen_reward = reduced_metrics[f"{head}_chosen_reward"] / count
reject_reward = reduced_metrics[f"{head}_reject_reward"] / count
logs_dict[f"eval/{head}_chosen_reward_mean"] = round(chosen_reward, 4)
logs_dict[f"eval/{head}_reject_reward_mean"] = round(reject_reward, 4)
if self.strategy.is_rank_0():
self.strategy.print(f"Evaluation scores written to {output_file}")
self.strategy.print(f"Eval metrics: {logs_dict}")
if self._wandb is not None:
logs_dict["eval/global_step"] = steps
self._wandb.log(logs_dict)
elif self._tensorboard is not None:
for k, v in logs_dict.items():
if k != "eval/global_step":
self._tensorboard.add_scalar(k, v, steps)
self.model.train() # reset model state
[docs] def concatenated_forward(
self,
model,
input0_ids,
input0_mask,
input1_ids,
input1_mask,
input0_input_features,
input0_feature_attention_mask,
input1_input_features,
input1_feature_attention_mask,
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
"""
Run the model once on two concatenated inputs.
Concatenation reduces the number of forward passes.
:param model: Model to evaluate.
:type model: nn.Module
:param input0_ids: Token ids for the first sequences.
:type input0_ids: torch.Tensor
:param input0_mask: Attention mask for the first sequences.
:type input0_mask: torch.Tensor
:param input1_ids: Token ids for the second sequences.
:type input1_ids: torch.Tensor
:param input1_mask: Attention mask for the second sequences.
:type input1_mask: torch.Tensor
:param input0_input_features: Audio features for the first sequences.
:type input0_input_features: torch.Tensor
:param input0_feature_attention_mask: Attention mask for the first audio features.
:type input0_feature_attention_mask: torch.Tensor
:param input1_input_features: Audio features for the second sequences.
:type input1_input_features: torch.Tensor
:param input1_feature_attention_mask: Attention mask for the second audio features.
:type input1_feature_attention_mask: torch.Tensor
:returns: Tuple of dicts ``(scores0, scores1)`` separating the outputs
for the first and second samples.
:rtype: Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]
"""
input_ids, att_masks = self.concatenated_inputs(input0_ids, input0_mask, input1_ids, input1_mask)
with torch.no_grad():
input_features = torch.cat((input0_input_features, input1_input_features), dim=0)
feature_attention_mask = torch.cat((input0_feature_attention_mask, input1_feature_attention_mask), dim=0)
scores = model(
input_ids,
attention_mask=att_masks,
input_features=input_features,
feature_attention_mask=feature_attention_mask,
)
scores0 = {head_type: score[:input0_ids.shape[0]] for head_type, score in scores.items()}
scores1 = {head_type: score[input0_ids.shape[0]:] for head_type, score in scores.items()}
return scores0, scores1
[docs] def concatenated_inputs(self, input0_ids, input0_mask, input1_ids,
input1_mask) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Concatenate two inputs into a single batch.
:param input0_ids: Token ids for the first sequences, shape ``(N, Lc)``.
:type input0_ids: torch.Tensor
:param input0_mask: Attention mask for chosen sequences, shape ``(N, Lc)``.
:type input0_mask: torch.Tensor
:param input1_ids: Token ids for the second sequences, shape ``(N, Lr)``.
:type input1_ids: torch.Tensor
:param input1_mask: Attention mask for the second sequences, shape ``(N, Lr)``.
:type input1_mask: torch.Tensor
:returns: Tuple ``(input_ids, att_masks)`` where inputs are padded to
a common max length across input0 and input1, then concatenated
along the batch dimension to shape ``(2N, Lmax)``.
:rtype: Tuple[torch.Tensor, torch.Tensor]
"""
max_length = max(input0_ids.shape[1], input1_ids.shape[1])
inputs_ids = torch.cat(
(
pad_to_length(input0_ids, max_length, self.tokenizer.pad_token_id),
pad_to_length(input1_ids, max_length, self.tokenizer.pad_token_id),
),
dim=0,
)
max_length = max(input0_mask.shape[1], input1_mask.shape[1])
att_masks = torch.cat((pad_to_length(input0_mask, max_length, 0), pad_to_length(input1_mask, max_length, 0)),
dim=0)
return inputs_ids, att_masks