diff --git a/src/instructlab/training/batch_loss_manager.py b/src/instructlab/training/batch_loss_manager.py index cc6da021..b199ac17 100644 --- a/src/instructlab/training/batch_loss_manager.py +++ b/src/instructlab/training/batch_loss_manager.py @@ -7,7 +7,8 @@ """ # Standard -from dataclasses import dataclass +from collections.abc import Callable +from dataclasses import dataclass, field import logging # Third Party @@ -33,6 +34,7 @@ class BatchMetrics: accumulated_aux_loss: torch.Tensor | None grad_accum_steps: int num_minibatches: int + interrupted: bool = field(default=False) class BatchLossManager: @@ -62,12 +64,21 @@ def __init__(self, model, accelerator, world_size: int, local_rank: int): self.local_rank: int = local_rank self.torch_device = torch.device("cuda", local_rank) - def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]: + def process_batch( + self, + batch: list[CollatedItem], + interrupt_check: Callable[[], bool] | None = None, + ) -> tuple[BatchMetrics, float]: """ Process a batch of minibatches, computing losses and accumulating gradients. Args: batch: List of minibatches to process + interrupt_check: Optional callback invoked after each minibatch's + backward pass. If it returns ``True``, gradient accumulation + stops early and ``BatchMetrics.interrupted`` is set. Used by + on-demand checkpointing to react within one fwd+bwd cycle + instead of waiting for the full optimizer step. Returns: tuple: (BatchMetrics, average_loss_across_ranks) @@ -82,6 +93,7 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float] accumulated_loss = 0.0 accumulated_aux_loss = 0.0 grad_accum_steps = 0 + interrupted = False # process each minibatch for mb in batch: @@ -108,6 +120,11 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float] if raw_losses.aux_loss is not None: accumulated_aux_loss += raw_losses.aux_loss + # check for early exit (e.g. on-demand checkpoint requested) + if interrupt_check is not None and interrupt_check(): + interrupted = True + break + # reduce metrics across ranks batch_total_samples, batch_total_length = self._reduce_metrics( batch_total_samples, batch_total_length @@ -127,6 +144,7 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float] accumulated_aux_loss=accumulated_aux_loss, grad_accum_steps=grad_accum_steps, num_minibatches=num_minibatches, + interrupted=interrupted, ) return metrics, avg_loss_across_ranks diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index 911c3898..11182a0e 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -351,6 +351,19 @@ class TrainingArgs(BaseModel): description="How often to evaluate validation loss (in training steps). Required when validation_split > 0.", ) + on_demand_checkpointing: bool = Field( + default=False, + description=( + "Enable on-demand full-state checkpointing triggered by Unix signals. " + "When enabled, the parent process intercepts termination signals " + "(SIGTERM, SIGINT, SIGUSR1, SIGUSR2, SIGXCPU, SIGHUP) and writes a " + "trigger file to /dev/shm. Worker processes check for this trigger " + "after each minibatch backward pass and collectively save a distributed " + "checkpoint before exiting gracefully. Designed for OpenShift AI / " + "KubeFlow training jobs where preemption signals must be handled." + ), + ) + @model_validator(mode="after") def validate_validation_config(self): if not 0.0 <= self.validation_split < 1.0: diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index 072b27c6..91f73008 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -173,6 +173,7 @@ def train( accelerator: Accelerator, val_data_loader=None, validation_frequency=None, + on_demand_checkpointing: bool = False, ): model.train() @@ -183,6 +184,18 @@ def train( metric_logger = logging.getLogger("instructlab.training.metrics") base_logger = logging.getLogger("instructlab.training") + # Import on-demand checkpointing utilities once if the feature is enabled + checkpoint_job_id = None + if on_demand_checkpointing: + # First Party + from instructlab.training.on_demand_checkpoint import ( + check_checkpoint_requested, + save_on_demand_checkpoint, + ) + + checkpoint_job_id = os.environ.get("INSTRUCTLAB_ON_DEMAND_JOB_ID") + base_logger.info("On-demand checkpointing is enabled in worker process.") + # Mini_trainer approach: batch_size will be determined dynamically by data loader # For save logic, use effective_batch_size since that's the target samples_seen = 0 @@ -220,11 +233,38 @@ def train( continue start = time.time() - # Process the batch using the BatchLossManager + # Process the batch using the BatchLossManager. + # When on-demand checkpointing is enabled, pass a callback so + # the check runs after every minibatch backward rather than + # waiting for the full optimizer step. + _interrupt_check = ( + (lambda: check_checkpoint_requested(checkpoint_job_id)) + if on_demand_checkpointing + else None + ) batch_metrics, avg_loss_across_ranks = batch_loss_manager.process_batch( - batch + batch, interrupt_check=_interrupt_check ) + # If the batch was interrupted by an on-demand checkpoint + # request, save immediately and exit — skip the optimizer step + # since we want to preserve the pre-step model state for + # exact resumption. + if batch_metrics.interrupted: + save_on_demand_checkpoint( + args=args, + accelerator=accelerator, + model=model, + tokenizer=model.tokenizer, + samples_seen=samples_seen, + epoch=epoch, + is_lora=bool(args.lora_r), + ) + base_logger.info( + "On-demand checkpoint saved. Exiting training gracefully." + ) + return + # Update samples seen samples_seen += batch_metrics.total_samples @@ -561,6 +601,7 @@ def main(args): accelerator=accelerator, val_data_loader=val_loader, validation_frequency=validation_frequency, + on_demand_checkpointing=getattr(args, "on_demand_checkpointing", False), ) dist.barrier() @@ -791,7 +832,29 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: if train_args.keep_last_checkpoint_only: command.append("--keep_last_checkpoint_only") + if train_args.on_demand_checkpointing: + command.append("--on_demand_checkpointing") + logger.info("Running training command as subprocess: %s", " ".join(command)) + + # --- On-demand checkpointing: install signal handlers in the parent --- + signal_handler = None + if train_args.on_demand_checkpointing: + # First Party + from instructlab.training.on_demand_checkpoint import ParentSignalHandler + + # Use rdzv_id to namespace the trigger file so concurrent jobs + # sharing /dev/shm don't interfere with each other. + checkpoint_job_id = str(torch_args.rdzv_id) + os.environ["INSTRUCTLAB_ON_DEMAND_JOB_ID"] = checkpoint_job_id + signal_handler = ParentSignalHandler(job_id=checkpoint_job_id) + signal_handler.install() + logger.info( + "On-demand checkpointing is ENABLED (job_id=%s). " + "Termination signals will trigger a full-state checkpoint before exit.", + checkpoint_job_id, + ) + process = None interrupt: KeyboardInterrupt | Exception | None = None failure = False @@ -811,36 +874,85 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: interrupt = e finally: if "process" not in locals() or process is None: + if signal_handler is not None: + signal_handler.uninstall() return - # wait for the process to exit so we can properly read the exit code - process.wait(timeout=60) - process_code = process.poll() - failure = process_code != 0 - - if not failure: - logger.info("Operation completed successfully! 🎉") - else: - logger.error( - f"Training subprocess has not exited yet. Sending SIGTERM. Process code: {process_code}" + # If a signal was caught by the on-demand checkpoint handler, give + # the workers time to detect the trigger file and save a checkpoint + # before we start sending our own signals to the subprocess. + if signal_handler is not None and signal_handler.signal_received is not None: + logger.info( + "On-demand checkpoint: signal %s received. Waiting for workers to " + "save checkpoint before proceeding with shutdown...", + signal_handler.signal_received.name, ) + # Give workers generous time to complete the checkpoint save. + # The workers will exit on their own after saving. + try: + process.wait(timeout=300) + except subprocess.TimeoutExpired: + logger.warning( + "On-demand checkpoint: workers did not finish within 300s. " + "Proceeding with shutdown." + ) - process.terminate() + # wait for the process to exit so we can properly read the exit code try: - logger.info("Waiting for process to exit, 60s...") process.wait(timeout=60) except subprocess.TimeoutExpired: + pass + process_code = process.poll() + + if process_code is not None and process_code == 0: + logger.info("Operation completed successfully!") + elif process_code is None: + logger.error("Training subprocess has not exited yet. Sending SIGTERM.") + process.terminate() + try: + logger.info("Waiting for process to exit, 60s...") + process.wait(timeout=60) + except subprocess.TimeoutExpired: + logger.error( + "Training subprocess did not terminate before timeout, sending SIGKILL." + ) + process.kill() + try: + process.wait(timeout=10) + except subprocess.TimeoutExpired: + pass + else: logger.error( - "Training subprocess did not terminate before timeout, sending SIGKILL." + "Training subprocess exited with code %d.", + process_code, ) - process.kill() + + # Recompute final exit status after any forced shutdown + process_code = process.poll() + failure = process_code is None or process_code != 0 + + if signal_handler is not None: + signal_handler.uninstall() if interrupt: raise interrupt if failure: - raise RuntimeError( - "Suffered a failure during distributed training. Please see the training logs for more context." - ) + msg = "Suffered a failure during distributed training. Please see the training logs for more context." + if ( + signal_handler is not None + and signal_handler.signal_received is not None + ): + msg += ( + f"\n\nNote: signal {signal_handler.signal_received.name} was" + " received and on-demand checkpointing was enabled, but the" + " training subprocess did not exit cleanly. This usually" + " means the process was killed (SIGKILL) before the" + " checkpoint could be saved. To fix this, increase" + " terminationGracePeriodSeconds in your pod spec to give" + " workers more time, or reduce the model's forward/backward" + " pass time so the checkpoint check fires sooner." + ) + raise RuntimeError(msg) if __name__ == "__main__": @@ -1045,6 +1157,17 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None: ), ) + parser.add_argument( + "--on_demand_checkpointing", + action="store_true", + default=False, + help=( + "Enable on-demand full-state checkpointing triggered by Unix signals. " + "When enabled, workers check for a trigger file in /dev/shm after each " + "minibatch backward pass and collectively save a distributed checkpoint before " + "exiting. Designed for OpenShift AI / KubeFlow preemption handling." + ), + ) parser.add_argument( "--use_liger", action="store_true", diff --git a/src/instructlab/training/on_demand_checkpoint.py b/src/instructlab/training/on_demand_checkpoint.py new file mode 100644 index 00000000..b9643531 --- /dev/null +++ b/src/instructlab/training/on_demand_checkpoint.py @@ -0,0 +1,284 @@ +# SPDX-License-Identifier: Apache-2.0 + +""" +On-demand checkpointing for distributed training. + +This module enables graceful checkpoint-and-exit when termination signals are +received. It is designed for environments like OpenShift AI / KubeFlow where +training jobs can be preempted at any time and the platform sends Unix signals +before killing the pod. + +Architecture +------------ +There are two sides to this feature: + +**Parent process** (``run_training`` in ``main_ds.py``): + Installs signal handlers that catch every signal OpenShift / Kubernetes can + send before a SIGKILL. When a signal arrives the handler writes a small + *trigger file* to ``/dev/shm`` (a tmpfs shared between containers in the + same pod). Because ``/dev/shm`` is node-local, every worker on the **same + node** can see the file instantly with zero network I/O. + +**Worker processes** (torchrun children): + After every optimizer step the training loop calls + ``check_checkpoint_requested()``. Each rank checks its local ``/dev/shm`` + for the trigger file, converts the boolean to a tensor, and does an + ``all_reduce(MAX)`` so that if *any* rank on *any* node detected the + trigger, *every* rank agrees to save a checkpoint. This works correctly in + multi-node training because all_reduce is a global collective. + +Signals handled +--------------- +We intercept every signal that Kubernetes / OpenShift can deliver before the +hard SIGKILL (which cannot be caught): + +* **SIGTERM** – the standard graceful-shutdown signal. Kubernetes sends this + first (configurable via ``terminationGracePeriodSeconds``). +* **SIGINT** – sent on Ctrl-C or by some job controllers. +* **SIGUSR1 / SIGUSR2** – commonly used by batch schedulers and custom + preemption controllers to signal upcoming eviction. +* **SIGXCPU** – sent when CPU time limits are exceeded (relevant for jobs + with resource quotas). +* **SIGHUP** – sent when the controlling terminal disconnects; some + container runtimes forward this on pod eviction. +""" + +# Standard +from pathlib import Path +from typing import Callable, Optional, Union +import logging +import os +import signal +import tempfile +import types + +# Third Party +import torch +import torch.distributed as dist + +# Type alias matching the return type of signal.getsignal(). +_SignalHandler = Union[ + Callable[[int, Optional[types.FrameType]], None], int, signal.Handlers, None +] + +logger = logging.getLogger("instructlab.training") + +# --------------------------------------------------------------------------- +# Trigger file helpers +# --------------------------------------------------------------------------- + +# The trigger file lives in /dev/shm which is a tmpfs (RAM-backed filesystem). +# It is: +# 1. Extremely fast (no disk I/O). +# 2. Shared between all containers in the same Kubernetes pod. +# 3. Automatically cleaned up when the pod is destroyed. +_TRIGGER_DIR = Path("/dev/shm") +_TRIGGER_FILENAME = "instructlab_checkpoint_requested" + + +def _get_trigger_path(job_id: Optional[str] = None) -> Path: + """Return the path to the checkpoint trigger file. + + An optional *job_id* can be supplied to avoid collisions if multiple + training jobs share the same ``/dev/shm`` (unlikely but possible). + """ + name = f"{_TRIGGER_FILENAME}_{job_id}" if job_id else _TRIGGER_FILENAME + return _TRIGGER_DIR / name + + +def write_trigger_file(job_id: Optional[str] = None) -> Path: + """Create the trigger file that tells workers to checkpoint. + + This is called from the *parent* process signal handler. + Returns the path that was written. + """ + path = _get_trigger_path(job_id) + # Use a atomic write via tempfile + rename to avoid partial reads. + fd, tmp = tempfile.mkstemp(dir=_TRIGGER_DIR, prefix=".ckpt_trigger_") + try: + os.write(fd, b"1") + finally: + os.close(fd) + os.rename(tmp, path) + logger.info( + "On-demand checkpoint trigger file written: %s", + path, + ) + return path + + +def trigger_file_exists(job_id: Optional[str] = None) -> bool: + """Check whether the trigger file exists (worker-side).""" + return _get_trigger_path(job_id).exists() + + +def remove_trigger_file(job_id: Optional[str] = None) -> None: + """Remove the trigger file after the checkpoint has been saved.""" + path = _get_trigger_path(job_id) + try: + path.unlink(missing_ok=True) + except OSError: + pass + + +# --------------------------------------------------------------------------- +# Parent-side signal handling +# --------------------------------------------------------------------------- + +# Signals that OpenShift / Kubernetes / batch schedulers may send before +# the hard SIGKILL. SIGKILL (9) and SIGSTOP (19) cannot be caught. +_CATCHABLE_SIGNALS = ( + signal.SIGTERM, # Kubernetes default graceful shutdown signal + signal.SIGINT, # Ctrl-C / some job controllers + signal.SIGUSR1, # Custom preemption controllers + signal.SIGUSR2, # Custom preemption controllers + signal.SIGXCPU, # CPU time limit exceeded (resource quotas) + signal.SIGHUP, # Terminal disconnect / some eviction paths +) + + +class ParentSignalHandler: + """Installs signal handlers in the parent (launcher) process. + + When any of the catchable signals fire, the handler: + 1. Writes the trigger file to ``/dev/shm``. + 2. Records that a signal was received (so the caller can decide to + wait for the child process to finish checkpointing). + + The handler is idempotent – multiple signals will not create multiple + trigger files. + + Parameters + ---------- + job_id : str, optional + Unique identifier for this training job. Used to namespace the + trigger file. + """ + + def __init__(self, job_id: Optional[str] = None): + self.job_id = job_id + self.signal_received: Optional[signal.Signals] = None + self._original_handlers: dict[signal.Signals, _SignalHandler] = {} + self._trigger_written = False + + def install(self) -> None: + """Register signal handlers for all catchable signals.""" + for sig in _CATCHABLE_SIGNALS: + try: + self._original_handlers[sig] = signal.getsignal(sig) + signal.signal(sig, self._handle) + except (OSError, ValueError): + # Some signals may not be available on all platforms + logger.debug("Could not install handler for %s", sig.name) + + logger.info( + "On-demand checkpoint signal handlers installed for: %s", + ", ".join(s.name for s in self._original_handlers), + ) + + def uninstall(self) -> None: + """Restore original signal handlers.""" + for sig, handler in self._original_handlers.items(): + try: + signal.signal(sig, handler) # type: ignore[arg-type] + except (OSError, ValueError): + pass + self._original_handlers.clear() + + def _handle(self, signum: int, _frame) -> None: + """Signal handler callback.""" + sig = signal.Signals(signum) + logger.info( + "On-demand checkpoint: received signal %s (%d). " + "Writing trigger file for workers to checkpoint before exit.", + sig.name, + signum, + ) + self.signal_received = sig + + if not self._trigger_written: + write_trigger_file(self.job_id) + self._trigger_written = True + + +# --------------------------------------------------------------------------- +# Worker-side synchronization +# --------------------------------------------------------------------------- + + +def check_checkpoint_requested(job_id: Optional[str] = None) -> bool: + """Check across all ranks whether an on-demand checkpoint was requested. + + This function must be called by **all ranks** at the same point in the + training loop (it contains a collective all_reduce). + + Returns ``True`` if any rank detected the trigger file, meaning all + ranks should save a checkpoint. + """ + local_trigger = trigger_file_exists(job_id) + + # Convert to a tensor and all-reduce (MAX) so that if ANY rank on ANY + # node saw the trigger, every rank gets True. + trigger_tensor = torch.tensor( + [1 if local_trigger else 0], + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + dist.all_reduce(trigger_tensor, op=dist.ReduceOp.MAX) + + requested = trigger_tensor.item() > 0 + + if requested: + if dist.is_initialized() and dist.get_rank() == 0: + logger.info( + "On-demand checkpoint: global consensus reached – " + "all ranks will save a checkpoint." + ) + # Clean up the trigger file so that if the process somehow + # continues, we don't save again immediately. + remove_trigger_file(job_id) + + return requested + + +def save_on_demand_checkpoint( + args, + accelerator, + model, + tokenizer, + samples_seen: int, + epoch: int, + is_lora: bool, +) -> None: + """Save a full-state distributed checkpoint for on-demand resume. + + This is a thin wrapper that calls the existing ``save_checkpoint`` + utility with ``full_state=True`` so that optimizer + LR scheduler + state are also persisted, enabling exact training resumption. + """ + # First Party + from instructlab.training.utils import save_checkpoint + + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + if local_rank == 0: + logger.info( + "On-demand checkpoint: saving full-state checkpoint at " + "epoch=%d, samples_seen=%d", + epoch, + samples_seen, + ) + + save_checkpoint( + args=args, + accelerator=accelerator, + model=model, + tokenizer=tokenizer, + samples_seen=samples_seen, + is_lora=is_lora, + full_state=True, + hf_format=True, + epoch=epoch, + ) + + if local_rank == 0: + logger.info("On-demand checkpoint: checkpoint saved successfully.")