Implementation
How to implement your model.
In this section, we’ll cover how to implement the actual logic for your model.
As was mentioned in Your First Model, the
logic for the model itself is specified in a model/model.py
file. To recap, the simplest
directory structure for a model is:
It’s expected that the model.py
file contains a class with particular methods:
- The
__init__
method is used to initialize theModel
class, and allows you to read in configuration parameters and other information. - The
load
method is where you define the logic for initializing the model. This might include downloading model weights, or loading them onto a GPU. - The
predict
method is where you define the logic for inference.
In the next sections, we’ll cover each of these methods in more detail.
init
As mentioned above, the __init__
method is used to initialize the Model
class, and allows you to
read in configuration parameters and runtime information.
The simplest signature for __init__
is:
If you need more information, however, you have the option to define your init method such that it accepts the following parameters:
config
: A dictionary containing the config.yaml for the model.data_dir
: A string containing the path to the data directory for the model.secrets
: A dictionary containing the secrets for the model. Note that at runtime, these will be populated with the actual values as stored on Baseten.environment
: A string containing the environment for the model, if the model has been deployed to an environment.
You can then make use of these parameters in the rest of your model but saving these as attributes:
load
The load
method is where you define the logic for initializing the model. As
mentioned before, this might include downloading model weights or loading them
onto the GPU.
load
, unlike the other method mentioned, does not accept any parameters:
After deploying your model, the deployment will not be considered “Ready” until load
has
completed successfully. Note that there is a timeout of 30 minutes for this, after which,
if load
has not completed, the deployment will be marked as failed.
predict
The predict
method is where you define the logic for performing inference.
The simplest signature for predict
is:
The return type of predict
must be JSON-serializable, so it can be:
dict
list
str
If you would like to return a more strictly typed object, you can return a
Pydantic
object.
You can then return an instance of this model from predict
:
Streaming
In addition to supporting a single request/response cycle, Truss also supports streaming.
See the Streaming guide for more information.
Async vs. Sync
Note that the predict
method is synchronous by default. However, if your model inference
depends on APIs require asyncio
, predict
can also be written as a coroutine.
If you are using asyncio
in your predict
method, be sure not to perform any blocking
operations, such as a synchronous file download. This can result in degraded performance.