Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions src/instructlab/training/batch_loss_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
"""

# Standard
from dataclasses import dataclass
from collections.abc import Callable
from dataclasses import dataclass, field
import logging

# Third Party
Expand All @@ -33,6 +34,7 @@ class BatchMetrics:
accumulated_aux_loss: torch.Tensor | None
grad_accum_steps: int
num_minibatches: int
interrupted: bool = field(default=False)


class BatchLossManager:
Expand Down Expand Up @@ -62,12 +64,21 @@ def __init__(self, model, accelerator, world_size: int, local_rank: int):
self.local_rank: int = local_rank
self.torch_device = torch.device("cuda", local_rank)

def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]:
def process_batch(
self,
batch: list[CollatedItem],
interrupt_check: Callable[[], bool] | None = None,
) -> tuple[BatchMetrics, float]:
"""
Process a batch of minibatches, computing losses and accumulating gradients.

Args:
batch: List of minibatches to process
interrupt_check: Optional callback invoked after each minibatch's
backward pass. If it returns ``True``, gradient accumulation
stops early and ``BatchMetrics.interrupted`` is set. Used by
on-demand checkpointing to react within one fwd+bwd cycle
instead of waiting for the full optimizer step.

Returns:
tuple: (BatchMetrics, average_loss_across_ranks)
Expand All @@ -82,6 +93,7 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]
accumulated_loss = 0.0
accumulated_aux_loss = 0.0
grad_accum_steps = 0
interrupted = False

# process each minibatch
for mb in batch:
Expand All @@ -108,6 +120,11 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]
if raw_losses.aux_loss is not None:
accumulated_aux_loss += raw_losses.aux_loss

# check for early exit (e.g. on-demand checkpoint requested)
if interrupt_check is not None and interrupt_check():
interrupted = True
break

# reduce metrics across ranks
batch_total_samples, batch_total_length = self._reduce_metrics(
batch_total_samples, batch_total_length
Expand All @@ -127,6 +144,7 @@ def process_batch(self, batch: list[CollatedItem]) -> tuple[BatchMetrics, float]
accumulated_aux_loss=accumulated_aux_loss,
grad_accum_steps=grad_accum_steps,
num_minibatches=num_minibatches,
interrupted=interrupted,
)

return metrics, avg_loss_across_ranks
Expand Down
13 changes: 13 additions & 0 deletions src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,19 @@ class TrainingArgs(BaseModel):
description="How often to evaluate validation loss (in training steps). Required when validation_split > 0.",
)

on_demand_checkpointing: bool = Field(
default=False,
description=(
"Enable on-demand full-state checkpointing triggered by Unix signals. "
"When enabled, the parent process intercepts termination signals "
"(SIGTERM, SIGINT, SIGUSR1, SIGUSR2, SIGXCPU, SIGHUP) and writes a "
"trigger file to /dev/shm. Worker processes check for this trigger "
"after each minibatch backward pass and collectively save a distributed "
"checkpoint before exiting gracefully. Designed for OpenShift AI / "
"KubeFlow training jobs where preemption signals must be handled."
),
)

@model_validator(mode="after")
def validate_validation_config(self):
if not 0.0 <= self.validation_split < 1.0:
Expand Down
161 changes: 142 additions & 19 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def train(
accelerator: Accelerator,
val_data_loader=None,
validation_frequency=None,
on_demand_checkpointing: bool = False,
):
model.train()

