Skip to main content

Documentation Index

Fetch the complete documentation index at: https://docs.baseten.co/llms.txt

Use this file to discover all available pages before exploring further.

This section covers how to implement the logic for your model. As covered in Build your first model, you define model logic in a model/model.py file. The simplest directory structure is:
model/
  model.py
config.yaml
The model.py file must contain a class with these methods:
model.py
class Model:
  def __init__(self):
    pass

  def load(self):
    pass

  def predict(self, input_data):
    pass
  • __init__ initializes the Model class. Read configuration parameters and other information here.
  • load initializes the model. Download model weights or load them onto a GPU here.
  • predict runs inference.
The next sections cover each method in detail.

__init__

The __init__ method initializes the Model class. Use it to read configuration parameters and runtime information. The simplest signature for __init__ is:
model.py
def __init__(self):
  pass
If you need more information, define your __init__ method to accept the following parameters:
model.py
def __init__(self, config: dict, data_dir: str, secrets: dict, environment: dict):
  pass
  • config: A dictionary containing the config.yaml for the model.
  • data_dir: A string containing the path to the data directory for the model.
  • secrets: A dictionary containing the secrets for the model. At runtime, these are populated with the actual values stored on Baseten.
  • environment: A dictionary containing the environment for the model, if the model has been deployed to an environment. None otherwise.
Save these as attributes to use them elsewhere in your model:
model.py
def __init__(self, config: dict, data_dir: str, secrets: dict, environment: dict):
  self._config = config
  self._data_dir = data_dir
  self._secrets = secrets
  self._environment = environment

load

The load method is where you define the logic for initializing the model. This might include downloading model weights or loading them onto the GPU. Unlike the other methods, load does not accept any parameters:
model.py
def load(self):
  pass
After you deploy your model, the deployment is not considered “Ready” until load has completed successfully. There is a timeout of 30 minutes for this, after which the deployment is marked as failed if load hasn’t completed.

predict

The predict method is where you define the logic for performing inference. The simplest signature for predict is:
model.py
def predict(self, input_data) -> str:
  return "Hello"
The return type of predict must be JSON-serializable, so it can be:
  • dict
  • list
  • str
To return a more strictly typed object, return a Pydantic model:
model.py
from pydantic import BaseModel

class Result(BaseModel):
  value: str
You can then return an instance of this model from predict:
model.py
def predict(self, input_data) -> Result:
  return Result(value="Hello")

Streaming

In addition to supporting a single request/response cycle, Truss also supports streaming. See the Streaming guide for more information.

Async vs. sync

The predict method is synchronous by default. If your model inference depends on APIs that require asyncio, you can write predict as a coroutine:
model.py
import asyncio

async def predict(self, input_data) -> dict:
    # Async logic here

    await asyncio.sleep(1)
    return {"value": "Hello"}
If you are using asyncio in your predict method, do not perform any blocking operations, such as a synchronous file download. This can result in degraded performance.