Skip to main content
The Model class in model/model.py is the imperative surface you reach for when config.yaml alone can’t express your logic. It gives you a Python class with lifecycle methods (__init__, load, and predict) that control how your model initializes, loads weights, and handles each request. When you need custom preprocessing, postprocessing, response shaping, or want to run an architecture that Baseten’s built-in engines don’t support, you write that logic here.

When to write a Model class

Most deployments don’t need custom Python. If you’re deploying a supported open-source model, the config-only approach in Build your first model is faster. Write a custom Model class when you need to:
  • Run a model architecture that Baseten’s engines don’t support.
  • Add custom preprocessing or postprocessing around inference.
  • Combine multiple models or libraries in a single endpoint.
  • Control the HTTP response directly, including status codes and streaming.
You define this logic in a model/model.py file. The simplest project structure is:
model/
  model.py
config.yaml

The class skeleton

model.py must contain a class with three methods:
model.py
class Model:
    def __init__(self, **kwargs):
        pass

    def load(self):
        pass

    def predict(self, model_input):
        return model_input
  • __init__ runs when the class is created. Read configuration parameters and runtime information here.
  • load runs once at startup, before any requests. Download model weights or load them onto a GPU here. Separating this from __init__ keeps expensive operations out of the request path.
  • predict runs on every API request. Process input, run inference, and return the response.
load and predict don’t run on the same thread, which matters for GPU workloads where state can be tied to the creating thread (such as CUDA contexts). With sync predict and the default predict_concurrency of 1, successive predict calls often reuse the same worker thread, but Baseten doesn’t guarantee it.

__init__

The __init__ method initializes the Model class. Use it to read configuration parameters and runtime information. The simplest signature accepts nothing:
model.py
def __init__(self):
    pass
If you need more information, define __init__ to accept these parameters:
model.py
def __init__(self, config: dict, data_dir: str, secrets: dict, environment: dict):
    pass
  • config: A dictionary containing the config.yaml for the model.
  • data_dir: A string containing the path to the data directory for the model.
  • secrets: A dictionary containing the secrets for the model. At runtime, these are populated with the actual values stored on Baseten.
  • environment: A dictionary containing the environment for the model, if the model has been deployed to an environment. None otherwise.
Save these as attributes to use them elsewhere in your model:
model.py
def __init__(self, config: dict, data_dir: str, secrets: dict, environment: dict):
    self._config = config
    self._data_dir = data_dir
    self._secrets = secrets
    self._environment = environment
You can also accept these through **kwargs and pull out only what you need:
model.py
def __init__(self, **kwargs):
    self._data_dir = kwargs["data_dir"]
    self._secrets = kwargs.get("secrets")

load

The load method initializes the model. This might include downloading model weights or loading them onto the GPU. Unlike the other methods, load accepts no parameters:
model.py
def load(self):
    pass
After you deploy your model, the deployment isn’t considered “Ready” until load completes successfully. There is a timeout of 30 minutes for this, after which the deployment is marked as failed if load hasn’t completed.

predict

The predict method runs inference. The simplest signature returns a value directly:
model.py
def predict(self, model_input) -> str:
    return "Hello"
The return type of predict must be JSON-serializable, so it can be a dict, list, or str. See Response objects for stricter typing and direct control over the HTTP response.

Async vs. sync

The predict method is synchronous by default. If your inference depends on APIs that require asyncio, write predict as a coroutine:
model.py
import asyncio

async def predict(self, model_input) -> dict:
    # Async logic here.
    await asyncio.sleep(1)
    return {"value": "Hello"}
If you use asyncio in predict, do not perform blocking operations such as a synchronous file download. This can degrade performance.

Pre/post-processing

To separate I/O from inference and maximize throughput, define optional preprocess and postprocess methods alongside predict. Tasks like downloading images or formatting responses then run without blocking GPU or CPU execution:
model.py
class Model:
    def __init__(self, **kwargs): ...
    def load(self): ...

    def preprocess(self, request):
        # Handle I/O before inference, such as downloading images.
        ...

    def predict(self, request):
        # Perform model inference.
        ...

    def postprocess(self, response):
        # Handle I/O after inference, such as formatting outputs.
        ...