Expand All @@ -183,6 +184,18 @@ def train(
metric_logger = logging.getLogger("instructlab.training.metrics")
base_logger = logging.getLogger("instructlab.training")

# Import on-demand checkpointing utilities once if the feature is enabled
checkpoint_job_id = None
if on_demand_checkpointing:
# First Party
from instructlab.training.on_demand_checkpoint import (
check_checkpoint_requested,
save_on_demand_checkpoint,
)

checkpoint_job_id = os.environ.get("INSTRUCTLAB_ON_DEMAND_JOB_ID")
base_logger.info("On-demand checkpointing is enabled in worker process.")

# Mini_trainer approach: batch_size will be determined dynamically by data loader
# For save logic, use effective_batch_size since that's the target
samples_seen = 0
Expand Down Expand Up @@ -220,11 +233,38 @@ def train(
continue
start = time.time()

# Process the batch using the BatchLossManager
# Process the batch using the BatchLossManager.
# When on-demand checkpointing is enabled, pass a callback so
# the check runs after every minibatch backward rather than
# waiting for the full optimizer step.
_interrupt_check = (
(lambda: check_checkpoint_requested(checkpoint_job_id))
if on_demand_checkpointing
else None
)
batch_metrics, avg_loss_across_ranks = batch_loss_manager.process_batch(
batch
batch, interrupt_check=_interrupt_check
)

# If the batch was interrupted by an on-demand checkpoint
# request, save immediately and exit — skip the optimizer step
# since we want to preserve the pre-step model state for
# exact resumption.
if batch_metrics.interrupted:
save_on_demand_checkpoint(
args=args,
accelerator=accelerator,
model=model,
tokenizer=model.tokenizer,
samples_seen=samples_seen,
epoch=epoch,
is_lora=bool(args.lora_r),
)
base_logger.info(
"On-demand checkpoint saved. Exiting training gracefully."
)
return

# Update samples seen
samples_seen += batch_metrics.total_samples

Expand Down Expand Up @@ -561,6 +601,7 @@ def main(args):
accelerator=accelerator,
val_data_loader=val_loader,
validation_frequency=validation_frequency,
on_demand_checkpointing=getattr(args, "on_demand_checkpointing", False),
)

dist.barrier()
Expand Down Expand Up @@ -791,7 +832,29 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
if train_args.keep_last_checkpoint_only:
command.append("--keep_last_checkpoint_only")

if train_args.on_demand_checkpointing:
command.append("--on_demand_checkpointing")

logger.info("Running training command as subprocess: %s", " ".join(command))

# --- On-demand checkpointing: install signal handlers in the parent ---
signal_handler = None
if train_args.on_demand_checkpointing:
# First Party
from instructlab.training.on_demand_checkpoint import ParentSignalHandler

# Use rdzv_id to namespace the trigger file so concurrent jobs
# sharing /dev/shm don't interfere with each other.
checkpoint_job_id = str(torch_args.rdzv_id)
os.environ["INSTRUCTLAB_ON_DEMAND_JOB_ID"] = checkpoint_job_id
signal_handler = ParentSignalHandler(job_id=checkpoint_job_id)
signal_handler.install()
logger.info(
"On-demand checkpointing is ENABLED (job_id=%s). "
"Termination signals will trigger a full-state checkpoint before exit.",
checkpoint_job_id,
)

process = None
interrupt: KeyboardInterrupt | Exception | None = None
failure = False
Expand All @@ -811,36 +874,85 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
interrupt = e
finally:
if "process" not in locals() or process is None:
if signal_handler is not None:
signal_handler.uninstall()
return

# wait for the process to exit so we can properly read the exit code
process.wait(timeout=60)
process_code = process.poll()
failure = process_code != 0

if not failure:
logger.info("Operation completed successfully! 🎉")
else:
logger.error(
f"Training subprocess has not exited yet. Sending SIGTERM. Process code: {process_code}"
# If a signal was caught by the on-demand checkpoint handler, give
# the workers time to detect the trigger file and save a checkpoint
# before we start sending our own signals to the subprocess.
if signal_handler is not None and signal_handler.signal_received is not None:
logger.info(
"On-demand checkpoint: signal %s received. Waiting for workers to "
"save checkpoint before proceeding with shutdown...",
signal_handler.signal_received.name,
)
# Give workers generous time to complete the checkpoint save.
# The workers will exit on their own after saving.
try:
process.wait(timeout=300)
except subprocess.TimeoutExpired:
logger.warning(
"On-demand checkpoint: workers did not finish within 300s. "
"Proceeding with shutdown."
)

process.terminate()
# wait for the process to exit so we can properly read the exit code
try:
logger.info("Waiting for process to exit, 60s...")
process.wait(timeout=60)
except subprocess.TimeoutExpired:
pass
process_code = process.poll()

if process_code is not None and process_code == 0:
logger.info("Operation completed successfully!")
elif process_code is None:
logger.error("Training subprocess has not exited yet. Sending SIGTERM.")
process.terminate()
try:
logger.info("Waiting for process to exit, 60s...")
process.wait(timeout=60)
except subprocess.TimeoutExpired:
logger.error(
"Training subprocess did not terminate before timeout, sending SIGKILL."
)
process.kill()
try:
process.wait(timeout=10)
except subprocess.TimeoutExpired:
pass
else:
logger.error(
"Training subprocess did not terminate before timeout, sending SIGKILL."
"Training subprocess exited with code %d.",
process_code,
)
process.kill()

# Recompute final exit status after any forced shutdown
process_code = process.poll()
failure = process_code is None or process_code != 0

if signal_handler is not None:
signal_handler.uninstall()

if interrupt:
raise interrupt
if failure:
raise RuntimeError(
"Suffered a failure during distributed training. Please see the training logs for more context."
)
msg = "Suffered a failure during distributed training. Please see the training logs for more context."
if (
signal_handler is not None
and signal_handler.signal_received is not None
):
msg += (
f"\n\nNote: signal {signal_handler.signal_received.name} was"
" received and on-demand checkpointing was enabled, but the"
" training subprocess did not exit cleanly. This usually"
" means the process was killed (SIGKILL) before the"
" checkpoint could be saved. To fix this, increase"
" terminationGracePeriodSeconds in your pod spec to give"
" workers more time, or reduce the model's forward/backward"
" pass time so the checkpoint check fires sooner."
)
raise RuntimeError(msg)


if __name__ == "__main__":
Expand Down Expand Up @@ -1045,6 +1157,17 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
),
)

parser.add_argument(
"--on_demand_checkpointing",
action="store_true",
default=False,
help=(
"Enable on-demand full-state checkpointing triggered by Unix signals. "
"When enabled, workers check for a trigger file in /dev/shm after each "
"minibatch backward pass and collectively save a distributed checkpoint before "
"exiting. Designed for OpenShift AI / KubeFlow preemption handling."
),
)
parser.add_argument(
"--use_liger",
action="store_true",
Expand Down
Loading
Loading