lightrft.utils.utils¶
get_current_device¶
- lightrft.utils.utils.get_current_device(num_device_per_node: int = 8) torch.device[source]¶
Returns the current CUDA device.
This function provides a convenient way to get the current CUDA device being used by PyTorch.
- Parameters:
num_device_per_node (int) – Number of devices per node for distributed training
- Returns:
Current CUDA device
- Return type:
torch.device
Example:
>>> device = get_current_device() >>> model = model.to(device)
print_rank_0¶
- lightrft.utils.utils.print_rank_0(msg: str, *args: Any, **kwargs: Any) None[source]¶
Prints message only from rank 0 process in distributed training.
This function helps avoid duplicate prints in multi-GPU training by only printing from the main process (rank 0).
- Parameters:
msg (str) – The message to print
args – Additional positional arguments to include in the message
kwargs – Additional keyword arguments to include in the message
Example:
>>> print_rank_0("Training started", epoch=1)
get_torch_profiler¶
- lightrft.utils.utils.get_torch_profiler(output_file: str, warmup: int = 1, active: int = 1, repeat: int = 1) torch.profiler.profile | DummyProfile[source]¶
Creates and returns a PyTorch profiler configured for distributed training.
This function returns a properly configured PyTorch profiler for the current process. For rank 0 process, it returns a full-featured profiler that records CPU and CUDA activities. For other ranks, it returns a dummy profiler that does nothing.
For more details on PyTorch Profiler, see: https://docs.pytorch.org/docs/stable/profiler.html
- Parameters:
output_file (str) – Path where profiling results will be saved (only for rank 0)
warmup (int) – Number of steps to wait before profiling starts
active (int) – Number of steps with active profiling
repeat (int) – Number of times to repeat the profiling cycle
- Returns:
A PyTorch profiler object or a dummy profiler
- Return type:
torch.profiler.profile or DummyProfile
Example:
>>> with get_torch_profiler("./profiler_output", warmup=5, active=10) as prof: >>> for step in range(100): >>> train_step() >>> prof.step()
DummyProfile¶
- class lightrft.utils.utils.DummyProfile(*args, **kwargs)[source]¶
Dummy Profile class that mimics the PyTorch profiler API but does nothing.
This class is used as a placeholder for non-rank-0 processes where profiling is not needed, allowing the same code to be used across all processes without conditional branches.
Example:
>>> prof = DummyProfile() if rank != 0 else torch.profiler.profile(...) >>> with prof: >>> # code to be profiled
- __enter__() DummyProfile[source]¶
Context manager entry method.
- Returns:
Self instance
- Return type:
- __exit__(a: Any, b: Any, c: Any) None[source]¶
Context manager exit method.
- Parameters:
a – Exception type
b – Exception value
c – Exception traceback