Source code for lightrft.trainer.grm_trainer_vl
"""
Trainer for generative reward models (vision-language capable).
This module contains a trainer that optimizes a generative reward model using
next-token prediction loss (``GPTLMLoss``). It integrates with the Strategy
abstraction for distributed training, gradient accumulation, checkpointing, and
logging via Weights & Biases or TensorBoard.
"""
import os
import json
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.optim import Optimizer
from typing import Dict, Any
from lightrft.models import GPTLMLoss
from lightrft.datasets.utils import extract_answer
from lightrft.utils import DistributedSampler, all_gather_and_flatten, all_reduce_dict
[docs]class GRMTrainerVL:
"""
Trainer for generative reward models.
:param model: The model to be trained. Expected to return logits for
next-token prediction when called with token ids and optional image or
video features.
:type model: torch.nn.Module
:param strategy: The training strategy to apply, handling distributed
setup, gradient accumulation, logging and checkpointing.
:type strategy: Strategy
:param optim: Optimizer to use during training.
:type optim: torch.optim.Optimizer
:param train_dataloader: Dataloader for the training dataset.
:type train_dataloader: torch.utils.data.DataLoader
:param scheduler: Learning rate scheduler for dynamic adjustments during
training.
:type scheduler: torch.optim.lr_scheduler._LRScheduler
:param tokenizer: Tokenizer for input data (used when padding/processing
sequences as needed by some helpers).
:type tokenizer: Callable
:param eval_dataloader: Dataloader for the evaluation dataset.
:type eval_dataloader: Optional[torch.utils.data.DataLoader]
:param max_epochs: Maximum number of training epochs.
:type max_epochs: int
:param loss: The loss function selector.
:type loss: str
"""
def __init__(
self,
model: nn.Module,
strategy,
optim: Optimizer,
train_dataloader,
scheduler,
tokenizer,
eval_dataloader=None,
max_epochs: int = 2,
loss: str = "GPTLMLoss",
) -> None:
self.strategy = strategy
self.epochs = max_epochs
self.model = model
self.train_dataloader = train_dataloader
self.eval_dataloader = eval_dataloader
self.scheduler = scheduler
self.optimizer = optim
self.tokenizer = tokenizer
self.args = strategy.args
if loss == 'GPTLMLoss':
self.loss_fn = GPTLMLoss()
self.strategy.print("GPT Language Model Loss")
else:
raise ValueError(f"invalid loss type: {loss}")
# wandb/tensorboard setting
self._wandb = None
self._tensorboard = None
if self.strategy.args.use_wandb and self.strategy.is_rank_0():
import wandb
self._wandb = wandb
if not wandb.api.api_key:
wandb.login(key=strategy.args.use_wandb)
wandb.init(
entity=strategy.args.wandb_org,
project=strategy.args.wandb_project,
group=strategy.args.wandb_group,
name=strategy.args.wandb_run_name,
config=strategy.args.__dict__,
reinit=True,
)
wandb.define_metric("train/global_step")
wandb.define_metric("train/*", step_metric="train/global_step", step_sync=True)
wandb.define_metric("eval/global_step")
wandb.define_metric("eval/*", step_metric="eval/global_step", step_sync=True)
# Initialize TensorBoard writer if wandb is not available
if self.strategy.args.use_tensorboard and self._wandb is None and self.strategy.is_rank_0():
from torch.utils.tensorboard import SummaryWriter
os.makedirs(self.strategy.args.use_tensorboard, exist_ok=True)
log_dir = os.path.join(self.strategy.args.use_tensorboard, strategy.args.wandb_run_name)
self._tensorboard = SummaryWriter(log_dir=log_dir)
[docs] def fit(self, args, consumed_samples=0, num_update_steps_per_epoch=None) -> None:
"""
Main training loop for generative reward model.
:param args: Training arguments containing hyperparameters and configurations.
:type args: argparse.Namespace
:param consumed_samples: Number of samples already consumed (for resuming).
:type consumed_samples: int
:param num_update_steps_per_epoch: Number of update steps per epoch. Used for
eval scheduling, determining epoch boundaries, and resuming training from checkpoints.
:type num_update_steps_per_epoch: Optional[int]
"""
# get eval and save steps
if args.eval_steps == -1:
args.eval_steps = num_update_steps_per_epoch # Evaluate once per epoch
if args.save_steps == -1:
args.save_steps = float("inf") # do not save ckpt
# Restore step and start_epoch
step = consumed_samples // args.train_batch_size * self.strategy.accumulated_gradient + 1
start_epoch = consumed_samples // args.train_batch_size // num_update_steps_per_epoch
consumed_samples = consumed_samples % (num_update_steps_per_epoch * args.train_batch_size)
epoch_bar = tqdm(range(start_epoch, self.epochs), desc="Train epoch", disable=not self.strategy.is_rank_0())
loss_mean = 0.0
for epoch in range(start_epoch, self.epochs):
if isinstance(self.train_dataloader.sampler, DistributedSampler):
self.train_dataloader.sampler.set_epoch(
epoch, consumed_samples=0 if epoch > start_epoch else consumed_samples
)
step_bar = tqdm(
range(self.train_dataloader.__len__()),
desc="Train step of epoch %d" % epoch,
disable=not self.strategy.is_rank_0(),
)
self.model.train()
for data in self.train_dataloader:
ids, mask, pixel_values, image_grid_thws, pixel_values_videos, video_grid_thws, labels, extras = data
device = torch.cuda.current_device()
ids = ids.squeeze(1).to(device)
mask = mask.squeeze(1).to(device)
labels = labels.squeeze(1).to(device)
if pixel_values is not None:
pixel_values = pixel_values.to(device)
image_grid_thws = image_grid_thws.to(device)
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.to(device)
video_grid_thws = video_grid_thws.to(device)
logits = self.model(
ids,
attention_mask=mask,
pixel_values=pixel_values,
image_grid_thw=image_grid_thws,
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thws,
return_outputs=False,
)
gpt_loss = self.loss_fn(logits, labels)
total_loss = gpt_loss
self.strategy.backward(total_loss, self.model, self.optimizer)
self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler)
loss_mean = loss_mean * 0.9 + 0.1 * total_loss.item()
logs_dict = {
"loss": total_loss.item(),
"loss_mean": loss_mean,
"lr": self.scheduler.get_last_lr()[0],
}
# step bar
for k in logs_dict.keys():
logs_dict[k] = self.strategy.all_reduce(logs_dict[k])
logs_dict = all_reduce_dict(logs_dict, op="mean")
step_bar.set_postfix(logs_dict)
step_bar.update()
# logs/checkpoints/evaluation
if step % self.strategy.accumulated_gradient == 0:
global_step = step // self.strategy.accumulated_gradient
client_states = {"consumed_samples": global_step * args.train_batch_size}
self.save_logs_and_checkpoints(args, global_step, step_bar, logs_dict, client_states)
step += 1
epoch_bar.update()
if self._wandb is not None and self.strategy.is_rank_0():
self._wandb.finish()
if self._tensorboard is not None and self.strategy.is_rank_0():
self._tensorboard.close()
# logs/checkpoints/evaluate
[docs] def save_logs_and_checkpoints(
self,
args,
global_step: int,
step_bar,
logs_dict: Dict[str, float] = {},
client_states: Dict[str, Any] = {}
) -> None:
"""
Log metrics and optionally run evaluation and checkpointing.
:param args: Training arguments providing logging/eval/save intervals
and checkpoint configurations.
:type args: Any
:param global_step: Current global optimization step.
:type global_step: int
:param step_bar: Progress bar for step-level updates.
:type step_bar: tqdm.tqdm
:param logs_dict: Dictionary of metrics to log.
:type logs_dict: Dict[str, float]
:param client_states: Extra state to persist with checkpoints (e.g.,
consumed samples).
:type client_states: Dict[str, Any]
:returns: None
:rtype: NoneType
"""
if global_step % args.logging_steps == 0:
# wandb
if self._wandb is not None and self.strategy.is_rank_0():
logs = {"train/%s" % k: v for k, v in {**logs_dict, "global_step": global_step}.items()}
self._wandb.log(logs)
# TensorBoard
elif self._tensorboard is not None and self.strategy.is_rank_0():
for k, v in logs_dict.items():
self._tensorboard.add_scalar(f"train/{k}", v, global_step)
# eval
if global_step % args.eval_steps == 0:
# do eval when len(dataloader) > 0, avoid zero division in eval.
if self.eval_dataloader and len(self.eval_dataloader) > 0:
# Pass args first to match evaluate signature (args, dataloader, steps)
self.evaluate(args, self.eval_dataloader, global_step)
# save ckpt
# TODO: save best model on dev, use loss/perplexity on whole dev dataset as metric
if global_step % args.save_steps == 0:
tag = f"global_step{global_step}"
self.strategy.save_ckpt(
self.model, args.ckpt_path, tag, args.max_ckpt_num, args.max_ckpt_mem, client_states
)
[docs] def evaluate(self, args, eval_dataloader, steps: int = 0) -> None:
"""
Evaluate the model on the provided dataloader by generating text responses and saving them to a
JSON file. This method handles distributed gathering of generated text, extracted assistant
responses, and extra metadata across all processes, with the rank 0 process writing the final
results to disk. Evaluation results and metrics are also logged to Weights & Biases or
TensorBoard if configured.
:param args: Training arguments containing generation configurations.
:type args: Any
:param eval_dataloader: Dataloader for evaluation samples.
:type eval_dataloader: torch.utils.data.DataLoader
:param steps: Global step id for logging and naming the output file.
:type steps: int
:returns: None
:rtype: NoneType
"""
step_bar = tqdm(
range(len(eval_dataloader)),
desc="Eval stage of steps %d" % steps,
disable=not self.strategy.is_rank_0(),
)
self.model.eval()
# Create JSON file path (only on rank 0)
if self.strategy.is_rank_0():
output_file = f"eval_result_{steps}.json"
output_file = os.path.join(self.strategy.args.save_path, "eval", output_file)
os.makedirs(os.path.dirname(output_file), exist_ok=True)
all_eval_records = []
with torch.no_grad():
for data in eval_dataloader:
ids, mask, pixel_values, image_grid_thws, pixel_values_videos, video_grid_thws, labels, extras = data
device = torch.cuda.current_device()
ids = ids.squeeze(1).to(device)
mask = mask.squeeze(1).to(device)
if pixel_values is not None:
pixel_values = pixel_values.to(device)
image_grid_thws = image_grid_thws.to(device)
if pixel_values_videos is not None:
pixel_values_videos = pixel_values_videos.to(device)
video_grid_thws = video_grid_thws.to(device)
# Generation
# Unwrap the model if it is wrapped in a DistributedDataParallel or similar wrapper
unwrapped_model = self.model.module if hasattr(self.model, "module") else self.model
generated_ids = unwrapped_model.generate(
input_ids=ids,
attention_mask=mask,
pixel_values=pixel_values,
image_grid_thw=image_grid_thws,
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thws,
max_new_tokens=args.generate_max_len,
synced_gpus=True, # Use synced_gpus=True for Zero-3 compatibility
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
)
generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
responses_text = []
predicted_answers = []
for gen_text in generated_text:
# Extract only the assistant's response part
# Qwen/Llama-3 templates use <|im_start|>assistant\n or assistant\n
if "<|im_start|>assistant" in gen_text:
response = gen_text.split("<|im_start|>assistant")[-1]
elif "assistant\n" in gen_text:
response = gen_text.split("assistant\n")[-1]
else:
response = gen_text
responses_text.append(response)
# Extract predicted answer from the response
predicted_answers.append(extract_answer(response))
# Construct records locally and gather them across all ranks
local_records = []
for gen_text, resp_text, pred_ans, extra in zip(
generated_text, responses_text, predicted_answers, extras
):
local_records.append({
"info": extra,
"generated_text": gen_text,
"response_text": resp_text,
"predicted_answer": pred_ans,
"gt_answer": extract_answer(extra["response"])
})
gathered_records = all_gather_and_flatten(local_records)
if self.strategy.is_rank_0():
all_eval_records.extend(gathered_records)
step_bar.update()
if self.strategy.is_rank_0():
# Write JSON file
with open(output_file, 'w') as f:
json.dump(all_eval_records, f, indent=4, ensure_ascii=False)
# Calculate accuracy
correct = 0
total = 0
for r in all_eval_records:
if r["gt_answer"] == r["predicted_answer"]:
correct += 1
elif r["predicted_answer"] is None:
self.strategy.print(f"Could not extract answer from generated text: {r['generated_text']}")
total += 1
accuracy = correct / total if total > 0 else 0
self.strategy.print(f"Step {steps} Evaluation Accuracy: {accuracy:.4f} ({correct}/{total})")
# wandb/tensorboard logging
if self._wandb is not None:
columns = ["info", "generated_text", "response_text", "predicted_answer", "gt_answer"]
# Log a subset of samples
data = [[
str(r["info"]), r["generated_text"], r["response_text"], r["predicted_answer"], r["gt_answer"]
] for r in all_eval_records[:10]]
self._wandb.log({
"eval/samples": self._wandb.Table(columns=columns, data=data),
"eval/accuracy": accuracy,
"eval/global_step": steps
})
if self._tensorboard is not None:
self._tensorboard.add_scalar("eval/accuracy", accuracy, steps)
for i, r in enumerate(all_eval_records[:5]):
text = (
f"Info: {r['info']}\n\nGenerated: {r['generated_text']}\n\n"
f"Response: {r['response_text']}\n\n"
f"Predicted Answer: {r['predicted_answer']}\n\n"
f"GT Answer: {r['gt_answer']}"
)
self._tensorboard.add_text(f"eval/sample_{i}", text, steps)
self.strategy.print(f"Evaluation generations written to {output_file}")
self.model.train() # reset model state