Shortcuts

Source code for lightrft.strategy.utils.ckpt_utils

"""
Utility module for finding the latest checkpoint directory in machine learning training workflows.

This module provides functionality to locate the most recent checkpoint directory based on
a naming pattern that includes a step number. It's commonly used in deep learning training
scenarios where checkpoints are saved periodically with incremental step numbers.
"""

import os
import re
from typing import Optional


[docs]def find_latest_checkpoint_dir(load_dir: str, prefix: str = "global_step") -> Optional[str]: """ Finds the latest subdirectory within the specified directory whose name matches the '<prefix><number>' format. This function is particularly useful in machine learning training scenarios where checkpoints are saved with incremental step numbers. It searches through all subdirectories in the given path and returns the one with the highest step number that matches the specified prefix pattern. If no matching subdirectory is found, returns the original `load_dir`. :param load_dir: The path to the parent directory containing checkpoint subdirectories. :type load_dir: str :param prefix: The expected prefix string at the beginning of checkpoint directory names. Defaults to "global_step". :type prefix: str, optional :return: The full path to the latest checkpoint subdirectory. Returns `load_dir` if no matching subdirectory is found. Returns `None` if `load_dir` is invalid (does not exist or is not a directory). :rtype: str or None Example:: # Find latest checkpoint with default prefix "global_step" latest_dir = find_latest_checkpoint_dir("/path/to/checkpoints") # Returns: "/path/to/checkpoints/global_step1000" (if it's the highest numbered) # Find latest checkpoint with custom prefix latest_dir = find_latest_checkpoint_dir("/path/to/models", prefix="step_") # Returns: "/path/to/models/step_500" (if it's the highest numbered) # Handle case where directory doesn't exist result = find_latest_checkpoint_dir("/nonexistent/path") # Returns: None """ # Check if load_dir exists and is a directory if not os.path.isdir(load_dir): print(f"Error: Directory '{load_dir}' not found or is not a valid directory.") return None latest_step = -1 # Initialize with a step number lower than any possible step # Default return value is the original path; it will be overwritten if a match is found latest_ckpt_path = load_dir try: # Regex: Matches start (^), escaped prefix, one or more digits (\d+), end ($) pattern_str = rf"^{re.escape(prefix)}(\d+)$" pattern = re.compile(pattern_str) except re.error as e: # Invalid regex prefix print(f"Error: Invalid prefix '{prefix}' resulted in a regex error: {e}") return None try: # Iterate through all entries in load_dir for item_name in os.listdir(load_dir): item_path = os.path.join(load_dir, item_name) # Check if the current item is a directory if os.path.isdir(item_path): match = pattern.match(item_name) if match: try: # Extract the numeric part step_num = int(match.group(1)) # If the current step number is greater than the recorded latest step if step_num > latest_step: latest_step = step_num # Update to the path of the latest checkpoint directory latest_ckpt_path = item_path except ValueError: # If \d+ matched but couldn't be converted to int (should not happen theoretically) print(f"Warning: Could not parse step number from directory name '{item_name}'.") # Skip this directory continue except OSError as e: # Catch potential OS errors during listdir (e.g., permission issues) print(f"Error: An OS error occurred while accessing directory '{load_dir}': {e}") return None # If a matching directory was found, latest_ckpt_path was updated to the latest one. # If none was found, latest_ckpt_path remains the initial load_dir. return latest_ckpt_path