Source code for lightrft.trainer.srm_trainer_vl
"""
Trainer utilities for scalar reward models (vision-language capable).
This module provides a trainer that supports pairwise preference training with
different loss functions, including Bradley-Terry (BT) loss and Human Preference
Score (HPS) Loss. It integrates with the project's Strategy abstraction for distributed
training, gradient accumulation, and checkpointing, while optionally logging to
Weights & Biases or TensorBoard.
"""
import os
import json
from tqdm import tqdm
from typing import Tuple, Dict
import torch
import torch.nn as nn
from torch.optim import Optimizer
import torch.distributed as dist
from lightrft.models import LogExpLoss, LogSigmoidLoss, HPSLoss, pad_to_length
from lightrft.utils import DistributedSampler, all_gather_and_flatten, all_reduce_dict
[docs]class SRMTrainerVL:
"""
Trainer for scalar vision-language reward models.
:param model: The model to be trained; expected to return a dict of head
scores for each head type when called with token ids and optional image
or video features.
:type model: torch.nn.Module
:param strategy: Training strategy that manages distributed operations,
gradient accumulation, checkpointing, logging helpers, and args.
:type strategy: Strategy
:param optim: Optimizer used for parameter updates.
:type optim: torch.optim.Optimizer
:param train_dataloader: Dataloader providing pairwise (A/B/Equal) batches.
:type train_dataloader: torch.utils.data.DataLoader
:param scheduler: Learning rate scheduler.
:type scheduler: torch.optim.lr_scheduler._LRScheduler
:param tokenizer: Tokenizer used to obtain pad token id for sequence
padding when concatenating inputs.
:type tokenizer: Callable
:param eval_dataloader: Optional dataloader used for evaluation.
:type eval_dataloader: Optional[torch.utils.data.DataLoader]
:param max_epochs: Maximum number of training epochs.
:type max_epochs: int
:param loss: Loss function to use. Choices are 'sigmoid' (PairWiseLoss),
'logexp' (LogExpLoss), and 'hps' (HPSLoss).
:type loss: str
:param margin: Margin value for BT loss (only used if ``loss`` is BT).
:type margin: float
"""
def __init__(
self,
model: nn.Module,
strategy,
optim: Optimizer,
train_dataloader,
scheduler,
tokenizer,
eval_dataloader=None,
max_epochs: int = 2,
loss: str = "sigmoid",
margin: float = 0.1,
) -> None:
self.strategy = strategy
self.epochs = max_epochs
self.model = model
self.train_dataloader = train_dataloader
self.eval_dataloader = eval_dataloader
self.scheduler = scheduler
self.optimizer = optim
self.tokenizer = tokenizer
self.margin = margin
self.args = strategy.args
if loss == "sigmoid":
self.loss = "sigmoid"
self.loss_fn = LogSigmoidLoss()
self.strategy.print("LogSigmoid Loss")
elif loss == "logexp":
self.loss = "logexp"
self.loss_fn = LogExpLoss()
self.strategy.print("LogExp Loss")
elif loss == "hps":
self.loss = "hps"
self.loss_fn = HPSLoss()
self.strategy.print("HPS Loss")
else:
raise ValueError(f"invalid loss type: {loss}")
# wandb/tensorboard setting
self._wandb = None
self._tensorboard = None
if self.strategy.args.use_wandb and self.strategy.is_rank_0():
import wandb
self._wandb = wandb
if not wandb.api.api_key:
wandb.login(key=strategy.args.use_wandb)
wandb.init(
entity=strategy.args.wandb_org,
project=strategy.args.wandb_project,
group=strategy.args.wandb_group,
name=strategy.args.wandb_run_name,
config=strategy.args.__dict__,
reinit=True,
)
wandb.define_metric("train/global_step")
wandb.define_metric("train/*", step_metric="train/global_step", step_sync=True)
wandb.define_metric("eval/global_step")
wandb.define_metric("eval/*", step_metric="eval/global_step", step_sync=True)
# Initialize TensorBoard writer if wandb is not available
if self.strategy.args.use_tensorboard and self._wandb is None and self.strategy.is_rank_0():
from torch.utils.tensorboard import SummaryWriter
os.makedirs(self.strategy.args.use_tensorboard, exist_ok=True)
log_dir = os.path.join(self.strategy.args.use_tensorboard, strategy.args.wandb_run_name)
self._tensorboard = SummaryWriter(log_dir=log_dir)
[docs] def fit(self, args, consumed_samples=0, num_update_steps_per_epoch=None) -> None:
"""
Train the model for ``max_epochs`` using the provided dataloaders.
:param args: Training arguments (typically ``strategy.args``) including
logging/eval/save intervals and batch sizing.
:type args: Any
:param consumed_samples: Number of samples already consumed (for
resuming training), defaults to ``0``.
:type consumed_samples: int
:param num_update_steps_per_epoch: Number of optimizer steps per epoch.
If ``None``, it's inferred externally and passed in by caller.
:type num_update_steps_per_epoch: Optional[int]
:returns: None
:rtype: NoneType
Notes:
- Supports HPS Scale training via ``scale_for_train`` flag in args.
- Logs to Weights & Biases or TensorBoard when configured.
"""
# get eval and save steps
if args.eval_steps == -1:
args.eval_steps = num_update_steps_per_epoch # Evaluate once per epoch
if args.save_steps == -1:
args.save_steps = float("inf") # do not save ckpt
# Restore step and start_epoch
step = consumed_samples // args.train_batch_size * self.strategy.accumulated_gradient + 1
start_epoch = consumed_samples // args.train_batch_size // num_update_steps_per_epoch
consumed_samples = consumed_samples % (num_update_steps_per_epoch * args.train_batch_size)
epoch_bar = tqdm(range(start_epoch, self.epochs), desc="Train epoch", disable=not self.strategy.is_rank_0())
head_types = self.model.head_types
loss_mean = {}
acc_mean = {}
acc = {}
total_loss_mean = 0.0
for head_type in head_types:
loss_mean[head_type] = 0
acc_mean[head_type] = 0
acc[head_type] = 0
for epoch in range(start_epoch, self.epochs):
if isinstance(self.train_dataloader.sampler, DistributedSampler):
self.train_dataloader.sampler.set_epoch(
epoch, consumed_samples=0 if epoch > start_epoch else consumed_samples
)
step_bar = tqdm(
range(self.train_dataloader.__len__()),
desc="Train step of epoch %d" % epoch,
disable=not self.strategy.is_rank_0(),
)
self.model.train()
scale_for_train = args.scale_for_train
for data in self.train_dataloader:
(
input0_ids,
input0_mask,
input1_ids,
input1_mask,
input0_img_pixels,
input0_img_grid_thws,
input1_img_pixels,
input1_img_grid_thws,
input0_video_pixels,
input0_video_grid_thws,
input1_video_pixels,
input1_video_grid_thws,
extras,
) = data
device = torch.cuda.current_device()
input0_ids = input0_ids.squeeze(1).to(device)
input0_mask = input0_mask.squeeze(1).to(device)
input1_ids = input1_ids.squeeze(1).to(device)
input1_mask = input1_mask.squeeze(1).to(device)
if input0_img_pixels is not None:
input0_img_pixels = input0_img_pixels.to(device)
input0_img_grid_thws = input0_img_grid_thws.to(device)
input1_img_pixels = input1_img_pixels.to(device)
input1_img_grid_thws = input1_img_grid_thws.to(device)
if input0_video_pixels is not None:
input0_video_pixels = input0_video_pixels.to(device)
input0_video_grid_thws = input0_video_grid_thws.to(device)
input1_video_pixels = input1_video_pixels.to(device)
input1_video_grid_thws = input1_video_grid_thws.to(device)
scores0, scores1 = self.concatenated_forward(
self.model,
input0_ids,
input0_mask,
input1_ids,
input1_mask,
input0_img_pixels,
input0_img_grid_thws,
input1_img_pixels,
input1_img_grid_thws,
input0_video_pixels,
input0_video_grid_thws,
input1_video_pixels,
input1_video_grid_thws,
)
labels = {}
for head_type in head_types:
labels[head_type] = [e[head_type] if head_type in e else "C" for e in extras]
chosens = {}
rejects = {}
equals = {}
for head_type in head_types:
chosens[head_type] = []
rejects[head_type] = []
equals[head_type] = []
for i in range(len(extras)):
for head_type in head_types:
label = labels[head_type][i]
if label == "A":
chosens[head_type].append(scores0[head_type][i])
rejects[head_type].append(scores1[head_type][i])
elif label == "B":
chosens[head_type].append(scores1[head_type][i])
rejects[head_type].append(scores0[head_type][i])
else:
equals[head_type].append([scores0[head_type][i], scores1[head_type][i]])
for head_type in head_types:
if len(chosens[head_type]) > 0:
chosens[head_type] = torch.stack(chosens[head_type])
rejects[head_type] = torch.stack(rejects[head_type])
equals_loss = {}
for head_type in head_types:
if len(equals[head_type]) > 0:
equals[head_type] = torch.stack([torch.stack(t) for t in equals[head_type]])
equal_loss = torch.abs(equals[head_type][:, 0] - equals[head_type][:, 1]).mean()
equals_loss[head_type] = equal_loss
else:
equals_loss[head_type] = torch.tensor(0.0, device=device)
head_loss = {}
for head_type in head_types:
if len(chosens[head_type]) == 0:
continue
# Compute per head loss based on the selected loss function
head_loss[head_type] = self.loss_fn(chosens[head_type], rejects[head_type], self.margin)
total_loss = sum(head_loss.values()) + 0.01 * sum(equals_loss.values())
self.strategy.backward(total_loss, self.model, self.optimizer)
self.strategy.optimizer_step(self.optimizer, self.model, self.scheduler)
for head_type in head_types:
if len(chosens[head_type]) == 0:
continue
acc[head_type] = (chosens[head_type] > rejects[head_type]).float().mean().item()
acc_mean[head_type] = acc_mean[head_type] * 0.9 + 0.1 * acc[head_type]
loss_mean[head_type] = loss_mean[head_type] * 0.9 + 0.1 * head_loss[head_type].item()
total_loss_mean = total_loss_mean * 0.9 + 0.1 * total_loss.item()
logs_dict = {
"loss": total_loss.item(),
"loss_mean": total_loss_mean,
"lr": self.scheduler.get_last_lr()[0],
}
for head_type in head_types:
if len(chosens[head_type]) > 0:
logs_dict[f"{head_type}_loss"] = head_loss[head_type].item()
logs_dict[f"{head_type}_acc"] = acc[head_type]
logs_dict[f"{head_type}_acc_mean"] = acc_mean[head_type]
logs_dict[f"{head_type}_loss_mean"] = loss_mean[head_type]
logs_dict[f"{head_type}_chosen_reward"] = (
round(chosens[head_type].mean().item() *
0.07, 4) if scale_for_train else chosens[head_type].mean().item()
)
logs_dict[f"{head_type}_reject_reward"] = (
round(rejects[head_type].mean().item() *
0.07, 4) if scale_for_train else rejects[head_type].mean().item()
)
else:
logs_dict[f"{head_type}_loss"] = 0.0
logs_dict[f"{head_type}_acc"] = 0.0
logs_dict[f"{head_type}_acc_mean"] = 0.0
logs_dict[f"{head_type}_loss_mean"] = 0.0
logs_dict[f"{head_type}_chosen_reward"] = 0.0
logs_dict[f"{head_type}_reject_reward"] = 0.0
# step bar
for k in logs_dict.keys():
if k.startswith("preference"):
logs_dict[k] = self.strategy.all_reduce(logs_dict[k], op="max")
else:
logs_dict[k] = self.strategy.all_reduce(logs_dict[k])
step_bar.set_postfix(logs_dict)
step_bar.update()
# logs/checkpoints/evaluation
if step % self.strategy.accumulated_gradient == 0:
global_step = step // self.strategy.accumulated_gradient
client_states = {"consumed_samples": global_step * args.train_batch_size}
self.save_logs_and_checkpoints(args, global_step, step_bar, logs_dict, client_states)
step += 1
epoch_bar.update()
if self._wandb is not None and self.strategy.is_rank_0():
self._wandb.finish()
if self._tensorboard is not None and self.strategy.is_rank_0():
self._tensorboard.close()
if self._wandb is not None and self.strategy.is_rank_0():
self._wandb.finish()
if self._tensorboard is not None and self.strategy.is_rank_0():
self._tensorboard.close()
# logs/checkpoints/evaluate
[docs] def save_logs_and_checkpoints(self, args, global_step, step_bar, logs_dict={}, client_states={}) -> None:
"""
Log metrics and optionally run evaluation and checkpointing.
:param args: Training arguments providing logging/eval/save intervals
and checkpoint configurations.
:type args: Any
:param global_step: Current global optimization step.
:type global_step: int
:param step_bar: Progress bar for step-level updates.
:type step_bar: tqdm.tqdm
:param logs_dict: Dictionary of metrics to log (will be reduced across
ranks via ``strategy.all_reduce`` before display/logging).
:type logs_dict: Dict[str, float]
:param client_states: Extra state to persist with checkpoints (e.g.,
consumed samples).
:type client_states: Dict[str, Any]
:returns: None
:rtype: NoneType
"""
if global_step % args.logging_steps == 0:
# wandb
if self._wandb is not None and self.strategy.is_rank_0():
logs = {"train/%s" % k: v for k, v in {**logs_dict, "global_step": global_step}.items()}
self._wandb.log(logs)
# TensorBoard
elif self._tensorboard is not None and self.strategy.is_rank_0():
for k, v in logs_dict.items():
self._tensorboard.add_scalar(f"train/{k}", v, global_step)
# eval
if global_step % args.eval_steps == 0:
# do eval when len(dataloader) > 0, avoid zero division in eval.
if self.eval_dataloader and len(self.eval_dataloader) > 0:
# Pass args first to match evaluate signature (args, dataloader, steps)
self.evaluate(args, self.eval_dataloader, global_step)
# save ckpt
# TODO: save best model on dev, use loss/perplexity on whole dev dataset as metric
if global_step % args.save_steps == 0:
tag = f"global_step{global_step}"
self.strategy.save_ckpt(
self.model, args.ckpt_path, tag, args.max_ckpt_num, args.max_ckpt_mem, client_states
)
[docs] def evaluate(self, args, eval_dataloader, steps=0) -> None:
"""
Evaluate the model on the provided dataloader and write a JSONL of
scores to the save path indicated by ``strategy.args.save_path``.
Also calculates and logs accuracy metrics to Weights & Biases or
TensorBoard.
:param args: present for API compatibility with callers.
:type args: Any
:param eval_dataloader: Dataloader for evaluation samples.
:type eval_dataloader: torch.utils.data.DataLoader
:param steps: Global step id for naming the output file.
:type steps: int
:returns: None
:rtype: NoneType
The output file name format is ``eval_scores_{steps}.jsonl``.
"""
step_bar = tqdm(
range(len(eval_dataloader)),
desc="Eval stage of steps %d" % steps,
disable=not self.strategy.is_rank_0(),
)
self.model.eval()
# Create JSONL file and write header (only on rank 0)
if self.strategy.is_rank_0():
self.strategy.print(f"Start Evaluation at global step {steps}...")
output_file = f"eval_results_{steps}.jsonl"
output_file = os.path.join(self.strategy.args.save_path, "evals", output_file)
os.makedirs(os.path.dirname(output_file), exist_ok=True)
with open(output_file, "w") as f:
f.write("") # Just create/clear the file
head_types = self.model.head_types
# Metrics accumulators
eval_metrics = {"count": 0}
for head in head_types:
eval_metrics[f"{head}_correct"] = 0.0
eval_metrics[f"{head}_count"] = 0
eval_metrics[f"{head}_chosen_reward"] = 0.0
eval_metrics[f"{head}_reject_reward"] = 0.0
with torch.no_grad():
for data in eval_dataloader:
(
input0_ids,
input0_mask,
input1_ids,
input1_mask,
input0_img_pixels,
input0_img_grid_thws,
input1_img_pixels,
input1_img_grid_thws,
input0_video_pixels,
input0_video_grid_thws,
input1_video_pixels,
input1_video_grid_thws,
extras,
) = data
device = torch.cuda.current_device()
input0_ids = input0_ids.squeeze(1).to(device)
input0_mask = input0_mask.squeeze(1).to(device)
input1_ids = input1_ids.squeeze(1).to(device)
input1_mask = input1_mask.squeeze(1).to(device)
if input0_img_pixels is not None:
input0_img_pixels = input0_img_pixels.to(device)
input0_img_grid_thws = input0_img_grid_thws.to(device)
input1_img_pixels = input1_img_pixels.to(device)
input1_img_grid_thws = input1_img_grid_thws.to(device)
if input0_video_pixels is not None:
input0_video_pixels = input0_video_pixels.to(device)
input0_video_grid_thws = input0_video_grid_thws.to(device)
input1_video_pixels = input1_video_pixels.to(device)
input1_video_grid_thws = input1_video_grid_thws.to(device)
scores0, scores1 = self.concatenated_forward(
self.model,
input0_ids,
input0_mask,
input1_ids,
input1_mask,
input0_img_pixels,
input0_img_grid_thws,
input1_img_pixels,
input1_img_grid_thws,
input0_video_pixels,
input0_video_grid_thws,
input1_video_pixels,
input1_video_grid_thws,
)
# --- Metric Calculation Start ---
labels = {}
for head_type in head_types:
labels[head_type] = [e[head_type] if head_type in e else "C" for e in extras]
chosens = {}
rejects = {}
for head_type in head_types:
chosens[head_type] = []
rejects[head_type] = []
for i in range(len(extras)):
for head_type in head_types:
label = labels[head_type][i]
if label == "A":
chosens[head_type].append(scores0[head_type][i])
rejects[head_type].append(scores1[head_type][i])
elif label == "B":
chosens[head_type].append(scores1[head_type][i])
rejects[head_type].append(scores0[head_type][i])
# We don't need equals for accuracy/reward calculation
for head_type in head_types:
if len(chosens[head_type]) > 0:
chosens[head_type] = torch.stack(chosens[head_type])
rejects[head_type] = torch.stack(rejects[head_type])
# Update local metrics
batch_size = len(extras)
eval_metrics["count"] += batch_size
for head_type in head_types:
if len(chosens[head_type]) > 0:
count = len(chosens[head_type])
eval_metrics[f"{head_type}_correct"] += ((chosens[head_type]
> rejects[head_type]).float().sum().item())
eval_metrics[f"{head_type}_count"] += count
eval_metrics[f"{head_type}_chosen_reward"] += chosens[head_type].sum().item()
eval_metrics[f"{head_type}_reject_reward"] += rejects[head_type].sum().item()
# --- Metric Calculation End ---
# Gather scores from all GPUs for each head_type
gathered_scores0 = {}
gathered_scores1 = {}
for head_type in scores0.keys():
if head_type in scores0:
# Create tensor list for all_gather
tensor_list0 = [torch.zeros_like(scores0[head_type]) for _ in range(dist.get_world_size())]
tensor_list1 = [torch.zeros_like(scores1[head_type]) for _ in range(dist.get_world_size())]
# Use all_gather instead of all_gather_object for tensors
dist.all_gather(tensor_list0, scores0[head_type])
dist.all_gather(tensor_list1, scores1[head_type])
# Concatenate all tensors along batch dimension
gathered_scores0[head_type] = torch.cat(tensor_list0, dim=0)
gathered_scores1[head_type] = torch.cat(tensor_list1, dim=0)
# Gather extras
all_extras = all_gather_and_flatten(extras)
# write scores to JSONL file immediately (only on rank 0)
if self.strategy.is_rank_0():
with open(output_file, "a") as f:
for i, extras in enumerate(all_extras):
# build per-sample scores dict from gathered_scores
input0_scores = {
head_type: gathered_scores0[head_type][i].item()
for head_type in gathered_scores0
}
input1_scores = {
head_type: gathered_scores1[head_type][i].item()
for head_type in gathered_scores1
}
# build per-sample results dict
results = {
"info": extras,
"scores0": input0_scores,
"scores1": input1_scores,
}
f.write(json.dumps(results) + "\n")
step_bar.update()
# --- Aggregate and Log Metrics ---
reduced_metrics = all_reduce_dict(eval_metrics, op="sum")
logs_dict = {}
for head in head_types:
count = reduced_metrics[f"{head}_count"]
if count > 0:
logs_dict[f"eval/{head}_acc"] = reduced_metrics[f"{head}_correct"] / count
chosen_reward = reduced_metrics[f"{head}_chosen_reward"] / count
reject_reward = reduced_metrics[f"{head}_reject_reward"] / count
logs_dict[f"eval/{head}_chosen_reward_mean"] = round(chosen_reward, 4)
logs_dict[f"eval/{head}_reject_reward_mean"] = round(reject_reward, 4)
if self.strategy.is_rank_0():
self.strategy.print(f"Evaluation scores written to {output_file}")
self.strategy.print(f"Eval metrics: {logs_dict}")
# Log to wandb/tensorboard
if self._wandb is not None:
logs_dict["eval/global_step"] = steps
self._wandb.log(logs_dict)
elif self._tensorboard is not None:
for k, v in logs_dict.items():
if k != "eval/global_step":
self._tensorboard.add_scalar(k, v, steps)
self.model.train() # reset model state
[docs] def concatenated_forward(
self,
model,
input0_ids,
input0_mask,
input1_ids,
input1_mask,
input0_img_pixels,
input0_img_grid_thws,
input1_img_pixels,
input1_img_grid_thws,
input0_video_pixels,
input0_video_grid_thws,
input1_video_pixels,
input1_video_grid_thws,
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
"""
Run the model once on concatenated chosen/rejected inputs.
Concatenation reduces the number of forward passes.
:param model: Model to evaluate.
:type model: nn.Module
:param input0_ids: Token ids for chosen sequences.
:type input0_ids: torch.Tensor
:param input0_mask: Attention mask for chosen sequences.
:type input0_mask: torch.Tensor
:param input1_ids: Token ids for rejected sequences.
:type input1_ids: torch.Tensor
:param input1_mask: Attention mask for rejected sequences.
:type input1_mask: torch.Tensor
:param input0_img_pixels: Optional image features for chosen samples.
:type input0_img_pixels: Optional[torch.Tensor]
:param input0_img_grid_thws: Optional image grid meta for chosen.
:type input0_img_grid_thws: Optional[torch.Tensor]
:param input1_img_pixels: Optional image features for rejected samples.
:type input1_img_pixels: Optional[torch.Tensor]
:param input1_img_grid_thws: Optional image grid meta for rejected.
:type input1_img_grid_thws: Optional[torch.Tensor]
:param input0_video_pixels: Optional video features for chosen samples.
:type input0_video_pixels: Optional[torch.Tensor]
:param input0_video_grid_thws: Optional video grid meta for chosen.
:type input0_video_grid_thws: Optional[torch.Tensor]
:param input1_video_pixels: Optional video features for rejected samples.
:type input1_video_pixels: Optional[torch.Tensor]
:param input1_video_grid_thws: Optional video grid meta for rejected.
:type input1_video_grid_thws: Optional[torch.Tensor]
:returns: Tuple of dicts ``(scores0, scores1)`` separating the outputs
for chosen and rejected samples.
:rtype: Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]
"""
input_ids, att_masks = self.concatenated_inputs(input0_ids, input0_mask, input1_ids, input1_mask)
pixel_values = None
image_grid_thws = None
pixel_values_videos = None
video_grid_thws = None
with torch.no_grad():
if input0_img_pixels is not None:
pixel_values = torch.cat((input0_img_pixels, input1_img_pixels), dim=0)
image_grid_thws = torch.cat((input0_img_grid_thws, input1_img_grid_thws), dim=0)
if input0_video_pixels is not None:
pixel_values_videos = torch.cat((input0_video_pixels, input1_video_pixels), dim=0)
video_grid_thws = torch.cat((input0_video_grid_thws, input1_video_grid_thws), dim=0)
scores = model(
input_ids,
attention_mask=att_masks,
pixel_values=pixel_values,
image_grid_thw=image_grid_thws,
pixel_values_videos=pixel_values_videos,
video_grid_thw=video_grid_thws,
)
scores0 = {head_type: score[:input0_ids.shape[0]] for head_type, score in scores.items()}
scores1 = {head_type: score[input0_ids.shape[0]:] for head_type, score in scores.items()}
return scores0, scores1
[docs] def concatenated_inputs(self, input0_ids, input0_mask, input1_ids,
input1_mask) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Concatenate two inputs into a single batch.
:param input0_ids: Token ids for the first sequences, shape ``(N, Lc)``.
:type input0_ids: torch.Tensor
:param input0_mask: Attention mask for chosen sequences, shape ``(N, Lc)``.
:type input0_mask: torch.Tensor
:param input1_ids: Token ids for the second sequences, shape ``(N, Lr)``.
:type input1_ids: torch.Tensor
:param input1_mask: Attention mask for the second sequences, shape ``(N, Lr)``.
:type input1_mask: torch.Tensor
:returns: Tuple ``(input_ids, att_masks)`` where inputs are padded to
a common max length across input0 and input1, then concatenated
along the batch dimension to shape ``(2N, Lmax)``.
:rtype: Tuple[torch.Tensor, torch.Tensor]
"""
max_length = max(input0_ids.shape[1], input1_ids.shape[1])
inputs_ids = torch.cat(
(
pad_to_length(input0_ids, max_length, self.tokenizer.pad_token_id),
pad_to_length(input1_ids, max_length, self.tokenizer.pad_token_id),
),
dim=0,
)
max_length = max(input0_mask.shape[1], input1_mask.shape[1])
att_masks = torch.cat((pad_to_length(input0_mask, max_length, 0), pad_to_length(input1_mask, max_length, 0)),
dim=0)
return inputs_ids, att_masks