Checkpoint loading lets you resume training from previously saved model states. When enabled, Baseten automatically downloads your specified checkpoints to the training environment before your training code starts.
Use cases:
- Resume failed training jobs
- Incremental training and fine-tuning
Accessing Downloaded Checkpoints
Checkpoints are available through the BT_LOAD_CHECKPOINT_DIR
environment variable. For single-node training, they’re located in BT_LOAD_CHECKPOINT_DIR/rank-0/
.
Checkpoint restoration currently does not support loading weights that are sharded across multiple nodes.
Checkpoint Reference
Create references to checkpoints using the BasetenCheckpoint
factory:
From Latest
# Load the latest checkpoint from a project
BasetenCheckpoint.from_latest_checkpoint(project_name="my-training-project")
# Load the latest checkpoint from a previous job
BasetenCheckpoint.from_latest_checkpoint(job_id="gvpql31")
Parameters:
project_name
: Load the latest checkpoint from the most recent job in this project
job_id
: Load the latest checkpoint from this specific job
- Both parameters: Load the latest checkpoint from that specific job in that project
From Named
# Pin your starting point to a specific checkpoint
BasetenCheckpoint.from_named_checkpoint(checkpoint_name="checkpoint-20", job_id="gvpql31")
Parameters:
checkpoint_name
: The name of the specific checkpoint to load
job_id
: The job that contains the named checkpoint
- Both parameters: Load the named checkpoint from that specific job in that project
Configuration Examples
Here are practical examples of how to configure checkpoint loading in your training jobs:
From Latest
# Latest checkpoint from project
load_config = LoadCheckpointConfig(
enabled=True,
checkpoints=[
BasetenCheckpoint.from_latest_checkpoint(project_name="gpt-finetuning")
]
)
# Latest checkpoint from specific job
load_config = LoadCheckpointConfig(
enabled=True,
checkpoints=[
BasetenCheckpoint.from_latest_checkpoint(job_id="gvpql31")
]
)
From Named
# Specific named checkpoint
load_config = LoadCheckpointConfig(
enabled=True,
checkpoints=[
BasetenCheckpoint.from_named_checkpoint(
checkpoint_name="checkpoint-20",
job_id="gvpql31"
)
]
)
# Named checkpoint with custom download location
load_config = LoadCheckpointConfig(
enabled=True,
download_folder="/tmp/my_checkpoints",
checkpoints=[
BasetenCheckpoint.from_named_checkpoint(
checkpoint_name="checkpoint-20",
job_id="rwnojdq"
)
]
)
Configuration parameters for :
enabled
: Set to True
to enable checkpoint loading
checkpoints
: List containing checkpoint references
download_folder
: Optional custom download location (defaults to /tmp/loaded_checkpoints
)
Complete TrainingJob Setup
from truss_train import LoadCheckpointConfig, BasetenCheckpoint, CheckpointingConfig, TrainingJob, Image, Runtime, TrainingProject
from truss_train.definitions import CacheConfig
# Configure checkpoint loading
load_checkpoint_config = LoadCheckpointConfig(
enabled=True,
download_folder="/tmp/loaded_checkpoints",
checkpoints=[
BasetenCheckpoint.from_latest_checkpoint(job_id="previous_job_id")
]
)
# Configure checkpointing for saving new checkpoints
checkpointing_config = CheckpointingConfig(
enabled=True,
checkpoint_path="/tmp/training_checkpoints"
)
# Create TrainingJob
job = TrainingJob(
image=Image(base_image="your-base-image"),
runtime=Runtime(
checkpointing_config=checkpointing_config,
load_checkpoint_config=load_checkpoint_config,
start_commands=["chmod +x ./run.sh && ./run.sh"],
cache_config=CacheConfig(enabled=True)
),
)
project = TrainingProject(name="my-training-project", job=job)
Using Checkpoints in Your Training Code
Access loaded checkpoints using the BT_LOAD_CHECKPOINT_DIR
environment variable:
from transformers import AutoModelForSequenceClassification, AutoTokenizer, TrainingArguments, Trainer
from transformers.trainer_utils import get_last_checkpoint
import os
def train():
checkpoint_dir = os.environ.get("BT_LOAD_CHECKPOINT_DIR")
last_checkpoint = None
if checkpoint_dir:
last_checkpoint = get_last_checkpoint(checkpoint_dir)
if last_checkpoint:
print(f"✅ Resuming from checkpoint: {last_checkpoint}")
model = AutoModelForSequenceClassification.from_pretrained(last_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint_dir)
else:
print("⚠️ No checkpoint found, starting from scratch")
model = AutoModelForSequenceClassification.from_pretrained("your-base-model")
tokenizer = AutoTokenizer.from_pretrained("your-base-model")
else:
print("ℹ️ No checkpoint loading configured")
model = AutoModelForSequenceClassification.from_pretrained("your-base-model")
tokenizer = AutoTokenizer.from_pretrained("your-base-model")
training_args = TrainingArguments(
output_dir=os.environ.get("BT_CHECKPOINT_DIR", "/tmp/training_checkpoints"),
save_strategy="steps",
save_steps=1000,
load_best_model_at_end=True,
)
trainer = Trainer(model=model, args=training_args)
trainer.train(resume_from_checkpoint=last_checkpoint)