Shortcuts

lightrft.strategy.utils.parallel_utils

Sequence Parallelism Utilities for Distributed Training

This module provides utilities for sequence parallelism in distributed training environments. It includes functions for managing sequence parallel groups, data processing for sequence parallelism, and operations for tensor splitting, gathering, and all-to-all communication in sequence parallel contexts.

The module supports: - Setting and retrieving sequence parallel groups - Processing data for sequence parallel distribution - Slicing and padding inputs for sequence parallelism - Specialized tensor operations for sequence parallel training

set_sequence_parallel_group

lightrft.strategy.utils.parallel_utils.set_sequence_parallel_group(group)[source]

Set the global sequence parallel process group.

Parameters:

group (torch.distributed.ProcessGroup) – The process group to use for sequence parallelism.

get_sequence_parallel_group

lightrft.strategy.utils.parallel_utils.get_sequence_parallel_group()[source]

Get the current sequence parallel process group.

Returns:

The current sequence parallel process group.

Return type:

torch.distributed.ProcessGroup or None

get_sequence_parallel_world_size

lightrft.strategy.utils.parallel_utils.get_sequence_parallel_world_size()[source]

Get the world size of the sequence parallel group.

Returns:

The world size of the sequence parallel group, or 1 if no group is set.

Return type:

int

get_sequence_parallel_rank

lightrft.strategy.utils.parallel_utils.get_sequence_parallel_rank()[source]

Get the rank of the current process in the sequence parallel group.

Returns:

The rank in the sequence parallel group, or 0 if no group is set.

Return type:

int

SPDataProcessor

class lightrft.strategy.utils.parallel_utils.SPDataProcessor[source]

A context manager for preprocessing data before conducting sequence parallel operations.

This class handles the distribution and collection of data across sequence parallel ranks, ensuring proper data sharding and gathering for sequence parallel training.

postprocess(data)[source]

Postprocess data after sequence parallel operations by distributing data to appropriate ranks.

Parameters:

data (Any) – The data to postprocess.

Returns:

The postprocessed data, distributed to the current rank.

Return type:

Any

Raises:

AssertionError – If the data length is not divisible by sp_size.

preprocess(data)[source]

Preprocess data for sequence parallelism by gathering data from all ranks.

Parameters:

data (Any) – The data to preprocess.

Returns:

The preprocessed data, gathered from all ranks if sp_size > 1.

Return type:

Any

sp_slice_and_pad_input

lightrft.strategy.utils.parallel_utils.sp_slice_and_pad_input(input_ids: torch.Tensor, position_ids: torch.Tensor)[source]

Pad and slice input_ids to be divisible by sp_size and pad position_ids to be divisible by sp_size.

Note both input_ids and position_ids will be padded, but only input_ids will be sliced. This is the utility of pre-forward for ulysses sequence parallelism.

Parameters:
  • input_ids (torch.Tensor) – Input tensor with shape [bsz, seqlen].

  • position_ids (torch.Tensor) – Position IDs tensor with shape [bsz, seqlen], where bsz must be 1.

Returns:

A tuple containing: - Padded and sliced input_ids - Padded position_ids - Size of padding added

Return type:

Tuple[torch.Tensor, torch.Tensor, int]

gather_forward_split_backward_and_unpad

lightrft.strategy.utils.parallel_utils.gather_forward_split_backward_and_unpad(input_, group, dim, padding_size=0, unpad_dim=None)[source]

Gather tensors in the forward pass, split gradients in the backward pass, and remove padding if needed.

Parameters:
  • input (torch.Tensor) – The input tensor from the current rank.

  • group (torch.distributed.ProcessGroup) – The process group for gathering.

  • dim (int) – The dimension along which to gather/split.

  • padding_size (int) – Size of padding to remove, defaults to 0.

  • unpad_dim (int or None) – Dimension from which to remove padding, defaults to None.

Returns:

The gathered (and possibly unpadded) tensor.

Return type:

torch.Tensor