The Training SDK provides classes for configuring and managing machine learning model training jobs on Baseten. This reference documents the key classes used to define training configurations.

Deploy a TrainingJob

To deploy a training job, use the following command:

truss train deploy <path_to_config_file>

The following classes are used to configure and deploy training jobs:

TrainingJob

Defines a complete training job configuration.

class TrainingJob:
    image: Image              # Container image configuration
    compute: Compute         # Compute resource configuration
    runtime: Runtime        # Runtime environment configuration

Example usage:

training_job = TrainingJob(
    image=Image(base_image="pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime"),
    compute=Compute(cpu_count=4, memory="16Gi"),
    runtime=Runtime(
        start_commands=["python train.py"],
        checkpointing_config=CheckpointingConfig(enabled=True),
        enable_cache=True,
    )
)

TrainingProject

Organizes training jobs and provides project-level configuration.

class TrainingProject:
    name: str           # Project name
    job: TrainingJob   # Training job configuration

Example usage:

project = TrainingProject(
    name="llm-fine-tuning",
    job=training_job
)

Image

Specifies the container image for the training environment.

class Image:
    base_image: str  # Docker image to use for training

Example usage:

image = Image(base_image="pytorch/pytorch:2.0.0-cuda11.7-cudnn8-runtime")

Compute

Specifies compute resources for training jobs.

class Compute:
    node_count: int = 1      # Number of nodes for distributed training
    cpu_count: int = 1       # Number of CPU cores
    memory: str = "2Gi"      # Memory allocation
    accelerator: Optional[AcceleratorSpec] = None  # GPU configuration

Example usage:

# Configure a training job with 2 GPUs and 4 CPUs
compute = Compute(
    accelerator=AcceleratorSpec(accelerator="H100", count=4)
)

Runtime

Defines the runtime environment for training jobs.

class Runtime:
    start_commands: List[str] = []  # Commands to run at job start
    environment_variables: Dict[str, Union[str, SecretReference]] = {}
    enable_cache: bool = False      # Enable caching
    checkpointing_config: CheckpointingConfig = CheckpointingConfig()

Example usage:

runtime = Runtime(
    start_commands=["python train.py"],
    environment_variables={
        "BATCH_SIZE": "32",
        "WANDB_API_KEY": SecretReference(name="WANDB_KEY")
    },
    checkpointing_config=CheckpointingConfig(enabled=True)
)

Training Cache

When enable_cache=True is set in your Runtime, the training cache will be enabled.

The cache will be mounted at two locations:

  • /root/.cache/huggingface
  • $BT_RW_CACHE_DIR - Baseten will export this variable in your job’s environment.

The cache storage is separate from ephemeral storage limits of your training job. Training Projects provide storage segragation within the cache. Training jobs within the same project share the same cache, while training jobs in different projects cannot access each other’s data.

SecretReference

Used to securely reference secrets stored in your Baseten workspace.

class SecretReference:
    name: str  # Name of the secret in your workspace

Example usage:

# Reference a secret named "WANDB_API_KEY" 
secret_ref = SecretReference(name="WANDB_API_KEY")

CheckpointingConfig

Configures model checkpointing behavior during training. Baseten will export the $BT_CHECKPOINT_DIR within the Training Job’s environment. The checkpointing storage is independent of the ephemeral stroage of the pod

class CheckpointingConfig:
    enabled: bool = False              # Enable/disable checkpointing
    checkpoint_path: Optional[str] = None  # Custom checkpoint directory path

Example usage:

# Enable checkpointing with default path
checkpointing = CheckpointingConfig(enabled=True)

Baseten Provided Environment Variables

Baseten automatically provides several environment variables in your training job’s environment to help integrate your code with the Baseten platform.

Environment Variables

Environment VariableDescriptionExample
BT_TRAINING_JOB_IDID of the Training Job"gvpql31"
BT_NUM_GPUSNumber of available GPUs per node"4"
BT_RW_CACHE_DIRNon-HuggingFace cache directory of the training cache mount"/root/.cache/user_artifacts"
BT_CHECKPOINTING_DIRDirectory of the automated checkpointing mount"/tmp/checkpoints"

Multinode Environment Variables

The following environment variables are particularly useful for multinode training jobs:

Environment VariableDescriptionExample
BT_GROUP_SIZENumber of nodes in the multinode deployment"2"
BT_LEADER_ADDRAddress of the leader node"10.0.0.1"
BT_NODE_RANKRank of the node"0"

For multinode deployments, any traditionally used port number (e.g. 29500) will work. There is no specific port number required by Baseten.

Deploy Checkpoints as a Model

These classes should be used with the following command

truss train deploy_checkpoints <path_to_config_file>

DeployCheckpointsRuntime

Configures the runtime environment for deployed checkpoints.

class DeployCheckpointsRuntime:
    environment_variables: Dict[str, Union[str, SecretReference]] = {}

Checkpoint

Represents metadata for a saved model checkpoint.

class Checkpoint:
    training_job_id: str   # ID of the training job
    id: str               # Checkpoint ID
    name: str            # Checkpoint name 
    lora_rank: Optional[int] = None  # LoRA rank if applicable. Auto-detected if not specified.

CheckpointList

Manages a collection of checkpoints and their download configuration.

class CheckpointList:
    download_folder: str = "training_checkpoints"  # Local download location upon deployment
    base_model_id: Optional[str] = None           # Base model identifier. Auto-dtected if not specified.
    checkpoints: List[Checkpoint] = []            # List of checkpoints

DeployCheckpointsConfig

Specifies configuration for deploying trained model checkpoints.

class DeployCheckpointsConfig:
    checkpoint_details: Optional[CheckpointList] = None  # Checkpoints to deploy
    model_name: Optional[str] = None                    # Name for the deployed model
    deployment_name: Optional[str] = None               # Name for the deployment
    runtime: Optional[DeployCheckpointsRuntime] = None  # Runtime configuration
    compute: Optional[Compute] = None                   # Compute resources

Example usage:

deploy_config = DeployCheckpointsConfig(
    model_name="fine-tuned-llm",
    deployment_name="production-llm",
    checkpoint_details=CheckpointList(
        checkpoints=[
            Checkpoint(
                training_job_id="gvpql31",
                id="checkpoint_1",
                name="checkpoint_1"
            )
        ]
    ),
    compute=Compute(
        accelerator=AcceleratorSpec(accelerator="H100", count=1)
    )
)