Skip to main content
TrainingClient talks directly to a dp_worker instance. Long-running operations use a submit-and-retrieve protocol: the submit fires immediately on the calling thread (so validation errors surface at call time) and .result() long-polls the server until the operation finishes. You can submit multiple operations before awaiting any of them. Construct one with ServiceClient.create_lora_training_client. Run one training step and save a checkpoint. Each long-running call is submit-then-.result():
from baseten.loops import Datum, ModelInput, TensorData, AdamParams

# tokens and targets come from tokenizing a masked prompt/answer pair;
# see the quickstart for the full tokenization step.
datum = Datum(
    model_input=ModelInput.from_ints(tokens),
    loss_fn_inputs={"target_tokens": TensorData(data=targets, dtype="int64", shape=[len(targets)])},
)

fb = training_client.forward_backward(data=[datum]).result(timeout=600.0)
training_client.optim_step(AdamParams(learning_rate=4e-5)).result(timeout=600.0)
save_resp = training_client.save_state(name="step-1").result(timeout=600.0)
Every long-running server operation on ServiceClient, TrainingClient, and SamplingClient (for example, forward_backward, sample, create_lora_training_client) has an await-able *_async counterpart for callers running their own event loop. The async variants accept the same arguments as their synchronous names. Simpler blocking calls like health, ensure_ready, get_tokenizer, and close (whose async form is aclose) have no *_async twin.
forward_backward(data, loss_fn="cross_entropy", loss_fn_config=None)
ForwardBackwardFuture
Run a forward and backward pass over data (a list of Datum objects) using the specified loss function. Returns a ForwardBackwardFuture; call .result() to block until the pass completes and retrieve the loss.
forward(data, loss_fn="cross_entropy", loss_fn_config=None)
ForwardBackwardFuture
Run a forward pass without gradient computation. Same inputs and output shape as forward_backward, but the gradient buffer is left untouched, so it is safe to interleave with gradient accumulation steps.
optim_step(adam_params)
OperationFuture[OptimStepResponse]
Apply the accumulated gradients using the Adam optimizer configured by AdamParams. Call this after one or more forward_backward calls.
save_state(name, ttl_seconds=None)
OperationFuture[SaveWeightsResponse]
Persist a local training checkpoint under name. When a weight sync URI is configured server-side, save_state also publishes the LoRA adapter so a polling sampler can hot-swap to the new weights.
save_weights_for_sampler(name, ttl_seconds=None)
OperationFuture[SaveWeightsResponse]
Publish the LoRA adapter to the paired sampling server under name without returning a snapshot-pinned SamplingClient. Use this when you don’t need the version gate that save_weights_and_get_sampling_client provides.
save_weights_and_get_sampling_client(name)
_ComposedFuture[SamplingClient]
Publish the LoRA adapter to the paired sampling server under name and return a future that resolves to a SamplingClient pinned to the newly published version. Calling .result() runs two stages: the trainer publishes weights, then the SDK polls the sampler until at least one replica reports the new version loaded. The sampler-wait phase has a fixed 600-second ceiling independent of the timeout= you pass to .result(); if no replica reports the new version by then, the call raises RuntimeError. The returned SamplingClient carries X-Min-Policy-Version on every subsequent sample() call, so requests only land on replicas that have the right weights.
load_state(path)
OperationFuture[LoadWeightsResponse]
Load weights from a bt://loops:<run_id>/weights/<checkpoint> URI into this trainer. Use to resume training from a checkpoint.
load_state_with_optimizer(path)
OperationFuture[LoadWeightsResponse]
Same as load_state but also restores Adam moments. Use when you want bit-exact resumption.
list_checkpoints()
list[Checkpoint]
List checkpoints for the run bound to this client. Requires that this client was constructed using ServiceClient.create_lora_training_client (which populates the necessary session and run IDs automatically). Returns a list of Checkpoint.
get_checkpoint_archive_url(checkpoint_id, page_size=1000, page_token=0)
CheckpointFilesResponse
Return presigned URLs for every file in a checkpoint folder. Same semantics as ServiceClient.get_checkpoint_archive_url.
create_sampling_client(model_path)
SamplingClient
Return a SamplingClient bound to the paired sampler, loading the weights at model_path (a bt://loops:<run_id>/sampler_weights/<checkpoint> URI). Distinct from ServiceClient.create_sampling_client, which provisions a fresh sampler.
get_tokenizer()
PreTrainedTokenizer
Return the Hugging Face PreTrainedTokenizer for the base model. Cached after the first load.
get_info()
GetInfoResponse
Return the model configuration for this training session (base model name, LoRA rank, and max sequence length) without a server round-trip.
run_id
str | None
Property. The run ID this client is bound to. Use this when filtering checkpoints or making HTTP API calls against the same run.
policy_version
int
Property. The current policy version the trainer has published. Incremented on each save_weights_and_get_sampling_client (or save_weights_for_sampler) call.
init_trainer_server(lora_rank)
OperationFuture[InitTrainerServerResponse]
Reset trainer state to a fresh LoRA adapter at lora_rank. Use to start a new adapter on an existing trainer without reprovisioning.
health()
None
Check the trainer’s /health endpoint. Returns None on success and raises if the trainer is unreachable or unhealthy.
close()
None
Close the client’s HTTP connections and finish any active Weights & Biases run. Pure-async callers can use aclose() instead, which closes connections directly on the running event loop.