Shortcuts

Source code for lightrft.strategy.utils.data_utils

"""
Distributed Sampling Module for PyTorch

This module provides utilities for distributed data sampling in PyTorch, particularly
useful for distributed training scenarios. It includes a customized DistributedSampler
that extends PyTorch's sampling capabilities with additional features like handling
consumed samples for resuming training.

The module is designed to work seamlessly with PyTorch's distributed training
infrastructure and provides proper data partitioning across multiple processes.
"""

import math
from typing import Iterator, Optional, TypeVar

import torch
import torch.distributed as dist
from torch.utils.data.dataset import Dataset
from torch.utils.data.sampler import Sampler

__all__ = ["DistributedSampler"]

_T_co = TypeVar("_T_co", covariant=True)


[docs]class DistributedSampler(Sampler[_T_co]): r"""Sampler that restricts data loading to a subset of the dataset. It is especially useful in conjunction with :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the original dataset that is exclusive to it. .. note:: Dataset is assumed to be of constant size and that any instance of it always returns the same elements in the same order. :param dataset: Dataset used for sampling. :type dataset: Dataset :param num_replicas: Number of processes participating in distributed training. By default, world_size is retrieved from the current distributed group. :type num_replicas: Optional[int] :param rank: Rank of the current process within num_replicas. By default, rank is retrieved from the current distributed group. :type rank: Optional[int] :param shuffle: If True (default), sampler will shuffle the indices. :type shuffle: bool :param seed: Random seed used to shuffle the sampler if shuffle=True. This number should be identical across all processes in the distributed group. :type seed: int :param drop_last: If True, then the sampler will drop the tail of the data to make it evenly divisible across the number of replicas. If False, the sampler will add extra indices to make the data evenly divisible across the replicas. :type drop_last: bool :param consumed_samples: Number of samples that have been consumed already, useful for resuming training. :type consumed_samples: int .. warning:: In distributed mode, calling the :meth:`set_epoch` method at the beginning of each epoch **before** creating the :class:`DataLoader` iterator is necessary to make shuffling work properly across multiple epochs. Otherwise, the same ordering will be always used. Example:: >>> # xdoctest: +SKIP >>> sampler = DistributedSampler(dataset) if is_distributed else None >>> loader = DataLoader(dataset, shuffle=(sampler is None), ... sampler=sampler) >>> for epoch in range(start_epoch, n_epochs): ... if is_distributed: ... sampler.set_epoch(epoch) ... train(loader) """ def __init__( # pylint: disable=W0231, R0917 self, dataset: Dataset, num_replicas: Optional[int] = None, rank: Optional[int] = None, shuffle: bool = True, seed: int = 0, drop_last: bool = False, consumed_samples=0, ) -> None: """ Initialize the DistributedSampler. :param dataset: Dataset used for sampling. :type dataset: Dataset :param num_replicas: Number of processes participating in distributed training. :type num_replicas: Optional[int] :param rank: Rank of the current process within num_replicas. :type rank: Optional[int] :param shuffle: If True, sampler will shuffle the indices. :type shuffle: bool :param seed: Random seed used for shuffling. :type seed: int :param drop_last: If True, drop the last incomplete batch. :type drop_last: bool :param consumed_samples: Number of samples already consumed, for resuming training. :type consumed_samples: int :raises RuntimeError: If distributed package is not available but required. :raises ValueError: If rank is invalid for the given number of replicas. """ if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") num_replicas = dist.get_world_size() if rank is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() if rank >= num_replicas or rank < 0: raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]") self.dataset = dataset self.num_replicas = num_replicas self.rank = rank self.epoch = 0 self.drop_last = drop_last # If the dataset length is evenly divisible by # of replicas, then there # is no need to drop any data, since the dataset will be split equally. if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] # Split to nearest available length that is evenly divisible. # This is to ensure each rank receives the same amount of data when # using this Sampler. self.num_samples = math.ceil( (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] ) else: self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] self.total_size = self.num_samples * self.num_replicas self.shuffle = shuffle self.seed = seed self.consumed_indicies = consumed_samples // self.num_replicas
[docs] def __iter__(self) -> Iterator[_T_co]: """ Iterate over the indices of the dataset. :return: An iterator over the indices of the dataset. :rtype: Iterator[_T_co] """ if self.shuffle: # deterministically shuffle based on epoch and seed g = torch.Generator() g.manual_seed(self.seed + self.epoch) indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] else: indices = list(range(len(self.dataset))) # type: ignore[arg-type] if not self.drop_last: # add extra samples to make it evenly divisible padding_size = self.total_size - len(indices) if padding_size <= len(indices): indices += indices[:padding_size] else: indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] else: # remove tail of data to make it evenly divisible. indices = indices[:self.total_size] assert len(indices) == self.total_size # subsample indices = indices[self.rank:self.total_size:self.num_replicas] # skip consumed_samples indices = indices[self.consumed_indicies:] assert len(indices) == self.num_samples - self.consumed_indicies return iter(indices)
[docs] def __len__(self) -> int: """ Return the length of the sampler. :return: The number of samples in this sampler. :rtype: int """ return self.num_samples - self.consumed_indicies
[docs] def set_epoch(self, epoch: int, consumed_samples: int = 0) -> None: """ Set the epoch for this sampler. When shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering. :param epoch: Epoch number. :type epoch: int :param consumed_samples: Number of samples already consumed in this epoch. :type consumed_samples: int """ self.epoch = epoch self.consumed_indicies = consumed_samples // self.num_replicas