Pre/post-processing runs in separate threads and isn’t subject to Truss’s concurrency limits, so I/O-heavy tasks don’t bottleneck compute resources.
Truss enforces concurrency limits on predict to prevent GPU or CPU overload:
config.yaml
runtime:
  predict_concurrency: 5
If the model receives 10 requests with predict_concurrency: 5, all 10 start preprocessing concurrently, but only 5 run inference at a time. The rest wait until a slot frees up.

Streaming

Truss also supports streaming output incrementally instead of waiting for the full response. For the full pattern, see Streaming output and endpoints.

Response objects

By default, Truss wraps prediction results into an HTTP response. For advanced use cases, create response objects manually to:
  • Control HTTP status codes.
  • Use server-sent events (SSEs) for streaming responses.
To return a more strictly typed object than a dict, list, or str, return a Pydantic model:
model.py
from pydantic import BaseModel

class Result(BaseModel):
    value: str

class Model:
    def predict(self, model_input) -> Result:
        return Result(value="Hello")
To control the raw HTTP response, return any subclass of starlette.responses.Response:
model.py
import fastapi

class Model:
    def predict(self, inputs) -> fastapi.Response:
        return fastapi.Response(...)
For server-sent events, return a StreamingResponse. See Streaming output and endpoints for a complete SSE example.
You can return a response from predict or postprocess, but not both. If predict returns a response or a generator, postprocess cannot be used.
Response headers aren’t fully propagated. Include any metadata in the response body.
To handle raw incoming requests, see Using request objects.

Bundled data

Most models need additional files at runtime, such as weights, tokenizers, configs, or reference datasets. For local files under ~1 GB total, bundle them in your Truss’s data/ directory. The contents are copied into your container image at build time and mounted at /app/data at runtime. Access them from model.py through kwargs["data_dir"]:
model.py
class Model:
    def __init__(self, **kwargs):
        self._data_dir = kwargs["data_dir"]

    def load(self):
        self.tokenizer = AutoTokenizer.from_pretrained(str(self._data_dir))
A bundled Truss might lay its data/ directory out like this Stable Diffusion 2.1 example:
data/
    scheduler/
        scheduler_config.json
    text_encoder/
        config.json
        diffusion_pytorch_model.bin
    tokenizer/
        merges.txt
        tokenizer_config.json
        vocab.json
    unet/
        config.json
        diffusion_pytorch_model.bin
    vae/
        config.json
        diffusion_pytorch_model.bin
    model_index.json
Use the data/ directory only when it’s under ~1 GB total. The files ship inside the container image, so every cold start re-pulls them, not just the first deploy. Larger bundles compound into slower scale-ups, and truss push itself slows down as the bundle grows.
For larger weights or remote sources (Hugging Face, S3, GCS, R2), use the Baseten Delivery Network (BDN) instead. BDN mirrors weights once and serves them from caches close to your pods, so cold starts read from local or nearby caches instead of pulling from the source on every scale-up.

Download files at runtime

Use this pattern when you need fine-grained control over the download, such as decrypting files on the fly or lazily fetching a subset of a larger dataset. The example below loads weights from a private S3 bucket using boto3.
To load private S3 weights at deploy time, prefer BDN with IAM credentials. BDN mirrors the weights once and serves them from a multi-tier cache; the pattern below re-downloads on every cold start unless you add caching.
Define AWS secrets in config.yaml:
secrets:
  aws_access_key_id: null
  aws_secret_access_key: null
  aws_region: null # for example, us-east-1
  aws_bucket: null
Do not store actual credentials in config.yaml. Add them securely to the Baseten secrets manager.
Authenticate with AWS in model.py, then deploy with truss push --watch:
model.py
import boto3

class Model:
    def __init__(self, **kwargs):
        self._config = kwargs.get("config")
        secrets = kwargs.get("secrets")
        self.s3_client = boto3.client(
            "s3",
            aws_access_key_id=secrets["aws_access_key_id"],
            aws_secret_access_key=secrets["aws_secret_access_key"],
            region_name=secrets["aws_region"],
        )
        self.s3_bucket = secrets["aws_bucket"]
If your model downloads weights at runtime via custom code, BDN proxy can cache those downloads across replicas. Available by request.

Next steps