# Any deployment by ID
post /v1/models/{model_id}/deployments/{deployment_id}/activate
Activates an inactive deployment and returns the activation status.
# 🆕 Activate environment deployment
post /v1/models/{model_id}/environments/{env_name}/activate
Activates an inactive deployment associated with an environment and returns the activation status.
# Development deployment
post /v1/models/{model_id}/deployments/development/activate
Activates an inactive development deployment and returns the activation status.
# Production deployment
post /v1/models/{model_id}/deployments/production/activate
Activates an inactive production deployment and returns the activation status.
# Cancel async request
DELETE https://model-{model_id}.api.baseten.co/async_request/{request_id}
Use this endpoint to cancel a queued async request.
Only `QUEUED` requests may be canceled.
### Parameters
The ID of the model that executed the request.
The ID of the async request.
### Headers
Your Baseten API key, formatted with prefix `Api-Key` (e.g. `{"Authorization": "Api-Key abcd1234.abcd1234"}`).
### Response
The ID of the async request.
Whether the request was canceled.
Additional details about whether the request was canceled.
### Rate limits
Calls to the cancel async request status endpoint are limited to **20 requests per second**. If this limit is exceeded, subsequent requests will receive a 429 status code.
```py Python
import requests
import os
model_id = ""
request_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = requests.delete(
f"https://model-{model_id}.api.baseten.co/async_request/{request_id}",
headers={"Authorization": f"Api-Key {baseten_api_key}"}
)
print(resp.json())
```
```sh cURL
curl --request DELETE \
--url https://model-{model_id}.api.baseten.co/async_request/{request_id} \
--header "Authorization: Api-Key $BASETEN_API_KEY"
```
```js Node.js
const fetch = require('node-fetch');
const resp = await fetch(
'https://model-{model_id}.api.baseten.co/async_request/{request_id}',
{
method: 'DELETE',
headers: { Authorization: 'Api-Key YOUR_API_KEY' }
}
);
const data = await resp.json();
console.log(data);
```
# Create a model environment
post /v1/models/{model_id}/environments
Creates an environment for the specified model and returns the environment.
# Any deployment by ID
post /v1/models/{model_id}/deployments/{deployment_id}/deactivate
Deactivates a deployment and returns the deactivation status.
# 🆕 Deactivate environment deployment
post /v1/models/{model_id}/environments/{env_name}/deactivate
Deactivates a deployment associated with an environment and returns the deactivation status.
# Development deployment
post /v1/models/{model_id}/deployments/development/deactivate
Deactivates a development deployment and returns the deactivation status.
# Production deployment
post /v1/models/{model_id}/deployments/production/deactivate
Deactivates a production deployment and returns the deactivation status.
# Published deployment
POST https://model-{model_id}.api.baseten.co/deployment/{deployment-id}/async_predict
Use this endpoint to call any [published deployment](/deploy/lifecycle) of your model.
### Parameters
The ID of the model you want to call.
The ID of the specific deployment you want to call.
### Headers
Your Baseten API key, formatted with prefix `Api-Key` (e.g. `{"Authorization": "Api-Key abcd1234.abcd1234"}`).
### Body
There is a 256 KiB size limit to `/async_predict` request payloads.
JSON-serializable model input.
Baseten **does not** store model outputs. If `webhook_endpoint` is empty, your model must save prediction outputs so they can be accessed later.
URL of the webhook endpoint. We require that webhook endpoints use HTTPS.
Priority of the request. A lower value corresponds to a higher priority (e.g. requests with priority 0 are scheduled before requests of priority 1).
`priority` is between 0 and 2, inclusive.
Maximum time a request will spend in the queue before expiring.
`max_time_in_queue_seconds` must be between 10 seconds and 72 hours, inclusive.
Exponential backoff parameters used to retry the model predict request.
Number of predict request attempts.
`max_attempts` must be between 1 and 10, inclusive.
Minimum time between retries in milliseconds.
`initial_delay_ms` must be between 0 and 10,000 milliseconds, inclusive.
Maximum time between retries in milliseconds.
`max_delay_ms` must be between 0 and 60,000 milliseconds, inclusive.
### Response
The ID of the async request.
### Rate limits
Two types of rate limits apply when making async requests:
* Calls to the `/async_predict` endpoint are limited to **200 requests per second**.
* Each organization is limited to **50,000 `QUEUED` or `IN_PROGRESS` async requests**, summed across all deployments.
If either limit is exceeded, subsequent `/async_predict` requests will receive a 429 status code.
To avoid hitting these rate limits, we advise:
* Implementing a backpressure mechanism, such as calling `/async_predict` with exponential backoff in response to 429 errors.
* Monitoring the [async queue size metric](/observability/metrics#async-queue-size). If your model is accumulating a backlog of requests, consider increasing the number of requests your model can process at once by increasing the number of max replicas or the concurrency target in your autoscaling settings.
```py Python
import requests
import os
model_id = ""
deployment_id = ""
webhook_endpoint = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = requests.post(
f"https://model-{model_id}.api.baseten.co/deployment/{deployment_id}/async_predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={
"model_input": {"prompt": "hello world!"},
"webhook_endpoint": webhook_endpoint
# Optional fields for priority, max_time_in_queue_seconds, etc
},
)
print(resp.json())
```
```sh cURL
curl --request POST \
--url https://model-{model_id}.api.baseten.co/deployment/{deployment_id}/async_predict \
--header "Authorization: Api-Key $BASETEN_API_KEY" \
--data '{
"model_input": {"prompt": "hello world!"},
"webhook_endpoint": "https://my_webhook.com/webhook",
"priority": 1,
"max_time_in_queue_seconds": 100,
"inference_retry_config": {
"max_attempts": 3,
"initial_delay_ms": 1000,
"max_delay_ms": 5000
}
}'
```
```js Node.js
const fetch = require('node-fetch');
const resp = await fetch(
'https://model-{model_id}.api.baseten.co/deployment/{deployment_id}/async_predict',
{
method: 'POST',
headers: { Authorization: 'Api-Key YOUR_API_KEY' },
body: JSON.stringify({
"model_input": {"prompt": "hello world!"},
"webhook_endpoint": "https://my_webhook.com/webhook",
"priority": 1,
"max_time_in_queue_seconds": 100,
"inference_retry_config": {
"max_attempts": 3,
"initial_delay_ms": 1000,
"max_delay_ms": 5000
}
}),
}
);
const data = await resp.json();
console.log(data);
```
```json 201
{
"request_id": ""
}
```
# Published deployment
GET https://model-{model_id}.api.baseten.co/deployment/{deployment_id}/async_queue_status
Use this endpoint to get the status of a published deployment's async queue.
### Parameters
The ID of the model.
The ID of the deployment.
### Headers
Your Baseten API key, formatted with prefix `Api-Key` (e.g. `{"Authorization": "Api-Key abcd1234.abcd1234"}`).
### Response
The ID of the model.
The ID of the deployment.
The number of requests in the deployment's async queue with `QUEUED` status (i.e. awaiting processing by the model).
The number of requests in the deployment's async queue with `IN_PROGRESS` status (i.e. currently being processed by the model).
```json 200
{
"model_id": "",
"deployment_id": "",
"num_queued_requests": 12,
"num_in_progress_requests": 3
}
```
### Rate limits
Calls to the `/async_queue_status` endpoint are limited to **20 requests per second**. If this limit is exceeded, subsequent requests will receive a 429 status code.
To gracefully handle hitting this rate limit, we advise implementing a backpressure mechanism, such as calling `/async_queue_status` with exponential backoff in response to 429 errors.
```py Python
import requests
import os
model_id = ""
deployment_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = requests.get(
f"https://model-{model_id}.api.baseten.co/deployment/{deployment_id}/async_queue_status",
headers={"Authorization": f"Api-Key {baseten_api_key}"}
)
print(resp.json())
```
```sh cURL
curl --request GET \
--url https://model-{model_id}.api.baseten.co/deployment/{deployment_id}/async_queue_status \
--header "Authorization: Api-Key $BASETEN_API_KEY"
```
```js Node.js
const fetch = require('node-fetch');
const resp = await fetch(
'https://model-{model_id}.api.baseten.co/deployment/{deployment_id}/async_queue_status',
{
method: 'GET',
headers: { Authorization: 'Api-Key YOUR_API_KEY' }
}
);
const data = await resp.json();
console.log(data);
```
# Published deployment
POST https://model-{model_id}.api.baseten.co/deployment/{deployment-id}/predict
Use this endpoint to call any [published deployment](/deploy/lifecycle) of your model.
```sh
https://model-{model_id}.api.baseten.co/deployment/{deployment_id}/predict
```
### Parameters
The ID of the model you want to call.
The ID of the specific deployment you want to call.
Your Baseten API key, formatted with prefix `Api-Key` (e.g. `{"Authorization": "Api-Key abcd1234.abcd1234"}`).
### Body
JSON-serializable model input.
```py Python
import urllib3
import os
model_id = ""
deployment_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = urllib3.request(
"POST",
f"https://model-{model_id}.api.baseten.co/deployment/{deployment_id}/predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={}, # JSON-serializable model input
)
print(resp.json())
```
```sh cURL
curl -X POST https://model-{model_id}.api.baseten.co/deployment/{deployment_id}/predict \
-H 'Authorization: Api-Key YOUR_API_KEY' \
-d '{}' # JSON-serializable model input
```
```sh Truss
truss predict --model-version DEPLOYMENT_ID -d '{}' # JSON-serializable model input
```
```js Node.js
const fetch = require('node-fetch');
const resp = await fetch(
'https://model-{model_id}.api.baseten.co/deployment/{deployment_id}/predict',
{
method: 'POST',
headers: { Authorization: 'Api-Key YOUR_API_KEY' },
body: JSON.stringify({}), // JSON-serializable model input
}
);
const data = await resp.json();
console.log(data);
```
```json Example Response
// JSON-serializable output varies by model
{}
```
# Published deployment
POST https://chain-{chain_id}.api.baseten.co/deployment/{deployment-id}/run_remote
Use this endpoint to call any [published deployment](/deploy/lifecycle) of your
chain.
```sh
https://chain-{chain_id}.api.baseten.co/deployment/{deployment_id}/run_remote
```
### Parameters
The ID of the chain you want to call.
The ID of the specific deployment you want to call.
Your Baseten API key, formatted with prefix `Api-Key` (e.g. `{"Authorization": "Api-Key abcd1234.abcd1234"}`).
### Body
JSON-serializable chain input. The input schema corresponds to the
signature of the entrypoint's `run_remote` method. I.e. The top-level keys
are the argument names. The values are the corresponding JSON representation of
the types.
```py Python
import urllib3
import os
chain_id = ""
deployment_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = urllib3.request(
"POST",
f"https://chain
-{chain_id}.api.baseten.co/deployment/{deployment_id}/run_remote",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={}, # JSON-serializable chain input
)
print(resp.json())
```
```sh cURL
curl -X POST https://chain-{chain_id}.api.baseten.co/deployment/{deployment_id}/run_remote \
-H 'Authorization: Api-Key YOUR_API_KEY' \
-d '{}' # JSON-serializable chain input
```
```js Node.js
const fetch = require('node-fetch');
const resp = await fetch(
'https://chain-{chain_id}.api.baseten.co/deployment/{deployment_id}/run_remote',
{
method: 'POST',
headers: { Authorization: 'Api-Key YOUR_API_KEY' },
body: JSON.stringify({}), // JSON-serializable chain input
}
);
const data = await resp.json();
console.log(data);
```
```json Example Response
// JSON-serializable output varies by chain
{}
```
# Published deployment
POST https://model-{model_id}.api.baseten.co/deployment/{deployment-id}/wake
Use this endpoint to wake any scaled-to-zero [published deployment](/deploy/lifecycle) of your model.
```sh
https://model-{model_id}.api.baseten.co/deployment/{deployment_id}/wake
```
### Parameters
The ID of the model you want to wake.
The ID of the specific deployment you want to wake.
Your Baseten API key, formatted with prefix `Api-Key` (e.g. `{"Authorization": "Api-Key abcd1234.abcd1234"}`).
```py Python
import urllib3
import os
model_id = ""
deployment_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = urllib3.request(
"POST",
f"https://model-{model_id}.api.baseten.co/deployment/{deployment_id}/wake",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
)
print(resp.json())
```
```sh cURL
curl -X POST https://model-{model_id}.api.baseten.co/deployment/{deployment_id}/wake \
-H 'Authorization: Api-Key YOUR_API_KEY' \
```
```js Node.js
const fetch = require('node-fetch');
const resp = await fetch(
'https://model-{model_id}.api.baseten.co/deployment/{deployment_id}/wake',
{
method: 'POST',
headers: { Authorization: 'Api-Key YOUR_API_KEY' },
}
);
const data = await resp.json();
console.log(data);
```
```json Example Response
// Returns a 202 response code
{}
```
# Development deployment
POST https://model-{model_id}.api.baseten.co/development/async_predict
Use this endpoint to call the [development deployment](/deploy/lifecycle) of your model asynchronously.
### Parameters
The ID of the model you want to call.
### Headers
Your Baseten API key, formatted with prefix `Api-Key` (e.g. `{"Authorization": "Api-Key abcd1234.abcd1234"}`).
### Body
There is a 256 KiB size limit to `/async_predict` request payloads.
JSON-serializable model input.
Baseten **does not** store model outputs. If `webhook_endpoint` is empty, your model must save prediction outputs so they can be accessed later.
URL of the webhook endpoint. We require that webhook endpoints use HTTPS.
Priority of the request. A lower value corresponds to a higher priority (e.g. requests with priority 0 are scheduled before requests of priority 1).
`priority` is between 0 and 2, inclusive.
Maximum time a request will spend in the queue before expiring.
`max_time_in_queue_seconds` must be between 10 seconds and 72 hours, inclusive.
Exponential backoff parameters used to retry the model predict request.
Number of predict request attempts.
`max_attempts` must be between 1 and 10, inclusive.
Minimum time between retries in milliseconds.
`initial_delay_ms` must be between 0 and 10,000 milliseconds, inclusive.
Maximum time between retries in milliseconds.
`max_delay_ms` must be between 0 and 60,000 milliseconds, inclusive.
### Response
The ID of the async request.
### Rate limits
Two types of rate limits apply when making async requests:
* Calls to the `/async_predict` endpoint are limited to **200 requests per second**.
* Each organization is limited to **50,000 `QUEUED` or `IN_PROGRESS` async requests**, summed across all deployments.
If either limit is exceeded, subsequent `/async_predict` requests will receive a 429 status code.
To avoid hitting these rate limits, we advise:
* Implementing a backpressure mechanism, such as calling `/async_predict` with exponential backoff in response to 429 errors.
* Monitoring the [async queue size metric](/observability/metrics#async-queue-size). If your model is accumulating a backlog of requests, consider increasing the number of requests your model can process at once by increasing the number of max replicas or the concurrency target in your autoscaling settings.
```py Python
import requests
import os
model_id = ""
webhook_endpoint = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = requests.post(
f"https://model-{model_id}.api.baseten.co/development/async_predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={
"model_input": {"prompt": "hello world!"},
"webhook_endpoint": webhook_endpoint
# Optional fields for priority, max_time_in_queue_seconds, etc
},
)
print(resp.json())
```
```sh cURL
curl --request POST \
--url https://model-{model_id}.api.baseten.co/development/async_predict \
--header "Authorization: Api-Key $BASETEN_API_KEY" \
--data '{
"model_input": {"prompt": "hello world!"},
"webhook_endpoint": "https://my_webhook.com/webhook",
"priority": 1,
"max_time_in_queue_seconds": 100,
"inference_retry_config": {
"max_attempts": 3,
"initial_delay_ms": 1000,
"max_delay_ms": 5000
}
}'
```
```js Node.js
const fetch = require('node-fetch');
const resp = await fetch(
'https://model-{model_id}.api.baseten.co/development/async_predict',
{
method: 'POST',
headers: { Authorization: 'Api-Key YOUR_API_KEY' },
body: JSON.stringify({
"model_input": {"prompt": "hello world!"},
"webhook_endpoint": "https://my_webhook.com/webhook",
"priority": 1,
"max_time_in_queue_seconds": 100,
"inference_retry_config": {
"max_attempts": 3,
"initial_delay_ms": 1000,
"max_delay_ms": 5000
}
}),
}
);
const data = await resp.json();
console.log(data);
```
```json 201
{
"request_id": ""
}
```
# Development deployment
GET https://model-{model_id}.api.baseten.co/development/async_queue_status
Use this endpoint to get the status of a development deployment's async queue.
### Parameters
The ID of the model.
### Headers
Your Baseten API key, formatted with prefix `Api-Key` (e.g. `{"Authorization": "Api-Key abcd1234.abcd1234"}`).
### Response
The ID of the model.
The ID of the deployment.
The number of requests in the deployment's async queue with `QUEUED` status (i.e. awaiting processing by the model).
The number of requests in the deployment's async queue with `IN_PROGRESS` status (i.e. currently being processed by the model).
```json 200
{
"model_id": "",
"deployment_id": "",
"num_queued_requests": 12,
"num_in_progress_requests": 3
}
```
### Rate limits
Calls to the `/async_queue_status` endpoint are limited to **20 requests per second**. If this limit is exceeded, subsequent requests will receive a 429 status code.
To gracefully handle hitting this rate limit, we advise implementing a backpressure mechanism, such as calling `/async_queue_status` with exponential backoff in response to 429 errors.
```py Python
import requests
import os
model_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = requests.get(
f"https://model-{model_id}.api.baseten.co/development/async_queue_status",
headers={"Authorization": f"Api-Key {baseten_api_key}"}
)
print(resp.json())
```
```sh cURL
curl --request GET \
--url https://model-{model_id}.api.baseten.co/development/async_queue_status \
--header "Authorization: Api-Key $BASETEN_API_KEY"
```
```js Node.js
const fetch = require('node-fetch');
const resp = await fetch(
'https://model-{model_id}.api.baseten.co/development/async_queue_status',
{
method: 'GET',
headers: { Authorization: 'Api-Key YOUR_API_KEY' }
}
);
const data = await resp.json();
console.log(data);
```
# Development deployment
POST https://model-{model_id}.api.baseten.co/development/predict
Use this endpoint to call the [development deployment](/deploy/lifecycle) of your model.
```sh
https://model-{model_id}.api.baseten.co/development/predict
```
### Parameters
The ID of the model you want to call.
Your Baseten API key, formatted with prefix `Api-Key` (e.g. `{"Authorization": "Api-Key abcd1234.abcd1234"}`).
### Body
JSON-serializable model input.
```py Python
import urllib3
import os
model_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = urllib3.request(
"POST",
f"https://model-{model_id}.api.baseten.co/development/predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={}, # JSON-serializable model input
)
print(resp.json())
```
```sh cURL
curl -X POST https://model-{model_id}.api.baseten.co/development/predict \
-H 'Authorization: Api-Key YOUR_API_KEY' \
-d '{}' # JSON-serializable model input
```
```sh Truss
truss predict --model-version DEPLOYMENT_ID -d '{}' # JSON-serializable model input
```
```js Node.js
const fetch = require('node-fetch');
const resp = await fetch(
'https://model-{model_id}.api.baseten.co/development/predict',
{
method: 'POST',
headers: { Authorization: 'Api-Key YOUR_API_KEY' },
body: JSON.stringify({}), // JSON-serializable model input
}
);
const data = await resp.json();
console.log(data);
```
```json Example Response
// JSON-serializable output varies by model
{}
```
# Development deployment
POST https://chain-{chain_id}.api.baseten.co/development/run_remote
Use this endpoint to call the [development deployment](/deploy/lifecycle) of
your chain.
```sh
https://chain-{chain_id}.api.baseten.co/development/run_remote
```
### Parameters
The ID of the chain you want to call.
Your Baseten API key, formatted with prefix `Api-Key` (e.g. `{"Authorization": "Api-Key abcd1234.abcd1234"}`).
### Body
JSON-serializable chain input. The input schema corresponds to the
signature of the entrypoint's `run_remote` method. I.e. The top-level keys
are the argument names. The values are the corresponding JSON representation of
the types.
```py Python
import urllib3
import os
chain_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = urllib3.request(
"POST",
f"https://chain-{chain_id}.api.baseten.co/development/run_remote",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={}, # JSON-serializable chain input
)
print(resp.json())
```
```sh cURL
curl -X POST https://chain-{chain_id}.api.baseten.co/development/run_remote \
-H 'Authorization: Api-Key YOUR_API_KEY' \
-d '{}' # JSON-serializable chain input
```
```js Node.js
const fetch = require('node-fetch');
const resp = await fetch(
'https://chain-{chain_id}.api.baseten.co/development/run_remote',
{
method: 'POST',
headers: { Authorization: 'Api-Key YOUR_API_KEY' },
body: JSON.stringify({}), // JSON-serializable chain input
}
);
const data = await resp.json();
console.log(data);
```
```json Example Response
// JSON-serializable output varies by chain
{}
```
# Development deployment
POST https://model-{model_id}.api.baseten.co/development/wake
Use this endpoint to wake the [development deployment](/deploy/lifecycle) of your model if it is scaled to zero.
```sh
https://model-{model_id}.api.baseten.co/development/wake
```
### Parameters
The ID of the model you want to wake.
Your Baseten API key, formatted with prefix `Api-Key` (e.g. `{"Authorization": "Api-Key abcd1234.abcd1234"}`).
```py Python
import urllib3
import os
model_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = urllib3.request(
"POST",
f"https://model-{model_id}.api.baseten.co/development/wake",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
)
print(resp.json())
```
```sh cURL
curl -X POST https://model-{model_id}.api.baseten.co/development/wake \
-H 'Authorization: Api-Key YOUR_API_KEY' \
```
```js Node.js
const fetch = require('node-fetch');
const resp = await fetch(
'https://model-{model_id}.api.baseten.co/development/wake',
{
method: 'POST',
headers: { Authorization: 'Api-Key YOUR_API_KEY' },
}
);
const data = await resp.json();
console.log(data);
```
```json Example Response
// Returns a 202 response code
{}
```
# 🆕 Async inference by environment
POST https://model-{model_id}.api.baseten.co/environments/{env_name}/async_predict
Use this endpoint to call the model associated with the specified environment asynchronously.
### Parameters
The ID of the model you want to call.
The name of the model's environment you want to call.
### Headers
Your Baseten API key, formatted with prefix `Api-Key` (e.g. `{"Authorization": "Api-Key abcd1234.abcd1234"}`).
### Body
There is a 256 KiB size limit to `/async_predict` request payloads.
JSON-serializable model input.
Baseten **does not** store model outputs. If `webhook_endpoint` is empty, your model must save prediction outputs so they can be accessed later.
URL of the webhook endpoint. We require that webhook endpoints use HTTPS.
Priority of the request. A lower value corresponds to a higher priority (e.g. requests with priority 0 are scheduled before requests of priority 1).
`priority` is between 0 and 2, inclusive.
Maximum time a request will spend in the queue before expiring.
`max_time_in_queue_seconds` must be between 10 seconds and 72 hours, inclusive.
Exponential backoff parameters used to retry the model predict request.
Number of predict request attempts.
`max_attempts` must be between 1 and 10, inclusive.
Minimum time between retries in milliseconds.
`initial_delay_ms` must be between 0 and 10,000 milliseconds, inclusive.
Maximum time between retries in milliseconds.
`max_delay_ms` must be between 0 and 60,000 milliseconds, inclusive.
### Response
The ID of the async request.
### Rate limits
Two types of rate limits apply when making async requests:
* Calls to the `/async_predict` endpoint are limited to **200 requests per second**.
* Each organization is limited to **50,000 `QUEUED` or `IN_PROGRESS` async requests**, summed across all deployments.
If either limit is exceeded, subsequent `/async_predict` requests will receive a 429 status code.
To avoid hitting these rate limits, we advise:
* Implementing a backpressure mechanism, such as calling `/async_predict` with exponential backoff in response to 429 errors.
* Monitoring the [async queue size metric](/observability/metrics#async-queue-size). If your model is accumulating a backlog of requests, consider increasing the number of requests your model can process at once by increasing the number of max replicas or the concurrency target in your autoscaling settings.
```py Python
import requests
import os
model_id = "" # Replace this with your model ID
webhook_endpoint = "" # Replace this with your webhook endpoint URL
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
# Call the async_predict endpoint of the production deployment
resp = requests.post(
f"https://model-{model_id}.api.baseten.co/production/async_predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={
"model_input": {"prompt": "hello world!"},
"webhook_endpoint": webhook_endpoint
# Optional fields for priority, max_time_in_queue_seconds, etc
},
)
print(resp.json())
```
```sh cURL
curl --request POST \
--url https://model-{model_id}.api.baseten.co/environments/{env_name}/async_predict \
--header "Authorization: Api-Key $BASETEN_API_KEY" \
--data '{
"model_input": {"prompt": "hello world!"},
"webhook_endpoint": "https://my_webhook.com/webhook"
}'
```
```js Node.js
const fetch = require('node-fetch');
const resp = await fetch(
'https://model-{model_id}.api.baseten.co/environments/{env_name}/async_predict',
{
method: 'POST',
headers: { Authorization: 'Api-Key YOUR_API_KEY' },
body: JSON.stringify({
"model_input": {"prompt": "hello world!"},
"webhook_endpoint": "https://my_webhook.com/webhook"
}),
}
);
const data = await resp.json();
console.log(data);
```
```json 201
{
"request_id": ""
}
```
# Environment deployment
GET https://model-{model_id}.api.baseten.co/environments/{env_name}/async_queue_status
Use this endpoint to get the async queue status for a model associated with the specified environment.
### Parameters
The ID of the model.
The name of the environment.
### Headers
Your Baseten API key, formatted with prefix `Api-Key` (e.g. `{"Authorization": "Api-Key abcd1234.abcd1234"}`).
### Response
The ID of the model.
The ID of the deployment.
The number of requests in the deployment's async queue with `QUEUED` status (i.e. awaiting processing by the model).
The number of requests in the deployment's async queue with `IN_PROGRESS` status (i.e. currently being processed by the model).
```json 200
{
"model_id": "",
"deployment_id": "",
"num_queued_requests": 12,
"num_in_progress_requests": 3
}
```
### Rate limits
Calls to the `/async_queue_status` endpoint are limited to **20 requests per second**. If this limit is exceeded, subsequent requests will receive a 429 status code.
To gracefully handle hitting this rate limit, we advise implementing a backpressure mechanism, such as calling `/async_queue_status` with exponential backoff in response to 429 errors.
```py Python
import requests
import os
model_id = ""
env_name = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = requests.get(
f"https://model-{model_id}.api.baseten.co/environments/{env_name}/async_queue_status",
headers={"Authorization": f"Api-Key {baseten_api_key}"}
)
print(resp.json())
```
```sh cURL
curl --request GET \
--url https://model-{model_id}.api.baseten.co/environments/{env_name}/async_queue_status \
--header "Authorization: Api-Key $BASETEN_API_KEY"
```
```js Node.js
const fetch = require('node-fetch');
const resp = await fetch(
'https://model-{model_id}.api.baseten.co/environments/{env_name}/async_queue_status',
{
method: 'GET',
headers: { Authorization: 'Api-Key YOUR_API_KEY' }
}
);
const data = await resp.json();
console.log(data);
```
# 🆕 Inference by environment
POST https://model-{model_id}.api.baseten.co/environments/{env_name}/predict
Use this endpoint to call the deployment associated with the specified [environment](/deploy/lifecycle#what-is-an-environment).
```sh
https://model-{model_id}.api.baseten.co/environments/{env_name}/predict"
```
### Parameters
The ID of the model you want to call.
The name of the model's environment you want to call.
Your Baseten API key, formatted with prefix `Api-Key` (e.g. `{"Authorization": "Api-Key abcd1234.abcd1234"}`).
### Body
JSON-serializable model input.
```py Python
import urllib3
import os
model_id = ""
env_name = "staging"
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = urllib3.request(
"POST",
f"https://model-{model_id}.api.baseten.co/environments/{env_name}/predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={}, # JSON-serializable model input
)
print(resp.json())
```
```sh cURL
curl -X POST https://model-{model_id}.api.baseten.co/environments/{env_name}/predict \
-H 'Authorization: Api-Key YOUR_API_KEY' \
-d '{}' # JSON-serializable model input
```
```js Node.js
const fetch = require('node-fetch');
const resp = await fetch(
'https://model-{model_id}.api.baseten.co/environments/{env_name}/predict',
{
method: 'POST',
headers: { Authorization: 'Api-Key YOUR_API_KEY' },
body: JSON.stringify({}), // JSON-serializable model input
}
);
const data = await resp.json();
console.log(data);
```
```json Example Response
// JSON-serializable output varies by model
{}
```
# 🆕 Inference by environment
POST https://chain-{chain_id}.api.baseten.co/environments/{env_name}/run_remote
Use this endpoint to call the deployment associated with the specified environment.
```sh
https://chain-{chain}.api.baseten.co/environments/{env_name}/run_remote"
```
### Parameters
The ID of the chain you want to call.
The name of the chain's environment you want to call.
Your Baseten API key, formatted with prefix `Api-Key` (e.g. `{"Authorization": "Api-Key abcd1234.abcd1234"}`).
### Body
JSON-serializable chain input. The input schema corresponds to the
signature of the entrypoint's `run_remote` method. I.e. The top-level keys
are the argument names. The values are the corresponding JSON representation of
the types.
```py Python
import urllib3
import os
chain_id = ""
env_name = "staging"
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = urllib3.request(
"POST",
f"https://chain-{chain_id}.api.baseten.co/environments/{env_name}/run_remote",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={}, # JSON-serializable chain input
)
print(resp.json())
```
```sh cURL
curl -X POST https://chain-{chain_id}.api.baseten.co/environments/{env_name}/run_remote \
-H 'Authorization: Api-Key YOUR_API_KEY' \
-d '{}' # JSON-serializable chain input
```
```js Node.js
const fetch = require('node-fetch');
const resp = await fetch(
'https://chain-{chain_id}.api.baseten.co/environments/{env_name}/run_remote',
{
method: 'POST',
headers: { Authorization: 'Api-Key YOUR_API_KEY' },
body: JSON.stringify({}), // JSON-serializable chain input
}
);
const data = await resp.json();
console.log(data);
```
```json Example Response
// JSON-serializable output varies by chain
{}
```
# Get chain environment
get /v1/chains/{chain_id}/environments/{env_name}
Gets a chain environment's details and returns the chain environment.
# Get all chain environments
get /v1/chains/{chain_id}/environments
Gets all chain environments for a given chain
# Get all model environments
get /v1/models/{model_id}/environments
Gets all environments for a given model
# Get model environment
get /v1/models/{model_id}/environments/{env_name}
Gets an environment's details and returns the environment.
# Get async request status
GET https://model-{model_id}.api.baseten.co/async_request/{request_id}
Use this endpoint to get the status of an async request.
### Parameters
The ID of the model that executed the request.
The ID of the async request.
### Headers
Your Baseten API key, formatted with prefix `Api-Key` (e.g. `{"Authorization": "Api-Key abcd1234.abcd1234"}`).
### Response
The ID of the async request.
The ID of the model that executed the request.
The ID of the deployment that executed the request.
An enum representing the status of the request.
Available options: `QUEUED`, `IN_PROGRESS`, `SUCCEEDED`, `FAILED`, `EXPIRED`, `CANCELED`, `WEBHOOK_FAILED`
An enum representing the status of sending the predict result to the provided webhook.
Available options: `PENDING`, `SUCCEEDED`, `FAILED`, `CANCELED`, `NO_WEBHOOK_PROVIDED`
The time in UTC at which the async request was created.
The time in UTC at which the async request's status was updated.
Any errors that occurred in processing the async request. Empty if no errors occurred.
An enum representing the type of error that occurred.
Available options: `MODEL_PREDICT_ERROR`, `MODEL_PREDICT_TIMEOUT`, `MODEL_NOT_READY`, `MODEL_DOES_NOT_EXIST`, `MODEL_UNAVAILABLE`, `MODEL_INVALID_INPUT`, `ASYNC_REQUEST_NOT_SUPPORTED`, `INTERNAL_SERVER_ERROR`
A message containing details of the error that occurred.
### Rate limits
Calls to the get async request status endpoint are limited to **20 requests per second**. If this limit is exceeded, subsequent requests will receive a 429 status code.
To avoid hitting this rate limit, we recommend [configuring a webhook endpoint](invoke/async#configuring-the-webhook-endpoint) to receive async predict results instead of frequently polling this endpoint for async request statuses.
```py Python
import requests
import os
model_id = ""
request_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = requests.get(
f"https://model-{model_id}.api.baseten.co/async_request/{request_id}",
headers={"Authorization": f"Api-Key {baseten_api_key}"}
)
print(resp.json())
```
```sh cURL
curl --request GET \
--url https://model-{model_id}.api.baseten.co/async_request/{request_id} \
--header "Authorization: Api-Key $BASETEN_API_KEY"
```
```js Node.js
const fetch = require('node-fetch');
const resp = await fetch(
'https://model-{model_id}.api.baseten.co/async_request/{request_id}',
{
method: 'GET',
headers: { Authorization: 'Api-Key YOUR_API_KEY' }
}
);
const data = await resp.json();
console.log(data);
```
# Get a chain by ID
get /v1/chains/{chain_id}
# Any chain deployment by ID
get /v1/chains/{chain_id}/deployments/{chain_deployment_id}
# Get a model by ID
get /v1/models/{model_id}
# Any model deployment by ID
get /v1/models/{model_id}/deployments/{deployment_id}
Gets a model's deployment by id and returns the deployment.
# Development model deployment
get /v1/models/{model_id}/deployments/development
Gets a model's development deployment and returns the deployment.
# Production model deployment
get /v1/models/{model_id}/deployments/production
Gets a model's production deployment and returns the deployment.
# Get all chain deployments
get /v1/chains/{chain_id}/deployments
# Get all chains
get /v1/chains
# Get all model deployments
get /v1/models/{model_id}/deployments
# Get all models
get /v1/models
# Get all secrets
get /v1/secrets
# Model endpoint migration guide
No more JSON wrapper with model output
This guide covers the new predict endpoints in two parts:
* Showing the format of the new model predict endpoint.
* Showing the change that you must make to how you parse model output when you switch to the new endpoint.
The change to model output format only applies when you switch to the new endpoints. Model output is unchanged for the old endpoints.
## Updates to endpoint paths
The new endpoint uses `model_id` as part of the subdomain, where formerly it was part of the path:
```sh
# Old endpoint (for production deployment)
https://app.baseten.co/models/{model_id}/predict
# New endpoint (for production deployment)
https://model-{model_id}.api.baseten.co/production/predict
```
Updated endpoints:
* The old `/models/id/predict` endpoint is now the [production deployment endpoint](/api-reference/production-predict).
* The old `/model_versions/id/predict` endpoint is now the [published deployment endpoint](/api-reference/deployment-predict).
* There's a new endpoint just for the development deployment of a model, the [development deployment endpoint](/api-reference/development-predict).
## Model output response format
With the new model endpoints, we've changed the output format of the model response. This change simplifies model responses and removes a step in parsing model output.
### Old endpoint response format
For the old endpoint, formatted `https://app.baseten.co/models//predict`, the model output was wrapped in a JSON dictionary with the model ID and model version ID (which is now the deployment ID):
```json Old response
{
"model_id":"MODEL_ID",
"model_version_id":"VERSION_ID",
"model_output": {
// Output varies by model, this is just an example
"prediction": true,
"confidence": 0.7839
}
}
```
These old endpoints will stay available and the response format for these old endpoints will not change. You only need to change the way you parse your model output when switching to the new endpoints.
### New endpoint response format
For the new endpoint, formatted `https://model-.api.baseten.co/production/predict`, the model output is no longer wrapped:
```json New response
// Output varies by model, this is just an example
{
"prediction": true,
"confidence": 0.7839
}
```
So, when you change your code to use the new endpoints, also update any code for parsing model responses, as it is no longer wrapped in an additional dictionary:
```python
# On old endpoints:
model_output = resp.json()["model_output"]
# On new endpoints:
model_output = resp.json()
```
# Call primary version
POST https://app.baseten.co/models/{model_id}/predict
This is an old endpoint. Update to the endpoint for a [production deployment](/api-reference/production-predict) and the new model response format based on the [migration guide](/api-reference/migration-guide).
Use this endpoint to call the primary version of a model (now known as the production deployment).
```sh
https://app.baseten.co/models/{model_id}/predict
```
### Parameters
The ID of the model you want to call.
Your Baseten API key, formatted with prefix `Api-Key` (e.g. `{"Authorization": "Api-Key abcd1234.abcd1234"}`).
JSON-serializable model input.
```py Python
import urllib3
resp = urllib3.request(
"POST",
"https://app.baseten.co/models/MODEL_ID/predict",
headers={"Authorization": "Api-Key YOUR_API_KEY"},
json={}, # JSON-serializable model input
)
print(resp.json())
```
```sh cURL
curl -X POST https://app.baseten.co/models/MODEL_ID/predict \
-H 'Authorization: Api-Key YOUR_API_KEY' \
-d '{}' # JSON-serializable model input
```
```sh Truss
truss predict --model MODEL_ID -d '{}' # JSON-serializable model input
```
```js Node.js
const fetch = require('node-fetch');
const resp = await fetch(
'https://app.baseten.co/models/MODEL_ID/predict',
{
method: 'POST',
headers: { Authorization: 'Api-Key YOUR_API_KEY' },
body: JSON.stringify({}), // JSON-serializable model input
}
);
const data = await resp.json();
console.log(data);
```
```json Example Response
{
"model_id":"MODEL_ID",
"model_version_id":"VERSION_ID",
"model_output": {
// Output varies by model
}
}
```
# Wake primary version
POST https://app.baseten.co/models/{model_id}/wake
This is an old endpoint. Update to the wake endpoint for the [production deployment](/api-reference/production-wake).
Use this endpoint to wake a scaled-to-zero model version (now known as a model deployment).
```sh
https://app.baseten.co/models/{model_id}/wake
```
### Parameters
The ID of the model you want to wake.
Your Baseten API key, formatted with prefix `Api-Key` (e.g. `{"Authorization": "Api-Key abcd1234.abcd1234"}`).
```py Python
import urllib3
import os
model_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = urllib3.request(
"POST",
f"https://app.baseten.co/models/{model_id}/wake",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
)
print(resp.json())
```
```sh cURL
curl -X POST https://app.baseten.co/models/{model_id}/wake \
-H 'Authorization: Api-Key YOUR_API_KEY' \
```
```js Node.js
const fetch = require('node-fetch');
const resp = await fetch(
'https://app.baseten.co/models/{model_id}/wake',
{
method: 'POST',
headers: { Authorization: 'Api-Key YOUR_API_KEY' },
}
);
const data = await resp.json();
console.log(data);
```
```json Example Response
// Returns a 202 response code
{}
```
# ChatCompletions
POST https://bridge.baseten.co/v1/direct
Use this endpoint with the OpenAI Python client and any [deployment](/deploy/lifecycle) of a [compatable](#output) model deployed on Baseten.
If you're serving a vLLM model in [OpenAI compatible mode](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html), this endpoint will support that model out of the box.
If your model does not have an OpenAI compatible mode, you can use the [previous version of the bridge](/api-reference/openai-deprecated) to make it compatible with OpenAI's client, but with a more limited set of supported features.
## Calling the model
```sh
https://bridge.baseten.co/v1/direct
```
### Parameters
Parameters supported by the OpenAI ChatCompletions request can be found in the [OpenAI documentation](https://github.com/openai/openai-python/blob/main/src/openai/types/chat/completion_create_params.py).
Below are details about Baseten-specific arguments that must be passed into the bridge.
Typically Hugging Face repo name (e.g. `meta-llama/Meta-Llama-3.1-70B-Instruct`). In some cases, it may be another default specified by your inference engine.
Python dictionary that enables extra arguments to be supplied to the chat completion request.
Baseten-specific parameters that should be passed to the bridge. The arguments should be passed as a dictionary.
The string identifier for the target model.
The string identifier for the target deployment. When `deployment_id` is not provided, the [production deployment](/deploy/lifecycle) will be used.
### Output
Streaming and non-streaming responses are supported. The [vLLM OpenAI Server](https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/openai/serving_chat.py)
is a good example of how to serve your model results.
For streaming outputs, data format must comply with the Server-Side-Events (SSE) format. A helpful example
for JSON payloads can be found [here](https://hpbn.co/server-sent-events-sse/#event-stream-protocol).
### Best Practices
* Pin your `openai` package version in your requirements.txt file. This helps avoid any breaking changes that get introduced
through package upgrades
* If you must make breaking changes to your truss server (i.e. to introduce a new feature), you should first publish a new model deployment then update your API call on the client side.
```py OpenAI Python client
from openai import OpenAI
import os
model_id = "abcd1234" # Replace with your model ID
deployment_id = "4321cbda" # [Optional] Replace with your deployment ID
client = OpenAI(
api_key=os.environ["BASETEN_API_KEY"],
base_url=f"https://bridge.baseten.co/v1/direct"
)
response = client.chat.completions.create(
model=f"meta-llama/Meta-Llama-3.1-70B-Instruct", # Replace with your model name
messages=[
{"role": "user", "content": "Who won the world series in 2020?"},
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
{"role": "user", "content": "Where was it played?"}
],
extra_body={
"baseten": {
"model_id": model_id,
"deployment_id": deployment_id
}
}
)
print(response.choices[0].message.content)
```
```json Example Response
{
"choices": [
{
"finish_reason": null,
"index": 0,
"message": {
"content": "The 2020 World Series was played in Texas at Globe Life Field in Arlington.",
"role": "assistant"
}
}
],
"created": 1700584611,
"id": "chatcmpl-eedbac8f-f68d-4769-a1a7-a1c550be8d08",
"model": "abcd1234",
"object": "chat.completion",
"usage": {
"completion_tokens": 0,
"prompt_tokens": 0,
"total_tokens": 0
}
}
```
# ChatCompletions (deprecated)
POST https://bridge.baseten.co/v1
Follow this step by step guide for using the OpenAI-compatable bridge endpoint.
Use this endpoint with the OpenAI Python client and any [deployment](/deploy/lifecycle) of a compatable model deployed on Baseten.
```sh
https://bridge.baseten.co/v1
```
### Parameters
Special attention should be give to the Baseten-specific arguments that must be passed into the bridge via the `extra_body` argument.
The name of the model you want to call, such as `"mistral-7b"`.
A list of dictionaries containing the chat history to complete.
The maximum number of tokens to generate. [Learn more](https://platform.openai.com/docs/api-reference/chat/create#chat-create-max_tokens)
Set `stream=True` to stream model output.
How deterministic to make the model. [Learn more](https://platform.openai.com/docs/api-reference/chat/create#chat-create-temperature)
Alternative to temperature. [Learn more](https://platform.openai.com/docs/api-reference/chat/create#chat-create-top_p)
Increase or decrease the model's likelihood to talk about new topics. [Learn more](https://platform.openai.com/docs/api-reference/chat/create#chat-create-presence_penalty)
Python dictionary that enables extra arguments to be supplied to the request.
Baseten-specific parameters that should be passed to the bridge. The arguments should be passed as a dictionary.
The string identifier for the target model.
The string identifier for the target deployment. When `deployment_id` is not provided, the [production deployment](/deploy/lifecycle) will be used.
### Output
The output will match the ChatCompletions API output format (shown the the right) with two caveats:
1. The output `id` is just a UUID. Baseten API requests are stateless, so this ID would not be meaningful.
2. Values for the `usage` dictionary are not calculated and are set to `0`. Baseten charges for compute directly rather than charging for inference by token.
### Streaming
You can also stream your model response by passing `stream=True` to the `client.chat.completions.create()` call. To parse your output, run:
```py
for chunk in response:
print(chunk.choices[0].delta)
```
```py OpenAI Python client
from openai import OpenAI
import os
model_id = "abcd1234" # Replace with your model ID
deployment_id = "4321dcba" # Optional,eplace with your deployment ID
client = OpenAI(
api_key=os.environ["BASETEN_API_KEY"],
base_url=f"https://bridge.baseten.co/{model_id}/v1"
)
response = client.chat.completions.create(
model="mistral-7b",
messages=[
{"role": "user", "content": "Who won the world series in 2020?"},
{"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
{"role": "user", "content": "Where was it played?"}
],
extra_body={
"baseten": {
"model_id": model_id,
"deployment_id": deployment_id
}
}
)
print(response.choices[0].message.content)
```
```json Example Response
{
"choices": [
{
"finish_reason": null,
"index": 0,
"message": {
"content": "The 2020 World Series was played in Texas at Globe Life Field in Arlington.",
"role": "assistant"
}
}
],
"created": 1700584611,
"id": "chatcmpl-eedbac8f-f68d-4769-a1a7-a1c550be8d08",
"model": "abcd1234",
"object": "chat.completion",
"usage": {
"completion_tokens": 0,
"prompt_tokens": 0,
"total_tokens": 0
}
}
```
# API reference
Details on model inference and management APIs
Baseten provides two sets of API endpoints:
1. An inference API for calling deployed models
2. A management API for managing your models and workspace
Many inference and management API endpoints have different routes for the three types of deployments — `development`, `production`, and individual published deployments — which are listed separately in the sidebar.
## Inference API
Each model deployed on Baseten has its own subdomain on `api.baseten.co` to enable faster routing. This subdomain is used for inference endpoints, which are formatted as follows:
```
https://model-{model_id}.api.baseten.co/{deployment_type_or_id}/{endpoint}
```
Where:
* `model_id` is the alphanumeric ID of the model, which you can find in your model dashboard.
* `deployment_type_or_id` is one of `development`, `production`, or a separate alphanumeric ID for a specific published deployment of the model.
* `endpoint` is a supported endpoint such as `predict` that you want to call.
The inference API also supports [asynchronous inference](/api-reference/production-async-predict) for long-running tasks and priority queuing.
## Management API
Management API endpoints all run through the base `api.baseten.co` subdomain. Use management API endpoints for monitoring, CI/CD, and building both model-level and workspace-level automations.
# Production deployment
POST https://model-{model_id}.api.baseten.co/production/async_predict
Use this endpoint to call the [production deployment](/deploy/lifecycle) of your model asynchronously.
### Parameters
The ID of the model you want to call.
### Headers
Your Baseten API key, formatted with prefix `Api-Key` (e.g. `{"Authorization": "Api-Key abcd1234.abcd1234"}`).
### Body
There is a 256 KiB size limit to `/async_predict` request payloads.
JSON-serializable model input.
Baseten **does not** store model outputs. If `webhook_endpoint` is empty, your model must save prediction outputs so they can be accessed later.
URL of the webhook endpoint. We require that webhook endpoints use HTTPS.
Priority of the request. A lower value corresponds to a higher priority (e.g. requests with priority 0 are scheduled before requests of priority 1).
`priority` is between 0 and 2, inclusive.
Maximum time a request will spend in the queue before expiring.
`max_time_in_queue_seconds` must be between 10 seconds and 72 hours, inclusive.
Exponential backoff parameters used to retry the model predict request.
Number of predict request attempts.
`max_attempts` must be between 1 and 10, inclusive.
Minimum time between retries in milliseconds.
`initial_delay_ms` must be between 0 and 10,000 milliseconds, inclusive.
Maximum time between retries in milliseconds.
`max_delay_ms` must be between 0 and 60,000 milliseconds, inclusive.
### Response
The ID of the async request.
### Rate limits
Two types of rate limits apply when making async requests:
* Calls to the `/async_predict` endpoint are limited to **200 requests per second**.
* Each organization is limited to **50,000 `QUEUED` or `IN_PROGRESS` async requests**, summed across all deployments.
If either limit is exceeded, subsequent `/async_predict` requests will receive a 429 status code.
To avoid hitting these rate limits, we advise:
* Implementing a backpressure mechanism, such as calling `/async_predict` with exponential backoff in response to 429 errors.
* Monitoring the [async queue size metric](/observability/metrics#async-queue-size). If your model is accumulating a backlog of requests, consider increasing the number of requests your model can process at once by increasing the number of max replicas or the concurrency target in your autoscaling settings.
```py Python
import requests
import os
model_id = "" # Replace this with your model ID
webhook_endpoint = "" # Replace this with your webhook endpoint URL
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
# Call the async_predict endpoint of the production deployment
resp = requests.post(
f"https://model-{model_id}.api.baseten.co/production/async_predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={
"model_input": {"prompt": "hello world!"},
"webhook_endpoint": webhook_endpoint
# Optional fields for priority, max_time_in_queue_seconds, etc
},
)
print(resp.json())
```
```sh cURL
curl --request POST \
--url https://model-{model_id}.api.baseten.co/production/async_predict \
--header "Authorization: Api-Key $BASETEN_API_KEY" \
--data '{
"model_input": {"prompt": "hello world!"},
"webhook_endpoint": "https://my_webhook.com/webhook"
}'
```
```js Node.js
const fetch = require('node-fetch');
const resp = await fetch(
'https://model-{model_id}.api.baseten.co/production/async_predict',
{
method: 'POST',
headers: { Authorization: 'Api-Key YOUR_API_KEY' },
body: JSON.stringify({
"model_input": {"prompt": "hello world!"},
"webhook_endpoint": "https://my_webhook.com/webhook"
}),
}
);
const data = await resp.json();
console.log(data);
```
```json 201
{
"request_id": ""
}
```
# Production deployment
GET https://model-{model_id}.api.baseten.co/production/async_queue_status
Use this endpoint to get the status of a production deployment's async queue.
### Parameters
The ID of the model.
### Headers
Your Baseten API key, formatted with prefix `Api-Key` (e.g. `{"Authorization": "Api-Key abcd1234.abcd1234"}`).
### Response
The ID of the model.
The ID of the deployment.
The number of requests in the deployment's async queue with `QUEUED` status (i.e. awaiting processing by the model).
The number of requests in the deployment's async queue with `IN_PROGRESS` status (i.e. currently being processed by the model).
```json 200
{
"model_id": "",
"deployment_id": "",
"num_queued_requests": 12,
"num_in_progress_requests": 3
}
```
### Rate limits
Calls to the `/async_queue_status` endpoint are limited to **20 requests per second**. If this limit is exceeded, subsequent requests will receive a 429 status code.
To gracefully handle hitting this rate limit, we advise implementing a backpressure mechanism, such as calling `/async_queue_status` with exponential backoff in response to 429 errors.
```py Python
import requests
import os
model_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = requests.get(
f"https://model-{model_id}.api.baseten.co/production/async_queue_status",
headers={"Authorization": f"Api-Key {baseten_api_key}"}
)
print(resp.json())
```
```sh cURL
curl --request GET \
--url https://model-{model_id}.api.baseten.co/production/async_queue_status \
--header "Authorization: Api-Key $BASETEN_API_KEY"
```
```js Node.js
const fetch = require('node-fetch');
const resp = await fetch(
'https://model-{model_id}.api.baseten.co/production/async_queue_status',
{
method: 'GET',
headers: { Authorization: 'Api-Key YOUR_API_KEY' }
}
);
const data = await resp.json();
console.log(data);
```
# Production deployment
POST https://model-{model_id}.api.baseten.co/production/predict
Use this endpoint to call the [production deployment](/deploy/lifecycle) of your model.
```sh
https://model-{model_id}.api.baseten.co/production/predict
```
### Parameters
The ID of the model you want to call.
Your Baseten API key, formatted with prefix `Api-Key` (e.g. `{"Authorization": "Api-Key abcd1234.abcd1234"}`).
### Body
JSON-serializable model input.
```py Python
import urllib3
import os
model_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = urllib3.request(
"POST",
f"https://model-{model_id}.api.baseten.co/production/predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={}, # JSON-serializable model input
)
print(resp.json())
```
```sh cURL
curl -X POST https://model-{model_id}.api.baseten.co/production/predict \
-H 'Authorization: Api-Key YOUR_API_KEY' \
-d '{}' # JSON-serializable model input
```
```sh Truss
truss predict --model MODEL_ID -d '{}' # JSON-serializable model input
```
```js Node.js
const fetch = require('node-fetch');
const resp = await fetch(
'https://model-{model_id}.api.baseten.co/production/predict',
{
method: 'POST',
headers: { Authorization: 'Api-Key YOUR_API_KEY' },
body: JSON.stringify({}), // JSON-serializable model input
}
);
const data = await resp.json();
console.log(data);
```
```json Example Response
// JSON-serializable output varies by model
{}
```
# Production deployment
POST https://chain-{chain_id}.api.baseten.co/production/run_remote
Use this endpoint to call the [production deployment](/deploy/lifecycle) of
your chain.
```sh
https://chain-{chain_id}.api.baseten.co/production/run_remote
```
### Parameters
The ID of the chain you want to call.
Your Baseten API key, formatted with prefix `Api-Key` (e.g. `{"Authorization": "Api-Key abcd1234.abcd1234"}`).
### Body
JSON-serializable chain input. The input schema corresponds to the
signature of the entrypoint's `run_remote` method. I.e. The top-level keys
are the argument names. The values are the corresponding JSON representation of
the types.
```py Python
import urllib3
import os
chain_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = urllib3.request(
"POST",
f"https://chain-{chain_id}.api.baseten.co/production/run_remote",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={}, # JSON-serializable chain input
)
print(resp.json())
```
```sh cURL
curl -X POST https://chain-{chain_id}.api.baseten.co/production/run_remote \
-H 'Authorization: Api-Key YOUR_API_KEY' \
-d '{}' # JSON-serializable chain input
```
```js Node.js
const fetch = require('node-fetch');
const resp = await fetch(
'https://chain-{chain_id}.api.baseten.co/production/run_remote',
{
method: 'POST',
headers: { Authorization: 'Api-Key YOUR_API_KEY' },
body: JSON.stringify({}), // JSON-serializable chain input
}
);
const data = await resp.json();
console.log(data);
```
```json Example Response
// JSON-serializable output varies by chain
{}
```
# Production deployment
POST https://model-{model_id}.api.baseten.co/production/wake
Use this endpoint to wake the [production deployment](/deploy/lifecycle) of your model if it is scaled to zero.
```sh
https://model-{model_id}.api.baseten.co/production/wake
```
### Parameters
The ID of the model you want to wake.
Your Baseten API key, formatted with prefix `Api-Key` (e.g. `{"Authorization": "Api-Key abcd1234.abcd1234"}`).
```py Python
import urllib3
import os
model_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = urllib3.request(
"POST",
f"https://model-{model_id}.api.baseten.co/production/wake",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
)
print(resp.json())
```
```sh cURL
curl -X POST https://model-{model_id}.api.baseten.co/production/wake \
-H 'Authorization: Api-Key YOUR_API_KEY' \
```
```js Node.js
const fetch = require('node-fetch');
const resp = await fetch(
'https://model-{model_id}.api.baseten.co/production/wake',
{
method: 'POST',
headers: { Authorization: 'Api-Key YOUR_API_KEY' },
}
);
const data = await resp.json();
console.log(data);
```
```json Example Response
// Returns a 202 response code
{}
```
# 🆕 Promote to chain environment
post /v1/chains/{chain_id}/environments/{env_name}/promote
Promotes an existing chain deployment to an environment and returns the promoted chain deployment.
# 🆕 Promote to model environment
post /v1/models/{model_id}/environments/{env_name}/promote
Promotes an existing deployment to an environment and returns the promoted deployment.
# Any model deployment by ID
post /v1/models/{model_id}/deployments/{deployment_id}/promote
Promotes an existing deployment to production and returns the same deployment.
# Development model deployment
post /v1/models/{model_id}/deployments/development/promote
Creates a new production deployment from the development deployment, the currently building deployment is returned.
# Update model environment
patch /v1/models/{model_id}/environments/{env_name}
Updates an environment's settings and returns the updated environment.
# Any model deployment by ID
patch /v1/models/{model_id}/deployments/{deployment_id}/autoscaling_settings
Updates a deployment's autoscaling settings and returns the update status.
# Development model deployment
patch /v1/models/{model_id}/deployments/development/autoscaling_settings
Updates a development deployment's autoscaling settings and returns the update status.
# Production model deployment
patch /v1/models/{model_id}/deployments/production/autoscaling_settings
Updates a production deployment's autoscaling settings and returns the update status.
# Upsert a secret
post /v1/secrets
Creates a new secret or updates an existing secret if one with the provided name already exists. The name and creation date of the created or updated secret is returned.
# Call model version
POST https://app.baseten.co/model_versions/{version_id}/predict
This is an old endpoint. Update to the endpoint for a [published deployment](/api-reference/deployment-predict) and the new model response format based on the [migration guide](/api-reference/migration-guide).
Use this endpoint to call any model version (now known as a model deployment).
```sh
https://app.baseten.co/model_versions/{version_id}/predict
```
### Parameters
The version ID of the model you want to call.
Your Baseten API key, formatted with prefix `Api-Key` (e.g. `{"Authorization": "Api-Key abcd1234.abcd1234"}`).
JSON-serializable model input.
```py Python
import urllib3
resp = urllib3.request(
"POST",
"https://app.baseten.co/model_versions/VERSION_ID/predict",
headers={"Authorization": "Api-Key YOUR_API_KEY"},
json={}, # JSON-serializable model input
)
print(resp.json())
```
```sh cURL
curl -X POST https://app.baseten.co/model_versions/VERSION_ID/predict \
-H 'Authorization: Api-Key YOUR_API_KEY' \
-d '{}' # JSON-serializable model input
```
```sh Truss
truss predict --model-version VERSION_ID -d '{}' # JSON-serializable model input
```
```js Node.js
const fetch = require('node-fetch');
const resp = await fetch(
'https://app.baseten.co/model_versions/VERSION_ID/predict',
{
method: 'POST',
headers: { Authorization: 'Api-Key YOUR_API_KEY' },
body: JSON.stringify({}), // JSON-serializable model input
}
);
const data = await resp.json();
console.log(data);
```
```json Example Response
{
"model_id":"MODEL_ID",
"model_version_id":"VERSION_ID",
"model_output": {
// Output varies by model
}
}
```
# Wake model version
POST https://app.baseten.co/model_versions/{version_id}/wake
This is an old endpoint. Update to the wake endpoint for a [published deployment](/api-reference/deployment-wake).
Use this endpoint to wake a scaled-to-zero model version (now known as a model deployment).
```sh
https://app.baseten.co/model_versions/{version_id}/wake
```
### Parameters
The ID of the model version you want to wake.
Your Baseten API key, formatted with prefix `Api-Key` (e.g. `{"Authorization": "Api-Key abcd1234.abcd1234"}`).
```py Python
import urllib3
import os
version_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = urllib3.request(
"POST",
f"https://app.baseten.co/model_versions/{version_id}/wake",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
)
print(resp.json())
```
```sh cURL
curl -X POST https://app.baseten.co/model_versions/{version_id}/wake \
-H 'Authorization: Api-Key YOUR_API_KEY' \
```
```js Node.js
const fetch = require('node-fetch');
const resp = await fetch(
'https://app.baseten.co/model_versions/{version_id}/wake',
{
method: 'POST',
headers: { Authorization: 'Api-Key YOUR_API_KEY' },
}
);
const data = await resp.json();
console.log(data);
```
```json Example Response
// Returns a 202 response code
{}
```
# Chains CLI reference
Details on Chains CLI
Chains is part of the Truss CLI.
# `push`
✨ \[new name]
```sh
truss chains deploy [OPTIONS] SOURCE [ENTRYPOINT]
```
Deploys a chain remotely.
* `SOURCE`: Path to a python file that contains the entrypoint chainlet.
* `ENTRYPOINT`: Class name of the entrypoint chainlet in source file. May be
omitted if a chainlet definition in `SOURCE` is tagged with
`@chains.mark_entrypoint`.
Options:
* `--name` (TEXT): Name of the chain to be deployed, if not given, the
entrypoint name is used.
* `--publish / --no-publish`: Create chainlets as a published deployment.
* `--promote / --no-promote`: Promote newly deployed chainlets into production.
* `--environment` (TEXT): Deploy chainlets into a particular environment.
* `--wait / --no-wait`: Wait until all chainlets are ready (or deployment
failed).
* `--watch / --no-watch`: Watches the chains source code and applies live
patches. Using this option will wait for the chain to be deployed
(i.e.`--wait` flag is applied), before starting to watch for changes.
This option requires the deployment to be a development deployment
* `--dryrun`: Produces only generated files, but doesn't deploy anything.
* `--remote` (TEXT): Name of the remote in .trussrc to push to.
* `--user_env`(TEXT): Key-value-pairs (as JSON str) that can be used to
control deployment-specific chainlet behavior.
* `--log` `[humanfriendly|I|INFO|D|DEBUG]`: Customizes logging.
* `--help`: Show this message and exit.
# `watch`
```sh
truss chains watch [OPTIONS] SOURCE [ENTRYPOINT]
```
Watches the chains source code and applies live patches to a development
deployment.
The development deployment must have been deployed before running this
command.
`SOURCE`: Path to a python file that contains the entrypoint chainlet.
`ENTRYPOINT`: Class name of the entrypoint chainlet in source file. May be
omitted if a chainlet definition in SOURCE is tagged with
`@chains.mark_entrypoint`.
Options:
* `--name` (TEXT): Name of the chain to be deployed, if not given, the
entrypoint name is used.
* `--remote`: (TEXT): Name of the remote in .trussrc to
push to.
* `--user_env`: (TEXT): Key-value-pairs (as JSON str) that can be used to
control deployment-specific chainlet behavior.
* `--log [humanfriendly|W|WARNING|I|INFO|D|DEBUG]`: Customizes logging.
* `--help`: Show this message and exit.
# `init`
```sh
truss chains init [OPTIONS] [DIRECTORY]
```
Initializes a chains project directory.
* `DIRECTORY`: A name of new or existing directory to create the chain in,
it must be empty. If not specified, the current directory is used.
Options:
* `--log` `[humanfriendly|I|INFO|D|DEBUG]`: Customizes logging.
* `--help`: Show this message and exit.
# `deploy`
🚫 \[deprecated] see `push` above.
# Chains reference
Details on Chains CLI and configuration options
[Chains](/chains/overview) is an abstraction for multi-model inference.
The [Chains SDK Reference](/chains-reference/sdk) documents all public
Python APIs of chains and configuration options.
The [Chains CLI reference](/chains-reference/cli) details the command line interface.
# Chains SDK Reference
Python SDK Reference for Chains
{/*
This file is autogenerated, do not edit manually, see:
https://github.com/basetenlabs/truss/tree/main/docs/chains/doc_gen
*/}
# Chainlet classes
APIs for creating user-defined Chainlets.
### *class* `truss_chains.ChainletBase`
Base class for all chainlets.
Inheriting from this class adds validations to make sure subclasses adhere to the
chainlet pattern and facilitates remote chainlet deployment.
Refer to [the docs](https://docs.baseten.co/chains/getting-started) and this
[example chainlet](https://github.com/basetenlabs/truss/blob/main/truss-chains/truss_chains/example_chainlet.py)
for more guidance on how to create subclasses.
### `truss_chains.depends`
Sets a “symbolic marker” to indicate to the framework that a chainlet is a
dependency of another chainlet. The return value of `depends` is intended to be
used as a default argument in a chainlet’s `__init__`-method.
When deploying a chain remotely, a corresponding stub to the remote is injected in
its place. In [`run_local`](#truss-chains-run-local) mode an instance of a local chainlet is injected.
Refer to [the docs](https://docs.baseten.co/chains/getting-started) and this
[example chainlet](https://github.com/basetenlabs/truss/blob/main/truss-chains/truss_chains/example_chainlet.py)
for more guidance on how make one chainlet depend on another chainlet.
Despite the type annotation, this does *not* immediately provide a
chainlet instance. Only when deploying remotely or using `run_local` a
chainlet instance is provided.
**Parameters:**
| Name | Type | Description |
| -------------- | --------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------ |
| `chainlet_cls` | *Type\[[ChainletBase](#class-truss-chains-chainletbase)]* | The chainlet class of the dependency. |
| `retries` | *int* | The number of times to retry the remote chainlet in case of failures (e.g. due to transient network issues). |
| `timeout_sec` | *int* | Timeout for the HTTP request to this chainlet. |
* **Returns:**
A “symbolic marker” to be used as a default argument in a chainlet’s
initializer.
### `truss_chains.depends_context`
Sets a “symbolic marker” for injecting a context object at runtime.
Refer to [the docs](https://docs.baseten.co/chains/getting-started) and this
[example chainlet](https://github.com/basetenlabs/truss/blob/main/truss-chains/truss_chains/example_chainlet.py)
for more guidance on the `__init__`-signature of chainlets.
Despite the type annotation, this does *not* immediately provide a
context instance. Only when deploying remotely or using `run_local` a
context instance is provided.
* **Returns:**
A “symbolic marker” to be used as a default argument in a chainlet’s
initializer.
### *class* `truss_chains.DeploymentContext`
Bases: `pydantic.BaseModel`
Bundles config values and resources needed to instantiate Chainlets.
The context can optionally added as a trailing argument in a Chainlet’s
`__init__` method and then used to set up the chainlet (e.g. using a secret as
an access token for downloading model weights).
**Parameters:**
| Name | Type | Description |
| --------------------- | -------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `data_dir` | *Path\|None* | The directory where the chainlet can store and access data, e.g. for downloading model weights. |
| `chainlet_to_service` | *Mapping\[str,[ServiceDescriptor](#class-truss-chains-servicedescriptor)]* | A mapping from chainlet names to service descriptors. This is used create RPCs sessions to dependency chainlets. It contains only the chainlet services that are dependencies of the current chainlet. |
| `secrets` | *Mapping\[str,str]* | A mapping from secret names to secret values. It contains only the secrets that are listed in `remote_config.assets.secret_keys` of the current chainlet. |
| `environment` | *[Environment](#class-truss-chains-definitions-environment)\|None* | The environment that the chainlet is deployed in. None if the chainlet is not associated with an environment. |
#### get\_baseten\_api\_key()
* **Return type:**
str
#### get\_service\_descriptor(chainlet\_name)
**Parameters:**
| Name | Type | Description |
| --------------- | ----- | ------------------------- |
| `chainlet_name` | *str* | The name of the chainlet. |
* **Return type:**
[*ServiceDescriptor*](#class-truss-chains-servicedescriptor)
### *class* `truss_chains.definitions.Environment`
Bases: `pydantic.BaseModel`
The environment the chainlet is deployed in.
* **Parameters:**
**name** (*str*) – The name of the environment.
### *class* `truss_chains.ChainletOptions`
Bases: `pydantic.BaseModel`
**Parameters:**
| Name | Type | Description |
| -------------------- | ------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `enable_b10_tracing` | *bool* | enables baseten-internal trace data collection. This helps baseten engineers better analyze chain performance in case of issues. It is independent of a potentially user-configured tracing instrumentation. Turning this on, could add performance overhead. |
| `env_variables` | *Mapping\[str,str]* | static environment variables available to the deployed chainlet. |
### *class* `truss_chains.RPCOptions`
Bases: `pydantic.BaseModel`
Options to customize RPCs to dependency chainlets.
**Parameters:**
| Name | Type | Description |
| ------------- | ----- | ----------- |
| `timeout_sec` | *int* | |
| `retries` | *int* | |
### `truss_chains.mark_entrypoint`
Decorator to mark a chainlet as the entrypoint of a chain.
This decorator can be applied to *one* chainlet in a source file and then the
CLI push command simplifies because only the file, but not the chainlet class
in the file, needs to be specified.
Example usage:
```python
import truss_chains as chains
@chains.mark_entrypoint
class MyChainlet(ChainletBase):
...
```
**Parameters:**
| Name | Type | Description |
| ----- | --------------------------------------------------------- | ------------------- |
| `cls` | *Type\[[ChainletBase](#class-truss-chains-chainletbase)]* | The chainlet class. |
* **Return type:**
*Type*\[*ChainletBase*]
# Remote Configuration
These data structures specify for each chainlet how it gets deployed remotely, e.g. dependencies and compute resources.
### *class* `truss_chains.RemoteConfig`
Bases: `pydantic.BaseModel`
Bundles config values needed to deploy a chainlet remotely.
This is specified as a class variable for each chainlet class, e.g.:
```python
import truss_chains as chains
class MyChainlet(chains.ChainletBase):
remote_config = chains.RemoteConfig(
docker_image=chains.DockerImage(
pip_requirements=["torch==2.0.1", ...]
),
compute=chains.Compute(cpu_count=2, gpu="A10G", ...),
assets=chains.Assets(secret_keys=["hf_access_token"], ...),
)
```
**Parameters:**
| Name | Type | Description |
| -------------- | -------------------------------------------------------- | ----------- |
| `docker_image` | *[DockerImage](#class-truss-chains-dockerimage)* | |
| `compute` | *[Compute](#class-truss-chains-compute)* | |
| `assets` | *[Assets](#class-truss-chains-assets)* | |
| `name` | *str\|None* | |
| `options` | *[ChainletOptions](#class-truss-chains-chainletoptions)* | |
### *class* `truss_chains.DockerImage`
Bases: `pydantic.BaseModel`
Configures the docker image in which a remoted chainlet is deployed.
Any paths are relative to the source file where `DockerImage` is
defined and must be created with the helper function [`make_abs_path_here`](#truss-chains-make-abs-path-here).
This allows you for example organize chainlets in different (potentially nested)
modules and keep their requirement files right next their python source files.
**Parameters:**
| Name | Type | Description |
| ----------------------- | -------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `base_image` | *[BasetenImage](#class-truss-chains-basetenimage)\|[CustomImage](#class-truss-chains-customimage)* | The base image used by the chainlet. Other dependencies and assets are included as additional layers on top of that image. You can choose a baseten default image for a supported python version (e.g. `BasetenImage.PY311`), this will also include GPU drivers if needed, or provide a custom image (e.g. `CustomImage(image="python:3.11-slim")`). |
| `pip_requirements_file` | *AbsPath\|None* | Path to a file containing pip requirements. The file content is naively concatenated with `pip_requirements`. |
| `pip_requirements` | *list\[str]* | A list of pip requirements to install. The items are naively concatenated with the content of the `pip_requirements_file`. |
| `apt_requirements` | *list\[str]* | A list of apt requirements to install. |
| `data_dir` | *AbsPath\|None* | Data from this directory is copied into the docker image and accessible to the remote chainlet at runtime. |
| `external_package_dirs` | *list\[AbsPath]\|None* | A list of directories containing additional python packages outside the chain’s workspace dir, e.g. a shared library. This code is copied into the docker image and importable at runtime. |
### *class* `truss_chains.BasetenImage`
Bases: `Enum`
Default images, curated by baseten, for different python versions. If a Chainlet
uses GPUs, drivers will be included in the image.
| Enum Member | Value |
| ----------- | ------- |
| `PY310` | *py310* |
| `PY311 ` | *py311* |
| `PY39` | *py39* |
### *class* `truss_chains.CustomImage`
Bases: `pydantic.BaseModel`
Configures the usage of a custom image hosted on dockerhub.
**Parameters:**
| Name | Type | Description |
| ------------------------ | -------------------------- | -------------------------------------------------------------------------------------------------------- |
| `image` | *str* | Reference to image on dockerhub. |
| `python_executable_path` | *str\|None* | Absolute path to python executable (if default `python` is ambiguous). |
| `docker_auth` | *DockerAuthSettings\|None* | See [corresponding truss config](https://docs.baseten.co/truss-reference/config#base-image-docker-auth). |
### *class* `truss_chains.Compute`
Specifies which compute resources a chainlet has in the *remote* deployment.
Not all combinations can be exactly satisfied by available hardware, in some
cases more powerful machine types are chosen to make sure requirements are met
or over-provisioned. Refer to the
[baseten instance reference](https://docs.baseten.co/performance/instances).
**Parameters:**
| Name | Type | Description |
| --------------------- | ----------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `cpu_count` | *int* | Minimum number of CPUs to allocate. |
| `memory` | *str* | Minimum memory to allocate, e.g. “2Gi” (2 gibibytes). |
| `gpu` | *str\|Accelerator\|None* | GPU accelerator type, e.g. “A10G”, “A100”, refer to the [truss config](https://docs.baseten.co/reference/config#resources-accelerator) for more choices. |
| `gpu_count` | *int* | Number of GPUs to allocate. |
| `predict_concurrency` | *int\|Literal\['cpu\_count']* | Number of concurrent requests a single replica of a deployed chainlet handles. |
Concurrency concepts are explained
in [this guide](https://docs.baseten.co/deploy/guides/concurrency#predict-concurrency).
It is important to understand the difference between predict\_concurrency and
the concurrency target (used for autoscaling, i.e. adding or removing replicas).
Furthermore, the `predict_concurrency` of a single instance is implemented in
two ways:
* Via python’s `asyncio`, if `run_remote` is an async def. This
requires that `run_remote` yields to the event loop.
* With a threadpool if it’s a synchronous function. This requires
that the threads don’t have significant CPU load (due to the GIL).
### *class* `truss_chains.Assets`
Specifies which assets a chainlet can access in the remote deployment.
For example, model weight caching can be used like this:
```python
import truss_chains as chains
from truss.base import truss_config
mistral_cache = truss_config.ModelRepo(
repo_id="mistralai/Mistral-7B-Instruct-v0.2",
allow_patterns=["*.json", "*.safetensors", ".model"]
)
chains.Assets(cached=[mistral_cache], ...)
```
See [truss caching guide](https://docs.baseten.co/deploy/guides/model-cache#enabling-caching-for-a-model)
for more details on caching.
**Parameters:**
| Name | Type | Description |
| --------------- | ----------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `cached` | *Iterable\[ModelRepo]* | One or more `truss_config.ModelRepo` objects. |
| `secret_keys` | *Iterable\[str]* | Names of secrets stored on baseten, that the chainlet should have access to. You can manage secrets on baseten [here](https://app.baseten.co/settings/secrets). |
| `external_data` | *Iterable\[ExternalDataItem]* | Data to be downloaded from public URLs and made available in the deployment (via `context.data_dir`). See [here](https://docs.baseten.co/reference/config#external-data) for more details. |
# Core
General framework and helper functions.
### `truss_chains.push`
Deploys a chain remotely (with all dependent chainlets).
**Parameters:**
| Name | Type | Description |
| ----------------------- | --------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
| `entrypoint` | *Type\[[ChainletBase](#class-truss-chains-chainletbase)]* | The chainlet class that serves as the entrypoint to the chain. |
| `chain_name` | *str* | The name of the chain. |
| `publish` | *bool* | Whether to publish the chain as a published deployment (it is a draft deployment otherwise) |
| `promote` | *bool* | Whether to promote the chain to be the production deployment (this implies publishing as well). |
| `only_generate_trusses` | *bool* | Used for debugging purposes. If set to True, only the the underlying truss models for the chainlets are generated in `/tmp/.chains_generated`. |
| `remote` | *str\|None* | name of a remote config in .trussrc. If not provided, it will be inquired. |
| `environment` | *str\|None* | The name of an environment to promote deployment into. |
* **Returns:**
A chain service handle to the deployed chain.
* **Return type:**
[*ChainService*](#class-truss-chains-remote-chainservice)
### `truss_chains.deploy_remotely`
Deprecated, use [`push`](#truss-chains-push) instead.
### *class* `truss_chains.remote.ChainService`
Handle for a deployed chain.
A `ChainService` is created and returned when using `push`. It
bundles the individual services for each chainlet in the chain, and provides
utilities to query their status, invoke the entrypoint etc.
#### get\_info()
Queries the statuses of all chainlets in the chain.
* **Returns:**
List of `DeployedChainlet`, `(name, is_entrypoint, status, logs_url)`
for each chainlet.
* **Return type:**
list\[*DeployedChainlet*]
#### *property* name *: str*
#### run\_remote(json)
Invokes the entrypoint with JSON data.
**Parameters:**
| Name | Type | Description |
| ------ | ----------- | ---------------------------- |
| `json` | *JSON dict* | Input data to the entrypoint |
* **Returns:**
The JSON response.
* **Return type:**
*Any*
#### *property* run\_remote\_url *: str*
URL to invoke the entrypoint.
#### *property* status\_page\_url *: str*
Link to status page on Baseten.
### `truss_chains.make_abs_path_here`
Helper to specify file paths relative to the *immediately calling* module.
E.g. in you have a project structure like this:
```default
root/
chain.py
common_requirements.text
sub_package/
chainlet.py
chainlet_requirements.txt
```
You can now in `root/sub_package/chainlet.py` point to the requirements
file like this:
```python
shared = make_abs_path_here("../common_requirements.text")
specific = make_abs_path_here("chainlet_requirements.text")
```
This helper uses the directory of the immediately calling module as an
absolute reference point for resolving the file location. Therefore,
you MUST NOT wrap the instantiation of `make_abs_path_here` into a
function (e.g. applying decorators) or use dynamic code execution.
Ok:
```python
def foo(path: AbsPath):
abs_path = path.abs_path
foo(make_abs_path_here("./somewhere"))
```
Not Ok:
```python
def foo(path: str):
dangerous_value = make_abs_path_here(path).abs_path
foo("./somewhere")
```
**Parameters:**
| Name | Type | Description |
| ----------- | ----- | -------------------------- |
| `file_path` | *str* | Absolute or relative path. |
* **Return type:**
*AbsPath*
### `truss_chains.run_local`
Context manager local debug execution of a chain.
The arguments only need to be provided if the chainlets explicitly access any the
corresponding fields of
[`DeploymentContext`](#class-truss-chains-deploymentcontext).
**Parameters:**
| Name | Type | Description |
| --------------------- | ------------------------------------------------------------------------- | -------------------------------------------------------------- |
| `secrets` | *Mapping\[str,str]\|None* | A dict of secrets keys and values to provide to the chainlets. |
| `data_dir` | *Path\|str\|None* | Path to a directory with data files. |
| `chainlet_to_service` | *Mapping\[str,[ServiceDescriptor](#class-truss-chains-servicedescriptor)* | A dict of chainlet names to service descriptors. |
* **Return type:**
*ContextManager*\[None]
Example usage (as trailing main section in a chain file):
```python
import os
import truss_chains as chains
class HelloWorld(chains.ChainletBase):
...
if __name__ == "__main__":
with chains.run_local(
secrets={"some_token": os.environ["SOME_TOKEN"]},
chainlet_to_service={
"SomeChainlet": chains.ServiceDescriptor(
name="SomeChainlet",
predict_url="https://...",
options=chains.RPCOptions(),
)
},
):
hello_world_chain = HelloWorld()
result = hello_world_chain.run_remote(max_value=5)
print(result)
```
Refer to the
[local debugging guide](https://docs.baseten.co/chains/guide#test-a-chain-locally)
for more details.
### *class* `truss_chains.ServiceDescriptor`
Bases: `pydantic.BaseModel`
Bundles values to establish an RPC session to a dependency chainlet,
specifically with `StubBase`.
**Parameters:**
| Name | Type | Description |
| ------------- | ---------------------------------------------- | ----------- |
| `name` | *str* | |
| `predict_url` | *str* | |
| `options` | *[RPCOptions](#class-truss-chains-rpcoptions)* | |
## *class* `truss_chains.StubBase`
Base class for stubs that invoke remote chainlets.
It is used internally for RPCs to dependency chainlets, but it can also be used
in user-code for wrapping a deployed truss model into the chains framework, e.g.
like that:
```python
import pydantic
import truss_chains as chains
class WhisperOutput(pydantic.BaseModel):
...
class DeployedWhisper(chains.StubBase):
async def run_remote(self, audio_b64: str) -> WhisperOutput:
resp = await self._remote.predict_async(
json_payload={"audio": audio_b64})
return WhisperOutput(text=resp["text"], language=resp["language"])
class MyChainlet(chains.ChainletBase):
def __init__(self, ..., context=chains.depends_context()):
...
self._whisper = DeployedWhisper.from_url(
WHISPER_URL,
context,
options=chains.RPCOptions(retries=3),
)
```
**Parameters:**
| Name | Type | Description |
| -------------------- | ------------------------------------------------------------ | ----------------------------------------- |
| `service_descriptor` | *[ServiceDescriptor](#class-truss-chains-servicedescriptor)* | Contains the URL and other configuration. |
| `api_key` | *str* | A baseten API key to authorize requests. |
#### *classmethod* from\_url(predict\_url, context, options=None)
Factory method, convenient to be used in chainlet’s `__init__`-method.
**Parameters:**
| Name | Type | Description |
| ------------- | ------------------------------------------------------------ | ----------------------------------------------------------------- |
| `predict_url` | *str* | URL to predict endpoint of another chain / truss model. |
| `context` | *[DeploymentContext](#class-truss-chains-deploymentcontext)* | Deployment context object, obtained in the chainlet’s `__init__`. |
| `options` | *[RPCOptions](#class-truss-chains-rpcoptions)* | RPC options, e.g. retries. |
### *class* `truss_chains.RemoteErrorDetail`
Bases: `pydantic.BaseModel`
When a remote chainlet raises an exception, this pydantic model contains
information about the error and stack trace and is included in JSON form in the
error response.
**Parameters:**
| Name | Type | Description |
| ----------------------- | ------------------- | ----------- |
| `remote_name` | *str* | |
| `exception_cls_name` | *str* | |
| `exception_module_name` | *str\|None* | |
| `exception_message` | *str* | |
| `user_stack_trace` | *list\[StackFrame]* | |
#### format()
Format the error for printing, similar to how Python formats exceptions
with stack traces.
* **Return type:**
str
# Concepts
Glossary of Chains concepts and terminology
Chains is in beta mode. Read our [launch blog post](https://www.baseten.co/blog/introducing-baseten-chains/).
## Chainlet
A Chainlet is the basic building block of Chains. A Chainlet is a Python class
that specifies:
* A set of compute resources.
* A Python environment with software dependencies.
* A typed interface [`run_remote()`](/chains/concepts#run-remote-chaining-chainlets)
for other Chainlets to call.
This is the simplest possible Chainlet — only the
[`run_remote()`](/chains/concepts#run-remote-chaining-chainlets) method is
required — and we can layer in other concepts to create a more capable Chainlet.
```python
import truss_chains as chains
class SayHello(chains.ChainletBase):
def run_remote(self, name: str) -> str:
return f"Hello, {name}"
```
### Remote configuration
Chainlets are meant for deployment as remote services. Each Chainlet specifies
its own requirements for compute hardware (CPU count, GPU type and count, etc)
and software dependencies (Python libraries or system packages). This
configuration is built into a Docker image automatically as part of the
deployment process.
When no configuration is provided, the Chainlet will be deployed on a basic
instance with one vCPU, 2GB of RAM, no GPU, and a standard set of Python and
system packages.
Configuration is set using the
[`remote_config`](/chains-reference/sdk#remote-configuration) class variable
within the Chainlet:
```python
import truss_chains as chains
class MyChainlet(chains.ChainletBase):
remote_config = chains.RemoteConfig(
docker_image=chains.DockerImage(
pip_requirements=["torch==2.3.0", ...]
),
compute=chains.Compute(gpu="H100", ...),
assets=chains.Assets(secret_keys=["hf_access_token"], ...),
)
```
See the
[remote configuration reference](/chains-reference/sdk#remote-configuration)
for a complete list of options.
### Initialization
Chainlets are implemented as classes because we often want to set up expensive
static resources once at startup and then re-use it with each invocation of the
Chainlet. For example, we only want to initialize an AI model and download its
weights once then re-use it every time we run inference.
We do this setup in `__init__()`, which is run exactly once when the Chainlet is
deployed or scaled up.
```python
import truss_chains as chains
class PhiLLM(chains.ChainletBase):
def __init__(self) -> None:
import torch
import transformers
self._model = transformers.AutoModelForCausalLM.from_pretrained(
PHI_HF_MODEL,
torch_dtype=torch.float16,
device_map="auto",
)
self._tokenizer = transformers.AutoTokenizer.from_pretrained(
PHI_HF_MODEL,
)
```
Chainlet initialization also has two important features: context and dependency
injection of other Chainlets, explained below.
#### Context (access information)
You can add
[`DeploymentContext`](/chains-reference/sdk#class-truss-chains-deploymentcontext-generic-userconfigt)
object as an optional argument to the `__init__`-method of a Chainlet.
This allows you to use secrets within your Chainlet, such as using
a `hf_access_token` to access a gated model on Hugging Face (note that when
using secrets, they also need to be added to the `assets`).
```python
import truss_chains as chains
class MistralLLM(chains.ChainletBase):
remote_config = chains.RemoteConfig(
...
assets=chains.Assets(secret_keys=["hf_access_token"], ...),
)
def __init__(
self,
# Adding the `context` argument, allows us to access secrets
context: chains.DeploymentContext = chains.depends_context(),
) -> None:
import transformers
# Using the secret from context to access a gated model on HF
self._model = transformers.AutoModelForCausalLM.from_pretrained(
"mistralai/Mistral-7B-Instruct-v0.2",
use_auth_token=context.secrets["hf_access_token"],
)
```
#### Depends (call other Chainlets)
The Chains framework uses the
[`chains.depends()`](/chains-reference/sdk#truss-chains-depends) function in
Chainlets' `__init__()` method to track the dependency relationship between
different Chainlets within a Chain.
This syntax, inspired by dependency injection, is used to translate local Python
function calls into calls to the remote Chainlets in production.
Once a dependency Chainlet is added with
[`chains.depends()`](/chains-reference/sdk#truss-chains-depends), its
[`run_remote()`](/chains/concepts#run-remote-chaining-chainlets) method can
call this dependency Chainlet, e.g. below `HelloAll` we can make calls to
`SayHello`:
```python
import truss_chains as chains
class HelloAll(chains.ChainletBase):
def __init__(self, say_hello_chainlet=chains.depends(SayHello)) -> None:
self._say_hello = say_hello_chainlet
def run_remote(self, names: list[str]) -> str:
output = []
for name in names:
output.append(self._say_hello.run_remote(name))
return "\n".join(output)
```
## Run remote (chaining Chainlets)
The `run_remote()` method is run each time the Chainlet is called. It is the
sole public interface for the Chainlet (though you can have as many private
helper functions as you want) and its inputs and outputs must have type
annotations.
In `run_remote()` you implement the actual work of the Chainlet, such as model
inference or data chunking:
```python
import truss_chains as chains
class PhiLLM(chains.ChainletBase):
def run_remote(self, messages: Messages) -> str:
import torch
model_inputs = self._tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = self._tokenizer(model_inputs, return_tensors="pt")
input_ids = inputs["input_ids"].to("cuda")
with torch.no_grad():
outputs = self._model.generate(input_ids=input_ids,
**self._generate_args)
output_text = self._tokenizer.decode(outputs[0],
skip_special_tokens=True)
return output_text
```
If `run_remote()` makes calls to other Chainlets, e.g. invoking a dependency
Chainlet for each element in a list, you can benefit from concurrent execution,
by making the `run_remote()` an `async` method and starting the calls as
concurrent tasks `asyncio.ensure_future(self._dep_chainlet.run_remote(...))`.
## Entrypoint
The entrypoint is called directly from the deployed Chain's API endpoint and
kicks off the entire chain. The entrypoint is also responsible for returning the
final result back to the client.
Using the
[`@chains.mark_entrypoint`](/chains-reference/sdk#truss-chains-mark-entrypoint)
decorator, one Chainlet within a file is set as the entrypoint to the chain.
```python
@chains.mark_entrypoint
class HelloAll(chains.ChainletBase):
```
## Stub
Chains can be combined with existing Truss models using Stubs.
A Stub acts as a substitute (client-side proxy) for a remotely deployed
dependency, either a Chainlet or a Truss model. The Stub performs the remote
invocations as if it were local by taking care of the transport layer,
authentication, data serialization and retries.
Stubs can be integrated into Chainlets by passing in a URL of the deployed
model. They also require
[`context`](/chains/concepts#context-access-information) to be initialized
(for authentication).
```python
import truss_chains as chains
class LLMClient(chains.StubBase):
async def run_remote(
self,
prompt: str
) -> str:
# Call the deployed model
resp = await self._remote.predict_async(json_payload={
"messages": [{"role": "user", "content": prompt}],
"stream" : False
})
# Return a string with the model output
return resp["output"]
LLM_URL = ...
class MyChainlet(chains.ChainletBase):
def __init__(
self,
context: chains.DeploymentContext = chains.depends_context(),
):
self._llm = LLMClient.from_url(LLM_URL, context)
```
See the
[StubBase reference](/chains-reference/sdk#class-truss-chains-stubbase)
for details on the `StubBase` implementation.
## Pydantic data types
To make orchestrating multiple remotely deployed services possible, Chains
relies heavily on typed inputs and outputs. Values must be serialized to a safe
exchange format to be sent over the network.
The Chains framework uses the type annotations to infer how data should be
serialized and currently is restricted to types that are JSON compatible. Types
can be:
* Direct type annotations for simple types such as `int`, `float`,
or `list[str]`.
* Pydantic models to define a schema for nested data structures or multiple
arguments.
An example of pydantic input and output types for a Chainlet is given below:
```python
import enum
import pydantic
class Modes(enum.Enum):
MODE_0 = "MODE_0"
MODE_1 = "MODE_1"
class SplitTextInput(pydantic.BaseModel):
data: str
num_partitions: int
mode: Modes
class SplitTextOutput(pydantic.BaseModel):
parts: list[str]
part_lens: list[int]
```
Refer to the [pydantic docs](https://docs.pydantic.dev/latest/) for more
details on how
to define custom pydantic data models.
We are working on more efficient support for numeric data and bytes, for the
time being a workaround for dealing with these types is to use base64-encoding
and add them as a string-valued field to a pydantic model.
## Chains compared to Truss
Chains is an alternate SDK for packaging and deploying AI models. It carries over many features and concepts from Truss and gives you access to the benefits of Baseten (resource provisioning, autoscaling, fast cold starts, etc), but it is not a 1-1 replacement for Truss.
Here are some key differences:
* Rather than running `truss init` and creating a Truss in a directory, a Chain is a single file, giving you more flexibility for implementing multi-step model inference. Create an example with `truss chains init`.
* Configuration is done inline in typed Python code rather than in a `config.yaml` file.
* While Chainlets are converted to Truss models when run on Baseten, `Chainlet != TrussModel`.
Chains is designed for compatibility and incremental adoption, with a stub function for wrapping existing deployed models.
# Audio Transcription Chain
Transcribe hours of audio to text in a few seconds
Chains is in beta mode. Read our [launch blog post](https://www.baseten.co/blog/introducing-baseten-chains/).
[Learn more about Chains](/chains/overview).
## Prerequisites
To use Chains, install a recent Truss version and ensure pydantic is v2:
```bash
pip install --upgrade truss 'pydantic>=2.0.0'
```
Truss requires python `>=3.8,<3.13`. To set up a fresh development environment,
you can use the following commands, creating a environment named `chains_env`
using `pyenv`:
```bash
curl https://pyenv.run | bash
echo 'export PYENV_ROOT="$HOME/.pyenv"' >> ~/.bashrc
echo '[[ -d $PYENV_ROOT/bin ]] && export PATH="$PYENV_ROOT/bin:$PATH"' >> ~/.bashrc
echo 'eval "$(pyenv init -)"' >> ~/.bashrc
source ~/.bashrc
pyenv install 3.11.0
ENV_NAME="chains_env"
pyenv virtualenv 3.11.0 $ENV_NAME
pyenv activate $ENV_NAME
pip install --upgrade truss 'pydantic>=2.0.0'
```
To deploy Chains remotely, you also need a
[Baseten account](https://app.baseten.co/signup).
It is handy to export your API key to the current shell session or permanently in your `.bashrc`:
```bash ~/.bashrc
export BASETEN_API_KEY="nPh8..."
```
# Overview
This example shows how to transcribe audio media files to text blazingly fast
and at high quality using a Chain. To achieve this we will combine a
number of methods:
* Partitioning large input files (10h+) into smaller chunks.
* Analyzing the audio for silence to find optimal split points of the chunks.
* Distributing the chunks tasks across auto-scaling Baseten deployments.
* Using batching with a highly optimized transcription model to maximize
GPU utilization.
* Range downloads and pipelining of audio extraction to minimize latency.
* `asyncio` for concurrent execution of tasks.
The implementation is quite a bit of code, located in the
[Chains examples repo](https://github.com/basetenlabs/truss/tree/main/truss-chains/examples/audio-transcription).
This guide is a commentary on the code, pointing out critical parts or
explaining design choices.
If you want to try out this Chain and create a customized version of it, check
out the [try it yourself section](#try-it-yourself) below.
## The Chain structure
The chunking has a 2-step hierarchy:
"macro chunks" partition the full media into segments of in the range of
\~300s. This ensures that for very long files, the workload of a single
`MacroChunkWorker` is limited by that duration and the source data for the
different macro chunks is downloaded in parallel, making processing very
long files much faster. For shorter inputs, there will be only a single
"macro chunk".
"micro chunks" have durations in the range of 5-30s. These are sent to the
transcription model.
More details in the explanations of the Chainlets below.
The `WhisperModel` is split off the transcription Chain. This is
optional, but has some advantages:
* A lot of "business logic", which might more frequently be changed, is
implemented in the Chain. When developing or changing the Chain and making
frequent re-deployments, it's a faster dev loop to not re-deploy the
Whisper model, since as a large GPU model with heavy dependencies, this is
slower.
* The Whisper model can be used in other Chains, or standalone, if it's not
part of this Chain. Specifically the same model can be used by dev and
prod version of a Chain - otherwise a separate Whisper model would need to
be deployed with each environment.
* When making changes and improvements to the Whisper model, the development
can be split of the development of the Chain - think of a separation of
concerns into high-level (the Chain) and low-level (the model) development.
More information on how to use and deploy non-Chain models within a Chain is
given in the [WhisperModel section](#whispermodel) below.
### `Transcribe`
This Chainlet is the "entrypoint" to the Chain, external client send
transcription requests to it. It's endpoint implementation has the following
signature:
```python
async def run_remote(
self,
media_url: str,
params: data_types.TranscribeParams
) -> data_types.TranscribeOutput:
```
The input arguments are separated into `media_url`, the audio source to work on,
and `params` that control the execution, e.g. the chunk sizes. You can find the
exact schemas and docstrings of these arguments in
[data\_types.py](https://github.com/basetenlabs/truss/blob/main/truss-chains/examples/transcribe/data_types.py).
An example request looks like this:
```bash
curl -X POST $INVOCATION_URL \
-H "Authorization: Api-Key $BASETEN_API_KEY" \
-d ''
```
with JSON input:
```json
{
"media_url": "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/TearsOfSteel.mp4",
"params": {
"micro_chunk_size_sec": 30,
"macro_chunk_size_sec": 300
}
}
```
The output looks like this (truncated):
```json
{
"segments": [
...
{
"start_time_sec": 517.9465,
"end_time_sec": 547.70975,
"text": "The world's changed, Celia. Maybe we can too. Memory override complete!",
"language": "english",
"bcp47_key": "en"
},
{
"start_time_sec": 547.70975,
"end_time_sec": 567.0716874999999,
"text": "You know, there's a lesson to be learned from this. Could've gone worse.",
"language": "english",
"bcp47_key": "en"
},
...
],
"input_duration_sec": 734.261406,
"processing_duration_sec": 82.42135119438171,
"speedup": 8.908631020478238
}
```
The `Transcribe` Chainlet does the following:
* Assert that the media URL supports range downloads. This is usually a given
for video / audio hosting services.
* Uses `FFMPEG` to query the length of the medium (both video and audio
files are supported).
* Generates a list of "macro chunks", defined by their start and end times.
The length is defined by `macro_chunk_size_sec` in `TranscribeParams`.
This will soon be upgraded to find silence aware split points, so
that a chunk does not end in the middle of a spoken word. To do this a
small segment around the desired chunk boundary is downloaded (e.g. +/- 5
seconds) and the most silent timestamp within is determined.
* Sends the media URL with chunk limits as "tasks" to `MacroChunkWorker`. Using
`asyncio.ensure_future`, these tasks are dispatched concurrently - meaning
that the loop over the chunks does not wait for each chunk to complete first,
before dispatched the task on the next chunk. These "calls" are network
requests (RPCs) to the `MacroChunkWorker` Chainlet which runs on its own
deployment and can auto-scale, depending on the load.
* Once all tasks are dispatched, it waits for the results and concatenates
all the partial transcriptions from the chunks to a final output.
### `MacroChunkWorker`
The `MacroChunkWorker` Chainlet works on chunk tasks it receives from the
`Transcribe` Chainlet. For each chunk it does the following:
* It starts a `DownloadSubprocess` asynchronously (i.e. this will need CPU on
the machine, but not block the event loop of the main process, making it
possible to serve multiple concurrent requests).
* In `DownloadSubprocess`, `FFMPEG` is used to download the relevant time
range from the source. It extracts the audio wave form and streams the raw
wave `bytes` back to the main process. This happens on-the-fly (i.e. not
waiting for the full download to complete) - so the initial latency until
receiving wave bytes is minimized. Furthermore, it resamples the wave form
to the sampling rate expected by the transcription model and averages
multichannel audio to a mono signal.
{/*
One detail is that when streaming the wave bytes to the main
process, we need to intercept the wave metadata from the header. There is
a function in `helpers.py` for this: `_extract_wav_info`.
Quite a lot of case distinctions and logging is done for error handling and
resource cleanup in case of failures, e.g. in the exiting of the
`DownloadSubprocess`-context.
*/}
* The main process reads as many bytes from the wave stream as needed for
`micro_chunk_size_sec` (5-30s).
* A helper function `_find_silent_split_point` analyzes the wave form to
find the most silent point in the *second half* of the chunk. E.g. if the
`micro_chunk_size_sec` is 5s, then it searches the most silent point
between 2.5 and 5.0s and uses this time to partition the chunk.
* The wave bytes are converted to wave file format (i.e. including metadata
in the header) and then b64-encoded, so they can be sent as JSON via HTTP.
* For each b64-encoded "micro" chunk, the transcription model is invoked.
* Like in the `Transcribe` Chainlet, these tasks are concurrent RPCs, the
transcription model deployment can auto-scale with the load.
* Finally, we wait for all "micro chunk" results, concatenate them to a
* "macro chunk" result and return it to `Transcribe`.
### `WhisperModel`
As mentioned in the [structure section](#the-chain-structure), the
`WhisperModel` is separately deployed from the transcription Chain.
In the Chain implementation we only need to define a small "adapter" class
`WhisperModel`, mainly for integrating the I/O types of that model with our
Chain. This is a subclass of `chains. StubBase` which abstracts sending
requests, retries etc. away from us (this class is also used for all RPCs that
the Chains framework makes internally). Furthermore, we need to take the
invocation URL of that model (e.g.
`https://model-5woz91z3.api.baseten.co/production/predict`) and pass it
along when initializing this adapter class with the `from_url` factory-method.
There are two options for deploying a model separately from a Chain:
**As a Chainlet**
This is done in this example.
As a Chainlet it can even be in the same file, but not "wired" into the Chain
with the `chains.depends`-directive. In this example we put it into a separate
file `whisper_chainlet.py`.
* It will not be included in the deployment when running the `truss chains
deploy transcribe.py` command for the entrypoint, since it's not formally a
tracked dependency of that Chain.
* It is separately deployed, with a deploy command specifically targeting
that class i.e. `truss chains push whisper_chainlet.py`.
Using a structure like this, has the advantage of benefiting from
high code-coherence, e.g. the pydantic models for the input and output are
shared in both files (defined in the common `data_types.py`), while still
allowing independent deployment cycles.
**As a conventional Truss model**
This is not done in this example.
This could be anything, from the
[model library](https://www.baseten.co/library/), the
[Truss examples repo](https://github.com/basetenlabs/truss-examples) or your
[own Truss model](https://truss.baseten.co/quickstart).
This might be the better choice, if the model has a substantial code base
itself and if you want to avoid mixing that (and the development of it) with
the Chain code.
# Performance considerations
Even for very large files, e.g. 10h+, the end-to-end runtime is still bounded:
since the `macro_chunk_size_sec` is fixed, each sub-task has a
bounded runtime.
So provided all Chainlet components have enough resources
to auto-scale horizontally and the network bandwidth of the source hosting is
sufficient, the overall runtime is still relatively small. Note that
auto-scaling, e.g. the transcription model, to a large number of replicas can
take a while, so you'll only see the full speedup after a "warm-up" phase.
Depending on distribution of your input durations and the "spikiness" of
your traffic there are a few knobs to tweak:
* `micro_chunk_size_sec`: using too small "micro" chunks creates more
overhead and leaves GPU underutilized, using too large ones, they
processing of a single chunk might take too long or overflow the GPU model
\-- the sweet spot is in the middle.
* `macro_chunk_size_sec`: larger chunks mean less overhead, but also less
download parallelism.
* Predict-concurrency and autoscaling settings of all deployed components.
Specifically make sure that the WhisperModel can scale up to enough
replicas (but should also not be underutilized). Look at the GPU and CPU
utilization metrics of the deployments.
# Try it yourself
If you want to try this yourself follow the steps below:
All code can be found and copied in this
[example directory](https://github.com/basetenlabs/truss/tree/main/truss-chains/examples/audio-transcription).
* Download the example code.
* Deploy the Whisper Chainlet first: `truss chains push whisper_chainlet.py`.
* Note the invocation URL of the form `https://chain-.api.baseten.co/production/run_remote`
and insert that URL as a value for `WHISPER_URL` in `transcribe.py`.
You can find the URL in the output of the push command or on the status
page.
* Deploy the transcription Chain with `truss chains push transcribe.py`.
As media source URL, you can pass both video or audio sources, as long as the
format can be handled by `FFMPEG` and the hosted file supports range downloads.
A public test file you can use is shown in the example below.
```bash
curl -X POST $INVOCATION_URL \
-H "Authorization: Api-Key $BASETEN_API_KEY" \
-d ''
```
with JSON input:
```json
{
"media_url": "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/TearsOfSteel.mp4",
"params": {
"micro_chunk_size_sec": 30,
"macro_chunk_size_sec": 300
}
}
```
# RAG Chain
Build a RAG (retrieval-augmented generation) pipeline with Chains
Chains is in beta mode. Read our [launch blog post](https://www.baseten.co/blog/introducing-baseten-chains/).
[Learn more about Chains](/chains/overview)
## Prerequisites
To use Chains, install a recent Truss version and ensure pydantic is v2:
```bash
pip install --upgrade truss 'pydantic>=2.0.0'
```
Truss requires python `>=3.8,<3.13`. To set up a fresh development environment,
you can use the following commands, creating a environment named `chains_env`
using `pyenv`:
```bash
curl https://pyenv.run | bash
echo 'export PYENV_ROOT="$HOME/.pyenv"' >> ~/.bashrc
echo '[[ -d $PYENV_ROOT/bin ]] && export PATH="$PYENV_ROOT/bin:$PATH"' >> ~/.bashrc
echo 'eval "$(pyenv init -)"' >> ~/.bashrc
source ~/.bashrc
pyenv install 3.11.0
ENV_NAME="chains_env"
pyenv virtualenv 3.11.0 $ENV_NAME
pyenv activate $ENV_NAME
pip install --upgrade truss 'pydantic>=2.0.0'
```
To deploy Chains remotely, you also need a
[Baseten account](https://app.baseten.co/signup).
It is handy to export your API key to the current shell session or permanently in your `.bashrc`:
```bash ~/.bashrc
export BASETEN_API_KEY="nPh8..."
```
If you want to run this example in
[local debugging mode](/chains/guide#test-a-chain-locally), you'll also need to
install chromadb:
```shell
pip install chromadb
```
The complete code used in this tutorial can also be found in the
[Chains examples repo](https://github.com/basetenlabs/truss/tree/main/truss-chains/examples/rag).
# Overview
Retrieval-augmented generation (RAG) is a multi-model pipeline for generating
context-aware answers from LLMs.
There are a number of ways to build a RAG system. This tutorial shows a minimum
viable implementation with a basic vector store and retrieval function. It's
intended as a starting point to show how Chains helps you flexibly combine model
inference and business logic.
In this tutorial, we'll build a simple RAG pipeline for a hypothetical alumni
matching service for a university. The system:
1. Takes a bio with information about a new graduate
2. Uses a vector database to retrieve semantically similar bios of other alums
3. Uses an LLM to explain why the new graduate should meet the selected alums
4. Returns the writeup from the LLM
Let's dive in!
## Building the Chain
Create a file `rag.py` in a new directory with:
```sh
mkdir rag
touch rag/rag.py
cd rag
```
Our RAG Chain is composed of three parts:
* `VectorStore`, a Chainlet that implements a vector database with a retrieval
function.
* `LLMClient`, a Stub for connecting to a deployed LLM.
* `RAG`, the entrypoint Chainlet that orchestrates the RAG pipeline and
has `VectorStore` and `LLMClient` as dependencies.
We'll examine these components one by one and then see how they all work
together.
### Vector store Chainlet
A real production RAG system would use a hosted vector database with a massive
number of stored embeddings. For this example, we're using a small local vector
store built with `chromadb` to stand in for a more complex system.
The Chainlet has three parts:
* [`remote_config`](/chains-reference/sdk#remote-configuration), which
configures a Docker image on deployment with dependencies.
* `__init__()`, which runs once when the Chainlet is spun up, and creates the
vector database with ten sample bios.
* [`run_remote()`](/chains/concepts#run-remote-chaining-chainlets), which runs
each time the Chainlet is called and is the sole public interface for the
Chainlet.
```python rag/rag.py
import truss_chains as chains
# Create a Chainlet to serve as our vector database.
class VectorStore(chains.ChainletBase):
# Add chromadb as a dependency for deployment.
remote_config = chains.RemoteConfig(
docker_image=chains.DockerImage(
pip_requirements=["chromadb"]
)
)
# Runs once when the Chainlet is deployed or scaled up.
def __init__(self):
# Import Chainlet-specific dependencies in init, not at the top of
# the file.
import chromadb
self._chroma_client = chromadb.EphemeralClient()
self._collection = self._chroma_client.create_collection(name="bios")
# Sample documents are hard-coded for your convenience
documents = [
"Angela Martinez is a tech entrepreneur based in San Francisco. As the founder and CEO of a successful AI startup, she is a leading figure in the tech community. Outside of work, Angela enjoys hiking the trails around the Bay Area and volunteering at local animal shelters.",
"Ravi Patel resides in New York City, where he works as a financial analyst. Known for his keen insight into market trends, Ravi spends his weekends playing chess in Central Park and exploring the city's diverse culinary scene.",
"Sara Kim is a digital marketing specialist living in San Francisco. She helps brands build their online presence with creative strategies. Outside of work, Sara is passionate about photography and enjoys hiking the trails around the Bay Area.",
"David O'Connor calls New York City his home and works as a high school teacher. He is dedicated to inspiring the next generation through education. In his free time, David loves running along the Hudson River and participating in local theater productions.",
"Lena Rossi is an architect based in San Francisco. She designs sustainable and innovative buildings that contribute to the city's skyline. When she's not working, Lena enjoys practicing yoga and exploring art galleries.",
"Akio Tanaka lives in Tokyo and is a software developer specializing in mobile apps. Akio is an avid gamer and enjoys attending eSports tournaments. He also has a passion for cooking and often experiments with new recipes in his spare time.",
"Maria Silva is a nurse residing in New York City. She is dedicated to providing compassionate care to her patients. Maria finds joy in gardening and often spends her weekends tending to her vibrant flower beds and vegetable garden.",
"John Smith is a journalist based in San Francisco. He reports on international politics and has a knack for uncovering compelling stories. Outside of work, John is a history buff who enjoys visiting museums and historical sites.",
"Aisha Mohammed lives in Tokyo and works as a graphic designer. She creates visually stunning graphics for a variety of clients. Aisha loves to paint and often showcases her artwork in local exhibitions.",
"Carlos Mendes is an environmental engineer in San Francisco. He is passionate about developing sustainable solutions for urban areas. In his leisure time, Carlos enjoys surfing and participating in beach clean-up initiatives."
]
# Add all documents to the database
self._collection.add(
documents=documents,
ids=[f"id{n}" for n in range(len(documents))]
)
# Runs each time the Chainlet is called
async def run_remote(self, query: str) -> list[str]:
# This call to includes embedding the query string.
results = self._collection.query(query_texts=[query], n_results=2)
if results is None or not results:
raise ValueError("No bios returned from the query")
if not results["documents"] or not results["documents"][0]:
raise ValueError("Bios are empty")
return results["documents"][0]
```
### LLM inference stub
Now that we can retrieve relevant bios from the vector database, we need to pass
that information to an LLM to generate our final output.
Chains can integrate previously deployed models using a Stub. Like Chainlets,
Stubs implement
[`run_remote()`](/chains/concepts#run-remote-chaining-chainlets), but as a call
to the deployed model.
For our LLM, we'll use Phi-3 Mini Instruct, a small-but-mighty open source LLM.
One-click model deployment from Baseten's model library.
While the model is deploying, be sure to note down the models' invocation URL from
the model dashboard for use in the next step.
To use our deployed LLM in the RAG Chain, we define a Stub:
```python rag/rag.py
class LLMClient(chains.StubBase):
# Runs each time the Stub is called
async def run_remote(self, new_bio: str, bios: list[str]) -> str:
# Use the retrieved bios to augment the prompt -- here's the "A" in RAG!
prompt = f"""You are matching alumni of a college to help them make connections. Explain why the person described first would want to meet the people selected from the matching database.
Person you're matching: {new_bio}
People from database: {" ".join(bios)}"""
# Call the deployed model.
resp = await self._remote.predict_async(json_payload={
"messages": [{"role": "user", "content": prompt}],
"stream" : False
})
return resp["output"][len(prompt) :].strip()
```
### RAG entrypoint Chainlet
The entrypoint to a Chain is the Chainlet that specifies the public-facing input
and output of the Chain and orchestrates calls to dependencies.
The `__init__` function in this Chainlet takes two new arguments:
* Add dependencies to any Chainlet with
[`chains.depends()`](/chains-reference/sdk#truss-chains-depends). Only
Chainlets, not Stubs, need to be added in this fashion.
* Use
[`chains.depends_context()`](/chains-reference/sdk#truss-chains-depends-context)
to inject a context object at runtime. This context object is required to
initialize the `LLMClient` stub.
* Visit your [baseten workspace](https://app.baseten.co/models) to find your
the URL of the previously deployed Phi-3 model and insert if as value
for `LLM_URL`.
```python rag/rag.py
# Insert the URL from the previously deployed Phi-3 model.
LLM_URL = ...
@chains.mark_entrypoint
class RAG(chains.ChainletBase):
# Runs once when the Chainlet is spun up
def __init__(
self,
# Declare dependency chainlets.
vector_store: VectorStore = chains.depends(VectorStore),
context: chains.DeploymentContext = chains.depends_context(),
):
self._vector_store = vector_store
# The stub needs the context for setting up authentication.
self._llm = LLMClient.from_url(LLM_URL, context)
# Runs each time the Chain is called
async def run_remote(self, new_bio: str) -> str:
# Use the VectorStore Chainlet for context retrieval.
bios = await self._vector_store.run_remote(new_bio)
# Use the LLMClient Stub for augmented generation.
contacts = await self._llm.run_remote(new_bio, bios)
return contacts
```
## Testing locally
Because our Chain uses a Stub for the LLM call, we can run the whole Chain
locally without any GPU resources.
Before running the Chainlet, make sure to set your Baseten API key as an
environment variable `BASETEN_API_KEY`.
```python rag/rag.py
if __name__ == "__main__":
import os
import asyncio
with chains.run_local(
# This secret is needed even locally, because part of this chain
# calls the separately deployed Phi-3 model. Only the Chainlets
# actually run locally.
secrets={"baseten_chain_api_key": os.environ["BASETEN_API_KEY"]}
):
rag_client = RAG()
result = asyncio.get_event_loop().run_until_complete(
rag_client.run_remote(
"""
Sam just moved to Manhattan for his new job at a large bank.
In college, he enjoyed building sets for student plays.
"""
)
)
print(result)
```
We can run our Chain locally:
```sh
python rag.py
```
After a few moments, we should get a recommendation for why Sam should meet the
alumni selected from the database.
## Deploying to production
Once we're satisfied with our Chain's local behavior, we can deploy it to
production on Baseten. To deploy the Chain, run:
```sh
truss chains push rag.py
```
This will deploy our Chain as a development deployment. Once the Chain is
deployed, we can call it from its API endpoint.
You can do this in the console with cURL:
```sh
curl -X POST 'https://chain-5wo86nn3.api.baseten.co/development/run_remote' \
-H "Authorization: Api-Key $BASETEN_API_KEY" \
-d '{"new_bio": "Sam just moved to Manhattan for his new job at a large bank.In college, he enjoyed building sets for student plays."}'
```
Alternatively, you can also integrate this in a Python application:
```python call_chain.py
import requests
import os
# Insert the URL from the deployed rag chain. You can get it from the CLI
# output or the status page, e.g.
# "https://chain-6wgeygoq.api.baseten.co/production/run_remote".
RAG_CHAIN_URL = ""
baseten_api_key = os.environ["BASETEN_API_KEY"]
if not RAG_CHAIN_URL:
raise ValueError("Please insert the URL for the RAG chain.")
resp = requests.post(
RAG_CHAIN_URL,
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={"new_bio": new_bio},
)
print(resp.json())
```
When we're happy with the deployed Chain, we can promote it to production via
the UI or by running:
```sh
truss chains push --promote rag.py
```
Once in production, the Chain will have access to full autoscaling settings.
Both the development and production deployments will scale to zero when not in
use.
# Build your first Chain
Build and deploy two example Chains
Chains is in beta mode. Read our [launch blog post](https://www.baseten.co/blog/introducing-baseten-chains/).
This quickstart guide contains instructions for creating two Chains:
1. A simple CPU-only “hello world”-Chain.
2. A Chain that implements Phi-3 Mini and uses it to write poems.
## Prerequisites
To use Chains, install a recent Truss version and ensure pydantic is v2:
```bash
pip install --upgrade truss 'pydantic>=2.0.0'
```
Truss requires python `>=3.8,<3.13`. To set up a fresh development environment,
you can use the following commands, creating a environment named `chains_env`
using `pyenv`:
```bash
curl https://pyenv.run | bash
echo 'export PYENV_ROOT="$HOME/.pyenv"' >> ~/.bashrc
echo '[[ -d $PYENV_ROOT/bin ]] && export PATH="$PYENV_ROOT/bin:$PATH"' >> ~/.bashrc
echo 'eval "$(pyenv init -)"' >> ~/.bashrc
source ~/.bashrc
pyenv install 3.11.0
ENV_NAME="chains_env"
pyenv virtualenv 3.11.0 $ENV_NAME
pyenv activate $ENV_NAME
pip install --upgrade truss 'pydantic>=2.0.0'
```
To deploy Chains remotely, you also need a
[Baseten account](https://app.baseten.co/signup).
It is handy to export your API key to the current shell session or permanently in your `.bashrc`:
```bash ~/.bashrc
export BASETEN_API_KEY="nPh8..."
```
## Example: Hello World
Chains are written in Python files. In your working directory,
create `hello_chain/hello.py`:
```sh
mkdir hello_chain
cd hello_chain
touch hello.py
```
In the file, we'll specify a basic Chain. It has two Chainlets:
* `HelloWorld`, the entrypoint, which handles the input and output.
* `RandInt`, which generates a random integer. It is used a as a dependency
by `HelloWorld`.
Via the entrypoint, the Chain takes a maximum value and returns the string "
Hello World!" repeated a
variable number of times.
```python hello.py
import random
import truss_chains as chains
class RandInt(chains.ChainletBase):
def run_remote(self, max_value: int) -> int:
return random.randint(1, max_value)
@chains.mark_entrypoint
class HelloWorld(chains.ChainletBase):
def __init__(self, rand_int=chains.depends(RandInt, retries=3)) -> None:
self._rand_int = rand_int
def run_remote(self, max_value: int) -> str:
num_repetitions = self._rand_int.run_remote(max_value)
return "Hello World! " * num_repetitions
```
### The Chainlet class-contract
Exactly one Chainlet must be marked as the entrypoint with
the [`@chains.mark_entrypoint`](/chains-reference/sdk#truss-chains-mark-entrypoint)
decorator. This Chainlet is responsible for
handling public-facing input and output for the whole Chain in response to an
API call.
A Chainlet class has a single public method,
[`run_remote()`](/chains/concepts#run-remote-chaining-chainlets), which is
the API
endpoint for the entrypoint Chainlet and the function that other Chainlets can
use as a dependency. The
[`run_remote()`](/chains/concepts#run-remote-chaining-chainlets)
method must be fully type-annotated
with primitive python
types
or [pydantic models](https://docs.pydantic.dev/latest/).
Chainlets cannot be
naively
instantiated. The only correct usages are:
1. Make one Chainlet depend on another one via the
[`chains.depends()`](/chains-reference/sdk#truss-chains-depends) directive
as an `__init__`-argument as shown above for the `RandInt` Chainlet.
2. In the [local debugging mode](/chains/guide#test-a-chain-locally).
Beyond that, you can structure your code as you like, with private methods,
imports from other files, and so forth.
Keep in mind that Chainlets are intended for distributed, replicated, remote
execution, so using global variables, global state, and certain Python features
like importing modules dynamically at runtime should be avoided as they may not
work as intended.
### Deploy your Chain to Baseten
To deploy your Chain to Baseten, run:
```bash
truss chains push hello.py
```
The deploy command results in an output like this:
```
⛓️ HelloWorld - Chainlets ⛓️
╭──────────────────────┬─────────────────────────┬─────────────╮
│ Status │ Name │ Logs URL │
├──────────────────────┼─────────────────────────┼─────────────┤
│ 💚 ACTIVE │ HelloWorld (entrypoint) │ https://... │
├──────────────────────┼─────────────────────────┼─────────────┤
│ 💚 ACTIVE │ RandInt (dep) │ https://... │
╰──────────────────────┴─────────────────────────┴─────────────╯
Deployment succeeded.
You can run the chain with:
curl -X POST 'https://chain-.../run_remote' \
-H "Authorization: Api-Key $BASETEN_API_KEY" \
-d ''
```
Wait for the status to turn to `ACTIVE` and test invoking your Chain (replace
`$INVOCATION_URL` in below command):
```bash
curl -X POST $INVOCATION_URL \
-H "Authorization: Api-Key $BASETEN_API_KEY" \
-d '{"max_value": 10}'
# "Hello World! Hello World! Hello World! "
```
## Example: Poetry with LLMs
Our second example also has two Chainlets, but is somewhat more complex and
realistic. The Chainlets are:
* `PoemGenerator`, the entrypoint, which handles the input and output and
orchestrates calls to the LLM.
* `PhiLLM`, which runs inference on Phi-3 Mini.
This Chain takes a list of words and returns a poem about each word, written by
Phi-3. Here's the architecture:
We build this Chain in a new working directory (if you are still inside
`hello_chain/`, go up one level with `cd ..` first):
```sh
mkdir poetry_chain
cd poetry_chain
touch poems.py
```
A similar ent-to-end code example, using Mistral as an LLM, is available in the
[examples repo](https://github.com/basetenlabs/truss/tree/main/truss-chains/examples/mistral).
### Building the LLM Chainlet
The main difference between this Chain and the previous one is that we now have
an LLM that needs a GPU and more complex dependencies.
Copy the following code into `poems.py`:
```python poems.py
import asyncio
from typing import List
import pydantic
import truss_chains as chains
from truss import truss_config
PHI_HF_MODEL = "microsoft/Phi-3-mini-4k-instruct"
# This configures to cache model weights from the hunggingface repo
# in the docker image that is used for deploying the Chainlet.
PHI_CACHE = truss_config.ModelRepo(
repo_id=PHI_HF_MODEL, allow_patterns=["*.json", "*.safetensors", ".model"]
)
class Messages(pydantic.BaseModel):
messages: List[dict[str, str]]
class PhiLLM(chains.ChainletBase):
# `remote_config` defines the resources required for this chainlet.
remote_config = chains.RemoteConfig(
docker_image=chains.DockerImage(
# The phi model needs some extra python packages.
pip_requirements=[
"accelerate==0.30.1",
"einops==0.8.0",
"transformers==4.41.2",
"torch==2.3.0",
]
),
# The phi model needs a GPU and more CPUs.
compute=chains.Compute(cpu_count=2, gpu="T4"),
# Cache the model weights in the image
assets=chains.Assets(cached=[PHI_CACHE]),
)
def __init__(self) -> None:
# Note the imports of the *specific* python requirements are
# pushed down to here. This code will only be executed on the
# remotely deployed Chainlet, not in the local environment,
# so we don't need to install these packages in the local
# dev environment.
import torch
import transformers
self._model = transformers.AutoModelForCausalLM.from_pretrained(
PHI_HF_MODEL,
torch_dtype=torch.float16,
device_map="auto",
)
self._tokenizer = transformers.AutoTokenizer.from_pretrained(
PHI_HF_MODEL,
)
self._generate_args = {
"max_new_tokens": 512,
"temperature": 1.0,
"top_p": 0.95,
"top_k": 50,
"repetition_penalty": 1.0,
"no_repeat_ngram_size": 0,
"use_cache": True,
"do_sample": True,
"eos_token_id": self._tokenizer.eos_token_id,
"pad_token_id": self._tokenizer.pad_token_id,
}
async def run_remote(self, messages: Messages) -> str:
import torch
model_inputs = self._tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = self._tokenizer(model_inputs, return_tensors="pt")
input_ids = inputs["input_ids"].to("cuda")
with torch.no_grad():
outputs = self._model.generate(input_ids=input_ids, **self._generate_args)
output_text = self._tokenizer.decode(outputs[0], skip_special_tokens=True)
return output_text
```
### Building the entrypoint
Now that we have an LLM, we can use it in a poem generator Chainlet. Add the
following code to `poems.py`:
```python poems.py
@chains.mark_entrypoint
class PoemGenerator(chains.ChainletBase):
def __init__(self, phi_llm: PhiLLM = chains.depends(PhiLLM)) -> None:
self._phi_llm = phi_llm
async def run_remote(self, words: list[str]) -> list[str]:
tasks = []
for word in words:
messages = Messages(
messages=[
{
"role": "system",
"content": (
"You are poet who writes short, "
"lighthearted, amusing poetry."
),
},
{"role": "user", "content": f"Write a poem about {word}"},
]
)
tasks.append(asyncio.ensure_future(self._phi_llm.run_remote(messages)))
return list(await asyncio.gather(*tasks))
```
Note that we use `asyncio.ensure_future` around each RPC to the LLM chainlet.
This makes the current python process start these remote calls concurrently,
i.e. the next call is started before the previous one has finished and we can
minimize our overall runtime. In order to await the results of all calls,
`asyncio.gather` is used which gives us back normal python objects.
If the LLM is hit with many concurrent requests, it can auto-scale up (if
autoscaling is configure). More advanced LLM models have batching capabilities,
so for those even a single instance can serve concurrent request.
### Deploy your Chain to Baseten
To deploy your Chain to Baseten, run:
```bash
truss chains push poems.py
```
Wait for the status to turn to `ACTIVE` and test invoking your Chain (replace
`$INVOCATION_URL` in below command):
```bash
curl -X POST $INVOCATION_URL \
-H "Authorization: Api-Key $BASETEN_API_KEY" \
-d '{"words": ["bird", "plane", "superman"]}'
#[[
#" [INST] Generate a poem about: bird [/INST] In the quiet hush of...",
#" [INST] Generate a poem about: plane [/INST] In the vast, boudl...",
#" [INST] Generate a poem about: superman [/INST] In the realm where..."
#]]
```
# User Guides
Using the full potential of Chains
Chains is in beta mode. Read our [launch blog post](https://www.baseten.co/blog/introducing-baseten-chains/).
## Designing the architecture of a Chain
A Chain is composed of multiple connecting Chainlets working together to perform
a task.
For example, the Chain in the diagram below takes a large audio file. Then it
splits it into smaller chunks, transcribes each chunk in parallel to speed up
the transcription process, and finally aggregates and returns the results.
To build an efficient end-to-end Chain, we recommend drafting your high level
structure as a flowchart or diagram. This will help you identify the Chainlets
needed and how to link them.
If one Chainlet creates many "sub-tasks" by calling other dependency
Chainlets (e.g. in a loop over partial work items),
these calls should be done as `aynscio`-tasks that run concurrently.
That way you get the most out of the parallelism that Chains offers. This
design pattern is extensively used in the
[audio transcription example](/chains/examples/audio-transcription).
## Local development
Chains are designed for production in replicated remote deployments. But
alongside that production-ready power, we need great local development and
deployment experiences.
Chains exists to help you build multi-step, multi-model pipelines. The
abstractions that Chains introduces are based on six opinionated principles:
three for architecture and three for developer experience.
**Architecture principles**
Each step in the pipeline can set its own hardware requirements and
software dependencies, separating GPU and CPU workloads.
Each component has independent autoscaling parameters for targeted
resource allocation, removing bottlenecks from your pipelines.
Components specify a single public interface for flexible-but-safe
composition and are reusable between projects
**Developer experience principles**
Eliminate entire taxonomies of bugs by writing typed Python code and
validating inputs, outputs, module initializations, function signatures,
and even remote server configurations.
Seamless local testing and cloud deployments: test Chains locally with
support for mocking the output of any step and simplify your cloud
deployment loops by separating large model deployments from quick
updates to glue code.
Use Chains to orchestrate existing model deployments, like pre-packaged
models from Baseten’s model library, alongside new model pipelines built
entirely within Chains.
Locally, a Chain is just Python files in a source tree. While that gives you a
lot of flexibility in how you structure your code, there are some constraints
and rules to follow to ensure successful distributed, remote execution in
production.
The best thing you can do while developing locally with Chains is torun your
code frequently, even if you do not have a `__main__` section: the Chains
framework runs various validations at
module initialization to help
you catch issues early.
Additionally, running `mypy` and fixing reported type errors can help you
find problems early and in a rapid feedback loop, before attempting a (much
slower) deployment.
Complementary to the purely local development Chains also has a "watch" mode,
like Truss, see the [watch section below](#Watch).
### Test a Chain locally
Let's revisit our "Hello World" Chain:
```python hello_chain/hello.py
import asyncio
import truss_chains as chains
# This Chainlet does the work
class SayHello(chains.ChainletBase):
async def run_remote(self, name: str) -> str:
return f"Hello, {name}"
# This Chainlet orchestrates the work
@chains.mark_entrypoint
class HelloAll(chains.ChainletBase):
def __init__(self, say_hello_chainlet=chains.depends(SayHello)) -> None:
self._say_hello = say_hello_chainlet
async def run_remote(self, names: list[str]) -> str:
tasks = []
for name in names:
tasks.append(asyncio.ensure_future(
self._say_hello.run_remote(name)))
return "\n".join(await asyncio.gather(*tasks))
# Test the Chain locally
if __name__ == "__main__":
with chains.run_local():
hello_chain = HelloAll()
result = asyncio.get_event_loop().run_until_complete(
hello_chain.run_remote(["Marius", "Sid", "Bola"]))
print(result)
```
When the `__main__()` module is run, local instances of the Chainlets are
created, allowing you to test functionality of your chain just by executing the
Python file:
```bash
cd hello_chain
python hello.py
# Hello, Marius
# Hello, Sid
# Hello, Bola
```
### Mock execution of GPU Chainlets
Using `run_local()` to run your code locally requires that your development
environment have the compute resources and dependencies that each Chainlet
needs. But that often isn't possible when building with AI models.
Chains offers a workaround, mocking, to let you test the coordination and
business logic of your multi-step inference pipeline without worrying about
running the model locally.
The second example in the [getting started guide](/chains/getting-started)
implements a Truss Chain for generating poems with Phi-3.
This Chain has two Chainlets:
1. The `PhiLLM` Chainlet, which requires an NVIDIA A10G GPU.
2. The `PoemGenerator` Chainlet, which easily runs on a CPU.
If you have an NVIDIA T4 under your desk, good for you. For the rest of us, we
can mock the `PhiLLM` Chainlet that is infeasible to run locally so that we can
quickly test the `PoemGenerator` Chainlet.
To do this, we define a mock Phi-3 model in our `__main__` module and give it
a [`run_remote()`](/chains/concepts#run-remote-chaining-chainlets) method that
produces a test output that matches the output type we expect from the real
Chainlet. Then, we inject an instance of this mock Chainlet into our Chain:
```python poems.py
if __name__ == "__main__":
class FakePhiLLM:
def run_remote(self, prompt: str) -> str:
return f"Here's a poem about {prompt.split(" ")[-1]}"
with chains.run_local():
poem_generator = PoemGenerator(phi_llm=FakePhiLLM())
result = poem_generator.run_remote(words=["bird", "plane", "superman"])
print(result)
```
And run your Python file:
```bash
python poems.py
# ['Here's a poem about bird', 'Here's a poem about plane', 'Here's a poem about superman']
```
You may notice that the argument `phi_llm` expects a type `PhiLLM`, while we are passing it an instance of `FakePhiLLM`. These aren't the same, which should be a type error.
However, this works at runtime because we constructed `FakePhiLLM` to use the
same protocol as the real thing. We can make this explicit by defining
a `Protocol` as a type annotation:
```python
from typing import Protocol
class PhiProtocol(Protocol):
def run_remote(self, data: str) -> str:
...
```
and changing the argument type in `PoemGenerator`:
```python
@chains.mark_entrypoint
class PoemGenerator(chains.ChainletBase):
def __init__(self, phi_llm: PhiProtocol = chains.depends(PhiLLM)) -> None:
self._phi_llm = phi_llm
```
This resolves the apparent type error.
## Chains Watch
The [watch command](/chains-reference/cli#watch) (`truss chains watch`) combines
the best of local development and full deployment. `watch` lets you run on an
exact copy of the production hardware and interface but gives you live reload
that lets you test changes in seconds without creating a new deployment.
To use `truss chains watch`:
1. Push a chain in development mode (i.e. `publish` and `promote` flags are
false).
2. Run the watch command `truss chains watch SOURCE`. You can also add the
`watch` option to the `push` command and combine both to a single step.
3. Each time you edit a file and save the changes, the watcher patches the
remote deployments. Updating the deployments might take a moment, but is
generally *much* faster than creating a new deployment.
4. You can call the chain with test data via `cURL` or the call dialogue
in the UI and observe the result and logs.
5. Iterate steps 3. and 4. until your chain behaves in the desired way.
## Deploy a Chain
Deploying a Chain is an atomic action that deploys every Chainlet
within the chain separately. Each Chainlet specifies its own remote
environment — hardware resources, Python and system dependencies, autoscaling
settings.
### Development
To deploy a Chain as a development deployment, run:
```sh
truss chains push ./my_chain.py
```
Where `my_chain.py` contains the entrypoint Chainlet for your Chain.
Development deployments are intended for testing and can't scale past one
replica. Each time you make a development deployment, it overwrites the existing
development deployment.
Development deployments support rapid iteration with `watch` - see [above
guide](#chains-watch).
### 🆕 Environments
To deploy a Chain to an environment, run:
```sh
truss chains push ./my_chain.py --environment {env_name}
```
Environments are intended for live traffic and have access to full
autoscaling settings. Each time you deploy to an environment, a new deployment is
created. Once the new deployment is live, it replaces the previous deployment,
which is relegated to the published deployments list.
[Learn more](/deploy/lifecycle#what-is-an-environment) about environments.
## Call a Chain's API endpoint
Once your Chain is deployed, you can call it via its API endpoint. Chains use
the same inference API as models:
* [Development endpoint](/api-reference/development-run-remote)
* [Production endpoint](/api-reference/production-run-remote)
* [🆕 Environment endpoint](/api-reference/environments-run-remote)
* [Endpoint by ID](/api-reference/deployment-run-remote)
Here's an example which calls the development deployment:
```python call_chain.py
import requests
import os
# From the Chain overview page on Baseten
# E.g. "https://chain-.api.baseten.co/development/run_remote"
CHAIN_URL = ""
baseten_api_key = os.environ["BASETEN_API_KEY"]
# JSON keys and types match the `run_remote` method signature.
data = {...}
resp = requests.post(
CHAIN_URL,
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json=data,
)
print(resp.json())
```
### How to pass chain input
The data schema of the inference request corresponds to the function
signature of [`run_remote()`](/chains/concepts#run-remote-chaining-chainlets)
in your entrypoint Chainlet.
For example, for the Hello Chain, `HelloAll.run_remote()`:
```python
def run_remote(self, names: list[str]) -> str:
```
You'd pass the following JSON payload:
```json
{"names": ["Marius", "Sid", "Bola"]}
```
I.e. the keys in the JSON record, match the argument names and types of
`run_remote.`
### Async chain inference
Like Truss models, Chains support async invocation. The [guide for
models](/invoke/async) applies largely - in particular for how to wrap the
input and set up the webhook to process results.
The following additional points are chains specific:
* Use chain-based URLS:
* `https://chain-{chain}.api.baseten.co/production/async_run_remote`
* `https://chain-{chain}.api.baseten.co/development/async_run_remote`
* `https://chain-{chain}.api.baseten.co/deployment/{deployment}/async_run_remote`.
* `https://chain-{chain}.api.baseten.co/environments/{env_name}/async_run_remote`.
* Only the entrypoint is invoked asynchronously. Internal Chainlet-Chainlet
calls are still run synchronously.
## Subclassing for code reuse
Sometimes you want to write one "main" implementation of a complicated inference
task, but then re-use it for similar variations. For example:
* Deploy it on different hardware and with different concurrency.
* Replace a dependency (e.g. silence detection in audio files) with a
different implementation of that step - while keeping all other processing
the same.
* Deploy the same inference flow, but exchange the model weights used. E.g. for
a large and small version of an LLM or different model weights fine-tuned to\
domains.
* Add an adapter to convert between a different input/output schema.
In all of those cases, you can create lightweight subclasses of your main
chainlet.
Below are some example code snippets - they can all be combined with each other!
### Example base class
```python
import truss_chains as chains
class Preprocess2x(chains.ChainletBase):
def run_remote(self, number: int) -> int:
return 2 * number
class MyBaseChainlet(chains.ChainletBase):
remote_config = chains.RemoteConfig(
compute=chains.Compute(cpu_count=1, memory="100Mi"),
options=chains.ChainletOptions(enable_b10_tracing=True),
)
def __init__(self, preprocess=chains.depends(Preprocess2x)):
self._preprocess = preprocess
def run_remote(self, number: int) -> float:
return 1.0 / self._preprocess.run_remote(number)
# Assert base behavior.
with chains.run_local():
chainlet = MyBaseChainlet()
assert chainlet.run_remote(4) == 1 / (4 * 2)
```
### Adapter for different I/O
The base class `MyBaseChainlet` works with integer inputs and returns floats. If you
want to reuse the computation, but provide an alternative interface (e.g.
for a different client with different request/response schema), you can create
a subclass which does the I/O conversion. The actual computation is delegated to
the base classes above.
```python
class ChainletStringIO(MyBaseChainlet):
def run_remote(self, number: str) -> str:
return str(super().run_remote(int(number)))
# Assert new behavior.
with chains.run_local():
chainlet_string_io = ChainletStringIO()
assert chainlet_string_io.run_remote("4") == "0.125"
```
### Chain with substituted dependency
The base class `MyBaseChainlet` uses preprocessing that doubles the input. If
you want to use a different variant of preprocessing - while keeping
`MyBaseChainlet.run_remote` and everything else as is - you can define a shallow
subclass of `MyBaseChainlet` where you use a different dependency `Preprocess8x`,
which multiplies by 8 instead of 2.
```python
class Preprocess8x(chains.ChainletBase):
def run_remote(self, number: int) -> int:
return 8 * number
class Chainlet8xPreprocess(MyBaseChainlet):
def __init__(self, preprocess=chains.depends(Preprocess8x)):
super().__init__(preprocess=preprocess)
# Assert new behavior.
with chains.run_local():
chainlet_8x_preprocess = Chainlet8xPreprocess()
assert chainlet_8x_preprocess.run_remote(4) == 1 / (4 * 8)
```
### Override remote config.
If you want to re-deploy a chain, but change some deployment options, e.g. run
on different hardware, you can create a subclass and override `remote_config`.
```python
class Chainlet16Core(MyBaseChainlet):
remote_config = chains.RemoteConfig(
compute=chains.Compute(cpu_count=16, memory="100Mi"),
options=chains.ChainletOptions(enable_b10_tracing=True),
)
```
Be aware that `remote_config` is a class variable. In the example above we
created a completely new `RemoteConfig` value, because changing fields
*inplace* would also affect the base class.
If you want to share config between the base class and subclasses, you can
define them in additional variables e.g. for the image:
```python
DOCKER_IMAGE = chains.DockerImage(pip_requirements=[...], ...)
class MyBaseChainlet(chains.ChainletBase):
remote_config = chains.RemoteConfig(docker_image=DOCKER_IMAGE, ...)
class Chainlet16Core(MyBaseChainlet):
remote_config = chains.RemoteConfig(docker_image=DOCKER_IMAGE, ...)
```
# Overview
Chains: A new DX for deploying multi-component ML workflows
Chains is in beta mode. Read our [launch blog post](https://www.baseten.co/blog/introducing-baseten-chains/).
Chains is a framework for building robust, performant multi-step and multi-model
inference pipelines and deploying them to production. It addresses the common
challenges of managing latency, cost and dependencies for complex workflows,
while leveraging Truss’ existing battle-tested performance, reliability and
developer toolkit.
## From model to system
Some models are actually pipelines (e.g. invoking a LLM involves sequentially
tokenizing the input, predicting the next token, and then decoding the predicted
tokens). These pipelines generally make sense to bundle together in a monolithic
deployment because they have the same dependencies, require the same compute
resources, and have a robust ecosystem of tooling to improve efficiency and
performance in a single deployment.
Many other pipelines and systems do not share these properties. Some examples
include:
* Running multiple different models in sequence.
* Chunking/partitioning a set of files and concatenating/organizing results.
* Pulling inputs from or saving outputs to a database or vector store.
Each step in these workflows has different hardware requirements, software
dependencies, and scaling needs so it doesn’t make sense to bundle them in a
monolithic deployment. That’s where Chains comes in!
## Six principles behind Chains
Chains exists to help you build multi-step, multi-model pipelines. The
abstractions that Chains introduces are based on six opinionated principles:
three for architecture and three for developer experience.
**Architecture principles**
Each step in the pipeline can set its own hardware requirements and
software dependencies, separating GPU and CPU workloads.
Each component has independent autoscaling parameters for targeted
resource allocation, removing bottlenecks from your pipelines.
Components specify a single public interface for flexible-but-safe
composition and are reusable between projects
**Developer experience principles**
Eliminate entire taxonomies of bugs by writing typed Python code and
validating inputs, outputs, module initializations, function signatures,
and even remote server configurations.
Seamless local testing and cloud deployments: test Chains locally with
support for mocking the output of any step and simplify your cloud
deployment loops by separating large model deployments from quick
updates to glue code.
Use Chains to orchestrate existing model deployments, like pre-packaged
models from Baseten’s model library, alongside new model pipelines built
entirely within Chains.
## Hello World with Chains
Here’s a simple Chain that says “hello” to each person in a list of provided
names:
```python hello_chain/hello.py
import asyncio
import truss_chains as chains
# This Chainlet does the work.
class SayHello(chains.ChainletBase):
async def run_remote(self, name: str) -> str:
return f"Hello, {name}"
# This Chainlet orchestrates the work.
@chains.mark_entrypoint
class HelloAll(chains.ChainletBase):
def __init__(self, say_hello_chainlet=chains.depends(SayHello)) -> None:
self._say_hello = say_hello_chainlet
async def run_remote(self, names: list[str]) -> str:
tasks = []
for name in names:
tasks.append(asyncio.ensure_future(
self._say_hello.run_remote(name)))
return "\n".join(await asyncio.gather(*tasks))
```
This is a toy example, but it shows how Chains can be used to separate
preprocessing steps like chunking from workload execution steps. If SayHello
were an LLM instead of a simple string template, we could do a much more complex
action for each person on the list.
## What to build with Chains
Connect to a vector databases and augment LLM results with additional
context information without introducing overhead to the model inference
step.
Try it yourself: [RAG Chain](/chains/examples/build-rag).
Transcribe large audio files by splitting them into smaller chunks and
processing them in parallel — we've used this approach to process 10-hour
files in minutes.
Try it yourself: [Audio Transcription Chain](/chains/examples/audio-transcription).
Build powerful experiences wit optimal scaling in each step like:
* AI phone calling (transcription + LLM + speech synthesis)
* Multi-step image generation (SDXL + LoRAs + ControlNets)
* Multimodal chat (LLM + vision + document parsing + audio)
Since each stage runs on its hardware with independent auto-scaling,
you chan achieve better hardware utilization and save costs.
Get started by
[building and deploying your first chain](/chains/getting-started).
# Autoscaling
Scale from internal testing to the top of Hacker News
Autoscaling lets you handle highly variable traffic while minimizing spend on idle compute resources.
## Autoscaling configuration
Autoscaling settings are configurable for each deployment of a model. New production deployments will inherit the autoscaling settings of the previous production deployment or be set to the default configuration if no prior production deployment exists.
Autoscaling settings can be configured in two ways:
* Manually, through the UI in your Baseten workspace.
* Programmatically, through the [autoscaling configuration management API endpoints](/api-reference/updates-a-production-deployments-autoscaling-settings).
### Min and max replicas
Every deployment can scale between a range of replicas:
* **Minimum count**: the deployment will not scale below this many active replicas.
* Lowest possible value: 0.
* Default value: 0.
* Highest possible value: the maximum replica count
* **Maximum count**: the deployment will not scale above this many active replicas.
* Lowest possible value: 1 or the minimum replica count, whichever is greater.
* Default value: 1.
* Highest possible value: 10 by default, contact us to unlock higher replica maximums.
When the model is first deployed, it will be deployed on one replica or the minimum number of replicas, whichever is greater. As it receives traffic, it will scale up to use additional replicas as necessary, up to the maximum replica count, then scale down to fewer replicas as traffic subsides.
### Autoscaler settings
There are three autoscaler settings:
* **Autoscaling window**: The timeframe of traffic considered for scaling replicas up and down. Default: 60 seconds.
* **Scale down delay**: The additional time the autoscaler waits before spinning down a replica. Default: 900 seconds (15 minutes).
* **Concurrency target**: The number of concurrent requests you want each replica to be responsible for handling. Default: 1 request.
Autoscaler settings aren't universal, but we generally recommend a shorter autoscaling window with a longer scale down delay to respond quickly to traffic spikes while maintaining capacity through variable traffic. This is reflected in the default values.
### Autoscaling in action
Here's how the autoscaler handles spikes in traffic without wasting money on unnecessary model resources:
* The autoscaler analyzes incoming traffic to your model. When the average number of requests divided by the number of active replicas exceeds the **concurrency target** for the duration of the **autoscaling window**, additional replicas are created until:
* The average requests divided by the number of active replicas drops below the **concurrency target**, or
* The **maximum count** of replicas is reached.
* When traffic dies down, fewer replicas are needed to stay below the **concurrency target**. When this has been true for the duration of the **autoscaling window**, excess replicas are marked for removal. The autoscaler waits for the **scale down delay** before gracefully spinning down any unneeded replicas. Replicas will not spin down if:
* Traffic picks back up during the **scale down delay**, or
* The deployment's **minimum count** of replicas is reached.
## Scale to zero
If you're just testing your model or anticipate light and inconsistent traffic, scale to zero can save you substantial amounts of money.
Scale to zero means that when a deployed model is not receiving traffic, it scales down to zero replicas. When the model is called, Baseten spins up a new instance to serve model requests.
To turn on scale to zero, just set a deployment's minimum replica count to zero. Scale to zero is enabled by default in the standard autoscaling config.
Models that have not received any traffic for more than two weeks will be automatically deactivated. These models will need to be activated manually before they can serve requests again.
## Cold starts
A "cold start" is the time it takes to spin up a new instance of a model server. Cold starts apply in two situations:
* When a model is scaled to zero and receives a request
* When the number of concurrent requests trigger the autoscaler to increase the number of active replicas
Cold starts are especially noticable for scaled-to-zero models as the time to process the first request includes the cold start time. Baseten has heavily invested in reducing cold start times for all models.
### Network accelerator
Baseten uses a network accelerator to speed up model loads from common model artifact stores, including HuggingFace, CloudFront, S3, and OpenAI. Our accelerator employs byte range downloads in the background to maximize the parallelism of downloads. This improves cold start times by reducing the amount of time it takes to load model weights and other required data.
### Cold start pods
To shorten cold start times, we spin up specifically designated pods to accelerate model loading that are not counted toward your ordinary model resources. You may see these pods in your logs and metrics.
Coldboost logs have `[Coldboost]` as a prefix to signify that a cold start pod is in use:
```md Example coldboost log line
Oct 09 9:20:25pm [Coldboost] Completed model.load() execution in 12650 ms
```
### Further optimizations
Read our [how-to guide for optimizing cold starts](/performance/cold-starts) to learn how you can edit your Truss and application to reduce the impact of cold starts.
## Autoscaling for development deployments
Autoscaling settings for development deployments are optimized for live reload workflows and a simplified testing setup. The standard configuration is:
* Minimum replicas: 0.
* Maximum replicas: 1.
* Autoscaling window: 60 seconds.
* Scale down delay: 900 seconds (15 minutes).
* Concurrency target: 1 request.
Development deployments cannot scale beyond 1 replica. To unlock full autoscaling for your deployment, promote it to production.
# Deployments and environments
Deployment lifecycle on Baseten
There are two special concepts related to models on Baseten: deployments and environments. Both have different features to match their role in the model lifecycle.
| Feature | Development deployment | Production environment | Custom environments |
| --------------------- | ----------------------- | ---------------------- | ------------------------- |
| API: Deployment ID | ☑️ | ☑️ | ☑️ |
| API: Model ID | - | ☑️ | - |
| Live reload | ☑️ | - | - |
| Scale to zero | ☑️ | ☑️ | ☑️ |
| Full autoscaling | - | ☑️ | ☑️ |
| Zero-downtime updates | - | ☑️ | ☑️ |
| Deactivate | ☑️ | ☑️ | ☑️ |
| Delete | ☑️ | - | ☑️ |
## What is a development deployment?
A development deployment is designed to make it easier for you to iterate on your model. As such, development deployments have three special properties:
* Development deployments have live reload so you can patch changes onto the model server while it runs.
* Development deployments don't have access to full autoscaling. They have a maximum of one replica and always scale to zero when not in use.
* Development deployments **do not** guarantee zero-downtime updates. A development deployment may be updated at any time, which may cause active requests to fail.
* Development deployments are always named *development* and cannot be renamed.
Live reload lets you use the Truss CLI to patch changes onto your running model server, rather than waiting for an entirely new deployment.
## What is an environment?
Environments encapsulate deployments, enabling you to manage your model’s release cycles. By providing a stable URL and autoscaling settings, environments allow you to create repeatable release processes for your model, ensuring its quality, stability, and scalability before it reaches end users.
Let’s say you’ve made some changes to a model, and you want to better understand the efficacy of its outputs without changing any behavior in your user-facing app. To take advantage of environments, you can create an environment with a custom name (e.g., "staging") and promote a candidate deployment to that environment. This deployment now receives any requests you make to the “staging” environment. Now, you can verify the quality of your changes before promoting the deployment to production.
Some common methods of verifying the quality of the deployment:
* Tests/Evals
* Manual testing in pre-production environment
* Canarying/Gradual rollout
* Shadow serving traffic
When promoting a deployment to an environment, including production, there are a few key differences:
* The environment uses the [environment-specific endpoint](/api-reference/environments-predict)
* The environment has full access to autoscaling settings.
* [Traffic ramp up](/deploy/lifecycle#canary-deployments) can be enabled on the environment.
* [Metrics can be exported](/observability/export-metrics/overview) for each environment.
A production environment is just like any other environment, with a couple differences:
* A production environment is designated for production use; you can't create additional custom environments with the name "production."
* A production environment cannot be deleted (unless you delete the entire model).
### Environments API
Each model's environment comes with its own:
* [Predict endpoint](/api-reference/environments-predict)
* [Async inference endpoint](/api-reference/environments-async-predict)
* Set of management endpoints for:
* [Promoting deployments](/api-reference/promotes-a-deployment-to-an-environment)
* [Activating deployments](/api-reference/activates-a-deployment-associated-with-an-environment)
* [Deactivating deployments](/api-reference/deactivates-a-deployment-associated-with-an-environment)
* [Updating settings](/api-reference/update-an-environments-settings)
## Promotion
Any deployment can be promoted to any environment, whether it is a development deployment, a published deployment, or a deployment that's already in an environment.
* Ensure that you have created an environment. The production environment will exist by default for every model.
* Deployments can be promoted from the UI or via the [REST API](/api-reference/promotes-a-deployment-to-an-environment).
### Promoting a deployment to an environment
Promoting the development deployment to an environment triggers a three-step process:
* A new deployment is created, with a new deployment ID and name.
* The new deployment is allocated resources and started up.
* Once active, the new deployment becomes associated with the environment, replacing any previous deployment.
* If there was no previous deployment, the new deployment is created with [standard autoscaling settings](/deploy/autoscaling).
* If there was a previous deployment, the new deployment is created with the same autoscaling settings as the previous deployment. The previous deployment is demoted but keeps its ID, autoscaling settings, and is by default scaled to zero.
Promoting the development deployment to an environment does not change the development deployment's ID, autoscaling settings, or activity status. You can continue to iterate on the development deployment with live reload.
### Promoting another published deployment to an environment
When promoting an already published deployment to an environment, keep the following in mind:
* The published deployment's autoscaling settings will be updated to match the previous deployment in the environment.
* If the deployment is inactive, you must activate it and wait for it to start up before promoting it.
The previous deployment is demoted and joins other deployments in the deployment list, but keeps its deployment ID, autoscaling settings, and is by default scaled to zero.
### Deploying directly to an environment
You can deploy a model directly to an environment, skipping the development stage and starting a promotion to any existing environment, by adding `--environment` to `truss push`:
```sh
cd my_model/
truss push --environment {environment_name}
```
There can only be one active promotion per environment at any given time.
### Canary deployments
Canary deployments allow you to ramp up traffic to existing environments, ensuring a smooth transition with minimal disruption to ongoing traffic.
Once this is enabled and a new deployment is promoted, traffic is shifted in 10 evenly-spaced stages over a configurable time window.
This gradual shift allows the deployment to scale in response to real-time demand, guided by your autoscaling settings, thus maintaining stability for existing users.
The traffic ramp-up can be enabled via the UI or [REST API](/api-reference/update-an-environments-settings).
If you cancel it, incoming traffic will revert to your existing deployment.
![UI for traffic ramp up (canary deployments)](https://mintlify.s3.us-west-1.amazonaws.com/baseten-preview/deploy/images/ramp.png)
Check out our [launch blog](https://www.baseten.co/blog/canary-deployments-on-baseten/) for more information.
### Accessing environments in your code
You can access the environment name from the `environment` keyword argument in the [`init` function of your `model.py`](../truss/guides/environments).
In order to guarantee that this variable is up to date, you should set the "Re-deploy when promoting" option for your environment. This can be done via the UI or [REST API](/api-reference/update-an-environments-settings).
When this option is enabled, all promotions to the environment will trigger a re-deployment. The image of the source deployment will be copied and deployed to the environment.
## Deactivating deployments
Any active deployment, including those in environments, can be deactivated.
* Deactivated deployments remain visible in the model dashboard.
* Deactivated deployments do not consume model resources.
* Requests to a deactivated deployment's endpoints will result in a 404 error.
* A deactivated deployment can be manually activated at any time from the model dashboard.
If you're regularly activating or deactivating deployments in response to traffic, consider using [autoscaling with scale to zero](/deploy/autoscaling) instead.
## Deleting deployments and environments
Any deployment and environment of a model can be deleted except for production. To delete the production deployment, first promote a different deployment to production (or delete the entire model).
* Deleted deployments are removed from the model dashboard but will appear in your billing and usage dashboard.
* Deleted deployments do not consume model resources.
* Requests to a deleted deployment's endpoints will result in a 404 error.
* Deleting a deployment is a permanent action and cannot be undone.
If you aren't completely certain about deleting a deployment, consider deactivating it instead until you're ready to delete.
# Setting GPU resources
Serve your model on the right instance type
Every ML model served on Baseten runs on an "instance," meaning a set of hardware allocated to the model server. You can choose how much hardware is allocated to your model server by choosing an instance type for your model.
Picking the right model server instance is all about making smart tradeoffs:
* If you don't allocate enough resources, model deployment and inference will be slow or may fail altogether.
* Picking an instance that's too big leaves you paying for unnecessary overhead.
This document will help you navigate that tradeoff and pick the appropriate instance when deploying your model.
## Glossary
* **Instance**: a fixed set of hardware for running ML model inference.
* **vCPU**: virtual CPU cores for general computing tasks.
* **RAM**: memory for the CPU.
* **GPU**: the graphics card for ML inference tasks.
* **VRAM**: specialized memory attached to the GPU.
## Setting model server resources
There are two ways to specify model server resources:
* Before initial deployment in your Truss.
* After initial deployment in the Baseten UI.
### In your Truss
You can specify resources in your Truss. You must configure these resources **before** running `truss push` on the Truss for the first time; any changes to the resources field after the first deployment will be ignored.
Here's an example for Stable Diffusion XL:
```yaml config.yaml
resources:
accelerator: A10G
cpu: "4"
memory: 16Gi
use_gpu: true
```
On deployment, your model will be assigned the smallest and cheapest available instance type that satisfies the resource constraints. For example, for `resources.cpu`, a Truss that specifies `"3"` or `"4"` will be assigned a 4-core instance, while specifying `"5"`, `"6"`, `"7"`, or `"8"` will yield an 8-core instance.
The `Gi` for `resources.memory` stands for Gibibytes, which are slightly larger than Gigabytes.
### In the model dashboard
After the model has been deployed, the only way to update the instance type it uses is in the model dashboard on Baseten.
For more information on picking the right model resources, see the [instance type reference](/performance/instances).
# Troubleshooting
Fixing common problems during model deployment
## Issue: `truss push` can't find `config.yaml`
```sh
[Errno 2] No such file or directory: '/Users/philipkiely/Code/demo_docs/config.yaml'
```
### Fix: set correct target directory
The directory `truss push` is looking at is not a Truss. Make sure you're giving `truss push` access to the correct directory by:
* Running `truss push` from the directory containing the Truss. You should see the file `config.yaml` when you run `ls` in your working directory.
* Or passing the target directory as an argument, such as `truss push /path/to/my-truss`.
## Issue: unexpected failure during model build
During the model build step, there can be unexpected failures from temporary circumstances. An example is a network error while downloading model weights from Hugging Face or installing a Python package from PyPi.
### Fix: restart deploy from Baseten UI
First, check your model logs to determine the exact cause of the error. If it's an error during model download, package installation, or similar, you can try restarting the deploy from the model dashboard in your workspace.
# Async inference user guide
Run asynchronous inference on deployed models
Async requests are a "fire and forget" way of executing model inference requests. Instead of waiting for a response from a model, making an async request queues the request, and immediately returns with a request identifier. Optionally, async request results are sent via a `POST` request to a user-defined webhook upon completion.
Use async requests for:
* Long-running inference tasks that may otherwise hit request timeouts.
* Batched inference jobs.
* Prioritizing certain inference requests.
Async fast facts:
* Async requests can be made to any model—**no model code changes necessary**.
* Async requests can remain queued for up to 72 hours and run for up to 1 hour.
* Async requests are **not** compatible with streaming model output.
* Async request inputs and model outputs are **not** stored after an async request has been completed. Instead, model outputs will be sent to your webhook via a `POST` request.
## Quick start
There are two ways to use async inference:
1. Provide a webhook endpoint where model outputs will be sent via a `POST` request. If providing a webhook, you can **use async inference on any model, without making any changes to your model code**.
2. Inside your Truss' `model.py`, save prediction results to cloud storage. If a webhook endpoint is provided, your model outputs will also be sent to your webhook.
Note that Baseten **does not** store model outputs. If you do not wish to use a webhook, your `model.py` must write model outputs to a cloud storage bucket or database as part of its implementation.
Set up a webhook endpoint for handling completed async requests. Since Baseten doesn't store model outputs, model outputs from async requests will be sent to your webhook endpoint.
Before creating your first async request, try running a sample request against your webhook endpoint to ensure that it can consume async predict results properly. Check out [this example webhook test](https://replit.com/@baseten-team/Baseten-Async-Inference-Starter-Code#test_webhook.py).
We recommend using [this Repl](https://replit.com/@baseten-team/Baseten-Async-Inference-Starter-Code) as a starting point.
Call `/async_predict` on your model. The body of an `/async_predict` request includes the model input in `model_input` field, with the addition of a webhook endpoint (from the previous step) in the `webhook_endpoint` field.
```py Python
import requests
import os
model_id = "" # Replace this with your model ID
webhook_endpoint = "" # Replace this with your webhook endpoint URL
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
# Call the async_predict endpoint of the production deployment
resp = requests.post(
f"https://model-{model_id}.api.baseten.co/production/async_predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={
"model_input": {"prompt": "hello world!"},
"webhook_endpoint": webhook_endpoint
# Optional fields for priority, max_time_in_queue_seconds, etc
},
)
print(resp.json())
```
Save the `request_id` from the `/async_predict` response to check its status or cancel it.
```json 201
{
"request_id": "9876543210abcdef1234567890fedcba"
}
```
See the [async inference API reference](/api-reference/production-async-predict) for more endpoint details.
Using the `request_id` saved from the previous step, check the status of your async predict request:
```py Python
import requests
import os
model_id = ""
request_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = requests.get(
f"https://model-{model_id}.api.baseten.co/async_request/{request_id}",
headers={"Authorization": f"Api-Key {baseten_api_key}"}
)
print(resp.json())
```
Once your model has finished executing the request, the async predict result will be sent to your webhook in a `POST` request.
```json
{
"request_id": "9876543210abcdef1234567890fedcba",
"model_id": "my_model_id",
"deployment_id": "my_deployment_id",
"type": "async_request_completed",
"time": "2024-04-30T01:01:08.883423Z",
"data": {
"my_model_output": "hello world!"
},
"errors": []
}
```
We strongly recommend securing the requests sent to your webhooks to validate that they are from Baseten.
For instructions, see our [guide to securing async requests](/invoke/async-secure).
Update your Truss's `model.py` to save prediction results to cloud storage, such as S3 or GCS. We recommend implementing this in your model's `postprocess()` method, which will run on CPU after the prediction has completed.
Optionally, set up a webhook endpoint so Baseten can notify you when your async request completes.
Before creating your first async request, try running a sample request against your webhook endpoint to ensure that it can consume async predict results properly. Check out [this example webhook test](https://replit.com/@baseten-team/Baseten-Async-Inference-Starter-Code#test_webhook.py).
We recommend using [this Repl](https://replit.com/@baseten-team/Baseten-Async-Inference-Starter-Code) as a starting point.
Call `/async_predict` on your model. The body of an `/async_predict` request includes the model input in `model_input` field, with the addition of a webhook endpoint (from the previous step) in the `webhook_endpoint` field.
```py Python
import requests
import os
model_id = "" # Replace this with your model ID
webhook_endpoint = "" # Replace this with your webhook endpoint URL
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
# Call the async_predict endpoint of the production deployment
resp = requests.post(
f"https://model-{model_id}.api.baseten.co/production/async_predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={
"model_input": {"prompt": "hello world!"},
"webhook_endpoint": webhook_endpoint
# Optional fields for priority, max_time_in_queue_seconds, etc
},
)
print(resp.json())
```
Save the `request_id` from the `/async_predict` response to check its status or cancel it.
```json 201
{
"request_id": "9876543210abcdef1234567890fedcba"
}
```
See the [async inference API reference](/api-reference/production-async-predict) for more endpoint details.
Using the `request_id` saved from the previous step, check the status of your async predict request:
```py Python
import requests
import os
model_id = ""
request_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = requests.get(
f"https://model-{model_id}.api.baseten.co/async_request/{request_id}",
headers={"Authorization": f"Api-Key {baseten_api_key}"}
)
print(resp.json())
```
Once your model has finished executing the request, the async predict result will be sent to your webhook in a `POST` request.
```json
{
"request_id": "9876543210abcdef1234567890fedcba",
"model_id": "my_model_id",
"deployment_id": "my_deployment_id",
"type": "async_request_completed",
"time": "2024-04-30T01:01:08.883423Z",
"data": {
"my_model_output": "hello world!"
},
"errors": []
}
```
We strongly recommend securing the requests sent to your webhooks to validate that they are from Baseten.
For instructions, see our [guide to securing async requests](/invoke/async-secure).
**Chains**: this guide is written for Truss models, but
[Chains](/chains/overview) support async inference likewise. An Chain
entrypoint can be invoked via its `async_run_remote` endpoint, e.g.
`https://chain-{chain_id}.api.baseten.co/production/run_run_remote`. The
internal Chainlet-Chainlet call will still run synchronously.
## User guide
### Configuring the webhook endpoint
Configure your webhook endpoint to handle `POST` requests with [async predict results](/invoke/async#processing-async-predict-results). We require that webhook endpoints use HTTPS.
We recommend running a sample request against your webhook endpoint to ensure that it can consume async predict results properly. Try running [this webhook test](https://replit.com/@baseten-team/Baseten-Async-Inference-Starter-Code#test_webhook.py).
For local development, we recommend using [this Repl](https://replit.com/@baseten-team/Baseten-Async-Inference-Starter-Code) as a starting point. This code validates the webhook request and logs the payload.
### Making async requests
```py Python
import requests
import os
model_id = "" # Replace this with your model ID
webhook_endpoint = "" # Replace this with your webhook endpoint URL
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
# Call the async_predict endpoint of the production deployment
resp = requests.post(
f"https://model-{model_id}.api.baseten.co/production/async_predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={
"model_input": {"prompt": "hello world!"},
"webhook_endpoint": webhook_endpoint
# Optional fields for priority, max_time_in_queue_seconds, etc
},
)
print(resp.json())
```
```py Python
import requests
import os
model_id = "" # Replace this with your model ID
webhook_endpoint = "" # Replace this with your webhook endpoint URL
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
# Call the async_predict endpoint of the development deployment
resp = requests.post(
f"https://model-{model_id}.api.baseten.co/development/async_predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={
"model_input": {"prompt": "hello world!"},
"webhook_endpoint": webhook_endpoint
# Optional fields for priority, max_time_in_queue_seconds, etc
},
)
print(resp.json())
```
```py Python
import requests
import os
model_id = "" # Replace this with your model ID
deployment_id = "" # Replace this with your deployment ID
webhook_endpoint = "" # Replace this with your webhook endpoint URL
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
# Call the async_predict endpoint of the given deployment
resp = requests.post(
f"https://model-{model_id}.api.baseten.co/deployment/{deployment_id}/async_predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={
"model_input": {"prompt": "hello world!"},
"webhook_endpoint": webhook_endpoint
# Optional fields for priority, max_time_in_queue_seconds, etc
},
)
print(resp.json())
```
Create an async request by calling a model's `/async_predict` endpoint. See the [async inference API reference](/api-reference/production-async-predict) for more endpoint details.
### Getting and canceling async requests
You may get the status of an async request for up to 1 hour after the request has been completed.
```py Python
import requests
import os
model_id = ""
request_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = requests.get(
f"https://model-{model_id}.api.baseten.co/async_request/{request_id}",
headers={"Authorization": f"Api-Key {baseten_api_key}"}
)
print(resp.json())
```
```py Python
import requests
import os
model_id = ""
request_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = requests.delete(
f"https://model-{model_id}.api.baseten.co/async_request/{request_id}",
headers={"Authorization": f"Api-Key {baseten_api_key}"}
)
print(resp.json())
```
Manage async requests using the [get async request API endpoint](/api-reference/get-async-request-status) and the [cancel async request API endpoint](/api-reference/cancel-async-request).
### Processing async predict results
Baseten does not store async predict results. Ensure that prediction outputs are either processed by your webhook, or saved to cloud storage in your model code (for example, in your model's `postprocess` method).
If a webhook endpoint was provided in the `/async_predict` request, the async predict results will be sent in a `POST` request to the webhook endpoint. Errors in executing the async prediction will be included in the `errors` field of the async predict result.
Async predict result schema:
* `request_id` (string): the ID of the completed async request. This matches the `request_id` field of the `/async_predict` response.
* `model_id` (string): the ID of the model that executed the request
* `deployment_id` (string): the ID of the deployment that executed the request
* `type` (string): the type of the async predict result. This will always be `"async_request_completed"`, even in error cases.
* `time` (datetime): the time in UTC at which the request was sent to the webhook
* `data` (dict or string): the prediction output
* `errors` (list): any errors that occurred in processing the async request
Example async predict result:
```json
{
"request_id": "9876543210abcdef1234567890fedcba",
"model_id": "my_model_id",
"deployment_id": "my_deployment_id",
"type": "async_request_completed",
"time": "2024-04-30T01:01:08.883423Z",
"data": {
"my_model_output": "hello world!"
},
"errors": []
}
```
## Observability
Metrics for async request execution are available on the [Metrics tab](/observability/metrics#time-in-async-queue) of your model dashboard.
* Async requests are included in inference latency and volume metrics.
* A time in async queue chart displays the time an async predict request spent in the `QUEUED` state before getting processed by the model.
* A async queue size chart displays the current number of queued async predict requests.
# Securing async inference
Secure the asynchronous inference results sent to your webhook
Since async predict results are sent to a webhook available to anyone over the internet with the endpoint, you'll want to have some verification that these results sent to the webhook are actually coming from Baseten.
We recommend leveraging webhook signatures to secure webhook payloads and ensure they are from Baseten.
This is a two-step process:
1. Create a webhook secret.
2. Validate a webhook signature sent as a header along with the webhook request payload.
## Creating webhook secrets
Webhook secrets can be generated via the [Secrets tab](https://app.baseten.co/settings/secrets).
A webhook secret looks like:
```
whsec_AbCdEf123456GhIjKlMnOpQrStUvWxYz12345678
```
Ensure this webhook secret is saved securely. It can be viewed at any time and [rotated if necessary](/invoke/async-secure#keeping-webhook-secrets-secure) in the Secrets tab.
## Validating webhook signatures
If a webhook secret exists, Baseten will include a webhook signature in the `"X-BASETEN-SIGNATURE"` header of the webhook request so you can verify that it is coming from Baseten.
A Baseten signature header looks like:
`"X-BASETEN-SIGNATURE": "v1=signature"`
Where `signature` is an [HMAC](https://docs.python.org/3.12/library/hmac.html#module-hmac) generated using a [SHA-256](https://en.wikipedia.org/wiki/SHA-2) hash function calculated over the whole async predict result and signed using a webhook secret.
If multiple webhook secrets are active, a signature will be generated using each webhook secret. In the example below, the newer webhook secret was used to create `newsignature` and the older (soon to expire) webhook secret was used to create `oldsignature`.
`"X-BASETEN-SIGNATURE": "v1=newsignature,v1=oldsignature"`
To validate a Baseten signature, we recommend the following. A full Baseten signature validation example can be found in [this Repl](https://replit.com/@baseten-team/Baseten-Async-Inference-Starter-Code#validation.py).
Compare the async predict result timestamp with the current time and decide if it was received within an acceptable tolerance window.
```python
TIMESTAMP_TOLERANCE_SECONDS = 300
# Check timestamp in async predict result against current time to ensure its within our tolerance
if (datetime.now(timezone.utc) -
async_predict_result.time).total_seconds() > TIMESTAMP_TOLERANCE_SECONDS:
logging.error(
f"Async predict result was received after {TIMESTAMP_TOLERANCE_SECONDS} seconds and is considered stale, Baseten signature was not validated."
)
```
Recreate the Baseten signature using webhook secret(s) and the async predict result.
```python
WEBHOOK_SECRETS = [] # Add your webhook secrets here
async_predict_result_json = async_predict_result.model_dump_json()
# We recompute expected Baseten signatures with each webhook secret
for webhook_secret in WEBHOOK_SECRETS:
for actual_signature in baseten_signature.replace("v1=", "").split(","):
expected_signature = hmac.digest(
webhook_secret.encode("utf-8"),
async_predict_result_json.encode("utf-8"),
hashlib.sha256,
).hex()
```
Compare the expected Baseten signature with the actual computed signature using [`compare_digest`](https://docs.python.org/3/library/hmac.html#hmac.compare_digest), which will return a boolean representing whether the signatures are indeed the same.
```python
hmac.compare_digest(expected_signature, actual_signature)
```
## Keeping webhook secrets secure
We recommend periodically rotating webhook secrets.
In the event that a webhook secret is exposed, you're able to rotate or remove it.
Rotating a secret in the UI will set the existing webhook secret to expire in 24 hours, and generate a new webhook secret. During this period, Baseten will include multiple signatures in the signature headers.
Removing webhook secrets could cause your signature validation to fail. Recreate a webhook secret after deleting and ensure your signature validation code is up to date with the new webhook secret.
# How to parse base64 output
Decode and save model output
Text-to-image and text-to-audio models like [Stable Diffusion XL](https://www.baseten.co/library/stable-diffusion-xl) and [MusicGen](https://www.baseten.co/library/musicgen-large) return the image or audio they create as base64-encoded strings, which then need to be parsed and saved as files. This guide provides examples for working with base64 output from these models.
## Example: Parsing Stable Diffusion output into a file
To follow this example, deploy [Stable Diffusion XL from the model library](https://www.baseten.co/library/stable-diffusion-xl).
### Python invocation
In this example, we'll use a Python script to call the model and parse the response.
```python call.py
import urllib3
import base64
import os, sys
# Model ID for production deployment
model_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
# Call the model
resp = urllib3.request(
"POST",
# Endpoint for production deployment, see API reference for more
f"https://model-{model_id}.api.baseten.co/production/predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={"prompt": "A tree in a field under the night sky"}
)
image = resp.json()["data"]
# Decode image from base64 model output
img=base64.b64decode(image)
# Give file random name using base64 string
file_name = f'{image[-10:].replace("/", "")}.jpeg'
# Save image to file
img_file = open(file_name, 'wb')
img_file.write(img)
img_file.close()
```
### Truss CLI invocation
You can also use the Truss CLI and pipe the results into a similar Python script.
Command line:
```sh
truss predict -d '{"prompt": "A tree in a field under the night sky"}' | python save.py
```
Script:
```python save.py
import json
import base64
import sys
# Read piped input from truss predict
resp = json.loads(sys.stdin.read())
image = resp["data"]
# Decode image from base64 model output
img=base64.b64decode(image)
# Give file random name using base64 string
file_name = f'{image[-10:].replace("/", "")}.jpeg'
# Save image to file
img_file = open(file_name, 'wb')
img_file.write(img)
img_file.close()
```
## Example: Parsing MusicGen output into multiple files
To follow this example, deploy [MusicGen from the model library](https://www.baseten.co/library/musicgen-large).
### Python invocation
In this example, we'll use a Python script to call the model and parse the response.
```python call.py
import urllib3
import base64
import os, sys
# Model ID for production deployment
model_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
# Call the model
resp = urllib3.request(
"POST",
# Endpoint for production deployment, see API reference for more
f"https://model-{model_id}.api.baseten.co/production/predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={"prompts": ["happy rock", "energetic EDM", "sad jazz"], "duration": 8}
)
clips = resp.json()["data"]
# Decode clips from base64 and save output to files
for idx, clip in enumerate(clips):
with open(f"clip_{idx}.wav", "wb") as f:
f.write(base64.b64decode(clip))
```
### Truss CLI invocation
You can also use the Truss CLI and pipe the results into a similar Python script.
Command line:
```sh
truss predict -d '{"prompts": ["happy rock", "energetic EDM", "sad jazz"], "duration": 8}' | python save.py
```
Script:
```python save.py
import json
import base64
import sys
# Read piped input from truss predict
resp = json.loads(sys.stdin.read())
clips = resp["data"]
# Decode clips from base64 and save output to files
for idx, clip in enumerate(clips):
with open(f"clip_{idx}.wav", "wb") as f:
f.write(base64.b64decode(clip))
```
# How to do model I/O in binary
Decode and save binary model output
Baseten and Truss natively support model I/O in binary and use msgpack encoding for efficiency.
## Deploy a basic Truss for binary I/O
If you need a deployed model to try the invocation examples below, follow these steps to create and deploy a super basic Truss that accepts and returns binary data. The Truss performs no operations and is purely illustrative.
To create a Truss, run:
```sh
truss init binary_test
```
This creates a Truss in a new directory `binary_test`. By default, newly created Trusses implement an identity function that returns the exact input they are given.
Optionally, modify `binary_test/model/model.py` to log that the data received is of type `bytes`:
```python binary_test/model/model.py
def predict(self, model_input):
# Run model inference here
print(f"Input type: {type(model_input['byte_data'])}")
return model_input
```
Deploy the Truss to Baseten with:
```sh
truss push
```
## Send raw bytes as model input
To send binary data as model input:
1. Set the `content-type` HTTP header to `application/octet-stream`
2. Use `msgpack` to encode the data or file
3. Make a POST request to the model
This code sample assumes you have a file `Gettysburg.mp3` in the current working directory. You can download the [11-second file from our CDN](https://cdn.baseten.co/docs/production/Gettysburg.mp3) or replace it with your own file.
```python call_model.py
import os
import requests
import msgpack
model_id = "MODEL_ID" # Replace with your model ID
deployment = "development" # `development`, `production`, or a deployment ID
baseten_api_key = os.environ["BASETEN_API_KEY"]
# Specify the URL to which you want to send the POST request
url = f"https://model-{model_id}.api.baseten.co/{deployment}/predict"
headers={
"Authorization": f"Api-Key {baseten_api_key}",
"content-type": "application/octet-stream",
}
with open('Gettysburg.mp3', 'rb') as file:
response = requests.post(
url,
headers=headers,
data=msgpack.packb({'byte_data': file.read()})
)
print(response.status_code)
print(response.headers)
```
To support certain types like numpy and datetime values, you may need to extend client-side `msgpack` encoding with the same [encoder and decoder used by Truss](https://github.com/basetenlabs/truss/blob/main/truss/templates/shared/serialization.py).
## Parse raw bytes from model output
To use the output of a non-streaming model response, decode the response content.
```python call_model.py
# Continues `call_model.py` from above
binary_output = msgpack.unpackb(response.content)
# Change extension if not working with mp3 data
with open('output.mp3', 'wb') as file:
file.write(binary_output["byte_data"])
```
## Streaming binary outputs
You can also stream output as binary. This is useful for sending large files or reading binary output as it is generated.
In the `model.py`, you must create a streaming output.
```python model/model.py
# Replace the predict function in your Truss
def predict(self, model_input):
import os
current_dir = os.path.dirname(__file__)
file_path = os.path.join(current_dir, "tmpfile.txt")
with open(file_path, mode="wb") as file:
file.write(bytes(model_input["text"], encoding="utf-8"))
def iterfile():
# Get the directory of the current file
current_dir = os.path.dirname(__file__)
# Construct the full path to the .wav file
file_path = os.path.join(current_dir, "tmpfile.txt")
with open(file_path, mode="rb") as file_like:
yield from file_like
return iterfile()
```
Then, in your client, you can use streaming output directly without decoding.
```python stream_model.py
import os
import requests
import json
model_id = "MODEL_ID" # Replace with your model ID
deployment = "development" # `development`, `production`, or a deployment ID
baseten_api_key = os.environ["BASETEN_API_KEY"]
# Specify the URL to which you want to send the POST request
url = f"https://model-{model_id}.api.baseten.co/{deployment}/predict"
headers={
"Authorization": f"Api-Key {baseten_api_key}",
}
s = requests.Session()
with s.post(
# Endpoint for production deployment, see API reference for more
f"https://model-{model_id}.api.baseten.co/{deployment}/predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
data=json.dumps({"text": "Lorem Ipsum"}),
# Include stream=True as an argument so the requests libray knows to stream
stream=True,
) as response:
for token in response.iter_content(1):
print(token) # Prints bytes
```
# How to do model I/O with files
Call models by passing a file or URL
Baseten supports a wide variety of file-based I/O approaches. These examples show our recommendations for working with files during model inference, whether local or remote, public or private, in the Truss or in your invocation code.
## Files as input
### Example: Send a file with JSON-serializable content
The Truss CLI has a `-f` flag to pass file input. If you're using the API endpoint via Python, get file contents with the standard `f.read()` function.
```sh Truss CLI
truss predict -f input.json
```
```python Python script
import urllib3
import json
model_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
# Read input as JSON
with open("input.json", "r") as f:
data = json.loads(f.read())
resp = urllib3.request(
"POST",
# Endpoint for production deployment, see API reference for more
f"https://model-{model_id}.api.baseten.co/production/predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json=data
)
print(resp.json())
```
### Example: Send a file with non-serializable content
The `-f` flag for `truss predict` only applies to JSON-serializable content. For other files, like the audio files required by [MusicGen Melody](https://www.baseten.co/library/musicgen-melody), the file content needs to be base64 encoded before it is sent.
```python
import urllib3
model_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
# Open a local file
with open("mymelody.wav", "rb") as f: # mono wav file, 48khz sample rate
# Convert file contents into JSON-serializable format
encoded_data = base64.b64encode(f.read())
encoded_str = encoded_data.decode("utf-8")
# Define the data payload
data = {"prompts": ["happy rock", "energetic EDM", "sad jazz"], "melody": encoded_str, "duration": 8}
# Make the POST request
response = requests.post(url, headers=headers, data=data)
resp = urllib3.request(
"POST",
# Endpoint for production deployment, see API reference for more
f"https://model-{model_id}.api.baseten.co/production/predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json=data
)
data = resp.json()["data"]
# Save output to files
for idx, clip in enumerate(data):
with open(f"clip_{idx}.wav", "wb") as f:
f.write(base64.b64decode(clip))
```
### Example: Send a URL to a public file
Rather than encoding and serializing a file to send in the HTTP request, you can instead write a Truss that takes a URL as input and loads the content in the `preprocess()` function.
Here's an example from [Whisper in the model library](https://www.baseten.co/library/whisper-v3).
```python
from tempfile import NamedTemporaryFile
import requests
# Get file content without blocking GPU
def preprocess(self, request):
resp = requests.get(request["url"])
return {"content": resp.content}
# Use file content in model inference
def predict(self, model_input):
with NamedTemporaryFile() as fp:
fp.write(model_input["content"])
result = whisper.transcribe(
self._model,
fp.name,
temperature=0,
best_of=5,
beam_size=5,
)
segments = [
{"start": r["start"], "end": r["end"], "text": r["text"]}
for r in result["segments"]
]
return {
"language": whisper.tokenizer.LANGUAGES[result["language"]],
"segments": segments,
"text": result["text"],
}
```
## Files as output
### Example: Save model output to local file
When saving model output to a local file, there's nothing Baseten-specific about the code. Just use the standard `>` operator in bash or `file.write()` function in Python to save the model output.
```sh Truss CLI
truss predict -d '"Model input!"' > output.json
```
```python Python script
import urllib3
import json
model_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
# Call model
resp = urllib3.request(
"POST",
# Endpoint for production deployment, see API reference for more
f"https://model-{model_id}.api.baseten.co/production/predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json=json.dumps("Model input!")
)
# Write results to file
with open("output.json", "w") as f:
f.write(resp.json())
```
Output for some models, like image and audio generation models, may need to be decoded before you save it. See [how to parse base64 output](/invoke/base64) for detailed examples.
{/*
### Example: Save model output to remote file
TODO: an example using a post-process function to save output to a file and upload it to a service
## Working with private files
TODO: Explain how to use secrets to securely authenticate with file hosts.
TODO: an example that modifies the pre-process function in Whisper Truss to read a file stored in an S3 bucket, secured by secrets. */}
# Function calling (tool use)
Use an LLM to select amongst provided tools
Function calling requires an LLM deployed using the [TensorRT-LLM Engine Builder](/performance/engine-builder-overview).
If you want to try this function calling example code for yourself, deploy [this implementation of Llama 3.1 8B](/performance/examples/llama-trt).
To use function calling:
1. Define a set of functions/tools in Python.
2. Pass the function set to the LLM with the `tools` argument.
3. Receive selected function(s) as output.
With function calling, it's essential to understand that **the LLM itself is not capable of executing the code in the function**. Instead, the LLM is used to suggest appropriate function(s), if they exist, based on the prompt. Any code execution must be handled outside of the LLM call – a great use for [chains](/chains/overview).
## Define functions in Python
Functions can be anything: API calls, ORM access, SQL queries, or just a script. It's essential that functions are well-documented; the LLM relies on the docstrings to select the correct function.
As a simple example, consider the four basic functions of a calculator:
```python
def multiply(a: float, b: float):
"""
A function that multiplies two numbers
Args:
a: The first number to multiply
b: The second number to multiply
"""
return a * b
def divide(a: float, b: float):
"""
A function that divides two numbers
Args:
a: The dividend
b: The divisor
"""
return a / b
def add(a: float, b: float):
"""
A function that adds two numbers
Args:
a: The first number
b: The second number
"""
return a + b
def subtract(a: float, b: float):
"""
A function that subtracts two numbers
Args:
a: The number to subtract from
b: The number to subtract
"""
return a - b
```
These functions must be serialized into LLM-accessible tools:
```python
from transformers.utils import get_json_schema
calculator_functions = {
'multiply': multiply,
'divide': divide,
'add': add,
'subtract': subtract
}
tools = [get_json_schema(f) for f in calculator_functions.values()]
```
## Pass functions to the LLM
The input spec for models like Llama 3.1 includes a `tools` key that we use to pass the functions:
```python
import json
import requests
payload = {
"messages": [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "What is 3.14+3.14?"},
],
"tools": tools, # tools are provided in the same format as OpenAI's API
"tool_choice": "auto", # auto is default - the model will choose whether or not to make a function call
}
MODEL_ID = ""
BASETEN_API_KEY = ""
resp = requests.post(
f"https://model-{MODEL_ID}.api.baseten.co/production/predict",
headers={"Authorization": f"Api-Key {BASETEN_API_KEY}"},
json=payload,
)
```
### tool\_choice: auto (default) – may return a function
The default `tool_choice` option, `auto`, leaves it up to the LLM whether to return one function, multiple functions, or no functions at all, depending on what the model feels is most appropriate based on the prompt.
### tool\_choice: required – will always return a function
The `required` option for `tool_choice` means that the LLM is guaranteed to chose at least one function, no matter what.
### tool\_choice: none – will always return a function
The `none` option for `tool_choice` means that the LLM will **not** return a function, and will instead produce ordinary text output. This is useful when you want to provide the full context of a conversation without adding and dropping the `tools` parameter call-by-call.
### tool\_choice: direct – will return a specified function
You can also pass a specific function directly into the call, which is guaranteed to be returned. This is useful if you want to hardcode specific behavior into your model call for testing or conditional execution.
```python
"tool_choice": {"type": "function", "function": {"name": "subtract"}}
```
## Receive function(s) as output
When the model returns functions, they'll be a list that can be parsed as follows:
```python
func_calls = json.loads(resp.text)
# In this example, we execute the first function (one of +-/*) on the provided parameters
func_call = func_calls[0]
calculator_functions[func_call["name"]](**func_call["parameters"])
```
After reading the LLM's selection, your execution environment can run the necessary functions. For more on combining LLMs with other logic, see the [chains documentation](/chains/overview).
# Baseten model integrations
Use your Baseten models with tools like LangChain
Build your own open-source ChatGPT with Baseten and Chainlit.
Use your Baseten models in the LangChain ecosystem.
Use your Baseten models in LiteLLM projects.
Build an AI-powered Twilio SMS chatbot with a Baseten-hosted LLM.
Want to integrate Baseten with your platform or project? Reach out to [support@baseten.co](mailto:support@baseten.co) and we'll help with building and marketing the integration.
# How to call your model
Run inference on deployed models
Once you've deployed your model, it's time to use it! Every model on Baseten is served behind an API endpoint. To call a model, you need:
* The model's ID.
* An [API key](https://app.baseten.co/settings/account/api_keys) for your Baseten account.
* JSON-serializable model input.
You can call a model using the:
* `/predict` endpoint for the [production deployment](/api-reference/production-predict), [development deployment](/api-reference/development-predict) or other [published deployment](/api-reference/deployment-predict).
* `/async_predict` endpoint for the [production deployment](/api-reference/production-async-predict), [development deployment](/api-reference/development-async-predict) or other [published deployment](/api-reference/deployment-async-predict).
* [Truss CLI](/truss-reference/cli/predict) command `truss predict`.
* "Call model" button on the model dashboard within your Baseten workspace.
## Call by API endpoint
```python
import requests
import os
model_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = requests.post(
f"https://model-{model_id}.api.baseten.co/production/predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={}, # JSON-serializable model input
)
print(resp.json())
```
```python
import requests
import os
model_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = requests.post(
f"https://model-{model_id}.api.baseten.co/development/predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={}, # JSON-serializable model input
)
print(resp.json())
```
```python
import requests
import os
model_id = ""
deployment_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = requests.post(
f"https://model-{model_id}.api.baseten.co/deployment/{deployment_id}/predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={}, # JSON-serializable model input
)
print(resp.json())
```
See the [inference API reference](/api-reference/production-predict) for more details.
## Call by async API endpoint
```python
import requests
import os
model_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = requests.post(
f"https://model-{model_id}.api.baseten.co/production/async_predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={
"model_input": {"prompt": "hello world!"},
"webhook_endpoint": "https://my_webhook.com/webhook"
# Optional fields for priority, max_time_in_queue_seconds, etc
}
)
print(resp.json())
```
```python
import requests
import os
model_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = requests.post(
f"https://model-{model_id}.api.baseten.co/development/async_predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={
"model_input": {"prompt": "hello world!"},
"webhook_endpoint": "https://my_webhook.com/webhook"
# Optional fields for priority, max_time_in_queue_seconds, etc
},
)
print(resp.json())
```
```python
import requests
import os
model_id = ""
deployment_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = requests.post(
f"https://model-{model_id}.api.baseten.co/deployment/{deployment_id}/async_predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={
"model_input": {"prompt": "hello world!"},
"webhook_endpoint": "https://my_webhook.com/webhook"
# Optional fields for priority, max_time_in_queue_seconds, etc
},
)
print(resp.json())
```
See the [async inference API reference](/api-reference/production-async-predict) for API details and the [async guide](/invoke/async) for more information about running async inference.
## Call with Truss CLI
```sh
truss predict --model $MODEL_ID -d '$MODEL_INPUT'
```
```sh
truss predict --model-version $DEPLOYMENT_ID -d '$MODEL_INPUT'
```
```sh
cd ~/path/to/my_truss
truss predict -d '$MODEL_INPUT'
```
See the [Truss CLI reference](/truss-reference/cli/predict) for more details.
# How to stream model output
Reduce time to first token for LLMs
For instructions on packaging and deploying a model with streaming output, see [this Truss example](/truss/examples/03-llm-with-streaming). This guide covers how to call a model that has a streaming-capable endpoint.
Any model could be packaged with support for streaming output, but it only makes sense to do so for models where:
* Generating a complete output takes a relatively long time.
* The first tokens of output are useful without the context of the rest of the output.
* Reducing the time to first token improves the user experience.
LLMs in chat applications are the perfect use case for streaming model output.
## Example: Streaming with Mistral
[Mistral 7B Instruct](https://www.baseten.co/library/mistral-7b-instruct) from Baseten's model library is a recent LLM with streaming support. Invocation should be the same for any other model library LLM as well as any Truss that follows the same standard.
[Deploy Mistral 7B Instruct](https://www.baseten.co/library/mistral-7b-instruct) or a similar LLM to run the following examples.
### Truss CLI
The Truss CLI has built-in support for streaming model output.
```sh
truss predict -d '{"prompt": "What is the Mistral wind?", "stream": true}'
```
### API endpoint
When using a streaming endpoint with cURL, use the `--no-buffer` flag to stream output as it is received.
As with all cURL invocations, you'll need a model ID and API key.
```sh
curl -X POST https://app.baseten.co/models/MODEL_ID/predict \
-H 'Authorization: Api-Key YOUR_API_KEY' \
-d '{"prompt": "What is the Mistral wind?", "stream": true}' \
--no-buffer
```
### Python application
Let's take things a step further and look at how to integrate streaming output with a Python application.
```python
import requests
import json
import os
# Model ID for production deployment
model_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
# Open session to enable streaming
s = requests.Session()
with s.post(
# Endpoint for production deployment, see API reference for more
f"https://model-{model_id}.api.baseten.co/production/predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
# Include "stream": True in the data dict so the model knows to stream
data=json.dumps({
"prompt": "What even is AGI?",
"stream": True,
"max_new_tokens": 4096
}),
# Include stream=True as an argument so the requests libray knows to stream
stream=True,
) as resp:
# Print the generated tokens as they get streamed
for content in resp.iter_content():
print(content.decode("utf-8"), end="", flush=True)
```
# Structured output (JSON mode)
Enforce an output schema on LLM inference
Structured outputs requires an LLM deployed using the [TensorRT-LLM Engine Builder](/performance/engine-builder-overview).
If you want to try this structured output example code for yourself, deploy [this implementation of Llama 3.1 8B](/performance/examples/llama-trt).
To generate structured outputs:
1. Define an object schema with [Pydantic](https://docs.pydantic.dev/latest/).
2. Pass the schema to the LLM with the `response_format` argument.
3. Receive output that is guaranteed to match the provided schema, including types and validations like `max_length`.
Using structured output, you should observe approximately equivalent tokens per second output speed to an ordinary call to the model after an initial delay for schema processing. If you're interested in the mechanisms behind structured output, check out this [engineering deep dive on our blog](https://www.baseten.co/blog/how-to-build-function-calling-and-json-mode-for-open-source-and-fine-tuned-llms).
## Schema generation with Pydantic
[Pydantic](https://docs.pydantic.dev/latest/) is an industry standard Python library for data validation. With Pydantic, we'll build precise schemas for LLM output to match.
For example, here's a schema for a basic `Person` object.
```python
from pydantic import BaseModel, Field
from typing import Optional
from datetime import date
class Person(BaseModel):
first_name: str = Field(..., description="The person's first name", max_length=50)
last_name: str = Field(..., description="The person's last name", max_length=50)
age: int = Field(..., description="The person's age, must be a non-negative integer")
email: str = Field(..., description="The person's email address")
```
Structured output supports multiple data types, required and optional fields, and additional validations like `max_length`.
## Add response format to LLM call
The first time that you pass a given schema for the model, it can take a minute for the schema to be processed and cached. Subsequent calls with the same schema will run at normal speeds.
Once your object is defined, you can add it as a parameter to your LLM call with the `response_format` field:
```python
import json
import requests
payload = {
"messages": [
{"role": "system", "content": "You are a helpful assistant"},
{ "role": "user", "content": "Make up a new person!"},
],
"max_tokens": 512,
"response_format": { # Add this parameter to use structured outputs
"type": "json_schema",
"json_schema": {"schema": Person.model_json_schema()},
},
}
MODEL_ID = ""
BASETEN_API_KEY = ""
resp = requests.post(
f"https://model-{MODEL_ID}.api.baseten.co/production/predict",
headers={"Authorization": f"Api-Key {BASETEN_API_KEY}"},
json=payload,
)
json.loads(resp.text)
```
The response may have an end of sequence token, which will need to be removed before the JSON can be parsed.
## Parsing LLM output
From the LLM, we expect output in the following format:
```json
{
"first_name": "Astrid",
"last_name": "Nyxoria",
"age": 28,
"email": "astrid.nyxoria@starlightmail.com",
}
```
This example output is valid, which can be double-checked with:
```python
Person.parse_raw(resp.text)
```
# Troubleshooting
Fixing common problems during model inference
## Model I/O issues
### Error: JSONDecodeError
```
json.decoder.JSONDecodeError: Expecting value: line 1 column 1 (char 0)
```
This error means you're attempting to pass a model input that is not JSON-serializable. For example, you might have left out the double quotes required for a valid string:
```sh
truss predict -d 'This is not a string' # Wrong
truss predict -d '"This is a string"' # Correct
```
## Model version issues
### Error: No OracleVersion matches the given query
```
```
Make sure that the model ID or deployment ID you're passing is correct and that the associated model has not been deleted.
Additionally, make sure you're using the correct endpoint:
* [Production deployment endpoint](/api-reference/production-predict).
* [Development deployment endpoint](/api-reference/development-predict).
* [Published deployment endpoint](/api-reference/deployment-predict).
## Authentication issues
### Error: Service provider not found
```
ValueError: Service provider example-service-provider not found in ~/.trussrc
```
This error means your `~/.trussrc` is incomplete or incorrect. It should be formatted as follows:
```
[baseten]
remote_provider = baseten
api_key = YOUR.API_KEY
remote_url = https://app.baseten.co
```
### Error: You have to log in to perform the request
```
```
This error occurs on `truss predict` when the API key in `~/.trussrc` for a given host is missing or incorrect. To fix it, update your API key in the `~/.trussrc` file.
### Error: Please check the API key you provided
```
{
"error": "please check the api-key you provided"
}
```
This error occurs when using `curl` or similar to call the model via its API endpoint when the API key passed in the request header is not valid. Make sure you're using a valid API key then try again.
# Workspace access control
Share your Baseten workspace with your team
Workspaces on the Startup plan are limited to five users. [Contact us](mailto:support@baseten.co) if you need to invite more than five users to your workspace.
## Roles and permissions
Baseten workspaces have basic role-based access control. There are two workspace roles:
| | Admin | Creator |
| -------------- | ----- | ------- |
| Manage members | ☑️ | - |
| Manage billing | ☑️ | - |
| Deploy models | ☑️ | ☑️ |
| Call models | ☑️ | ☑️ |
# Best practices for API keys
Securely access your Baseten models
API keys are used to:
* Deploy models to your Baseten account from the Truss CLI.
* Call models via the [inference API](/api-reference/overview#inference-api) or `truss predict` in the CLI.
* Make requests to other model endpoints such as `/wake`.
* Manage models via the [management API](/api-reference/overview#management-api).
* [Export metrics](/observability/export-metrics/overview) to your observability stack via the `/metrics` endpoint.
You can create and revoke API keys from your [Baseten account](https://app.baseten.co/settings/account/api_keys).
## API key scope: account vs workspace
There are two types of API keys on Baseten:
* **Personal** keys are tied to your Baseten account and have the full permissions associated with your account in the workspace. They can be used to deploy, call, and manage models, and can also be used to export model metrics. Every action taken with a personal API key is associated with the matching user account.
* **Workspace** keys are shared across your entire Baseten workspace. When you create a workspace API key, you can grant it full access to the workspace or limit it to only being able to perform actions on selected models.
Use account-level API keys for deploying and testing models and use workspace-level API keys in automated actions and production environments.
## Using API keys with Truss
To use an API key for authentication with commands like `truss push` and `truss predict`, set it in your `~/.trussrc` file:
```sh ~/.trussrc
[baseten]
remote_provider = baseten
api_key = abcdefgh.1234567890ABCDEFGHIJKL1234567890
remote_url = https://app.baseten.co
```
If you rotate your API key, just open the file in a text editor and paste the new API key to update.
### Using API keys with endpoints
To use an API key for requests to model endpoints, pass it as a header in the HTTP request:
```sh
curl -X POST https://app.baseten.co/models/MODEL_ID/predict \
-H 'Authorization: Api-Key abcdefgh.1234567890ABCDEFGHIJKL1234567890' \
-d 'MODEL_INPUT'
```
The header is a key-value pair:
```python
headers = {"Authorization": "Api-Key abcdefgh.1234567890ABCDEFGHIJKL1234567890"}
```
## Tips for managing API keys
Best practices for API key use apply to your Baseten API keys:
* Always store API keys securely.
* Never commit API keys to your codebase.
* Never share or leak API keys in notebooks or screenshots.
The [API key list on your Baseten account](https://app.baseten.co/settings/account/api_keys) shows when each key was first created and last used. Rotate API keys regularly and remove any unused API keys to reduce the risk of accidental leaks.
# Export metrics to Datadog
Export metrics from Baseten to Datadog
Exporting metrics is in beta mode.
The Baseten metrics endpoint can be integrated with [OpenTelemetry Collector](https://opentelemetry.io/docs/collector/) by configuring a Prometheus receiver that scrapes the endpoint. This allows Baseten metrics to be pushed to a variety of popular exporters—see the [OpenTelemetry registry](https://opentelemetry.io/ecosystem/registry/?component=exporter) for a full list.
**Using OpenTelemetry Collector to push to Datadog**
```yaml config.yaml
receivers:
# Configure a Prometheus receiver to scrape the Baseten metrics endpoint.
prometheus:
config:
scrape_configs:
- job_name: 'baseten'
scrape_interval: 60s
metrics_path: '/metrics'
scheme: https
authorization:
type: "Api-Key"
credentials: "{BASETEN_API_KEY}"
static_configs:
- targets: ['app.baseten.co']
processors:
batch:
exporters:
# Configure a Datadog exporter.
datadog:
api:
key: "{DATADOG_API_KEY}"
service:
pipelines:
metrics:
receivers: [prometheus]
processors: [batch]
exporters: [datadog]
```
# Export metrics to Grafana Cloud
Export metrics from Baseten to Grafana Cloud
The Baseten + Grafana Cloud integration enables you to get real-time inference metrics within your existing Grafana setup.
## Video tutorial
See below for step-by-step details from the video.
## Set up the integration
For a visual guide, please follow along with the video above.
Open your Grafana Cloud account:
1. Navigate to "Home > Connections > Add new connection".
2. In the search bar, type `Metrics Endpoint` and select it.
3. Give your scrape job a name like `baseten_metrics_scrape`.
4. Set the scrape job URL to `https://app.baseten.co/metrics`.
5. Leave the scrape interval set to `Every minute`.
6. Select `Bearer` for authentication credentials.
7. In your Baseten account, generate a metrics-only workspace API key.
8. In Grafana, enter the Bearer Token as `Api-Key abcd.1234567890` where the latter value is replaced by your API key.
9. Use the "Test Connection" button to ensure everything is entered correctly.
10. Click "Save Scrape Job."
11. Click "Install."
12. In your integrations list, select your new export and go through the "Enable" flow shown on video.
Now, you can navigate to your Dashboards tab, where you will see your data! Please note that it can take a couple of minutes for data to arrive and only new data will be scraped, not historical metrics.
## Build a Grafana dashboard
Importing the data is a great first step, but you'll need a dashboard to properly visualize the incoming information.
We've prepared a basic dashboard to get you started, which you can import by:
1. Downloading `baseten_grafana_dashboard.json` from [this GitHub Gist](https://gist.github.com/philipkiely-baseten/9952e7592775ce1644944fb644ba2a9c).
2. Selecting "New > Import" from the dropdown in the top-right corner of the Dashboard page.
3. Dropping in the provided JSON file.
For visual reference in navigating the dashboard, please see the video above.
# Export metrics to New Relic
Export metrics from Baseten to New Relic
Exporting metrics is in beta mode.
Export Baseten metrics to New Relic by integrating with [OpenTelemetry Collector](https://opentelemetry.io/docs/collector/). This involves configuring a Prometheus receiver that scrapes Baseten's metrics endpoint and configuring a New Relic exporter to send the metrics to your observability backend.
**Using OpenTelemetry Collector to push to New Relic**
```yaml config.yaml
receivers:
# Configure a Prometheus receiver to scrape the Baseten metrics endpoint.
prometheus:
config:
scrape_configs:
- job_name: 'baseten'
scrape_interval: 60s
metrics_path: '/metrics'
scheme: https
authorization:
type: "Api-Key"
credentials: "{BASETEN_API_KEY}"
static_configs:
- targets: ['app.baseten.co']
processors:
batch:
exporters:
# Configure a New Relic exporter. Visit New Relic documentation to get your regional otlp endpoint.
otlphttp/newrelic:
endpoint: https://otlp.nr-data.net
headers:
api-key: "{NEW_RELIC_KEY}"
service:
pipelines:
metrics:
receivers: [prometheus]
processors: [batch]
exporters: [otlphttp/newrelic]
```
# Metrics export overview
Export metrics from Baseten to your observability stack
Exporting metrics is in beta mode.
Baseten exposes an endpoint that returns real-time metrics in the [Prometheus format](https://github.com/prometheus/docs/blob/main/content/docs/instrumenting/exposition_formats.md). By using this endpoint as a Prometheus / OpenMetrics scrape endpoint, you can integrate Baseten metrics with compatible software like Prometheus, OpenTelemetry Collector / Agent, Datadog Agent, Vector, and more.
To scrape metrics from Baseten:
1. Set the scrape endpoint to `app.baseten.co/metrics`.
2. Configure the `Authorization` header for scrape requests. This should be set to your Baseten API key, prefixed with **`Api-Key`** (e.g. **`{"Authorization": "Api-Key abcd1234.abcd1234"}`**).
* We recommend using [workspace API keys](/observability/api-keys) with export metrics permission.
3. Set an appropriate scrape interval. We recommend scraping metrics every minute.
Note that the metrics endpoint is updated every 30 seconds.
## Supported services
The Baseten metrics endpoint can be integrated with [OpenTelemetry Collector](https://opentelemetry.io/docs/collector/) by configuring a Prometheus receiver that scrapes the endpoint. Baseten metrics can be pushed to any exporter on the [OpenTelemetry registry](https://opentelemetry.io/ecosystem/registry/?component=exporter). We integrate with providers including:
* [Prometheus](/observability/export-metrics/prometheus)
* [Datadog](/observability/export-metrics/datadog)
* [Grafana](/observability/export-metrics/grafana)
* [New Relic](/observability/export-metrics/new-relic)
For a list of supported metrics, see the [supported metrics reference](/observability/export-metrics/supported-metrics).
## Rate limits
Calls to the Baseten metrics endpoint are limited on a per-organization basis to **6 requests per minute**. If this limit is exceeded, subsequent calls will result in 429 responses.
To avoid hitting this limit, we recommend setting your scrape interval to 1 minute.
# Export metrics to Prometheus
Export metrics from Baseten to Prometheus
Exporting metrics is in beta mode.
To integrate with Prometheus, specify the Baseten metrics endpoint in a scrape config. For example:
```yaml prometheus.yml
global:
scrape_interval: 60s
scrape_configs:
- job_name: 'baseten'
metrics_path: '/metrics'
authorization:
type: "Api-Key"
credentials: "{BASETEN_API_KEY}"
static_configs:
- targets: ['app.baseten.co']
scheme: https
```
See the Prometheus docs for more details on [getting started](https://prometheus.io/docs/prometheus/latest/getting_started/) and [configuration options](https://prometheus.io/docs/prometheus/latest/configuration/configuration/).
# Metrics support matrix
Which metrics can be exported
Exporting metrics is in beta mode.
## `baseten_inference_requests_total`
Cumulative number of requests to the model.
Type: `counter`
Labels:
The ID of the model.
The name of the model.
The ID of the deployment.
The status code of the response.
Whether the request was an [async inference request](/invoke/async).
The environment that the deployment corresponds to. Empty if the deployment is not associated with an environment.
The phase of the deployment in the [promote to production process](/deploy/lifecycle#promoting-to-production). Empty if the deployment is not associated with an environment.
Possible values:
* `"promoting"`
* `"stable"`
## `baseten_end_to_end_response_time_seconds`
End-to-end response time in seconds.
Type: `histogram`
Labels:
The ID of the model.
The name of the model.
The ID of the deployment.
The status code of the response.
Whether the request was an [async inference request](/invoke/async).
The environment that the deployment corresponds to. Empty if the deployment is not associated with an environment.
The phase of the deployment in the [promote to production process](/deploy/lifecycle#promoting-to-production). Empty if the deployment is not associated with an environment.
Possible values:
* `"promoting"`
* `"stable"`
## `baseten_container_cpu_usage_seconds_total`
Cumulative CPU time consumed by the container in core-seconds.
Type: `counter`
Labels:
The ID of the model.
The name of the model.
The ID of the deployment.
The ID of the replica.
The environment that the deployment corresponds to. Empty if the deployment is not associated with an environment.
The phase of the deployment in the [promote to production process](/deploy/lifecycle#promoting-to-production). Empty if the deployment is not associated with an environment.
Possible values:
* `"promoting"`
* `"stable"`
## `baseten_replicas_active`
Number of replicas ready to serve model requests.
Type: `gauge`
Labels:
The ID of the model.
The name of the model.
The ID of the deployment.
The environment that the deployment corresponds to. Empty if the deployment is not associated with an environment.
The phase of the deployment in the [promote to production process](/deploy/lifecycle#promoting-to-production). Empty if the deployment is not associated with an environment.
Possible values:
* `"promoting"`
* `"stable"`
## `baseten_replicas_starting`
Number of replicas starting up--i.e. either waiting for resources to be available or loading the model.
Type: `gauge`
Labels:
The ID of the model.
The name of the model.
The ID of the deployment.
The environment that the deployment corresponds to. Empty if the deployment is not associated with an environment.
The phase of the deployment in the [promote to production process](/deploy/lifecycle#promoting-to-production). Empty if the deployment is not associated with an environment.
Possible values:
* `"promoting"`
* `"stable"`
## `baseten_container_cpu_memory_working_set_bytes`
Cumulative CPU time consumed by the container in seconds.
Type: `gauge`
Labels:
The ID of the model.
The name of the model.
The ID of the deployment.
The ID of the replica.
The environment that the deployment corresponds to. Empty if the deployment is not associated with an environment.
The phase of the deployment in the [promote to production process](/deploy/lifecycle#promoting-to-production). Empty if the deployment is not associated with an environment.
Possible values:
* `"promoting"`
* `"stable"`
## `baseten_gpu_memory_used`
GPU memory used in MiB.
Type: `gauge`
Labels:
The ID of the model.
The name of the model.
The ID of the deployment.
The ID of the replica.
The ID of the GPU.
The environment that the deployment corresponds to. Empty if the deployment is not associated with an environment.
The phase of the deployment in the [promote to production process](/deploy/lifecycle#promoting-to-production). Empty if the deployment is not associated with an environment.
Possible values:
* `"promoting"`
* `"stable"`
## `baseten_gpu_utilization`
GPU utilization as a percentage (between 0 and 100).
Type: `gauge`
Labels:
The ID of the model.
The name of the model.
The ID of the deployment.
The ID of the replica.
The ID of the GPU.
The environment that the deployment corresponds to. Empty if the deployment is not associated with an environment.
The phase of the deployment in the promote to production process. Empty if the deployment is not associated with an environment.
Possible values:
* `"promoting"`
* `"stable"`
# Monitoring model health
Diagnose and fix model server issues
Every model deployment in your Baseten workspace has a status to represent its activity and health.
## Model statuses
**Healthy states:**
* **Active**: The deployment is active and available. It can be called with `truss predict` or from its API endpoints.
* **Scaled to zero**: The deployment is active but is not consuming resources. It will automatically start up when called, then scale back to zero after traffic ceases.
* **Starting up**: The deployment is starting up from a scaled to zero state after receiving a request.
* **Inactive**: The deployment is unavailable and is not consuming resources. It may be manually reactivated.
**Error states:**
* **Unhealthy**: The deployment is active but is in an unhealthy state due to errors while running, such as an external service it relies on going down or a problem in your Truss that prevents it from responding to requests.
* **Build failed**: The deployment is not active due to a Docker build failure.
* **Deployment failed**: The deployment is not active due to a model deployment failure.
## Debug logging
See [this](/truss-reference/config#enable-debug-logs) Truss config option.
## Fixing unhealthy deployments
If you have an unhealthy or failed deployment, check the model logs to see if there's any indication of what the problem is. You can try deactivating and reactivating your deployment to see if the issue goes away. In the case of an external service outage, you may need to wait for the service to come back up before your deployment works again. For issues inside your Truss, you'll need to diagnose your code to see what is making it unresponsive.
# Reading model metrics
Understand the load and performance of your model
The metrics tab on the model dashboard shows charts for each deployment of your model. These metrics help to understand the relationship between model load and model performance.
Metrics are shown per deployment. Use the dropdowns on the metrics page to switch between deployments and time ranges.
## Inference volume
Inference volume shows the rate of requests to the model over time. It's broken out into 2xx, 4xx, and 5xx representing ranges of HTTP response status code. Any exceptions thrown in your model predict code will be represented under the 5xx responses.
Some older models may not show the 2xx, 4xx, and 5xx breakdown. If you don't see this breakdown:
* Get the latest version of Truss with `pip install --upgrade truss`.
* Re-deploy your model with `truss push`.
* Promote the updated model to production after testing it.
## Response time
* **End-to-end response time** includes time for cold starts, queuing, and inference (but not client-side latency). This most closely mirrors the performance of your model as experienced by your users.
* **Inference time** includes just the time spent running the model, including pre- and post-processing. This is useful for optimizing the performance of your model code at the single replica level.
Response time is broken out into p50, p90, p95, and p99, referring to the 50th, 90th, 95th, and 99th percentile of response times.
## Replicas
The replicas chart shows the number of replicas in both active and starting up states:
* The starting up count includes replicas that are waiting for resources to be available and replicas that are in the process of loading the model.
* The active count includes replicas that are ready to serve requests.
For development deployments, the replica shows as active while loading the model and running the live reload server.
## CPU usage and memory
These charts show the CPU and memory usage of your deployment. If you have multiple replicas, they show the average across all your replicas. Note that this data is not instantanous, so sharp spikes in usage may not appear on the graph.
What to look out for:
* When the load on the CPU or memory get too high, the performance of your deployment may degrade. You may want to consider updating your model's instance type to one with more memory or CPUs.
* If CPU load and memory usage are consistently very low, you may be using an instance with too many vCPU cores and too much RAM. If you're using a CPU-only instance, or a GPU instance where a smaller instance type with the same GPU is available, you may be able to save money by switching.
## GPU usage and memory
GPU usage shows the GPU usage and memory usage of your deployment. If you have multiple replicas, they show the average across all your replicas. Note that this data is not instantanous, so sharp spikes in usage may not appear on the graph.
In technical terms, the GPU usage is a measure of the fraction of time within a cycle that a kernel function is occupying GPU resources.
What to look out for:
* When the load on the GPU gets too high, model inference can slow down. Look for corresponding increases in inference time.
* When GPU memory usage gets too high, requests can fail with out-of-memory errors.
* If GPU load and memory usage are consistently very low, you may be using an overpowered GPU and could save money with a less powerful card.
## Time in async queue
The time in async queue chart shows the time in seconds that an async predict request spent in the async queue before getting processed by the model. This chart is broken out into p50, p90, p95, and p99, referring to the 50th, 90th, 95th, and 99th percentile of time spent in the async queue.
## Async queue size
The async queue size chart shows the number of async predict requests that are currently queued to be executed.
What to look out for:
* If the queue size is large, async requests are being queued faster than they can be executed. In this case, requests may take longer to complete or expire after the user-specified `max_time_in_queue_seconds`.
* To increase the number of async requests your model can process, increase the max number of replicas or concurrency target in your autoscaling settings.
# Best practices for secrets
Securely store and access passwords, tokens, keys, and more
Use the [secrets dashboard in your Baseten workspace](https://app.baseten.co/settings/secrets) to store sensitive data like access tokens, API keys, and passwords.
Every secret is a key-value pair with a "name" and "token." Tokens can be multiple lines, which is useful for secrets like SSH and PGP keys. Note that for the secret "name",
all non-alphanumeric characters will be treated the same (i.e. `"hf_access_token"` and `"hf-access-token"` will map to the same underlying secret). This means that, if
an existing secret with name \`"hf\_access\_token" exists, attempting to create a secret with name "hf-access-token" will overwrite the existing key.
Adding, updating, and deleting secrets immediately affects all models that use said secrets.
## Deploying models with secrets
When you deploy a model, use the `--trusted` flag to give it access to secrets in your Baseten workspace:
```sh
truss push --trusted
```
## Using secrets in Truss
In your Truss, add the secret name in `config.yaml` but set the value to `null`:
```yaml config.yaml
...
secrets:
hf_access_token: null
...
```
Never set the actual value of the secret in `config.yaml` or any other file that gets committed to your codebase.
Then, access the secret from the `secrets` keyword argument in your `model.py` initialization:
```py model/model.py
def __init__(self, **kwargs):
self._secrets = kwargs["secrets"]
```
You can then use the `self._secrets` dictionary in the `load` and `predict` functions:
```py model/model.py
def load(self):
self._model = pipeline(
"fill-mask",
model="baseten/docs-example-gated-model",
use_auth_token=self._secrets["hf_access_token"]
)
```
# Secure model inference
Keeping your models safe and private
We take the security of your models and data seriously. Baseten maintains a [SOC 2 Type II certification](https://www.baseten.co/blog/soc-2-type-2) and [HIPAA compliance](https://www.baseten.co/blog/baseten-announces-hipaa-compliance), but we're aware that these certifications don't guarantee the security of the system. This doc provides a more specific look at Baseten's security posture.
## Data privacy
Baseten is not in the business of using customers' data. We are in the business of providing ML inference infrastructure. We provide strong data privacy for your workloads.
### Model inputs and outputs
By default, Baseten never stores models' inputs or outputs.
Model inputs sent via [async inference](/invoke/async) are stored until the async request has been processed by the model. Model outputs from async requests are never stored.
Baseten used to offer, and maintains for existing users, a hosted Postgres data table system. A user could store model inputs and outputs in these data tables, which means they'd be stored on Baseten. Baseten's hosted Postgres data tables are secured with the same level of care as the rest of our infrastructure, and information in those tables can be permanently deleted by the user at any time.
### Model weights
By default, Baseten does not store models' weights.
By default, when a model is loaded, the model weights are simply downloaded from the source of truth (e.g. private HuggingFace repo, GCS, S3, etc) and moved from CPU memory to GPU memory (i.e. never stored on disk).
A user may explicitly instruct Baseten to store model weights using the [caching mechanism in Truss](/truss/examples/06-high-performance-cached-weights). If a user stores weights on Baseten with this mechanism, they can request for those weights to be permanently erased. Baseten will process these requests for any models specified by the user within 1 business day.
For open-source models from Baseten's model library, model weights are stored with this caching mechanism by default to speed up cold starts.
Additionally, Baseten uses a network accelerator that we developed to speed up model loads from common model artifact stores, including Hugging Face, S3, and GCS. Our accelerator employs byte range downloads in the background to maximize the parallelism of downloads. If you prefer to disable this network acceleration for your Baseten workspace, please contact our support team at [support@baseten.co](mailto:support@baseten.co).
## Workload security
Baseten runs ML inference workloads on users' behalf. This necessitates creating the right level of isolation to protect users' workloads from each other, and to protect Baseten's core services from the users' workloads. This is achieved through:
* Container security via enforcing security policies and the principle of least privilege.
* Network security policies including giving each customer their own Kubernetes namespace.
* Keeping our infrastructure up-to-date with the latest security patches.
### Container security
No two users' model share the same GPU. In order to mitigate container related risks, we have made use of security tooling such as Falco (via Sysdig), pod security policies (via Gatekeeper), and running pods in a security context with minimal privileges. Furthermore, we ensure that the nodes themselves don't have any privileges to affect other users' workloads. Nodes have the lowest possible privileges within the Kubernetes cluster in order to minimize the blast radius of a security incident.
### Network security policies
There exists a 1-1 relationship between a customer and a Kubernetes namespace. For each customer, all of their workloads live within that namespace.
We use this architecture to ensure isolation between customers' workloads through network isolation enforced through [Calico](https://docs.tigera.io/calico/latest/about). Customers' workloads are further isolated at the network level from the rest of Baseten's infrastructure as the nodes run in a private subnet and are firewalled from public access.
### Extended pentesting
While some pentesting is required for the SOC 2 certification, Baseten exceeds these requirements in both the scope of pentests and the access we give our testers.
We've contracted with ex-OpenAI and Crowdstrike security experts at [RunSybil](https://www.runsybil.com/) to perform extended pentesting including deploying malicious models on a dedicated prod-like Baseten environment, with the goal of breaking through the security measures described on this page.
## Self-hosted model inference
We do offer:
* Single-tenant environments.
* Self-hosting Baseten within your own infrastructure.
Given the security measures we have already put in place, we recommend the cloud version of Baseten for most customers as it provides faster setup, lower cost, and elastic GPU availablity. However, if Baseten's [self hosted plan](https://www.baseten.co/pricing/) sounds right for your needs, please contact our support team at [support@baseten.co](mailto:support@baseten.co).
# Tracing
Investigate the prediction flow in detail
The truss server has [OpenTelemetry](https://opentelemetry.io/) (OTEL)
instrumentation builtin. Additionally, users can add their custom instrumentation.
Traces can be useful to investigate performance bottlenecks or other issues.
By default, tracing is not enabled as it can lead to some minor performance overhead
which in some use cases is undesirable. Follow below guides to collect trace data.
## Exporting builtin trace data to Honeycomb
To enable data export, create a Honeycomb API key and add it as a secret to
[baseten](https://app.baseten.co/settings/secrets). Then make add the following settings
to the truss config of the model that you want to enable tracing for:
```yaml config.yaml
environment_variables:
HONEYCOMB_DATASET: your_dataset_name
runtime:
enable_tracing_data: true
secrets:
HONEYCOMB_API_KEY: '***'
```
When making requests to the model, you can provide trace parent IDs with the OTEL
standard header key `traceparent`. If not provided Baseten will add random IDs.
An example trace, visualized on Honeycomb, resolving preprocessing, predict and
postprocessing. Additionally, these traces have some span events timing
(de-)serialization inputs and outputs.
![trace-example-honeycomb](https://mintlify.s3.us-west-1.amazonaws.com/baseten-preview/observability/observability/trace-example-honeycomb.png)
## Adding custom OTEL instrumentation
If you want a different resolution of tracing spans and event recording, you also add
your own OTEL tracing implementation.
We made sure that our builtin tracing instrumentation
does not mix the trace context with user defined tracing.
```python model.py
import time
from typing import Any, Generator
import opentelemetry.exporter.otlp.proto.http.trace_exporter as oltp_exporter
import opentelemetry.sdk.resources as resources
import opentelemetry.sdk.trace as sdk_trace
import opentelemetry.sdk.trace.export as trace_export
from opentelemetry import trace
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
trace.set_tracer_provider(
TracerProvider(resource=Resource.create({resources.SERVICE_NAME: "UserModel"}))
)
tracer = trace.get_tracer(__name__)
trace_provider = trace.get_tracer_provider()
class Model:
def __init__(self, **kwargs) -> None:
honeycomb_api_key = kwargs["secrets"]["HONEYCOMB_API_KEY"]
honeycomb_exporter = oltp_exporter.OTLPSpanExporter(
endpoint="https://api.honeycomb.io/v1/traces",
headers={
"x-honeycomb-team" : honeycomb_api_key,
"x-honeycomb-dataset": "marius_testing_user",
},
)
honeycomb_processor = sdk_trace.export.BatchSpanProcessor(honeycomb_exporter)
trace_provider.add_span_processor(honeycomb_processor)
@tracer.start_as_current_span("load_model")
def load(self):
...
def preprocess(self, model_input):
with tracer.start_as_current_span("preprocess"):
...
return model_input
@tracer.start_as_current_span("predict")
def predict(self, model_input: Any) -> Generator[str, None, None]:
with tracer.start_as_current_span("start-predict") as span:
def inner():
time.sleep(0.01)
for i in range(5):
span.add_event("yield")
yield str(i)
return inner()
```
# Billing and usage
Manage payments and track overall Baseten usage
The [billing and usage dashboard](https://app.baseten.co/settings/billing) shows each model that has been active in the current billing period, how long that model has run for, and the bill for that model's compute time. Model usage is broken down by deployment for models with multiple deployments. The model usage dashboard is updated hourly.
If your account has credits, they will be applied against your bill automatically and shown in the model usage dashboard.
## Billing
### Credits
Every new Baseten workspace is issued free credits to get started with model deployment and serving. Use these credits to explore open-source models or test your own model deployment.
If your credits run out and there is no payment method on your account, all active models will be de-activated and you won't be able to deploy new models until a payment method is added.
### Payment method
On the [billing page](https://app.baseten.co/settings/billing), you can set and update your payment method. Your payment information, including credit card numbers and bank information, is always stored securely with our payments processor and not by Baseten directly.
### Invoice history
Your invoice history shows prior invoices and payments for your records. [Contact us](mailto:support@baseten.co) if you have any questions about your invoice history.
## Usage and billing FAQs
For complete information, see our [pricing page](https://www.baseten.co/pricing/), but here are answers to some common questions:
### How exactly is usage calculated?
Model usage is calculated by the minute for the time your model is actively deploying, scaling up or down, or being called. You are billed based on the [instance type](/performance/instances) that your model uses.
### How often are payments due?
We automatically charge the credit card you have on file for model resource usage. When you first sign up for Baseten, your workspaces is charged as soon as you exceed \$50.00 in usage or at the end of the monthly billing period, whichever happens first. Once you've established a history of successful payments, your workspace is only charged at the end of each monthly billing cycle.
### Do you offer volume discounts?
We offer volume discounts for workspaces on our Pro plan. [Contact us](mailto:support@baseten.co) to learn more.
### Do you offer education and non-profit discounts?
Yes, we are happy to support ML efforts for education and non-profit organizations. [Contact us](mailto:support@baseten.co) to learn more.
# How to get faster cold starts
Engineering your Truss and application for faster cold starts
A "cold start" is the time it takes to spin up a new instance of a model server. Fast cold starts are essential for useful autoscaling, especially scale to zero.
While Baseten has platform-level features that speed up cold starts, a lot of the possible optimizations are model-specific. This guide provides techniques for making cold starts faster for a given model.
## Use caching to reduce cold start time
Everything that happens in the `load()` function in the Truss is part of the cold start time. This generally includes loading binaries for one or more models from a data store like Hugging Face, which can be one of the longest-running steps.
Caching model weights can dramatically improve cold start times. Learn how to cache model weights in Truss with this guide:
Accelerate cold starts by caching your weights
## Use wake to hide cold start time
Every deployment has a [wake endpoint](/api-reference/version-wake) that can be used to activate the model when it's scaled to zero. This can be used to hide the cold start time from the end user.
Imagine you have an app where the user can enter a prompt and get an image from Stable Diffusion. The app has inconsistent traffic, so you have a minimum replica count of zero. Here's what happens when the model is scaled to zero and the app gets a user:
1. The user loads the app
2. The user enters input and the app calls the model endpoint
3. Baseten spins up an instance and loads the model (the time this takes is the cold start)
4. Model inference runs
5. After waiting, the user receives the image they requested
But, you can use the wake endpoint to hide the cold start time from the user. Instead:
1. The user loads the app
2. The app calls the wake endpoint for the scaled-to-zero model
3. Baseten spins up an instance and loads the model (the time this takes is the cold start)
4. Meanwhile, the user enters input and the app calls the model endpoint
5. Model inference runs
6. The user receives the image they requested
Wake is also useful when you have predictable traffic, such as starting up the model during business hours. It can also be triggered manually from the model dashboard when needed, like for a demo.
# Setting concurrency
Handle variable throughput with this autoscaling parameter
Configuring concurrency is one of the major knobs available for getting the most performance
out of your model. In this doc, we'll cover the options that are available to you.
## Configuring concurrency
At a very high level, "concurrency" on Baseten refers to how many requests a single replica can
process at the same time. There's no universal best value for concurrenty — it depends on your model and the metrics that you are optimizing for (like throughput or latency).
In Baseten & Truss, there are two notions of concurrency:
There are two levers for managing concurrency:
* **Concurrency target**: set in the Baseten UI, the number of requests that will be sent to a model at the same time
* **Predict concurrency**: set in the Truss config, governs how many requests can go through the `predict` function on your Truss at once after they've made it to the model container.
### Concurrency target
The concurrency target is set in the Baseten UI and governs the maximum number of requests that will be sent to a single model replica. Once the concurrency target is exceeded across all active replicas, the autoscaler will add more replicas (unless the max replica count is reached).
Let's dive into a concrete example. Let's say that we have:
* A model deployment with exactly 1 replica.
* A concurrency target of 2 requests.
* 5 incoming requests.
In this situation, the first 2 requests will be sent to the model container, while the other 3 are placed in a queue. As the requests in the container are completed, requests are sent in from the queue.
However, if the model deployment's autoscaling settings were to allow for more than one replica, this situation would trigger another replica to be created as there are requests in the queue.
### Predict concurrency
Predict concurrency operates within the model container and governs how many requests will go through the Truss' `predict` function concurrently.
A Truss can implement three functions to process a request:
* **preprocess**: processes model input before inference. For example, in a Truss for Whisper, this function might download the audio file for transcription from a URL in the request body.
* **predict**: performs model inference. This is the only function that blocks the GPU.
* **postprocess**: processes model output after inference, such as uploading the results of a text-to-image model like Stable Diffusion to S3.
The predict concurrency setting lets you limit access to the GPU-blocking `predict` function while still handling pre- and post-processing steps with higher concurrency.
Predict concurrency is set in the Truss' `config.yaml` file:
```yaml config.yaml
model_name: "My model with concurrency limits"
...
runtime:
predict_concurrency: 2 # the default is 1
...
```
To better understand this, let's extend our previous example by zooming in on the model pod:
* A model deployment with exactly 1 replica.
* A concurrency target of 2 requests.
* **New**: a predict concurrency of 1 request.
* 5 incoming requests.
Here's what happens:
1. Two requests enter the model container.
2. Both requests begin pre-processing immediately.
3. When one request finishes pre-processing, it it let into the GPU to run inference. The other request will be queued if it finishes pre-processing before the first request finishes inference.
4. After the first request finishes inference, it moves to post-processing and the second requests begins inference on the GPU.
5. After the second request finishes inference, it can immediately move to post-processing whether or not the first request is still in post-processing.
This shows how predict concurrency protects the GPU resources in the model container while still allowing for high concurrency in the CPU-bound pre- and post-processing steps.
Concurrency target must be greater than or equal to predict concurrency, or your maximum predict concurrency will never be reached.
# Engine Builder configuration
Configure your TensorRT-LLM inference engine
This reference lists every configuration option for the TensorRT-LLM Engine Builder. These options are used in `config.yaml`, such as for this Llama 3.1 8B example:
```yaml config.yaml
model_name: Llama 3.1 8B Engine
resources:
accelerator: H100:1
secrets:
hf_access_token: "set token in baseten workspace"
trt_llm:
build:
base_model: llama
checkpoint_repository:
repo: meta-llama/Llama-3-8B-Instruct
source: HF
max_seq_len: 8000
```
## `trt_llm.build`
TRT-LLM engine build configuration. TensorRT-LLM attempts to build a highly optimized network based on input shapes representative of your workload.
### `base_model`
The base model architecture of your model checkpoint. Supported architectures include:
* `llama`
* `mistral`
* `deepseek`
* `qwen`
### `checkpoint_repository`
Specification of the model checkpoint to be leveraged for engine building. E.g.
```yaml
checkpoint_repository:
source: HF | GCS | REMOTE_URL
repo: meta-llama/Llama-3.1-8B-Instruct | gs://bucket_name | https://your-checkpoint.com
```
To configure access to private model checkpoints, [register secrets in your Baseten workspace](https://docs.baseten.co/observability/secrets#best-practices-for-secrets), namely the `hf_access_token` or `trt_llm_gcs_service_account` secrets with a valid service account json for HuggingFace or GCS, respectively.
Ensure that you push your truss with the `--trusted` flag to enable access to your secrets.
#### `checkpoint_repository.source`
Source where the checkpoint is stored. Supported sources include:
* `HF` (HuggingFace)
* `GCS` (Google Cloud Storage)
* `REMOTE_URL`
#### `checkpoint_repository.repo`
Checkpoint repository name, bucket, or url.
### `kv_cache_free_gpu_mem_fraction`
(default: `0.9`)
Used to control the fraction of free gpu memory allocated for the KV cache. For more information, refer to the documentation [here](https://nvidia.github.io/TensorRT-LLM/performance/perf-best-practices.html#max-tokens-in-paged-kv-cache-and-kv-cache-free-gpu-memory-fraction).
### `max_batch_size`
(default: `256`)
Maximum number of input sequences to pass through the engine concurrently. Batch size and throughput share a direct relation, whereas batch size and single request latency share an indirect relation.
Tune this value according to your SLAs and latency budget.
### `max_beam_width`
(default: `1`)
Maximum number of candidate sequences with which to conduct beam search. This value should act as an minimal upper bound for beam candidates.
Currently, only a beam width of 1 is supported.
### `max_seq_len`
Defines the maximum sequence length (context) of single request.
### `max_num_tokens`
(default: `8192`)
Defines the maximum number of batched input tokens after padding is removed in each batch. Tuning this value more efficiently allocates memory to KV cache and executes more requests together.
### `max_prompt_embedding_table_size`
(default: `0`)
Maximum prompt embedding table size for [prompt tuning](https://developer.nvidia.com/blog/an-introduction-to-large-language-models-prompt-engineering-and-p-tuning/).
### `num_builder_gpus`
(default: `auto`)
Number of GPUs to be used at build time, defaults to configured `resource.accelerator` count – useful for FP8 quantization in particular, when more GPU memory is required at build time relative to memory usage at inference.
### `enable_chunked_context`
(default: `False`)
Enables chunked context, increasing the chance of batch processing between context and generation phase – which may be useful to increase throughput.
Note that one must set `plugin_configuration.use_paged_context_fmha: True` in order to leverage this feature.
### `plugin_configuration`
Config for inserting plugin nodes into network graph definition for execution of user-defined kernels.
#### `plugin_configuration.paged_kv_cache`
(default: `True`)
Decompose KV cache into page blocks. Read more about what this does [here](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/advanced/gpt-attention.md#paged-kv-cache).
#### `plugin_configuration.gemm_plugin`
(default: `auto`)
Utilize NVIDIA cuBLASLt for GEMM ops. Read more about when to enable this [here](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/performance/perf-best-practices.md#gemm-plugin).
#### `plugin_configuration.use_paged_context_fmha`
(default: `False`)
Utilize paged context for fused multihead attention. This configuration is necessary to enable KV cache reuse. Read more about this configuration [here](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/kv_cache_reuse.md#how-to-enable-kv-cache-reuse).
#### `plugin_configuration.use_fp8_context_fmha`
(default: `False`)
Utilize FP8 quantization for context fused multihead attention to accelerate attention. To use this configuration, also set `plugin_configuration.use_paged_context_fmha`. Read more about when to enable this [here](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/advanced/gpt-attention.md#fp8-context-fmha).
### `quantization_type`
(default: `no_quant`)
Quantization format with which to build the engine. Supported formats include:
* `no_quant` (meaning fp16)
* `weights_int8`
* `weights_kv_int8`
* `weights_int4`
* `weights_int4_kv_int8`
* `smooth_quant`
* `fp8`
* `fp8_kv`
Read more about different post training quantization techniques supported by TRT-LLM [here](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/quantization-in-TRT-LLM.md).
Additionally, refer to the hardware and quantization technique [support matrix](https://nvidia.github.io/TensorRT-LLM/reference/support-matrix.html).
### `strongly_typed`
(default: `False`)
Whether to build the engine using strong typing, enabling TensorRT's optimizer to statically infer intermediate tensor types which can speed up build time for some formats.
Weak typing enables the optimizer to elect tensor types, which may result in a faster runtime. For more information refer to TensorRT documentation [here](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#strong-vs-weak-typing).
### `tensor_parallel_count`
(default: `1`)
Tensor parallelism count. For more information refer to NVIDIA documentation [here](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/parallelisms.html#tensor-parallelism).
# Engine control in Python
Use `model.py` to customize engine behavior
When you create a new Truss with `truss init`, it creates two files: `config.yaml` and `model/model.py`. While you configure the Engine Builder in `config.yaml`, you may use `model/model.py` to access and control the engine object during inference.
You have two options:
1. Delete the `model/model.py` file and your TensorRT-LLM engine will run according to its base spec.
2. Update the code to support TensorRT-LLM.
You must either update `model/model.py` to pass `trt_llm` as an argument to the `__init__` method OR delete the file. Otherwise you will get an error on deployment as the default `model/model.py` file is not written for TensorRT-LLM.
The `engine` object is a property of the `trt_llm` argument and must be initialized in `__init__` to be accessed in `load()` (which runs once on server start-up) and `predict()` (which runs for each request handled by the server).
This example applies a chat template with the Llama 3.1 8B tokenizer to the model prompt:
```python model/model.py
from typing import Any
from transformers import AutoTokenizer
class Model:
def __init__(self, trt_llm, **kwargs) -> None:
self._secrets = kwargs["secrets"]
self._engine = trt_llm["engine"]
self._model = None
self._tokenizer = None
def load(self) -> None:
self._tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", token=self._secrets["hf_access_token"])
async def predict(self, model_input: Any) -> Any:
# Apply chat template to prompt
model_input["prompt"] = self._tokenizer.apply_chat_template(model_input["prompt"], tokenize=False)
return await self._engine.predict(model_input)
```
# Engine Builder overview
Deploy optimized model inference servers in minutes
If you have a foundation model like Llama 3 or a fine-tuned variant and want to create a low-latency, high-throughput model inference server, TensorRT-LLM via the Engine Builder is likely the tool for you.
TensorRT-LLM is an open source performance optimization toolbox created by NVIDIA. It helps you build TensorRT engines for large language models like Llama and Mistral as well as certain other models like Whisper and large vision models.
Baseten's TensorRT-LLM Engine Builder simplifies and automates the process of using TensorRT-LLM for development and production. All you need to do is write a few lines of configuration and an optimized model serving engine will be built automatically during the model deployment process.
Get started with an [end-to-end tutorial](/performance/engine-builder-tutorial) or jump straight in with reference implementations for [Llama](/performance/examples/llama-trt), [Mistral](/performance/examples/mistral-trt), and [Qwen](/performance/examples/qwen-trt). Check the [Engine Builder config reference](/performance/engine-builder-config) for a complete set of configuration options.
## FAQs
### Where are the engines stored?
The engines are stored in Baseten but owned by the user — we're working on a mechanism for downloading them. In the meantime, reach out if you need access to an engine that you created using the Engine Builder.
### Does the Engine Builder support quantization?
Yes. The Engine Builder can perform post-training quantization during the building process. For supported options, see [quantization in the config reference](/performance/engine-builder-config/#quantization_type).
### Can I customize the engine behavior?
For further control over the TensorRT-LLM engine during inference, use the `model/model.py` file to access the engine object at runtime. See [controlling engines with Python](/performance/engine-builder-customization) for details.
# Build your first LLM engine
Automatically build and deploy a TensorRT-LLM model serving engine
Deploying a TensorRT-LLM model with the Engine Builder is a three-step process:
1. Pick a model and GPU instance
2. Write your engine configuration and optional model serving code
3. Deploy your packaged model and the engine will be built automatically
In this guide, we'll walk through the process of using the engine builder end-to-end. To make this tutorial as quick and cheap as possible, we'll use a [1.1 billion parameter TinyLlama model](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0) on an A10G GPU.
We also have production-ready examples for [Llama 3](/performance/examples/llama-trt) and [Mistral](/performance/examples/mistral-trt).
## Prerequisites
Before you deploy a model, you'll need three quick setup steps.
Create an [API key](https://app.baseten.co/settings/api_keys) and save it as an environment variable:
```sh
export BASETEN_API_KEY="abcd.123456"
```
Some models require that you accept terms and conditions on Hugging Face before deployment. To prevent issues:
1. Accept the license for any gated models you wish to access, like [Llama 3](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct).
2. Create a read-only [user access token](https://huggingface.co/docs/hub/en/security-tokens) from your Hugging Face account.
3. Add the `hf_access_token` secret [to your Baseten workspace](https://app.baseten.co/settings/secrets).
Install the latest version of Truss, our open-source model packaging framework, with:
```sh
pip install --upgrade truss
```
## Configure your engine
We'll start by creating a new Truss:
```sh
truss init tinyllama-trt
cd tinyllama-trt
```
In the newly created `tinyllama-trt/` folder, open `config.yaml`. In this file, we'll configure our model serving engine:
```yaml config.yaml
model_name: tinyllama-trt
python_version: py310
resources:
accelerator: A10G
use_gpu: True
trt_llm:
build:
max_seq_len: 4096
base_model: llama
quantization_type: no_quant
checkpoint_repository:
repo: TinyLlama/TinyLlama-1.1B-Chat-v1.0
source: HF
```
This build configuration sets a number of important parameters:
* `max_seq_len` controls the maximum number of total tokens supported by the engine. We want to match this as closely as possible to expected real-world use to improve engine performance.
* `base_model` determines which type of supported model architecture to build the engine for.
* `quantization_type` asks if the model should be quantized on deployment. `no_quant` will run the model in standard `fp16` precision.
* `checkpoint_repository` determines where to load the weights from, in this case [a Hugging Face repository for TinyLlama](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0).
The `config.yaml` file also contains Baseten-specific configuration for model name, GPU type, and model serving environment.
## Delete or update `model.py`
The `config.yaml` file above specifies a complete TensorRT-LLM engine. However, we also provide further control in the `model/model.py` file in Truss.
If you do not need to add any custom logic in `model/model.py`, instead delete the file. Otherwise, you'll get the following error on deployment:
```
truss.errors.ValidationError: Model class `__init__` method
is required to have `trt_llm` as an argument.
Please add that argument.
```
The `model/model.py` file is useful for custom behaviors like applying a prompt template.
```python model/model.py
from typing import Any
from transformers import AutoTokenizer
class Model:
def __init__(self, trt_llm, **kwargs) -> None:
self._engine = trt_llm["engine"]
self._model = None
self._tokenizer = None
def load(self) -> None:
self._tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
async def predict(self, model_input: Any) -> Any:
# Apply chat template to prompt
model_input["prompt"] = self._tokenizer.apply_chat_template(model_input["prompt"], tokenize=False)
return await self._engine.predict(model_input)
```
Including a `model/model.py` file is optional. If the file is not present, the TensorRT-LLM engine will run according to its base spec.
## Deploy and build
To deploy your model and have the TensorRT-LLM engine automatically build, run:
```sh
truss push --publish
```
This will create a new deployment in your Baseten workspace. Navigate the model dashboard to see engine building and model deployment logs.
The engines are stored in Baseten but owned by the user — we're working on a mechanism for downloading them. In the meantime, reach out if you need access to an engine that you created using the Engine Builder.
## Call deployed model
When your model is deployed, you can call it via its API endpoint:
```python call_model.py
import requests
# Model ID for production deployment
model_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
# Call model endpoint
resp = requests.post(
f"https://model-{model_id}.api.baseten.co/production/predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={
"messages": [{"role": "user", "content": "How awesome is TensorRT-LLM?"}],
"max_tokens": 1024
},
stream=True
)
# Print the generated tokens as they get streamed
for content in resp.iter_content():
print(content.decode("utf-8"), end="", flush=True)
```
Supported parameters for LLMs:
The input text prompt to guide the language model's generation.
One of `prompt` XOR `messages` is required.
A list of dictionaries representing the message history, typically used in conversational contexts.
One of `prompt` XOR `messages` is required.
The maximum number of tokens to generate in the output. Controls the length of the generated text.
The number of beams used in beam search. Maximum of `1`.
A penalty applied to repeated tokens to discourage the model from repeating the same words or phrases.
A penalty applied to tokens already present in the prompt to encourage the generation of new topics.
Controls the randomness of the output. Lower values make the output more deterministic, while higher values increase randomness.
A penalty applied to the length of the generated sequence to control verbosity. Higher values make the model favor shorter outputs.
The token ID that indicates the end of the generated sequence.
The token ID used for padding sequences to a uniform length.
Limits the sampling pool to the top `k` tokens, ensuring the model only considers the most likely tokens at each step.
Applies nucleus sampling to limit the sampling pool to a cumulative probability `p`, ensuring only the most likely tokens are considered.
# Llama 3 with TensorRT-LLM
Build an optimized inference engine for Llama 3.1 8B
This configuration builds an inference engine to serve Llama 3.1 8B on an A100 GPU. It is very similar to the configuration for any other Llama model, including fine-tuned variants.
## Setup
See the [end-to-end engine builder tutorial](/performance/engine-builder-tutorial) prerequisites for full setup instructions.
Make sure you have [accessed the gated model on Hugging Face](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) and set your `hf_access_token` in your Baseten workspace secrets.
Please upgrade to the latest version of Truss with `pip install --upgrade truss` before following this example.
```sh
pip install --upgrade truss
truss init llama-3-1-8b-trt-llm
cd llama-3-1-8b-trt-llm
rm model/model.py
```
## Configuration
This is a well-rounded configuration that balances latency and throughput. It supports long sequence lengths for multi-step chat and has a batch size of 32 as that's reasonable for an eight-billion-parameter model on an A100 GPU. The model is served unquantized in `fp16`.
```yaml config.yaml
model_name: Llama 3.1 8B Engine
resources:
accelerator: A100
secrets:
hf_access_token: "set token in baseten workspace"
trt_llm:
build:
base_model: llama
checkpoint_repository:
repo: meta-llama/Llama-3.1-8B-Instruct
source: HF
max_seq_len: 8192
```
## Deployment
```sh
truss push --publish --trusted
```
## Usage
```python call_model.py
import requests
import os
# Model ID for production deployment
model_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
# Call model endpoint
resp = requests.post(
f"https://model-{model_id}.api.baseten.co/production/predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={
"messages": [{"role": "user", "content": "How awesome is TensorRT-LLM?"}],
"max_tokens": 1024
},
stream=True
)
# Print the generated tokens as they get streamed
for content in resp.iter_content():
print(content.decode("utf-8"), end="", flush=True)
```
The input text prompt to guide the language model's generation.
One of `prompt` XOR `messages` is required.
A list of dictionaries representing the message history, typically used in conversational contexts.
One of `prompt` XOR `messages` is required.
The maximum number of tokens to generate in the output. Controls the length of the generated text.
The number of beams used in beam search. Maximum of `1`.
A penalty applied to repeated tokens to discourage the model from repeating the same words or phrases.
A penalty applied to tokens already present in the prompt to encourage the generation of new topics.
Controls the randomness of the output. Lower values make the output more deterministic, while higher values increase randomness.
A penalty applied to the length of the generated sequence to control verbosity. Higher values make the model favor shorter outputs.
The token ID that indicates the end of the generated sequence.
The token ID used for padding sequences to a uniform length.
Limits the sampling pool to the top `k` tokens, ensuring the model only considers the most likely tokens at each step.
Applies nucleus sampling to limit the sampling pool to a cumulative probability `p`, ensuring only the most likely tokens are considered.
# Mistral with TensorRT-LLM
Build an optimized inference engine for Mistral
This configuration builds an inference engine to serve Mistral 7B on an H100 GPU. It is very similar to the configuration for any other Mistral model, including fine-tuned variants.
## Setup
See the [end-to-end engine builder tutorial](/performance/engine-builder-tutorial) prerequisites for full setup instructions.
Make sure you have [accessed the gated model on Hugging Face](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2) and set your `hf_access_token` in your Baseten workspace secrets.
Please upgrade to the latest version of Truss with `pip install --upgrade truss` before following this example.
```sh
pip install --upgrade truss
truss init mistral-7b-trt-llm
cd mistral-7b-trt-llm
rm model/model.py
```
## Configuration
This configuration is optimized for low latency, with a batch size of 8. It applies post-training quantization to `fp8` for further speed gains.
```yaml config.yaml
model_name: Mistral Engine 7B v0.2 6
python_version: py39
resources:
accelerator: H100:1
use_gpu: true
secrets:
hf_access_token: "set token in baseten workspace"
trt_llm:
build:
base_model: mistral
checkpoint_repository:
repo: mistralai/Mistral-7B-Instruct-v0.2
source: HF
max_input_len: 8192
quantization_type: fp8
```
## Deployment
```sh
truss push --publish --trusted
```
## Usage
```python call_model.py
import requests
import os
# Model ID for production deployment
model_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
# Call model endpoint
resp = requests.post(
f"https://model-{model_id}.api.baseten.co/production/predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={
"messages": [{"role": "user", "content": "How awesome is TensorRT-LLM?"}],
"max_tokens": 1024
},
stream=True
)
# Print the generated tokens as they get streamed
for content in resp.iter_content():
print(content.decode("utf-8"), end="", flush=True)
```
The input text prompt to guide the language model's generation.
One of `prompt` XOR `messages` is required.
A list of dictionaries representing the message history, typically used in conversational contexts.
One of `prompt` XOR `messages` is required.
The maximum number of tokens to generate in the output. Controls the length of the generated text.
The number of beams used in beam search. Maximum of `1`.
A penalty applied to repeated tokens to discourage the model from repeating the same words or phrases.
A penalty applied to tokens already present in the prompt to encourage the generation of new topics.
Controls the randomness of the output. Lower values make the output more deterministic, while higher values increase randomness.
A penalty applied to the length of the generated sequence to control verbosity. Higher values make the model favor shorter outputs.
The token ID that indicates the end of the generated sequence.
The token ID used for padding sequences to a uniform length.
Limits the sampling pool to the top `k` tokens, ensuring the model only considers the most likely tokens at each step.
Applies nucleus sampling to limit the sampling pool to a cumulative probability `p`, ensuring only the most likely tokens are considered.
# Qwen with TensorRT-LLM
Build an optimized inference engine for Qwen
This configuration builds an inference engine to serve Qwen 2.5 3B on an A10G GPU. It is very similar to the configuration for any other Qwen model, including fine-tuned variants.
Recommended basic GPU configurations for Qwen 2.5 sizes:
| Size and variant | FP16 unquantized | FP8 quantized |
| ------------------------- | ---------------- | ------------- |
| 3B (Instruct) | `A10G` | N/A |
| 7B (Instruct, Math, Code) | `H100_40GB` | N/A |
| 14B (Instruct) | `H100` | `H100_40GB` |
| 32B (Instruct) | `H100:2` | `H100` |
| 72B (Instruct, Math) | `H100:4` | `H100:2` |
If you use multiple GPUs, make sure to match `num_builder_gpus` and `tensor_parallel_count` in the config. When quantizing, you may need to double the number of builder GPUs.
## Setup
See the [end-to-end engine builder tutorial](/performance/engine-builder-tutorial) prerequisites for full setup instructions.
Please upgrade to the latest version of Truss with `pip install --upgrade truss` before following this example.
```sh
pip install --upgrade truss
mkdir qwen-engine
touch qwen-engine/config.yaml
```
## Configuration
This configuration file specifies model information and Engine Builder arguments. For a different Qwen model, change the `model_name`, `accelerator`, and `repo` fields, along with any changes to the `build` arguments.
```yaml config.yaml
model_name: Qwen 2.5 3B Instruct
resources:
accelerator: A10G
use_gpu: true
trt_llm:
build:
base_model: qwen
checkpoint_repository:
repo: Qwen/Qwen2.5-3B-Instruct
source: HF
max_seq_len: 8192
num_builder_gpus: 1
quantization_type: no_quant
tensor_parallel_count: 1
```
## Deployment
```sh
truss push --publish
```
## Usage
```python call_model.py
import requests
import os
# Model ID for production deployment
model_id = ""
# Read secrets from environment variables
baseten_api_key = os.environ["BASETEN_API_KEY"]
# Call model endpoint
resp = requests.post(
f"https://model-{model_id}.api.baseten.co/production/predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={
"messages": [
{"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."},
{"role": "user", "content": "What does Tongyi Qianwen mean?"}
],
"max_tokens": 512
},
stream=True
)
# Print the generated tokens as they get streamed
for content in resp.iter_content():
print(content.decode("utf-8"), end="", flush=True)
```
The input text prompt to guide the language model's generation.
One of `prompt` XOR `messages` is required.
A list of dictionaries representing the message history, typically used in conversational contexts.
One of `prompt` XOR `messages` is required.
The maximum number of tokens to generate in the output. Controls the length of the generated text.
The number of beams used in beam search. Maximum of `1`.
A penalty applied to repeated tokens to discourage the model from repeating the same words or phrases.
A penalty applied to tokens already present in the prompt to encourage the generation of new topics.
Controls the randomness of the output. Lower values make the output more deterministic, while higher values increase randomness.
A penalty applied to the length of the generated sequence to control verbosity. Higher values make the model favor shorter outputs.
The token ID that indicates the end of the generated sequence.
The token ID used for padding sequences to a uniform length.
Limits the sampling pool to the top `k` tokens, ensuring the model only considers the most likely tokens at each step.
Applies nucleus sampling to limit the sampling pool to a cumulative probability `p`, ensuring only the most likely tokens are considered.
# Instance type reference
Specs and recommendations for every instance type on Baseten
Choosing [the right resources for your model inference workload](/deploy/resources) requires carefully balancing performance and cost. This page lists every instance type currently available on Baseten to help you pick the best fit for serving your model.
## CPU-only instance reference
Instances with no GPU start at \$0.00058 per minute.
**Available instance types**
| Instance | Cost/minute | vCPU | RAM |
| -------- | ----------- | ---- | ------ |
| 1×2 | \$0.00058 | 1 | 2 GiB |
| 1×4 | \$0.00086 | 1 | 4 GiB |
| 2×8 | \$0.00173 | 2 | 8 GiB |
| 4×16 | \$0.00346 | 4 | 16 GiB |
| 8×32 | \$0.00691 | 8 | 32 GiB |
| 16×64 | \$0.01382 | 16 | 64 GiB |
**What can it run?**
CPU-only instances are a cost-effective way to run inference on a variety of models like:
* Many `transformers` pipeline models, such as the [Text classification pipeline from the Truss quickstart](/quickstart), run well on the smallest instance, the 1x2.
* Smaller extractive question answering models like [LayoutLM Document QA](https://www.baseten.co/library/layoutlm-document-qa/) run well on the midsize 4x16 instance.
* Many text embeddings models, like [this sentence transformers model](https://www.baseten.co/library/all-minilm-l6-v2/) don't need a GPU to run. Pick an 4x16 or larger instance for best performance, especially when creating embeddings for a larger corpus of text.
## GPU instance reference
**Available instance types**
| Instance | Cost/minute | vCPU | RAM | GPU | VRAM |
| ----------------- | ----------- | ---- | -------- | ---------------------- | ------- |
| T4x4x16 | \$0.01052 | 4 | 16 GiB | NVIDIA T4 | 16 GiB |
| T4x8x32 | \$0.01504 | 8 | 32 GiB | NVIDIA T4 | 16 GiB |
| T4x16x64 | \$0.02408 | 16 | 64 GiB | NVIDIA T4 | 16 GiB |
| L4x4x16 | \$0.01414 | 4 | 16 GiB | NVIDIA L4 | 24 GiB |
| L4:2x4x16 | \$0.04002 | 24 | 96 GiB | 2 NVIDIA L4s | 48 GiB |
| L4:4x48x192 | \$0.08003 | 48 | 192 GiB | 4 NVIDIA L4s | 96 GiB |
| A10Gx4x16 | \$0.02012 | 4 | 16 GiB | NVIDIA A10G | 24 GiB |
| A10Gx8x32 | \$0.02424 | 8 | 32 GiB | NVIDIA A10G | 24 GiB |
| A10Gx16x64 | \$0.03248 | 16 | 64 GiB | NVIDIA A10G | 24 GiB |
| A10G:2x24x96 | \$0.05672 | 24 | 96 GiB | 2 NVIDIA A10Gs | 48 GiB |
| A10G:4x48x192 | \$0.11344 | 48 | 192 GiB | 4 NVIDIA A10Gs | 96 GiB |
| A10G:8x192x768 | \$0.32576 | 192 | 768 GiB | 8 NVIDIA A10Gs | 188 GiB |
| V100x8x61 | \$0.06120 | 16 | 61 GiB | NVIDIA V100 | 16 GiB |
| A100x12x144 | \$0.10240 | 12 | 144 GiB | 1 NVIDIA A100 | 80 GiB |
| A100:2x24x288 | \$0.20480 | 24 | 288 GiB | 2 NVIDIA A100s | 160 GiB |
| A100:3x36x432 | \$0.30720 | 36 | 432 GiB | 3 NVIDIA A100s | 240 GiB |
| A100:4x48x576 | \$0.40960 | 48 | 576 GiB | 4 NVIDIA A100s | 320 GiB |
| A100:5x60x720 | \$0.51200 | 60 | 720 GiB | 5 NVIDIA A100s | 400 GiB |
| A100:6x72x864 | \$0.61440 | 72 | 864 GiB | 6 NVIDIA A100s | 480 GiB |
| A100:7x84x1008 | \$0.71680 | 84 | 1008 GiB | 7 NVIDIA A100s | 560 GiB |
| A100:8x96x1152 | \$0.81920 | 96 | 1152 GiB | 8 NVIDIA A100s | 640 GiB |
| H100x26x234 | \$0.16640 | 26 | 234 GiB | 1 NVIDIA H100 | 80 GiB |
| H100:2x52x468 | \$0.33280 | 52 | 468 GiB | 2 NVIDIA H100s | 160 GiB |
| H100:4x104x936 | \$0.66560 | 104 | 936 GiB | 4 NVIDIA H100s | 320 GiB |
| H100:8x208x1872 | \$1.33120 | 208 | 1872 GiB | 8 NVIDIA H100s | 640 GiB |
| H100MIG:3gx13x117 | \$0.08250 | 13 | 117 GiB | Fractional NVIDIA H100 | 40 GiB |
### NVIDIA T4
Instances with an NVIDIA T4 GPU start at \$0.01052 per minute.
**GPU specs**
The T4 is an [Turing-series GPU](https://en.wikipedia.org/wiki/Turing_\(microarchitecture\)) with:
* 2,560 CUDA cores
* 320 Tensor cores
* 16 GiB VRAM
**Available instance types**
| Instance | Cost/minute | vCPU | RAM | GPU | VRAM |
| -------- | ----------- | ---- | ------ | --------- | ------ |
| T4x4x16 | \$0.01052 | 4 | 16 GiB | NVIDIA T4 | 16 GiB |
| T4x8x32 | \$0.01504 | 8 | 32 GiB | NVIDIA T4 | 16 GiB |
| T4x16x64 | \$0.02408 | 16 | 64 GiB | NVIDIA T4 | 16 GiB |
**What can it run?**
T4-equipped instances can run inference for models like:
* [Whisper](https://www.baseten.co/library/whisper-v3), transcribing 5 minutes of audio in 31.4 seconds with Whisper small.
* While the T4's 16 GiB of VRAM is insufficient for 7 billion parameter LLMs, it can run smaller 3B parameter models like [StableLM](https://github.com/basetenlabs/stablelm-truss).
### NVIDIA L4
The L4 is an [Ada Lovelace GPU](https://en.wikipedia.org/wiki/Ada_Lovelace_\(microarchitecture\)) with:
* 7,680 CUDA cores
* 240 Tensor cores
* 24 GiB VRAM
* 300 GiB/s Memory bandwidth
This enables the card to reach 121 teraFLOPS in fp16 operations, the most common quantization for large language models.
**Available instance types**
| Instance | Cost/minute | vCPU | RAM | GPU | VRAM |
| ----------- | ----------- | ---- | ------- | ------------ | ------ |
| L4x4x16 | \$0.01414 | 4 | 16 GiB | NVIDIA L4 | 24 GiB |
| L4:2x4x16 | \$0.04002 | 24 | 96 GiB | 2 NVIDIA L4s | 48 GiB |
| L4:4x48x192 | \$0.08003 | 48 | 192 GiB | 4 NVIDIA L4s | 96 GiB |
**What can it run?**
The L4 is a great choice for running inference on models like Stable Diffusion XL but not LLMs due to limited memory bandwidth.
### NVIDIA A10G
Instances with the NVIDIA A10G GPU start at \$0.02012 per minute.
**GPU specs**
The A10G is an [Ampere-series GPU](https://en.wikipedia.org/wiki/Ampere_\(microarchitecture\)) with:
* 9,216 CUDA cores
* 288 Tensor cores
* 24 GiB VRAM
* 600 GiB/s Memory bandwidth
This enables the card to reach 70 teraFLOPS in fp16 operations, the most common quantization for large language models.
**Available instance types**
| Instance | Cost/minute | vCPU | RAM | GPU | VRAM |
| -------------- | ----------- | ---- | ------- | -------------- | ------- |
| A10Gx4x16 | \$0.02012 | 4 | 16 GiB | NVIDIA A10G | 24 GiB |
| A10Gx8x32 | \$0.02424 | 8 | 32 GiB | NVIDIA A10G | 24 GiB |
| A10Gx16x64 | \$0.03248 | 16 | 64 GiB | NVIDIA A10G | 24 GiB |
| A10G:2x24x96 | \$0.05672 | 24 | 96 GiB | 2 NVIDIA A10Gs | 48 GiB |
| A10G:4x48x192 | \$0.11344 | 48 | 192 GiB | 4 NVIDIA A10Gs | 96 GiB |
| A10G:8x192x768 | \$0.32576 | 192 | 768 GiB | 8 NVIDIA A10Gs | 188 GiB |
**What can it run?**
Single A10Gs are great for running 7 billion parameter LLMs, and multi-A10 instances can work together to run larger models.
A10G-equipped instances can run inference for models like:
* Most 7-billion-parameter LLMs, such as [Mistral 7B](https://www.baseten.co/library/mistral-7b-instruct), at float16 precision.
* [Stable Diffusion](https://www.baseten.co/library/stable-diffusion) in 1.77 seconds for 50 steps and [Stable Diffusion XL](https://www.baseten.co/library/stable-diffusion-xl) in 6 seconds for 20 steps.
* [Whisper](https://www.baseten.co/library/whisper-v3), transcribing 5 minutes of audio in 23.9 seconds with Whisper small.
### NVIDIA V100
Instances with the NVIDIA V100 GPU start at \$0.06120 per minute.
**GPU specs**
The V100 is an [Volta-series GPU](https://en.wikipedia.org/wiki/Volta_\(microarchitecture\)) with 16GiB of VRAM.
**Available instance types**
| Instance | Cost/minute | vCPU | RAM | GPU | VRAM |
| --------- | ----------- | ---- | ------ | ----------- | ------ |
| V100x8x61 | \$0.06120 | 16 | 61 GiB | NVIDIA V100 | 16 GiB |
### NVIDIA A100
Instances with the NVIDIA A100 GPU start at \$0.10240 per minute.
**GPU specs**
The A100 is an [Ampere-series GPU](https://en.wikipedia.org/wiki/Ampere_\(microarchitecture\)) with:
* 6,912 CUDA cores
* 432 Tensor cores
* 80 GiB VRAM
* 1,935 GiB/s Memory bandwidth
This enables the card to reach 312 teraFLOPS in fp16 operations, the most common quantization for large language models.
**Available instance types**
| Instance | Cost/minute | vCPU | RAM | GPU | VRAM |
| -------------- | ----------- | ---- | -------- | -------------- | ------- |
| A100x12x144 | \$0.10240 | 12 | 144 GiB | 1 NVIDIA A100 | 80 GiB |
| A100:2x24x288 | \$0.20480 | 24 | 288 GiB | 2 NVIDIA A100s | 160 GiB |
| A100:3x36x432 | \$0.30720 | 36 | 432 GiB | 3 NVIDIA A100s | 240 GiB |
| A100:4x48x576 | \$0.40960 | 48 | 576 GiB | 4 NVIDIA A100s | 320 GiB |
| A100:5x60x720 | \$0.51200 | 60 | 720 GiB | 5 NVIDIA A100s | 400 GiB |
| A100:6x72x864 | \$0.61440 | 72 | 864 GiB | 6 NVIDIA A100s | 480 GiB |
| A100:7x84x1008 | \$0.71680 | 84 | 1008 GiB | 7 NVIDIA A100s | 560 GiB |
| A100:8x96x1152 | \$0.81920 | 96 | 1152 GiB | 8 NVIDIA A100s | 640 GiB |
**What can it run?**
A100s are the second-largest and most powerful GPUs currently available on Baseten. They're great for large language models, high-performance image generation, and other demanding tasks.
A100-equipped instances can run inference for models like:
* [Mixtral 8x7B](https://www.baseten.co/library/mixtral-8x7b/) on a single A100 in int8 precision.
* [Stable Diffusion](https://www.baseten.co/library/stable-diffusion) in 0.89 seconds for 50 steps and [Stable Diffusion XL](https://www.baseten.co/library/stable-diffusion-xl) in 1.92 seconds for 20 steps (with `torch.compile` and max autotune).
* Most 70-billion-parameter LLMs, such as [Llama-2-chat 70B](https://github.com/basetenlabs/truss-examples/tree/main/model_library/llama-2-70b-chat), in fp16 precision, on 2 A100s.
* The 180-billion-parameter LLM [Falcon 180B](https://huggingface.co/tiiuae/falcon-180B), in fp16 precision, on 5 A100s.
### NVIDIA H100
Instances with the NVIDIA H100 GPU start at \$0.1664 per minute.
**GPU specs**
The H100 is an [Hopper-series GPU](https://en.wikipedia.org/wiki/Hopper_\(microarchitecture\)) with:
* 16,896 CUDA cores
* 640 Tensor cores
* 80 GiB VRAM
* 3.35 TB/s Memory bandwidth
This enables the card to reach 990 teraFLOPS in fp16 operations, the most common quantization for large language models.
**Available instance types**
| Instance | Cost/minute | vCPU | RAM | GPU | VRAM |
| --------------- | ----------- | ---- | -------- | -------------- | ------- |
| H100x26x234 | \$0.16640 | 26 | 234 GiB | 1 NVIDIA H100 | 80 GiB |
| H100:2x52x468 | \$0.33280 | 52 | 468 GiB | 2 NVIDIA H100s | 160 GiB |
| H100:4x104x936 | \$0.66560 | 104 | 936 GiB | 4 NVIDIA H100s | 320 GiB |
| H100:8x208x1872 | \$1.33120 | 208 | 1872 GiB | 8 NVIDIA H100s | 640 GiB |
**What can it run?**
H100s are the most powerful GPUs currently available on Baseten. They're great for large language models, high-performance image generation, and other demanding tasks.
H100-equipped instances can run inference for models like:
* [Mixtral 8x7B](https://www.baseten.co/library/mixtral-8x7b/) on a single H100 in fp16 precision.
* 20 steps of [Stable Diffusion XL](https://www.baseten.co/library/stable-diffusion-xl) in 1.31 seconds.
* Most 70-billion-parameter LLMs, such as [Llama-2-chat 70B](https://github.com/basetenlabs/truss-examples/tree/main/model_library/llama-2-70b-chat), in fp16 precision, on 2 H100s.
### NVIDIA H100mig
Instances with the NVIDIA H100mig GPU start at \$0.08250 per minute.
**GPU specs**
The H100mig family of instances runs on a fractional share of an [H100 GPU](/performance/instances#nvidia-h100) using Nvidia's [Multi-Instance GPU](https://www.nvidia.com/en-us/technologies/multi-instance-gpu/) (MIG) virtualization technology. Currently we support a single instance type `H100MIG:3gx13x117` with access to 1/2 the memory and 3/7 the compute of a full H100. This results in:
* 7,242 CUDA cores
* 40 GiB VRAM
* 1.675 TB/s Memory bandwidth
**Available instance types**
| Instance | Cost/minute | vCPU | RAM | GPU | VRAM |
| ----------------- | ----------- | ---- | ------- | ---------------------- | ------ |
| H100MIG:3gx13x117 | \$0.08250 | 13 | 117 GiB | Fractional NVIDIA H100 | 40 GiB |
**What can it run?**
H100mig provides access to the same state-of-the-art AI inference architecture as the H100 in a smaller package. Based on our benchmarks, it can achieve higher throughput than an single A100 GPUs and has a lower cost per minute.
# Model performance overview
Improve your latency and throughput
Model performance means optimizing every layer of your model serving infrastructure to balance four goals:
1. **Latency**: on a per-request basis, how quickly does each user get output from the model?
2. **Throughput**: how many requests or users can the deployment handle at once?
3. **Cost**: how much does a standardized unit of work (e.g. 1M tokens from an LLM) cost?
4. **Quality**: does your model consistently deliver high-quality output after optimization?
## Model performance tooling
Baseten's TensorRT-LLM engine builder simplifies and automates the process of using TensorRT-LLM for development and production.
## Full-stack model performance
### Model and GPU selection
Two of the highest-impact choices for model perofrmance come before the optimization process: picking the best model size and implementation and picking the right GPU to run it on.
*Tradeoff: Latency/Throughput/Cost vs Quality*
The biggest factor in your latency, throughput, cost, and quality is what model you use. Before you jump into optimizing a foundation model, consider:
* Can you use a smaller size, like Llama 8B instead of 70B? Can you fine-tune the smaller model for your use case?
* Can you use a different model, like [SDXL Lightning](https://www.baseten.co/library/sdxl-lightning/) instead of SDXL?
* Can you use a different implementation, like [Faster Whisper](https://github.com/basetenlabs/truss-examples/tree/main/whisper/faster-whisper-v3) instead of Whisper?
Usually, model selection is bound by quality. For example SDXL Lightning makes images incredibly quickly, but they may not be detailed enough for your use case.
Experiment with alternative models to see if they can reset your performance expectations while meeting your quality bar.
*Tradeoff: Latency/Throughput vs Cost*
The minimum requirement for a GPU instance is that it must have enough VRAM to load model weights with headroom left for inference.
It often makes sense to use a more powerful (but more expensive) GPU than the minimum requirement, especially if you have ambitious latency goals and/or high utilization.
For example, you might choose:
* (Multiple) H100 GPUs for [deployments optimized with TensorRT/TensorRT-LLM](https://www.baseten.co/blog/unlocking-the-full-power-of-nvidia-h100-gpus-for-ml-inference-with-tensorrt/)
* H100 MIGs for [high-throughput deployments of smaller models like Llama 3 8B and SDXL](https://www.baseten.co/blog/using-fractional-h100-gpus-for-efficient-model-serving/)
* L4 GPUs for autoscaling Whisper deployments
The [GPU instance reference](/performance/instances) lists all available options.
### GPU-level optimizations
Our first goal is to get the best possible performance out of a single GPU or GPU cluster.
*Benefit: Latency/Throughput/Cost*
You can just use `transformers` and `pytorch` out of the box to serve your model. But best-in-class performance comes from using a dedicated inference engine, like:
1. [TensorRT](https://developer.nvidia.com/tensorrt)/[TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), maintained by NVIDIA
2. [vLLM](https://github.com/vllm-project/vllm), an independent open source project
3. [TGI](https://github.com/huggingface/text-generation-inference), maintained by Hugging Face
We [often recommend TensorRT/TensorRT-LLM](https://www.baseten.co/blog/high-performance-ml-inference-with-nvidia-tensorrt/) for best performance. The easiest way to get started with TensorRT-LLM is our [TRT-LLM engine builder](/performance/engine-builder-overview).
*Benefit: Latency/Throughput*
In addition to an optimized inference engine, you need an inference server to handle requests and supply features like in-flight batching.
Baseten runs a modified version of Triton for compatible model deployments. Other models use `TrussServer`, a capable general-purpose model inference server built into Truss.
*Tradeoff: Latency/Throughput/Cost vs Quality*
By default, model inference happens in `fp16`, meaning that model weights and other values are represented as 16-bit floating-point numbers.
Through a process called [post-training quantization](https://www.baseten.co/blog/fp8-efficient-model-inference-with-8-bit-floating-point-numbers/), you can instead run inference in a different format, like `fp8`, `int8`, or `int4`. This has massive benefits: more teraFLOPS at lower precision means lower latency, smaller numbers being retrieved from VRAM means higher throughput, and smaller model weights means saving on cost and potentially using fewer GPUs.
However, quantization can affect output quality. Thoroughly review quantized model outputs by hand and with standard checks like perplexity to ensure that the output of the quantized model matches the original.
We've had a lot of success with [fp8 for faster inference without quality loss](https://www.baseten.co/blog/33-faster-llm-inference-with-fp8-quantization/) and encourage experimenting with quantization, especially when using the TRT-LLM engine builder.
*Tradeoff: Latency/Throughput/Cost vs Quality*
There are a number of exciting cutting-edge techniques for model inference that can massively improve latency and/or throughput for a model. For example, LLMs can use Speculative Decoding or Medusa to generate multiple tokens per forward pass, improving TPS.
When using a new technique to improve model performance, always run real-world benchmarks and carefully validate output quality to ensure the performance improvements aren't undermining the model's usefulness.
*Tradeoff: Latency vs Throughput/Cost*
Batch size is how many requests are processed concurrently on the GPU. It is a direct tradeoff between latency and throughput:
* Increase batch size to improve throughput and cost
* Reduce batch size to improve latency
### Infrastructure-level optimizations
Once we squeeze as much TPS as possible out of the GPU, we scale that out horizontally with infrastructure optimization.
*Tradeoff: Latency/Throughput vs Cost*
If traffic to a deployment is high enough, even an optimized model server won't be able to keep up. By creating replicas, you keep latency consistent for all users.
Learn more about [autoscaling model replicas](/deploy/autoscaling).
*Benefit: Latency*
A "cold start" is the time it takes to spin up a new instance of a model server. Fast cold starts are essential for useful autoscaling, especially scale to zero.
Read our [guide to improving cold start times](/performance/cold-starts) for options like caching model weights.
*Tradeoff: Latency vs Throughput/Cost*
Replica-level concurrency sets the number of requests that can be sent to the model server at one time. This is different from the on-GPU concurrency as your model server may perform pre- and post-processing tasks on CPU.
Replica-level concurrency should always be greater than or equal to on-device concurrency (batch size).
*Tradeoff: Latency vs Cost*
If your GPU is in us-east-1 and your customer is in Australia, it doesn't matter how much you've optimized TTFT — your real-world latency will be terrible.
Region-specific deployments are available on a per-customer basis. Contact us at [support@baseten.co](mailto:support@baseten.co) to discuss your needs.
### Application-level optimizations
There are also application-level steps that you can take to make sure you're getting the most value from your optimized endpoint.
*Benefits: Latency, Quality*
Every token an LLM doesn't have to process or generate is a token that you don't have to wait for or pay for.
Prompt engineering can be as simple as saying "be concise" or as complex as making sure your RAG system returns the minimum number of highly-relevant retrievals.
*Benefits: Latency, Throughput*
When using TensorRT-LLM, make sure that your input and output sequences are a consistent length. The inference engine is built for a specific number of tokens, and going outside of those sequence shapes will hurt performance.
*Benefits: Latency, Cost*
The only thing running on your GPU should be the AI model. Other tasks like retrievals, secondary models, and business logic should be deployed and scaled separately to avoid bottlenecks.
Use [Chains](/chains/overview) for performant multi-step and multi-model inference.
*Benefit: Latency*
Use sessions rather than individual requests to avoid unnecessary network latency. See [inference documentation](/invoke/quickstart) for details.
# Deploy your first model
From model weights to API endpoint
In this guide, you will package and deploy [Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct), a 3.8-billion-parameter large language model.
We'll cover:
1. Loading model weights from Hugging Face
2. Running model inference on a GPU
3. Configuring your infrastructure and Python environment
4. Iterating on your model server in a live reload development environment
5. Deploying your finished model serving instance for production use
By the end of this tutorial, you will have built a production-ready API endpoint for an open source LLM on autoscaling infrastructure.
This tutorial is a comprehensive introduction to deploying models from scratch. If you want to quickly deploy an off-the-shelf model, start with our [model library](https://www.baseten.co/library) and [Truss examples](https://github.com/basetenlabs/truss-examples).
## Setup
Before we dive into the code:
* [Sign up](https://app.baseten.co/signup) for or [sign in](https://app.baseten.co/login) to your Baseten account.
* Generate an [API key](https://app.baseten.co/settings/account/api_keys) and store it securely.
* Install [Truss](https://pypi.org/project/truss/), our open-source model packaging framework.
```sh
pip install --upgrade truss
```
New Baseten accounts come with free credits to experiment with model inference. Completing this tutorial should consume less than a dollar of GPU resources.
### What is Truss?
Truss is a framework for writing model serving code in Python and configuring the model's production environment without touching Docker. It also includes a CLI to power a robust developer experience that will be introduced shortly.
A Truss contains:
* A file `model.py` where the `Model` class is implemented as a serving interface for an AI model.
* A file `config.yaml` that specifies GPU resources, Python environment, metadata, and more.
* Optional folders for bundling model weights (`data/`) and custom dependencies (`packages/`).
Truss is designed to map directly from model development code to production-ready model serving code:
## Create a Truss
To get started, create a Truss with the following terminal command:
```sh
truss init phi-3-mini
```
When prompted, give your Truss a name like `Phi 3 Mini`.
Then, navigate to the newly created directory:
```sh
cd phi-3-mini
```
You should see the following file structure:
```
phi-3-mini/
data/
model/
__init__.py
model.py
packages/
config.yaml
```
For this tutorial, we will be editing `model/model.py` and `config.yaml`.
### Load model weights
[Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct) is an open source LLM available for download on Hugging Face. We'll access its model weights via the `transformers` library.
Two functions in the `Model` object, `__init__()` and `load()`, run exactly once when the model server is spun up or patched. Using these functions, we load model weights and anything else the model server needs for inference.
For Phi 3, we need to load the LLM and its tokenizer. After initializing the necessary instance attributes, we load the weights and tokenzier from Hugging Face:
```python model/model.py
# We'll bundle these packages with our Truss in a future step
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer
)
class Model:
def __init__(self, **kwargs):
self._model = None
self._tokenizer = None
def load(self):
self._model = AutoModelForCausalLM.from_pretrained(
"microsoft/Phi-3-mini-4k-instruct", # Loads model from Hugging Face
device_map="cuda",
torch_dtype="auto"
)
self._tokenizer = AutoTokenizer.from_pretrained(
"microsoft/Phi-3-mini-4k-instruct"
)
```
### Run model inference
The final required function in the `Model` class, `predict()`, runs each time the model endpoint is requested. The `predict()` function handles model inference.
The implementation for `predict()` determines what features your model endpoint supports. You can implement anything from streaming to support for specific input and output specs:
```python model/model.py
class Model:
...
def predict(self, request):
messages = request.pop("messages")
model_inputs = self._tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = self._tokenizer(model_inputs, return_tensors="pt")
input_ids = inputs["input_ids"].to("cuda")
with torch.no_grad():
outputs = self._model.generate(input_ids=input_ids, max_length=256)
output_text = self._tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"output": output_text}
```
### Set Python environment
Now that the model server is implemented, we need to give it an environment to run in. In `model/model.py`, we imported a couple of objects from `transformers`:
```python model/model.py
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer
)
```
To add `transformers`, `torch`, and other required packages to our Python environment, we move to `config.yaml`, the other essential file in every Truss. Here, you can set your Python requirements:
```yaml config.yaml
requirements:
- accelerate==0.30.1
- einops==0.8.0
- transformers==4.41.2
- torch==2.3.0
```
We strongly recommend pinning versions for every Python requirement. The AI/ML ecosystem moves fast, and breaking changes to unpinned dependencies can cause errors in production.
### Select a GPU
Picking the right GPU is a balance between performance and cost. First, consider the size of the model weights. A good rule of thumb is that for `float16` LLM inference, you need 2GB of VRAM on your GPU for every billion parameters in the model, plus overhead for processing requests.
Phi 3 Mini has 3.8 billion parameters, meaning that it needs 7.6GB of VRAM just to load model weights. An NVIDIA T4 GPU, the smallest and least expensive GPU available on Baseten, has 16GB of VRAM, which will be more than enough to run the model.
To use a T4 in your Truss, update the `resources` in `config.yaml`:
```yaml config.yaml
resources:
accelerator: T4
use_gpu: true
```
Here's a [list of supported GPUs](/truss-reference/config#resources-accelerator).
## Create a development deployment
With the implementation finished, it's time to test the packaged model. With Baseten, you can spin up a development deployment, which replicates a production environment but with a live reload system that lets you patch your running model and test changes in seconds.
### Get your API key
Retreive your Baseten API key or, if necessary, [create one from your workspace.](https://app.baseten.co/settings/account/api_keys).
To use your API key for model inference, we recommend storing it as an enviornment variable:
```sh
export BASETEN_API_KEY=
```
Add this line to your `~/.zshrc` or similar shell config file.
The first time you run `truss push`, you'll be asked to paste in an API key.
### Run `truss push`
To create a development deployment for your model, run the following command in your `phi-3-mini` working directory:
```sh
truss push
```
You can monitor your model deployment from [your model dashboard on Baseten](https://app.baseten.co/models/).
### Call the development deployment
Your model deployment will go through three stages:
1. Building the model serving environment (creating a Docker container for model serving)
2. Deploying the model to the model serving environment (provisioning GPU resources and installing the image)
3. Loading the model onto the model server (running the `load()` function)
After deployment is complete, the model will show as "active" in your workspace. You can call the model with:
```python
import requests
import os
model_id = "" # Paste your model ID from your Baseten dashboard
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = requests.post(
f"https://model-{model_id}.api.baseten.co/development/predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={"messages": [{"role": "user", "content": "What even is AGI?"}]}
)
print(resp.json())
```
## Live reload development environment
Even with Baseten's optimized infrastructure, deploying a model from scratch takes time. If you had to wait for the image to build, GPU to be provisioned, and model environment to be loaded every time you make a change as you test your code, that would be a frustrating and slow developer experience.
Instead, the development environment has live reload. This way, when you make changes to your model, you skip the first two steps of deployment and only need to wait for `load()` to run, cutting your dev loop from minutes to seconds.
To activate live reload, in your working directory, run:
```
truss watch
```
Now, when you make changes to your `model/model.py` or certain parts of your `config.yaml` (such as Python requirements), your changes will be patched onto your running model server.
### Implementation: generation configs
Let's implement a few more features into our model object to experience the live reload workflow.
Currently, we only support passing the messages to the model. But LLMs have a number of other parameters like `max_length` and `temperature` that matter during inference.
To set these appropriately, we'll use the `preprocess()` function in the `Model` object. Truss models have optional `preprocess()` and `postprocess()` functions, which run on the CPU on either side of `predict()`, which runs on the GPU.
Add the following function to your Truss:
```python model.py
class Model:
...
def preprocess(self, request):
terminators = [
self._tokenizer.eos_token_id,
self._tokenizer.convert_tokens_to_ids("<|eot_id|>"),
]
generate_args = {
"max_length": request.get("max_tokens", 512),
"temperature": request.get("temperature", 1.0),
"top_p": request.get("top_p", 0.95),
"top_k": request.get("top_k", 40),
"repetition_penalty": request.get("repetition_penalty", 1.0),
"no_repeat_ngram_size": request.get("no_repeat_ngram_size", 0),
"do_sample": request.get("do_sample", True),
"use_cache": True,
"eos_token_id": terminators,
"pad_token_id": self._tokenizer.pad_token_id,
}
request["generate_args"] = generate_args
return request
```
To use the generation args, we'll modify our `predict()` function as follows:
```diff model.py
class Model:
...
def predict(self, request):
messages = request.pop("messages")
+ generation_args = request.pop("generate_args")
model_inputs = self._tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = self._tokenizer(model_inputs, return_tensors="pt")
input_ids = inputs["input_ids"].to("cuda")
with torch.no_grad():
- outputs = self._model.generate(input_ids=input_ids, max_length=256)
+ outputs = self._model.generate(input_ids=input_ids, **generation_args)
return self._tokenizer.decode(outputs[0], skip_special_tokens=True)
```
Save your `model/model.py` file and check your `truss watch` logs to see the patch being applied. Once the model status on your model dashboard shows as "active", you can call the API endpoint again with new parameters:
```python
import requests
import os
model_id = "" # Paste your model ID from your Baseten dashboard
baseten_api_key = os.environ["BASETEN_API_KEY"]
resp = requests.post(
f"https://model-{model_id}.api.baseten.co/development/predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={
"messages": [{"role": "user", "content": "What even is AGI?"}],
"max_tokens": 512,
"temperature": 2.0
}
)
print(resp.json())
```
### Implementation: streaming output
Right now, the model works by returning the entire output at once. For many use cases, we'd rather stream model output, receiving the tokens as they are generated to reduce user-facing latency.
This requires updates to the imports at the top of `model/model.py`:
```diff model.py
+from threading import Thread
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
+ GenerationConfig,
+ TextIteratorStreamer,
)
```
We can implement streaming in `model/model.py`. We'll define a function to handle streaming:
```python model.py
class Model:
...
def stream(self, input_ids: list, generation_args: dict):
streamer = TextIteratorStreamer(self._tokenizer)
generation_config = GenerationConfig(**generation_args)
generation_kwargs = {
"input_ids": input_ids,
"generation_config": generation_config,
"return_dict_in_generate": True,
"output_scores": True,
"max_new_tokens": generation_args["max_length"],
"streamer": streamer,
}
with torch.no_grad():
# Begin generation in a separate thread
thread = Thread(target=self._model.generate, kwargs=generation_kwargs)
thread.start()
# Yield generated text as it becomes available
def inner():
for text in streamer:
yield text
thread.join()
return inner()
```
Then in `predict()`, we enable streaming:
```diff model.py
class Model:
...
def predict(self, request):
messages = request.pop("messages")
generation_args = request.pop("generate_args")
+ stream = request.pop("stream", True)
model_inputs = self._tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = self._tokenizer(model_inputs, return_tensors="pt")
input_ids = inputs["input_ids"].to("cuda")
+ if stream:
+ return self.stream(input_ids, generation_args)
with torch.no_grad():
outputs = self._model.generate(input_ids=input_ids, **generation_args)
return self._tokenizer.decode(outputs[0], skip_special_tokens=True)
```
To call the streaming endpoint, update your API call to process the streaming output:
```python
import requests
# Replace the empty string with your model id below
model_id = ""
baseten_api_key = os.environ["BASETEN_API_KEY"]
# Call model endpoint
resp = requests.post(
f"https://model-{model_id}.api.baseten.co/development/predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={
"messages": [{"role": "user", "content": "What even is AGI?"}],
"stream": True,
"max_tokens": 256
},
stream=True
)
# Print the generated tokens as they get streamed
for content in resp.iter_content():
print(content.decode("utf-8"), end="", flush=True)
```
## Promote to production
Now that we're happy with how our model is implemented, we can promote our deployment to production. Production deployments don't have live reload, but are suitable for real traffic as they have access to full autoscaling settings and can't be interrupted by patches or other deployment activities.
You can promote your deployment to production through the Baseten UI or by running:
```sh
truss push --publish
```
When a development deployment is promoted to production, it gets rebuilt and deployed.
## Call the production endpoint
When the deployment is running in production, the API endpoint for calling it changes from `/development/predict` to `/production/predict`. All other inference code remains unchanged:
```python
import requests
# Replace the empty string with your model id below
model_id = ""
baseten_api_key = os.environ["BASETEN_API_KEY"]
# Call model endpoint
resp = requests.post(
f"https://model-{model_id}.api.baseten.co/production/predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json={
"messages": [{"role": "user", "content": "What even is AGI?"}],
"stream": True,
"max_tokens": 256
},
stream=True
)
# Print the generated tokens as they get streamed
for content in resp.iter_content():
print(content.decode("utf-8"), end="", flush=True)
```
Both your development and production deployments will scale to zero when not in use.
## Learn more
You've completed the quickstart by packaging, deploying, and invoking an AI model with Truss!
From here, you may be interested in:
* Learning more about [model serving with Truss](/truss/overview).
* [Example implementations](https://github.com/basetenlabs/truss-examples) for dozens of open source models.
* [Inference examples](/invoke/quickstart) and [Baseten integrations](/invoke/integrations).
* Using [autoscaling settings](/deploy/autoscaling) to spin up and down multiple GPU replicas.
# truss
The simplest way to serve models in production
```Usage
truss [OPTIONS] COMMAND [ARGS]...
```
### Options
Show the version and exit.
Show help message and exit.
### Main usage
Authenticate with Baseten.
Create a new Truss.
Pushes a Truss to a TrussRemote.
Seamless remote development with Truss.
Invokes the packaged model.
### Advanced usage
Runs a Python script in the same environment as your Truss.
Subcommands for `truss image`.
Subcommands for `truss container`.
Clean up Truss data.
# truss cleanup
Clean up truss data.
```
truss cleanup [OPTIONS]
```
Truss creates temporary directories for various operations such as for building Docker images. This command clears that data to free up disk space.
### Options
Show help message and exit.
# truss container
Subcommands for truss container.
```
truss container [OPTIONS] COMMAND [ARGS]...
```
### Options
Show help message and exit.
## truss container kill
Kills containers related to Truss.
```
truss container kill [OPTIONS] [TARGET_DIRECTORY]
```
### Options
Show help message and exit.
### Arguments
A Truss directory. If none, use current directory.
## truss container kill-all
Kills all Truss containers that are not manually persisted.
```
truss container kill-all [OPTIONS]
```
### Options
Show help message and exit.
## truss container logs
Get logs in a container is running for a Truss.
```
truss container logs [OPTIONS] [TARGET_DIRECTORY]
```
### Options
Show help message and exit.
### Arguments
A Truss directory. If none, use current directory.
# truss image
Subcommands for truss image.
```
truss image [OPTIONS] COMMAND [ARGS]...
```
### Options
Show help message and exit.
## truss image build
Builds the docker image for a Truss.
```
truss image build [OPTIONS] [TARGET_DIRECTORY] [BUILD_DIR]
```
### Options
Docker image tag.
Show help message and exit.
### Arguments
A Truss directory. If none, use current directory.
Image context. If none, a temp directory is created.
## truss image build-context
Create a docker build context for a Truss.
```
truss image build-context [OPTIONS] BUILD_DIR [TARGET_DIRECTORY]
```
### Options
Show help message and exit.
### Arguments
Folder where image context is built for Truss.
A Truss directory. If none, use current directory.
## truss image run
Runs the docker image for a Truss.
```
truss image run [OPTIONS] [TARGET_DIRECTORY] [BUILD_DIR]
```
### Options
Docker build image tag.
Local port used to run image.
Flag for attaching the process.
Show help message and exit.
### Arguments
A Truss directory. If none, use current directory.
Image context. If none, a temp directory is created.
# truss init
Create a new Truss.
```
truss init [OPTIONS] TARGET_DIRECTORY
```
## Options
What type of server to create. Default: `TrussServer`.
The value assigned to `model_name` in `config.yaml`.
Show help message and exit.
## Arguments
A Truss is created in this directory.
## Example
```
truss init my_truss_directory
```
```
truss init --name "My Truss" my_truss_directory
```
# truss login
Authenticate with Baseten.
Authenticate with Baseten.
```
truss login [OPTIONS]
```
Authenticates with Baseten, storing the API key in the local configuration file.
If used with no options, runs in interactive mode. Otherwise, the API key can be passed as an option.
### Options
Baseten API Key. If this is passed, the command runs in non-interactive mode.
# truss predict
Invokes the packaged model.
```
truss predict [OPTIONS]
```
## Options
Name of the remote in .trussrc to patch changes to.
String formatted as json that represents request.
Path to json file containing the request.
ID of model version to invoke.
ID of model to invoke.
Show help message and exit.
## Arguments
A Truss directory. If none, use current directory.
## Examples
```
truss predict -d '{"prompt": "What is the meaning of life?"}'
```
```
truss predict --published -f my-prompt.json
```
# truss push
Pushes a truss to a TrussRemote.
```Usage
truss push [OPTIONS] [TARGET_DIRECTORY]
```
## Options
Name of the remote in .trussrc to patch changes to.
Push the truss as a published deployment. If no production deployment exists, promote the truss to production after deploy completes.
Push the truss as a published deployment. Even if a production deployment exists, promote the truss to production after deploy completes.
Push the truss as a published deployment. Promote the truss into the environment after deploy completes.
Preserve the previous production deployment's autoscaling setting. When not specified, the previous production deployment will be updated to allow it to scale to zero. Can only be use in combination with `--promote` option.
Give Truss access to secrets on remote host.
Name of the model
Name of the deployment created by the push. Can only be used in combination with `--publish` or `--environment`. Deployment name must only contain alphanumeric, '.', '-' or '\_' characters.
Whether to wait for deployment to complete before returning. If the deploy or build fails, will return with a non-zero exit code.
Maximum time to wait for deployment to complete in seconds. Without specifying, the command will not complete until the deployment is complete.
Show help message and exit.
## Arguments
A Truss directory. If none, use current directory.
## Examples
```
truss push
```
```
truss push --publish /path/to/my-truss
```
```
truss push --remote baseten --publish --trusted
```
```
truss push --remote baseten --publish --deployment-name my-truss_1.0
```
# truss run-python
Subcommands for truss run-python.
```
truss run-python [OPTIONS] SCRIPT [TARGET_DIRECTORY]
```
Runs selected script in the same environment as your Truss. It builds a Docker
image matching your Truss environment, mounts the script you supply, and then
runs the script.
### Options
Show help message and exit.
### Arguments
Path to Python script to run.
A Truss directory. If none, use current directory.
# truss watch
Seamless remote development with truss.
```
truss watch [OPTIONS] [TARGET_DIRECTORY]
```
### Options
Name of the remote in .trussrc to patch changes to.
Automatically open remote logs tab.
Show help message and exit.
### Arguments
A Truss directory. If none, use current directory.
### Examples
```
truss watch
```
```
truss watch /path/to/my-truss
```
```
truss watch --remote baseten
```
# Config options
Set your model resources, dependencies, and more
Truss is configurable to its core. Every Truss must include a file config.yaml in its root directory, which is automatically generated when the Truss is created. However, configuration is optional. Every configurable value has a sensible default, and a completely empty config file is valid.
YAML syntax can be a bit non-obvious when dealing with empty lists and dictionaries. You may notice the following in the default Truss config file:
```yaml
requirements: []
secrets: {}
```
When you fill them in with values, lists and dictionaries should look like this:
```yaml
requirements:
- dep1
- dep2
secrets:
key1: default_value1
key2: default_value2
```
## Example
Here's an example config file for a Truss that uses the [WizardLM](https://huggingface.co/WizardLM/WizardLM-7B-V1.0) model:
```yaml WizardLM config
description: An instruction-following LLM Using Evol-Instruct.
environment_variables: {}
model_name: WizardLM
requirements:
- accelerate==0.20.3
- bitsandbytes==0.39.1
- peft==0.3.0
- protobuf==4.23.3
- sentencepiece==0.1.99
- torch==2.0.1
- transformers==4.30.2
resources:
cpu: "3"
memory: 14Gi
use_gpu: true
accelerator: A10G
secrets: {}
system_packages: []
```
## Full config reference
# Truss reference
Details on Truss CLI and configuration options
[Truss](/deploy/) is an open source framework for writing model server code in Python. It also contains a CLI to power your developer workflows.
We recommend always using the latest version of Truss, which you can update or install with:
```sh
pip install --upgrade truss
```
## Truss configuration options
Every model packaged with Truss uses a `config.yaml` to specify hardware requirements, software dependencies, and other model serving setup.
The [Truss config reference](/truss-reference/config) details the dozens of supported configuration options.
## Truss CLI
The Truss CLI runs the core developer loop of deploying models to Baseten. See the Truss CLI reference for a complete list of commands and options or run `truss --help`.
# Truss Python SDK Reference
Python SDK Reference for Truss
### `truss.login`
Authenticates with Baseten.
**Parameters:**
| Name | Type | Description |
| --------- | ----- | ---------------- |
| `api_key` | *str* | Baseten API Key. |
### `truss.push`
Pushes a Truss to Baseten.
**Parameters:**
| Name | Type | Description |
| ----------------------------------------- | ---------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `target_directory` | *str* | Directory of Truss to push. |
| `remote` | *Optional\[str]* | The name of the remote in .trussrc to patch changes to. |
| `model_name` | *Optional\[str]* | The name of the model, if different from the one in the config.yaml. |
| `publish` | *bool* | Push the truss as a published deployment. If no production deployment exists, promote the truss to production after deploy completes. |
| `promote` | *bool* | Push the truss as a published deployment. Even if a production deployment exists, promote the truss to production after deploy completes. |
| `preserve_previous_production_deployment` | *bool* | Preserve the previous production deployment’s autoscaling setting. When not specified, the previous production deployment will be updated to allow it to scale to zero. Can only be use in combination with `promote` option. |
| `trusted` | *bool* | Give Truss access to secrets on remote host. |
| `deployment_name` | *Optional\[str]* | Name of the deployment created by the push. Can only be used in combination with `publish` or `promote`. Deployment name must only contain alphanumeric, ’.’, ’-’ or ’\_’ characters. |
* **Return type:**
[*ModelDeployment*](#class-truss-api-definitions-modeldeployment)
### *class* `truss.api.definitions.ModelDeployment`
Represents a deployed model. Not to be instantiated directly,
but returned by `truss.push`.
**Fields:**
* `model_id`: ID of the deployed model
* `model_deployment_id`: ID of the model deployment
#### wait\_for\_active()
Waits for the deployment to be in an active. Returns `True` when complete,
and raises if there is an error in deployment.
* **Return type:**
bool
# Getting Started
Building your first Truss
In this example, we go through building your first Truss model. We'll be using the HuggingFace transformers
library to build a text classification model that can detect sentiment of text.
# Step 1: Implementing the model
Set up imports for this model. In this example, we simply use the HuggingFace transformers library.
```python model/model.py
from transformers import pipeline
```
Every Truss model must implement a `Model` class. This class must have:
* an `__init__` function
* a `load` function
* a `predict` function
In the `__init__` function, set up any variables that will be used in the `load` and `predict` functions.
```python model/model.py
class Model:
def __init__(self, **kwargs):
self._model = None
```
In the `load` function of the Truss, we implement logic
involved in downloading the model and loading it into memory.
For this Truss example, we define a HuggingFace pipeline, and choose
the `text-classification` task, which uses BERT for text classification under the hood.
Note that the load function runs once when the model starts.
```python model/model.py
def load(self):
self._model = pipeline("text-classification")
```
In the `predict` function of the Truss, we implement logic related
to actual inference. For this example, we just call the HuggingFace pipeline
that we set up in the `load` function.
```python model/model.py
def predict(self, model_input):
return self._model(model_input)
```
# Step 2: Writing the config.yaml
Each Truss has a config.yaml file where we can configure
options related to the deployment. It's in this file where
we can define requirements, resources, and runtime options like
secrets and environment variables
### Basic Options
In this section, we can define basic metadata about the model,
such as the name, and the Python version to build with.
```yaml config.yaml
model_name: bert
python_version: py310
model_metadata:
example_model_input: { "text": "Hello my name is {MASK}" }
```
### Set up python requirements
In this section, we define any pip requirements that
we need to run the model. To run this, we need PyTorch
and Tranformers.
```yaml config.yaml
requirements:
- torch==2.0.1
- transformers==4.33.2
```
### Configure the resources needed
In this section, we can configure resources
needed to deploy this model. Here, we have no need for a GPU
so we leave the accelerator section blank.
```yaml config.yaml
resources:
accelerator: null
cpu: '1'
memory: 2Gi
use_gpu: false
```
### Other config options
Truss also has provisions for adding other runtime options
packages. In this example, we don't need these, so we leave
this empty for now.
```yaml config.yaml
secrets: {}
system_packages: []
environment_variables: {}
external_package_dirs: []
```
# Step 3: Deploying & running inference
Deploy the model with the following command:
```bash
$ truss push
```
And then you can performance inference with:
```
$ truss predict -d '"Truss is awesome!"'
```
```python model/model.py
from transformers import pipeline
class Model:
def __init__(self, **kwargs):
self._model = None
def load(self):
self._model = pipeline("text-classification")
def predict(self, model_input):
return self._model(model_input)
```
```yaml config.yaml
model_name: bert
python_version: py310
model_metadata:
example_model_input: { "text": "Hello my name is {MASK}" }
requirements:
- torch==2.0.1
- transformers==4.33.2
resources:
accelerator: null
cpu: '1'
memory: 2Gi
use_gpu: false
secrets: {}
system_packages: []
environment_variables: {}
external_package_dirs: []
```
# LLM
Building an LLM
In this example, we go through a Truss that serves an LLM. We
use the model Mistral-7B, which is a general-purpose LLM that
can used for a variety of tasks, like summarization, question-answering,
translation, and others.
# Set up the imports and key constants
In this example, we use the Huggingface transformers library to build a text generation model.
```python model/model.py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
```
We use the 7B version of the Mistral model.
```python model/model.py
CHECKPOINT = "mistralai/Mistral-7B-v0.1"
```
# Define the `Model` class and load function
In the `load` function of the Truss, we implement logic involved in
downloading and setting up the model. For this LLM, we use the `Auto`
classes in `transformers` to instantiate our Mistral model.
```python model/model.py
class Model:
def __init__(self, **kwargs) -> None:
self.tokenizer = None
self.model = None
def load(self):
self.model = AutoModelForCausalLM.from_pretrained(
CHECKPOINT, torch_dtype=torch.float16, device_map="auto"
)
self.tokenizer = AutoTokenizer.from_pretrained(
CHECKPOINT,
)
```
# Define the `predict` function
In the predict function, we implement the actual inference logic. The steps
here are:
* Set up the generation params. We have defaults for both of these, but
adjusting the values will have an impact on the model output
* Tokenize the input
* Generate the output
* Use tokenizer to decode the output
```python model/model.py
def predict(self, request: dict):
prompt = request.pop("prompt")
generate_args = {
"max_new_tokens": request.get("max_new_tokens", 128),
"temperature": request.get("temperature", 1.0),
"top_p": request.get("top_p", 0.95),
"top_k": request.get("top_p", 50),
"repetition_penalty": 1.0,
"no_repeat_ngram_size": 0,
"use_cache": True,
"do_sample": True,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
}
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.cuda()
with torch.no_grad():
output = self.model.generate(inputs=input_ids, **generate_args)
return self.tokenizer.decode(output[0])
```
# Setting up the config.yaml
Running Mistral 7B requires a few libraries, such as
`torch`, `transformers` and a couple others.
```yaml config.yaml
environment_variables: {}
external_package_dirs: []
model_metadata:
example_model_input: {"prompt": "What is the meaning of life?"}
model_name: Mistral 7B
python_version: py311
requirements:
- transformers==4.34.0
- sentencepiece==0.1.99
- accelerate==0.23.0
- torch==2.0.1
```
## Configure resources for Mistral
Note that we need an A10G to run this model.
```yaml config.yaml
resources:
accelerator: A10G
use_gpu: true
secrets: {}
system_packages: []
```
# Deploy the model
Deploy the model like you would other Trusses, with:
```bash
$ truss push
```
You can then invoke the model with:
```bash
$ truss predict -d '{"inputs": "What is a large language model?"}'
```
```python model/model.py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
CHECKPOINT = "mistralai/Mistral-7B-v0.1"
class Model:
def __init__(self, **kwargs) -> None:
self.tokenizer = None
self.model = None
def load(self):
self.model = AutoModelForCausalLM.from_pretrained(
CHECKPOINT, torch_dtype=torch.float16, device_map="auto"
)
self.tokenizer = AutoTokenizer.from_pretrained(
CHECKPOINT,
)
def predict(self, request: dict):
prompt = request.pop("prompt")
generate_args = {
"max_new_tokens": request.get("max_new_tokens", 128),
"temperature": request.get("temperature", 1.0),
"top_p": request.get("top_p", 0.95),
"top_k": request.get("top_p", 50),
"repetition_penalty": 1.0,
"no_repeat_ngram_size": 0,
"use_cache": True,
"do_sample": True,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
}
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.cuda()
with torch.no_grad():
output = self.model.generate(inputs=input_ids, **generate_args)
return self.tokenizer.decode(output[0])
```
```yaml config.yaml
environment_variables: {}
external_package_dirs: []
model_metadata:
example_model_input: {"prompt": "What is the meaning of life?"}
model_name: Mistral 7B
python_version: py311
requirements:
- transformers==4.34.0
- sentencepiece==0.1.99
- accelerate==0.23.0
- torch==2.0.1
resources:
accelerator: A10G
use_gpu: true
secrets: {}
system_packages: []
```
# LLM with Streaming
Building an LLM with streaming output
In this example, we go through a Truss that serves an LLM, and streams the output to the client.
# Why Streaming?
For certain ML models, generations can take a long time. Especially with LLMs, a long output could take
10 - 20 seconds to generate. However, because LLMs generate tokens in sequence, useful output can be
made available to users sooner. To support this, in Truss, we support streaming output. In this example,
we build a Truss that streams the output of the Falcon-7B model.
# Set up the imports and key constants
In this example, we use the HuggingFace transformers library to build a text generation model.
```python model/model.py
from threading import Thread
from typing import Dict
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
GenerationConfig,
TextIteratorStreamer,
)
```
We use the instruct version of the Falcon-7B model, and have some defaults
for inference parameters.
```python model/model.py
CHECKPOINT = "tiiuae/falcon-7b-instruct"
DEFAULT_MAX_NEW_TOKENS = 150
DEFAULT_TOP_P = 0.95
```
# Define the load function
In the `load` function of the Truss, we implement logic
involved in downloading the model and loading it into memory.
```python model/model.py
class Model:
def __init__(self, **kwargs) -> None:
self.tokenizer = None
self.model = None
def load(self):
self.tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
```
```python model/model.py
self.tokenizer.pad_token = self.tokenizer.eos_token_id
self.model = AutoModelForCausalLM.from_pretrained(
CHECKPOINT,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto",
)
```
# Define the predict function
In the `predict` function of the Truss, we implement the actual
inference logic. The two main steps are:
* Tokenize the input
* Call the model's `generate` function, ensuring that we pass a
`TextIteratorStreamer`. This is what gives us streaming output, and
and also do this in a Thread, so that it does not block the main
invocation.
* Return a generator that iterates over the `TextIteratorStreamer` object
```python model/model.py
def predict(self, request: Dict) -> Dict:
prompt = request.pop("prompt")
inputs = self.tokenizer(
prompt, return_tensors="pt", max_length=512, truncation=True, padding=True
)
input_ids = inputs["input_ids"].to("cuda")
```
Instantiate the Streamer object, which we'll later use for
returning the output to users.
```python model/model.py
streamer = TextIteratorStreamer(self.tokenizer)
generation_config = GenerationConfig(
temperature=1,
top_p=DEFAULT_TOP_P,
top_k=40,
)
```
When creating the generation parameters, ensure to pass the `streamer` object
that we created previously.
```python model/model.py
with torch.no_grad():
generation_kwargs = {
"input_ids": input_ids,
"generation_config": generation_config,
"return_dict_in_generate": True,
"output_scores": True,
"pad_token_id": self.tokenizer.eos_token_id,
"max_new_tokens": DEFAULT_MAX_NEW_TOKENS,
"streamer": streamer,
}
```
Spawn a thread to run the generation, so that it does not block the main
thread.
```python model/model.py
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
```
In Truss, the way to achieve streaming output is to return a generator
that yields content. In this example, we yield the output of the `streamer`,
which produces output and yields it until the generation is complete.
We define this `inner` function to create our generator.
```python model/model.py
def inner():
for text in streamer:
yield text
thread.join()
return inner()
```
# Setting up the config.yaml
Running Falcon 7B requires torch, transformers,
and a few other related libraries.
```yaml config.yaml
model_name: "LLM with Streaming"
model_metadata:
example_model_input: {"prompt": "what is the meaning of life"}
requirements:
- torch==2.0.1
- peft==0.4.0
- scipy==1.11.1
- sentencepiece==0.1.99
- accelerate==0.21.0
- bitsandbytes==0.41.1
- einops==0.6.1
- transformers==4.31.0
```
## Configure resources for Falcon
Note that we need an A10G to run this model.
```yaml config.yaml
resources:
cpu: "3"
memory: 14Gi
use_gpu: true
accelerator: A10G
```
```python model/model.py
from threading import Thread
from typing import Dict
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
GenerationConfig,
TextIteratorStreamer,
)
CHECKPOINT = "tiiuae/falcon-7b-instruct"
DEFAULT_MAX_NEW_TOKENS = 150
DEFAULT_TOP_P = 0.95
class Model:
def __init__(self, **kwargs) -> None:
self.tokenizer = None
self.model = None
def load(self):
self.tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
self.tokenizer.pad_token = self.tokenizer.eos_token_id
self.model = AutoModelForCausalLM.from_pretrained(
CHECKPOINT,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto",
)
def predict(self, request: Dict) -> Dict:
prompt = request.pop("prompt")
inputs = self.tokenizer(
prompt, return_tensors="pt", max_length=512, truncation=True, padding=True
)
input_ids = inputs["input_ids"].to("cuda")
streamer = TextIteratorStreamer(self.tokenizer)
generation_config = GenerationConfig(
temperature=1,
top_p=DEFAULT_TOP_P,
top_k=40,
)
with torch.no_grad():
generation_kwargs = {
"input_ids": input_ids,
"generation_config": generation_config,
"return_dict_in_generate": True,
"output_scores": True,
"pad_token_id": self.tokenizer.eos_token_id,
"max_new_tokens": DEFAULT_MAX_NEW_TOKENS,
"streamer": streamer,
}
thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
thread.start()
def inner():
for text in streamer:
yield text
thread.join()
return inner()
```
```yaml config.yaml
model_name: "LLM with Streaming"
model_metadata:
example_model_input: {"prompt": "what is the meaning of life"}
requirements:
- torch==2.0.1
- peft==0.4.0
- scipy==1.11.1
- sentencepiece==0.1.99
- accelerate==0.21.0
- bitsandbytes==0.41.1
- einops==0.6.1
- transformers==4.31.0
resources:
cpu: "3"
memory: 14Gi
use_gpu: true
accelerator: A10G
```
# Text-to-image
Building a text-to-image model with SDXL
In this example, we go through a Truss that serves a text-to-image model. We
use SDXL 1.0, which is one of the highest performing text-to-image models out
there today.
# Set up imports and torch settings
In this example, we use the Huggingface diffusers library to build our text-to-image model.
```python model/model.py
import base64
import time
from io import BytesIO
from typing import Any
import torch
from diffusers import AutoencoderKL, DiffusionPipeline, DPMSolverMultistepScheduler
from PIL import Image
```
The following line is needed to enable TF32 on NVIDIA GPUs
```python model/model.py
torch.backends.cuda.matmul.allow_tf32 = True
```
# Define the `Model` class and load function
In the `load` function of the Truss, we implement logic involved in
downloading and setting up the model. For this model, we use the
`DiffusionPipeline` class in `diffusers` to instantiate our SDXL pipeline,
and configure a number of relevant parameters.
See the [diffusers docs](https://huggingface.co/docs/diffusers/index) for details
on all of these parameters.
```python model/model.py
class Model:
def __init__(self, **kwargs):
self._model = None
def load(self):
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
)
self.pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
vae=vae,
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
)
self.pipe.unet.to(memory_format=torch.channels_last)
self.pipe.to("cuda")
self.pipe.enable_xformers_memory_efficient_attention()
self.refiner = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=self.pipe.text_encoder_2,
vae=self.pipe.vae,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
)
self.refiner.to("cuda")
self.refiner.enable_xformers_memory_efficient_attention()
```
This is a utility function for converting PIL image to base64.
```python model/model.py
def convert_to_b64(self, image: Image) -> str:
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_b64
```
# Define the predict function
The `predict` function contains the actual inference logic. The steps here are:
* Setting up the generation params. We have defaults for these, and some, such
as the `scheduler`, are somewhat complicated
* Running the Diffusion Pipeline
* If `use_refiner` is set to `True`, we run the refiner model on the output
* Convert the resulting image to base64 and return it
```python model/model.py
def predict(self, model_input: Any) -> Any:
prompt = model_input.pop("prompt")
negative_prompt = model_input.pop("negative_prompt", None)
use_refiner = model_input.pop("use_refiner", True)
num_inference_steps = model_input.pop("num_inference_steps", 30)
denoising_frac = model_input.pop("denoising_frac", 0.8)
end_cfg_frac = model_input.pop("end_cfg_frac", 0.4)
guidance_scale = model_input.pop("guidance_scale", 7.5)
seed = model_input.pop("seed", None)
scheduler = model_input.pop(
"scheduler", None
) # Default: EulerDiscreteScheduler (works pretty well)
```
Set the scheduler based on the user's input.
See possible schedulers: [https://huggingface.co/docs/diffusers/api/schedulers/overview](https://huggingface.co/docs/diffusers/api/schedulers/overview) for
what the tradeoffs are.
```python model/model.py
if scheduler == "DPM++ 2M":
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config
)
elif scheduler == "DPM++ 2M Karras":
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config, use_karras_sigmas=True
)
elif scheduler == "DPM++ 2M SDE Karras":
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config,
algorithm_type="sde-dpmsolver++",
use_karras_sigmas=True,
)
generator = None
if seed is not None:
torch.manual_seed(seed)
generator = [torch.Generator(device="cuda").manual_seed(seed)]
if not use_refiner:
denoising_frac = 1.0
start_time = time.time()
image = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
generator=generator,
end_cfg=end_cfg_frac,
num_inference_steps=num_inference_steps,
denoising_end=denoising_frac,
guidance_scale=guidance_scale,
output_type="latent" if use_refiner else "pil",
).images[0]
scheduler = self.pipe.scheduler
if use_refiner:
self.refiner.scheduler = scheduler
image = self.refiner(
prompt=prompt,
negative_prompt=negative_prompt,
generator=generator,
end_cfg=end_cfg_frac,
num_inference_steps=num_inference_steps,
denoising_start=denoising_frac,
guidance_scale=guidance_scale,
image=image[None, :],
).images[0]
```
Convert the results to base64, and return them.
```python model/model.py
b64_results = self.convert_to_b64(image)
end_time = time.time() - start_time
print(f"Time: {end_time:.2f} seconds")
return {"status": "success", "data": b64_results, "time": end_time}
```
# Setting up the config yaml
Running SDXL requires a handful of Python libraries, including
diffusers, transformers, and others.
```yaml config.yaml
environment_variables: {}
external_package_dirs: []
model_metadata:
example_model_input: {"prompt": "A tree in a field under the night sky", "use_refiner": true}
model_name: Stable Diffusion XL
python_version: py39
requirements:
- transformers==4.34.0
- accelerate==0.23.0
- safetensors==0.4.0
- git+https://github.com/basetenlabs/diffusers.git@9a353290b1497023d4745a719ec02c50f680499a
- invisible-watermark>=0.2.0
- xformers==0.0.22
```
## Configuring resources for SDXL 1.0
Note that we need an A10G to run this model.
```yaml config.yaml
resources:
accelerator: A10G
cpu: 3500m
memory: 20Gi
use_gpu: true
secrets: {}
```
## System Packages
Running diffusers requires `ffmpeg` and a couple other system
packages.
```yaml config.yaml
system_packages:
- ffmpeg
- libsm6
- libxext6
```
## Enabling Caching
SDXL is a very large model, and downloading it could take up to 10 minutes. This means
that the cold start time for this model is long. We can solve that by using our build
caching feature. This moves the model download to the build stage of your model--
caching the model will take about 10 minutes initially but you will get \~9s cold starts
subsequently.
To enable caching, add the following to the config:
```yaml
model_cache:
- repo_id: madebyollin/sdxl-vae-fp16-fix
allow_patterns:
- config.json
- diffusion_pytorch_model.safetensors
- repo_id: stabilityai/stable-diffusion-xl-base-1.0
allow_patterns:
- "*.json"
- "*.fp16.safetensors"
- sd_xl_base_1.0.safetensors
- repo_id: stabilityai/stable-diffusion-xl-refiner-1.0
allow_patterns:
- "*.json"
- "*.fp16.safetensors"
- sd_xl_refiner_1.0.safetensors
```
# Deploy the model
Deploy the model like you would other Trusses, with:
```bash
$ truss push
```
You can then invoke the model with:
```bash
$ truss predict -d '{"prompt": "A tree in a field under the night sky", "use_refiner": true}'
```
```python model/model.py
import base64
import time
from io import BytesIO
from typing import Any
import torch
from diffusers import AutoencoderKL, DiffusionPipeline, DPMSolverMultistepScheduler
from PIL import Image
torch.backends.cuda.matmul.allow_tf32 = True
class Model:
def __init__(self, **kwargs):
self._model = None
def load(self):
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
)
self.pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
vae=vae,
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
)
self.pipe.unet.to(memory_format=torch.channels_last)
self.pipe.to("cuda")
self.pipe.enable_xformers_memory_efficient_attention()
self.refiner = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=self.pipe.text_encoder_2,
vae=self.pipe.vae,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
)
self.refiner.to("cuda")
self.refiner.enable_xformers_memory_efficient_attention()
def convert_to_b64(self, image: Image) -> str:
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_b64
def predict(self, model_input: Any) -> Any:
prompt = model_input.pop("prompt")
negative_prompt = model_input.pop("negative_prompt", None)
use_refiner = model_input.pop("use_refiner", True)
num_inference_steps = model_input.pop("num_inference_steps", 30)
denoising_frac = model_input.pop("denoising_frac", 0.8)
end_cfg_frac = model_input.pop("end_cfg_frac", 0.4)
guidance_scale = model_input.pop("guidance_scale", 7.5)
seed = model_input.pop("seed", None)
scheduler = model_input.pop(
"scheduler", None
) # Default: EulerDiscreteScheduler (works pretty well)
if scheduler == "DPM++ 2M":
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config
)
elif scheduler == "DPM++ 2M Karras":
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config, use_karras_sigmas=True
)
elif scheduler == "DPM++ 2M SDE Karras":
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config,
algorithm_type="sde-dpmsolver++",
use_karras_sigmas=True,
)
generator = None
if seed is not None:
torch.manual_seed(seed)
generator = [torch.Generator(device="cuda").manual_seed(seed)]
if not use_refiner:
denoising_frac = 1.0
start_time = time.time()
image = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
generator=generator,
end_cfg=end_cfg_frac,
num_inference_steps=num_inference_steps,
denoising_end=denoising_frac,
guidance_scale=guidance_scale,
output_type="latent" if use_refiner else "pil",
).images[0]
scheduler = self.pipe.scheduler
if use_refiner:
self.refiner.scheduler = scheduler
image = self.refiner(
prompt=prompt,
negative_prompt=negative_prompt,
generator=generator,
end_cfg=end_cfg_frac,
num_inference_steps=num_inference_steps,
denoising_start=denoising_frac,
guidance_scale=guidance_scale,
image=image[None, :],
).images[0]
b64_results = self.convert_to_b64(image)
end_time = time.time() - start_time
print(f"Time: {end_time:.2f} seconds")
return {"status": "success", "data": b64_results, "time": end_time}
```
```yaml config.yaml
environment_variables: {}
external_package_dirs: []
model_metadata:
example_model_input: {"prompt": "A tree in a field under the night sky", "use_refiner": true}
model_name: Stable Diffusion XL
python_version: py39
requirements:
- transformers==4.34.0
- accelerate==0.23.0
- safetensors==0.4.0
- git+https://github.com/basetenlabs/diffusers.git@9a353290b1497023d4745a719ec02c50f680499a
- invisible-watermark>=0.2.0
- xformers==0.0.22
resources:
accelerator: A10G
cpu: 3500m
memory: 20Gi
use_gpu: true
secrets: {}
system_packages:
- ffmpeg
- libsm6
- libxext6
```
# Fast Cold Starts with Cached Weights
Deploy a language model, with the model weights cached at build time
In this example, we go through a Truss that serves an LLM, and *caches* the weights
at build time. Loading model weights for any model can often be the most time-consuming
part of starting a model. Caching the weights at build time means that the weights
will be baked into the Truss image, and will be available *immediately* when your model
replica starts. This means that **cold starts** will be *significantly faster* with this approach.
# Implementing the `Model` class
With weight caching, you don't have to change anything about how the `Model` class
is implemented to take advantage of the weight caching.
```python model/model.py
from typing import Dict, List
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant."
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<>\n", "\n<>\n\n"
CHECKPOINT = "NousResearch/Llama-2-7b-chat-hf"
def format_prompt(prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
return f"{B_INST} {B_SYS} {system_prompt} {E_SYS} {prompt} {E_INST}"
class Model:
def __init__(self, **kwargs) -> None:
self.model = None
self.tokenizer = None
def load(self):
self.model = LlamaForCausalLM.from_pretrained(
CHECKPOINT, torch_dtype=torch.float16, device_map="auto"
)
self.tokenizer = LlamaTokenizer.from_pretrained(CHECKPOINT)
def predict(self, request: Dict) -> Dict[str, List]:
prompt = request.pop("prompt")
input_ids = self.tokenizer(
format_prompt(prompt), return_tensors="pt"
).input_ids.cuda()
outputs = self.model.generate(
inputs=input_ids, do_sample=True, num_beams=1, max_new_tokens=100
)
response = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
return {"response": response}
```
# Setting up the config.yaml
The `config.yaml` file is where you need to include the changes to
actually cache the weights at build time.
```yaml config.yaml
environment_variables: {}
external_package_dirs: []
model_metadata:
example_model_input: {"prompt": "What is the meaning of life?"}
model_name: Llama with Cached Weights
python_version: py39
requirements:
- accelerate==0.21.0
- safetensors==0.3.2
- torch==2.0.1
- transformers==4.34.0
- sentencepiece==0.1.99
- protobuf==4.24.4
```
# Configuring the model\_cache
To cache model weights, set the `model_cache` key.
The `repo_id` field allows you to specify a Huggingface
repo to pull down and cache at build-time, and the `ignore_patterns`
field allows you to specify files to ignore. If this is specified, then
this repo won't have to be pulled during runtime.
Check out the [guide](/truss/guides/model-cache) for more info.
```yaml config.yaml
model_cache:
- repo_id: "NousResearch/Llama-2-7b-chat-hf"
ignore_patterns:
- "*.bin"
```
The remaining config options are again, similar to what you would
configure for the model without the weight caching.
```yaml config.yaml
resources:
cpu: "4"
memory: 30Gi
use_gpu: True
accelerator: A10G
secrets: {}
```
# Deploy the model
Deploy the model like you would other Trusses, with:
```bash
$ truss push
```
The build step will take longer than with the normal
Llama Truss, since bundling the model weights is now happening during the build.
The deploy step & scale-ups will happen much faster with this approach.
You can then invoke the model with:
```bash
$ truss predict -d '{"inputs": "What is a large language model?"}'
```
```python model/model.py
from typing import Dict, List
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant."
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<>\n", "\n<>\n\n"
CHECKPOINT = "NousResearch/Llama-2-7b-chat-hf"
def format_prompt(prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
return f"{B_INST} {B_SYS} {system_prompt} {E_SYS} {prompt} {E_INST}"
class Model:
def __init__(self, **kwargs) -> None:
self.model = None
self.tokenizer = None
def load(self):
self.model = LlamaForCausalLM.from_pretrained(
CHECKPOINT, torch_dtype=torch.float16, device_map="auto"
)
self.tokenizer = LlamaTokenizer.from_pretrained(CHECKPOINT)
def predict(self, request: Dict) -> Dict[str, List]:
prompt = request.pop("prompt")
input_ids = self.tokenizer(
format_prompt(prompt), return_tensors="pt"
).input_ids.cuda()
outputs = self.model.generate(
inputs=input_ids, do_sample=True, num_beams=1, max_new_tokens=100
)
response = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
return {"response": response}
```
```yaml config.yaml
environment_variables: {}
external_package_dirs: []
model_metadata:
example_model_input: {"prompt": "What is the meaning of life?"}
model_name: Llama with Cached Weights
python_version: py39
requirements:
- accelerate==0.21.0
- safetensors==0.3.2
- torch==2.0.1
- transformers==4.34.0
- sentencepiece==0.1.99
- protobuf==4.24.4
model_cache:
- repo_id: "NousResearch/Llama-2-7b-chat-hf"
ignore_patterns:
- "*.bin"
resources:
cpu: "4"
memory: 30Gi
use_gpu: True
accelerator: A10G
secrets: {}
```
# Private Hugging Face Model
Load a model that requires authentication with Hugging Face
In this example, we build a Truss that uses a model that
requires Hugging Face authentication. The steps for loading a model
from Hugging Face are:
1. Create an [access token](https://huggingface.co/settings/tokens) on your Hugging Face account.
2. Add the \`hf\_access\_token\`\` key to your config.yaml secrets and value to your [Baseten account](https://app.baseten.co/settings/secrets).
3. Add `use_auth_token` when creating the actual model.
# Setting up the model
In this example, we use a private version of the [BERT base model](https://huggingface.co/bert-base-uncased).
The model is publicly available, but for the purposes of our example, we copied it into a private
model repository, with the path "baseten/docs-example-gated-model".
First, like with other Hugging Face models, start by importing the `pipeline` function from the
transformers library, and defining the `Model` class.
```python model/model.py
from transformers import pipeline
class Model:
```
An important step in loading a model that requires authentication is to
have access to the secrets defined for this model. We pull these out of
the keyword args in the `__init__` function.
```python model/model.py
def __init__(self, **kwargs) -> None:
self._secrets = kwargs["secrets"]
self._model = None
def load(self):
```
Ensure that when you define the `pipeline`, we use the `use_auth_token` parameter,
pass the `hf_access_token` secret that is on our Baseten account.
```python model/model.py
self._model = pipeline(
"fill-mask",
model="baseten/docs-example-gated-model",
use_auth_token=self._secrets["hf_access_token"],
)
def predict(self, model_input):
return self._model(model_input)
```
# Setting up the config.yaml
The main things that need to be set up in the config are
`requirements`, which need to include Hugging Face transformers,
and the secrets.
```yaml config.yaml
environment_variables: {}
model_name: private-model
python_version: py39
requirements:
- torch==2.0.1
- transformers==4.30.2
resources:
cpu: "1"
memory: 2Gi
use_gpu: false
accelerator: null
```
To make the `hf_access_token` available in the Truss, we need to include
it in the config. Setting the value to `null` here means that the value
will be set by the Baseten secrets manager.
```yaml config.yaml
secrets:
hf_access_token: null
system_packages: []
```
# Deploying the model
An important note for deploying models with secrets is that
you must use the `--trusted` flag to give the model access to
secrets stored on the remote secrets manager.
```bash
$ truss push --trusted
```
After the model finishes deploying, you can invoke it with:
```bash
$ truss predict -d '"It is a [MASK] world"'
```
```python model/model.py
from transformers import pipeline
class Model:
def __init__(self, **kwargs) -> None:
self._secrets = kwargs["secrets"]
self._model = None
def load(self):
self._model = pipeline(
"fill-mask",
model="baseten/docs-example-gated-model",
use_auth_token=self._secrets["hf_access_token"],
)
def predict(self, model_input):
return self._model(model_input)
```
```yaml config.yaml
environment_variables: {}
model_name: private-model
python_version: py39
requirements:
- torch==2.0.1
- transformers==4.30.2
resources:
cpu: "1"
memory: 2Gi
use_gpu: false
accelerator: null
secrets:
hf_access_token: null
system_packages: []
```
# Model with system packages
Deploy a model with both Python and system dependencies
In this example, we build a Truss with a model that requires specific system packages.
To add system packages to your Truss, you can add a `system_packages` key to your config.yaml file,
for instance:
To add system packages to your model serving environment, open config.yaml and
update the system\_packages key with a list of apt-installable Debian packages:
```yaml config.yaml
system_packages:
- tesseract-ocr
```
For this example, we use the [LayoutLM Document QA](https://huggingface.co/impira/layoutlm-document-qa) model,
a multimodal model that answers questions about provided invoice documents. This model requires a system
package, tesseract-ocr, which needs to be included in the model serving environment.
# Setting up the model.py
For this model, we use the HuggingFace transformers library, and the document-question-answering task.
```python model/model.py
from transformers import pipeline
class Model:
def __init__(self, **kwargs) -> None:
self._model = None
def load(self):
self._model = pipeline(
"document-question-answering",
model="impira/layoutlm-document-qa",
)
def predict(self, model_input):
return self._model(model_input["url"], model_input["prompt"])
```
# Setting up the config.yaml file
The main items that need to be configured in the config.yaml file are requirements
and `system_packages` sections.
```yaml config.yaml
environment_variables: {}
external_package_dirs: []
model_metadata:
example_model_input: {"url": "https://templates.invoicehome.com/invoice-template-us-neat-750px.png", "prompt": "What is the invoice number?"}
model_name: LayoutLM Document QA
python_version: py39
```
Specify the versions of the Python requirements that are needed.
Always pin exact versions for your Python dependencies. The ML/AI space moves fast, so you want to have an up-to-date version of each package while also being protected from breaking changes.
```yaml config.yaml
requirements:
- Pillow==10.0.0
- pytesseract==0.3.10
- torch==2.0.1
- transformers==4.30.2
resources:
cpu: "4"
memory: 16Gi
use_gpu: false
accelerator: null
secrets: {}
```
The system\_packages section is the other important bit here, you can
add any package that's available via `apt` on Debian.
```yaml config.yaml
system_packages:
- tesseract-ocr
```
# Deploy the model
```bash
$ truss push
```
You can then invoke the model with:
```
$ truss predict -d '{"url": "https://templates.invoicehome.com/invoice-template-us-neat-750px.png", "prompt": "What is the invoice number?"}'
```
```python model/model.py
from transformers import pipeline
class Model:
def __init__(self, **kwargs) -> None:
self._model = None
def load(self):
self._model = pipeline(
"document-question-answering",
model="impira/layoutlm-document-qa",
)
def predict(self, model_input):
return self._model(model_input["url"], model_input["prompt"])
```
```yaml config.yaml
environment_variables: {}
external_package_dirs: []
model_metadata:
example_model_input: {"url": "https://templates.invoicehome.com/invoice-template-us-neat-750px.png", "prompt": "What is the invoice number?"}
model_name: LayoutLM Document QA
python_version: py39
requirements:
- Pillow==10.0.0
- pytesseract==0.3.10
- torch==2.0.1
- transformers==4.30.2
resources:
cpu: "4"
memory: 16Gi
use_gpu: false
accelerator: null
secrets: {}
system_packages:
- tesseract-ocr
```
# Base Docker images
A guide to configuring a base image for your truss
Model serving enviroments will often be standardized as container images to avoid wrangling python, system, and other requirements needed to run your model on every deploy.
Leverage your existing container artifacts by bringing your own base image to Truss.
While the default image for the truss server in adequate for most use cases, there may come a time when you require a custom base image for your truss.
For example, maybe the python packages required for you project are not compatible with the ones installed by default on Truss.
In a situation like this, you can create your own base image based on the default truss image.
## Setting a base image in config.yaml
To specify a base image to build a truss container image from in your `config.yaml` configure a `base_image`.
```yaml config.yaml
base_image:
image:
python_executable_path:
```
where `python_executable_path` is a path to a Python executable with which to run your server.
## Example usage
One great use case for base images is to more easily package and deploy models from sources like the [NeMo Toolkit](https://github.com/NVIDIA/NeMo) without having to recreate all of the dependencies yourself.
This example demonstrates how to properly configure a base image for [NVIDIA NeMo TitaNet](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/titanet_large), a speaker recognition model from NVIDIA.
```yaml config.yaml
base_image:
image: nvcr.io/nvidia/nemo:23.03
python_executable_path: /usr/bin/python
apply_library_patches: true
requirements:
- PySoundFile
resources:
accelerator: T4
cpu: 2500m
memory: 4512Mi
use_gpu: true
secrets: {}
system_packages:
- python3.8-venv
```
## Configuring private base images with build time secrets
Secrets of the form `DOCKER_REGISTRY_` will be supplied to your model build to authenticate image pulls from private container registries.
For information on where to store secret values see the [secrets guide](/truss/guides/secrets#storing-secrets-on-your-remote).
For example, to configure docker credentials to a private dockerhub repository your `config.yaml` should include the following secret and placeholder:
```yaml config.yaml
secrets:
DOCKER_REGISTRY_https://index.docker.io/v1/: null
```
along with a configured Baseten secret `DOCKER_REGISTRY_https://index.docker.io/v1/` with a base64 encoded `username:password` secret value:
```sh
echo -n 'username:password' | base64
```
To add docker credentials for Google Cloud Artifact Registry provide an [access token](https://cloud.google.com/artifact-registry/docs/docker/authentication#token) as the secret value.
For example, to configure authentication for a repository in `us-west2` your `config.yaml` should include the following secret and placeholder:
```yaml config.yaml
secrets:
DOCKER_REGISTRY_us-west2-docker.pkg.dev: null
```
Note that since access tokens are short-lived, you may want to leverage the Baseten [Secrets API](/api-reference/upserts-a-secret) to automatically check for expired access tokens before every deploy, and update them as necessary.
If you don't have a base image but want to create one, the following section will guide you through that process.
## Creating a custom base image
While it's possible to create your own base image from scratch, it may be easier to use a truss server image as a starting point.
All of the base images used by truss can be found by going to [docker hub](https://hub.docker.com/r/baseten/truss-server-base/tags).
Each image tag is tied to a specific python version and may support a GPU. For example, the image with the tag `3.11-gpu-v0.7.16` is for Python 3.11 and has GPU support.
On the other hand, image `3.9-v0.7.15` is for Python 3.9 and does not have GPU support. Based on your project requirements you can select the appropriate base image.
Next, we can write our `Dockerfile`.
```Dockerfile Dockerfile
FROM baseten/truss-server-base:3.11-gpu-v0.7.16
RUN pip uninstall cython -y
RUN pip install cython==0.29.30
```
In the example dockerfile above, we use the `truss-server-base` as our base image. This base image comes with some python dependencies installed.
If you want to override the versions of these default python dependencies you can simply uninstall the package and reinstall it using the `RUN pip install` command.
You can even add you own files or directories to this Dockerfile if required. Once you have your Dockerfile set up, you can build, tag, and push the image to your own docker registry.
## Build, tag, and push your custom base image
### Docker installation required
For this portion, docker needs to be installed and running on your system. Additionally, you will need docker installed on your command line.
To build the image run the command:
```sh
docker build -t my-custom-base-image:0.1 .
```
You can replace `my-custom-base-image` with the name for your docker image.
Next, to tag the image you can run the command:
```sh
docker tag my-custom-base-image:0.1 your-docker-username/my-custom-base-image:0.1
```
Lastly, to push the image to docker hub run the following command:
```sh
docker push your-docker-username/my-custom-base-image:0.1
```
# Running custom docker commands
How to run your own docker commands during the build stage
Caching objects during the build stage is advantageous, especially when cold starts are concerned.
While `model_cache` and `external_data` work well in many scenarios, you may find yourself needing advanced features to control the type of data and the location where it gets cached.
The `build_commands` feature does just that. It allows you to run custom docker commands at build time.
To give a few examples, you can clone Github repositories, download models, and even create directories during the build stage!
## Using run commands during the docker build
Build commands is accessible in Truss via the `config.yaml` file.
```yaml
build_commands:
- git clone https://github.com/comfyanonymous/ComfyUI.git
- cd ComfyUI && git checkout b1fd26fe9e55163f780bf9e5f56bf9bf5f035c93 && pip install -r requirements.txt
model_name: Build Commands Demo
python_version: py310
resources:
accelerator: A100
use_gpu: true
```
In the example above a git repository is being cloned and a set of python requirements is installed.
All of this happens during the container build step so that the Github repository and the Python packages will be loaded from cache during deployment.
## Creating directories in your truss
Sometimes your truss relies on a large codebase. You can now add files or directories to this codebase directly through build commands.
```yaml
build_commands:
- git clone https://github.com/comfyanonymous/ComfyUI.git
- cd ComfyUI && mkdir ipadapter
- cd ComfyUI && mkdir instantid
model_name: Build Commands Demo
python_version: py310
resources:
accelerator: A100
use_gpu: true
```
## Yet another way to cache your model weights
### Best practices for caching model weights
While you can use the `build_commands` feature to cache model weights, it should be used to cache weights under 10 GB.
To cache larger model weights, the `model_cache` and `external_data` features offer more robust capabilites.
If you're familiar with the Linux/Unix OS, you may have used the `wget` tool to download files.
Build commands allow you to use `wget` to download model weights and store them wherever you like in the truss.
Here's an example:
```yaml
build_commands:
- git clone https://github.com/comfyanonymous/ComfyUI.git
- cd ComfyUI && pip install -r requirements.txt
- cd ComfyUI/custom_nodes && git clone https://github.com/Fannovel16/comfyui_controlnet_aux --recursive && cd comfyui_controlnet_aux && pip install -r requirements.txt
- cd ComfyUI/models/controlnet && wget -O control-lora-canny-rank256.safetensors https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank256/control-lora-canny-rank256.safetensors
- cd ComfyUI/models/controlnet && wget -O control-lora-depth-rank256.safetensors https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank256/control-lora-depth-rank256.safetensors
- cd ComfyUI/models/checkpoints && wget -O dreamshaperXL_v21TurboDPMSDE.safetensors https://civitai.com/api/download/models/351306
- cd ComfyUI/models/loras && wget -O StudioGhibli.Redmond-StdGBRRedmAF-StudioGhibli.safetensors https://huggingface.co/artificialguybr/StudioGhibli.Redmond-V2/resolve/main/StudioGhibli.Redmond-StdGBRRedmAF-StudioGhibli.safetensors
model_name: Build Commands Demo
python_version: py310
resources:
accelerator: A100
use_gpu: true
system_packages:
- wget
```
Using `build_commands` you can run any kind of shell command that you would normally run locally.
The main benefit you get is that everything you run gets cached, which helps reduce the cold-start time for your model.
# Deploy Llama 2 with Caching
Enable fast cold starts for a model with private Hugging Face weights
In this example, we will cover how you can use the `model_cache` key in your Truss's `config.yml` to automatically bundle model weights from a private Hugging Face repo.
Bundling model weights can significantly reduce cold start times because your instance won't waste time downloading the model weights from Hugging Face's servers.
We use `Llama-2-7b`, a popular open-source large language model, as an example. In order to follow along with us, you need to request access to Llama 2.
1. First, [sign up for a Hugging Face account](https://huggingface.co/join) if you don't already have one.
2. Request access to Llama 2 from [Meta's website](https://ai.meta.com/resources/models-and-libraries/llama-downloads/).
3. Next, request access to Llama 2 on [Hugging Face](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) by clicking the "Request access" button on the model page.
If you want to deploy on Baseten, you also need to create a Hugging Face API token and add it to your organizations's secrets.
1. [Create a Hugging Face API token](https://huggingface.co/settings/tokens) and copy it to your clipboard.
2. Add the token with the key `hf_access_token` to [your organization's secrets](https://app.baseten.co/settings/secrets) on Baseten.
### Step 0: Initialize Truss
Get started by creating a new Truss:
```sh
truss init llama-2-7b-chat
```
Select the `TrussServer` option then hit `y` to confirm Truss creation. Then navigate to the newly created directory:
```sh
cd llama-2-7b-chat
```
### Step 1: Implement Llama 2 7B in Truss
Next, we'll fill out the `model.py` file to implement Llama 2 7B in Truss.
In `model/model.py`, we write the class `Model` with three member functions:
* `__init__`, which creates an instance of the object with a `_model` property
* `load`, which runs once when the model server is spun up and loads the `pipeline` model
* `predict`, which runs each time the model is invoked and handles the inference. It can use any JSON-serializable type as input and output.
We will also create a helper function `format_prompt` outside of the `Model` class to appropriately format the incoming text according to the Llama 2 specification.
[Read the quickstart guide](/quickstart) for more details on `Model` class implementation.
```python model/model.py
from typing import Dict, List
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant."
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<>\n", "\n<>\n\n"
class Model:
def __init__(self, **kwargs) -> None:
self._data_dir = kwargs["data_dir"]
self._config = kwargs["config"]
self._secrets = kwargs["secrets"]
self.model = None
self.tokenizer = None
def load(self):
self.model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-chat-hf",
use_auth_token=self._secrets["hf_access_token"],
torch_dtype=torch.float16,
device_map="auto"
)
self.tokenizer = LlamaTokenizer.from_pretrained(
"meta-llama/Llama-2-7b-chat-hf",
use_auth_token=self._secrets["hf_access_token"]
)
def predict(self, request: Dict) -> Dict[str, List]:
prompt = request.pop("prompt")
prompt = format_prompt(prompt)
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, do_sample=True, num_beams=1, max_new_tokens=100)
response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
return {"response": response}
def format_prompt(prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
return f"{B_INST} {B_SYS} {system_prompt} {E_SYS} {prompt} {E_INST}"
```
### Step 2: Set Python dependencies
Now, we can turn our attention to configuring the model server in `config.yaml`.
In addition to `transformers`, Llama 2 has three other dependencies. We list them below as follows:
```yaml config.yaml
requirements:
- accelerate==0.21.0
- safetensors==0.3.2
- torch==2.0.1
- transformers==4.30.2
```
Always pin exact versions for your Python dependencies. The ML/AI space moves fast, so you want to have an up-to-date version of each package while also being protected from breaking changes.
### Step 3: Configure Hugging Face caching
Finally, we can configure Hugging Face caching in `config.yaml` by adding the `model_cache` key. When building the image for your Llama 2 deployment, the Llama 2 model weights will be downloaded and cached for future use.
```yaml config.yaml
model_cache:
- repo_id: "meta-llama/Llama-2-7b-chat-hf"
ignore_patterns:
- "*.bin"
```
In this configuration:
* `meta-llama/Llama-2-7b-chat-hf` is the `repo_id`, pointing to the exact model to cache.
* We use a wild card to ignore all `.bin` files in the model directory by providing a pattern under `ignore_patterns`. This is because the model weights are stored in `.bin` and `safetensors` format, and we only want to cache the `safetensors` files.
### Step 4: Deploy the model
You'll need a [Baseten API key](https://app.baseten.co/settings/account/api_keys) for this step. Make sure you added your `HUGGING_FACE_HUB_TOKEN` to your organization's secrets.
We have successfully packaged Llama 2 as a Truss. Let's deploy!
```sh
truss push --trusted
```
### Step 5: Invoke the model
You can invoke the model with:
```sh
truss predict -d '{"prompt": "What is a large language model?"}'
```
```yaml config.yaml
environment_variables: {}
external_package_dirs: []
model_metadata: {}
model_name: null
python_version: py39
requirements:
- accelerate==0.21.0
- safetensors==0.3.2
- torch==2.0.1
- transformers==4.30.2
model_cache:
- repo_id: "NousResearch/Llama-2-7b-chat-hf"
ignore_patterns:
- "*.bin"
resources:
cpu: "4"
memory: 30Gi
use_gpu: True
accelerator: A10G
secrets: {}
```
```python model/model.py
from typing import Dict, List
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant."
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<>\n", "\n<>\n\n"
class Model:
def __init__(self, **kwargs) -> None:
self._data_dir = kwargs["data_dir"]
self._config = kwargs["config"]
self._secrets = kwargs["secrets"]
self.model = None
self.tokenizer = None
def load(self):
self.model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-chat-hf",
use_auth_token=self._secrets["hf_access_token"],
torch_dtype=torch.float16,
device_map="auto"
)
self.tokenizer = LlamaTokenizer.from_pretrained(
"meta-llama/Llama-2-7b-chat-hf",
use_auth_token=self._secrets["hf_access_token"]
)
def predict(self, request: Dict) -> Dict[str, List]:
prompt = request.pop("prompt")
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, do_sample=True, num_beams=1, max_new_tokens=100)
response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
return {"response": response}
def format_prompt(prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
return f"{B_INST} {B_SYS} {system_prompt} {E_SYS} {prompt} {E_INST}"
```
# Request concurrency
A guide to setting concurrency for your model
Configuring concurrency is one of the major knobs available for getting the most performance
out of your model. In this doc, we'll cover the options that are available to you.
# What is concurrency, and why configure it?
At a very high level, "concurrency" in this context refers to how many requests a single replica can
process at the same time. There are no right answers to what this number ought to be -- the specifics
of your model and the metrics you are optimizing for (throughput? latency?) matter a lot for determining this.
In Baseten & Truss, there are two notions of concurrency:
* **Concurrency Target** -- the number of requests that will be sent to a model at the same time
* **Predict Concurrency** -- once requests have made it onto the model container, the "predict concurrency" governs how many
requests can go through the `predict` function on your Truss at once.
# Concurrency Target
The concurrency target is set in the Baseten UI, and to re-iterate, governs the maximum number of requests that will be sent
to a single model replica.
An important note about this setting is that it is also used as a part of the auto-scaling parameters. If all replicas have
hit their Concurrency Target, this triggers Baseten's autoscaling.
Let's dive into a concrete example:
Let's say that there is a single replica of a model, and the concurrency target is 2. If 5 requests come in, the first 2 will
be sent to the replica, and the other 3 get queued up. Once the requests on the container complete the queued up
requests will make it to the model container.
Remember that if all replicas have hit their concurrency target, this will trigger autoscaling. So in this specific example,
the queuing of requests 3-5 will trigger another replica to come up, if the model has not hit its max replicas yet.
# Predict Concurrency
Alright, so we've talked about the **Concurreny Target** feature that governs how many requests will be sent to a model at once.
predict concurrency is a bit different -- it operates on the level of the model container and governs how many requests will go
through the `predict` function concurrently.
To get a sense for why this matters, let's recap the structure of a Truss:
```python model.py
class Model:
def __init__(self):
...
def preprocess(self, request):
...
def predict(self, request):
...
def postprocess(self, response):
...
```
In this Truss model, there are three functions that are called in order to serve a request:
* **preprocess** -- this function is used to perform any prework / modifications on the request before the `predict` function
runs. For instance, if you are running an image classification model, and need to download images from S3, this is a good placeholder
to do it.
* **predict** -- this function is where the actual inference happens. It is likely where the logic that runs GPU code lives
* **postprocess** -- this function is used to perform any postwork / modifications on the response before it is returned to the
user. For instance, if you are running a text-to-image model, this is a good place to implement the logic for uploading an image
to S3.
You can see with these three functions and the behaviors that they are used for that you might want to have different
levels of concurrency for the `predict` function. The most common need here is to limit access to the GPU, since multiple
requests running on the GPU at the same time could cause serious degradation in performance.
Unlike **Concurrency Target**, which is configured in the Baseten UI, the **Predict Concurrency** is configured as a part
of the Truss Config (in the `config.yaml` file).
```yaml config.yaml
model_name: "My model with concurrency limits"
...
runtime:
predict_concurrency: 2 # the default is 1
...
```
To better understand this, let's zoom in on the model pod:
Let's say predict concurrency is 1.
1. Two requests come in to the pod.
2. Both requests will begin preprocessing immediately (let's say,
downloading images from S3).
3. Once the first request finishes preprocessing, it will begin running on the GPU. The second request
will then remain queued until the first request finishes running on the GPU in predict.
4. After the first request finishes, the second request will begin being processed on the GPU
5. Once the second request finishes, it will begin postprocessing, even if the first request is not done postprocessing
To reiterate, predict concurrency is really great to use if you want to protect your GPU resource on your model pod,
while still allowing for high concurrency for the pre and post-process steps.
Remember that to actually achieve the predict concurrency you desire, the Concurrency Target must be at least that amount,
so that the requests make it to the model container.
# Deploy Custom Server from Docker image
A config.yaml is all you need
If you have a ready-to-use API server packaged in a Docker image, either an open source serve image like [vLLM](https://github.com/vllm-project/vllm) or a customized Docker image built in house, it's very easy to deploy on Baseten -- all you need is a `config.yaml` file.
## Specifying a Docker image in config.yaml
To specify a Docker image of Custom Server, in your `config.yaml`, add a `docker_server` field:
```yaml config.yaml
base_image:
image: vllm/vllm-openai:latest
docker_server:
start_command: vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct --port 8000 --max-model-len 1024
readiness_endpoint: /health
liveness_endpoint: /health
predict_endpoint: /v1/chat/completions
server_port: 8000
```
where
* `start_command` is the command to start the server
* `predict_endpoint` is the endpoint to send requests to the server, please note that deployed models can only support a single predict endpoint at the moment
* `server_port` is the port to run the server on
* `readiness_endpoint`(Optional) is the endpoint override for the [readiness probe](https://kubernetes.io/docs/tasks/configure-pod-container/configure-liveness-readiness-probes/) used by Kubernetes
* `liveness_endpoint`(Optional) is the endpoint override for the [liveness probe](https://kubernetes.io/docs/tasks/configure-pod-container/configure-liveness-readiness-probes/) used by Kubernetes
Even though both `readiness_endpoint` and `liveness_endpoint` are optional, we recommend specifying both of them if your Custom Server is exposing a health check endpoint. Otherwise we will use the default endpoints which might not accurately check the health of your Custom Server.
## Example usage: run vLLM server from Docker image
One great use case for Custom Server is to spin up a popular open source model server like [vLLM OpenAI Compatible Server](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html). Below is an example to deploy the [Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) model using vLLM on 1 A10G GPU.
Also as you can see here, we are passing in `/health` endpoint provided by vLLM server as both `readiness_endpoint` and `liveness_endpoint`, this way we can use the internal health probe of vLLM server to decide if the server is ready to accept requests, or if it is unhealthy and needs to be restarted.
```yaml config.yaml
base_image:
image: vllm/vllm-openai:latest
docker_server:
start_command: sh -c "HF_TOKEN=$(cat /secrets/hf_access_token) vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct --port 8000 --max-model-len 1024"
readiness_endpoint: /health
liveness_endpoint: /health
predict_endpoint: /v1/chat/completions
server_port: 8000
resources:
accelerator: A10G
model_name: vllm-model-server
secrets:
hf_access_token: null
runtime:
predict_concurrency: 128
```
More usage examples of Custom Server can be found [here](https://github.com/basetenlabs/truss-examples/tree/main/custom-server).
## Installing custom python packages
If you need to install additional python packages, you can do so by adding a `requirements.txt` file to your truss. The following example shows how to start the [Infinity Embedding Model Server](https://github.com/michaelfeil/infinity) from a Docker image with python package `infinity-embedding` installed.
```yaml config.yaml
base_image:
image: python:3.11-slim
docker_server:
start_command: sh -c "infinity_emb v2 --model-id BAAI/bge-small-en-v1.5"
readiness_endpoint: /health
liveness_endpoint: /health
predict_endpoint: /embeddings
server_port: 7997
resources:
accelerator: L4
use_gpu: true
model_name: infinity-embedding-server
requirements:
- infinity-emb[all]
environment_variables:
hf_access_token: null
```
## Accessing secrets in Custom Server
As you might have noticed in the vLLM example above, you can access secrets in the Custom Server by reading them from the `/secrets` directory if you have [stored those secrets in Baseten](https://docs.baseten.co/truss/guides/secrets#storing-secrets-on-your-remote). This is useful if you need to pass in environment variables or other secrets to your server.
# Model weights
Load model weights without Hugging Face or S3
Serving a model requires access to model files, such as model weights. These files are often many gigabytes.
For many models, these files are loaded from Hugging Face. However, model files can come from other sources or be stored directly in the Truss. Model weights and other model data can be:
* Public on Hugging Face (default, [example here](/truss/examples/04-image-generation))
* [Private on Hugging Face](/truss/examples/09-private-huggingface)
* [Bundled directly with the Truss](#bundling-model-weights-in-truss)
* [Public cloud storage like S3](#loading-public-model-weights-from-s3)
* [Private cloud storage like S3](#loading-private-model-weights-from-s3)
## Bundling model weights in Truss
You can bundle model data directly with your model in Truss. To do so, use the Truss' `data` folder to store any necessary files.
Here's an example of the `data` folder for [a Truss of Stable Diffusion 2.1](https://github.com/basetenlabs/truss-examples/tree/main/stable-diffusion/stable-diffusion).
```
data/
scheduler/
scheduler_config.json
text_encoder/
config.json
diffusion_pytorch_model.bin
tokenizer/
merges.txt
special_tokens_map.json
tokenizer_config.json
vocab.json
unet/
config.json
diffusion_pytorch_model.bin
vae/
config.json
diffusion_pytorch_model.bin
model_index.json
```
To access the data in the model, use the `self._data_dir` variable in the `load()` function of `model/model.py`:
```python
class Model:
def __init__(self, **kwargs) -> None:
self._data_dir = kwargs["data_dir"]
def load(self):
self.model = StableDiffusionPipeline.from_pretrained(
str(self._data_dir), # Set to "data" by default from config.yaml
revision="fp16",
torch_dtype=torch.float16,
).to("cuda")
```
## Loading public model weights from S3
Bundling multi-gigabyte files with your Truss can be difficult if you have limited local storage and can make deployment slow. Instead, you can store your model weights and other files in cloud storage like S3.
Using files from S3 requires four steps:
1. Uploading the content of your data directory to S3
2. Setting `external_data` in config.yaml
3. Removing unneeded files from the `data` directory
4. Accessing data correctly in the model
Here's an example of that setup for Stable Diffusion, where we have already uploaded the content of our `data/` directory to S3.
First, add the URLs for hosted versions of the large files to `config.yaml`:
```yaml
external_data:
- url: https://baseten-public.s3.us-west-2.amazonaws.com/models/stable-diffusion-truss/unet/diffusion_pytorch_model.bin
local_data_path: unet/diffusion_pytorch_model.bin
- url: https://baseten-public.s3.us-west-2.amazonaws.com/models/stable-diffusion-truss/text_encoder/pytorch_model.bin
local_data_path: text_encoder/pytorch_model.bin
- url: https://baseten-public.s3.us-west-2.amazonaws.com/models/stable-diffusion-truss/vae/diffusion_pytorch_model.bin
local_data_path: vae/diffusion_pytorch_model.bin
```
Each URL matches with a local data path that represents where the model data would be stored if everything was bundled together locally. This is how your model code will know where to look for the data.
Then, get rid of the large files from your `data` folder. The Stable Diffusion Truss has the following directory structure after large files are removed:
```
data/
scheduler/
scheduler_config.json
text_encoder/
config.json
tokenizer/
merges.txt
special_tokens_map.json
tokenizer_config.json
vocab.json
unet/
config.json
vae/
config.json
model_index.json
```
The code in `model/model.py` does not need to be changed and will automatically pull the large files from the provided links.
## Loading private model weights from S3
If your model weights are proprietary, you'll be storing them in a private S3 bucket or similar access-restricted data store. Accessing these model files works exactly the same as above, but first uses [secrets](/truss/guides/secrets) to securely authenticate your model with the data store.
First, set the following secrets in `config.yaml`. Set the values to `null`, only the keys are needed here.
```yaml
secrets:
aws_access_key_id: null
aws_secret_access_key: null
aws_region: null # e.g. us-east-1
aws_bucket: null
```
Then, [add secrets to your Baseten account](https://docs.baseten.co/observability/secrets) for your AWS access key id, secret access key, region, and bucket. This time, use the actual values as they will be securely stored and provided to your model at runtime.
In your model code, authenticate with AWS in the `__init__()` function:
```python
def __init__(self, **kwargs) -> None:
self._config = kwargs.get("config")
secrets = kwargs.get("secrets")
self.s3_config = (
{
"aws_access_key_id": secrets["aws_access_key_id"],
"aws_secret_access_key": secrets["aws_secret_access_key"],
"aws_region": secrets["aws_region"],
}
)
self.s3_bucket = (secrets["aws_bucket"])
```
You can then use the `boto3` package to access your model weights in `load()`.
When you're ready to deploy your model, make sure to pass `is_trusted=True` to `baseten.deploy()`:
```python
import baseten
import truss
my_model = truss.load("my-model")
baseten.deploy(
my_model,
model_name="My model",
is_trusted=True
)
```
For further details, see [docs on using secrets in models](/truss/guides/secrets).
# Access model environments
A guide to leveraging environments in your models
A model's environment is passed to your `Model` class as a keyword argument in `init`. It can be accessed with:
```py model/model.py
def __init__(self, **kwargs):
self._environment = kwargs["environment"]
```
You can then use the `self._environment` dictionary in the `load` function:
```py model/model.py
def load(self):
# Configure monitoring and weights based on the deployment environment
if self._environment.get("name") == "production":
# Production setup
self.setup_sentry()
self.setup_logging(level="INFO")
self.load_production_weights()
else:
# Default setup for staging or development deployments
self.setup_logging(level="DEBUG")
self.load_default_weights()
```
[Learn more](/deploy/lifecycle#what-is-an-environment) about environments.
# External (source) packages
A guide on configuring your truss to use external packages
You might encounter a situation where you have to incorporate your own modules or third-party package(not on PyPi) into your truss. Truss has a few different mechanisms to support this.
1. Using the packages directory
2. Using the external packages directory
Let's look at using the packages directory first.
## Using the packages directory
Each truss, when initialized, comes with a `packages` directory. This directory is at the same level as the `model` directory in the hierarchy.
Inside this directory, you can place any additional python packages that your would like to reference inside your truss.
For example, your packages directory might look like this:
```
stable-diffusion/
packages/
package_1/
subpackage/
script.py
package_2/
utils.py
another_script.py
model/
model.py
__init__.py
config.yaml
```
You can import these packages inside your `model.py` like so:
```python model.py
from package_1.subpackage.script import run_script
from package_2.utils import RandomClass
class Model:
def __init__(self, **kwargs):
random_class = RandomClass()
def load(self):
run_script()
...
...
...
```
These packages get bundled with your truss at build time. Because of this, it's ideal to use this method when your packages are small.
But what if you have multiple trusses that want to reference the same package? This is where the `external_package_dirs` comes in handy.
The `external_package_dirs` allows you to import packages that are *outside* your truss and hence allows multiple trusses to reference the same package.
Let's look at an example.
## Using the external packages directory
Let's say you have the following setup:
```
stable-diffusion/
model/
model.py
__init__.py
config.yaml
super_cool_awesome_plugin/
plugin1/
script.py
plugin2/
run.py
```
In this case the package you want to import, `super_cool_awesome_plugin`, is outside the truss. You could move the `super_cool_awesome_plugin` directory inside the `packages` directory if you wanted to, but there is another option.
Inside the `config.yaml` you can specify the path to external packages by using the key `external_package_dirs`. Under this key, you can provide a list of external packages that you would like to use in your truss.
Here is what that would look like for the example above:
```yaml config.yaml
environment_variables: {}
external_package_dirs:
- ../super_cool_awesome_plugin/
model_name: Stable Diffusion
python_version: py39
...
...
...
```
### Configuring the external\_package\_dirs path
The path of the external packages must be relative to the config.yaml file.
So `super_cool_awesome_plugin/` is parallel to `stable-diffusion/`, but it's one directory up from the config.yaml so we use `../super_cool_awesome_plugin`.
Here's how you can reference your packages inside `model.py`:
```python model.py
from plugin1.script import cool_constant
from plugin2.run import AwesomeRunner
class Model:
def __init__(self, **kwargs):
awesome_runner = AwesomeRunner()
def load(self):
awesome_runner.run(cool_constant)
...
...
...
```
Depending on the use-case either of these techniques can be used. If you have a one-off package that your truss needs, consider using the `packages` directory.
On the other hand, if you have a common package that will get used by multiple trusses, `external_package_dirs` is the better option.
# Caching model weights
Accelerate cold starts by caching your weights
Truss natively supports automatic caching for model weights. This is a simple yet effective strategy to enhance deployment speed and operational efficiency when it comes to cold starts and scaling beyond a single replica.
### What is a "cold start"?
"Cold start" is a term used to refer to the time taken by a model to boot up after being idle. This process can become a critical factor in serverless environments, as it can significantly influence the model response time, customer satisfaction, and cost.
Without caching our model's weights, we would need to download weights every time we scale up. Caching model weights circumvents this download process. When our new instance boots up, the server automatically finds the cached weights and can proceed with starting up the endpoint.
In practice, this reduces the cold start for large models to just a few seconds. For example, Stable Diffusion XL can take a few minutes to boot up without caching. With caching, it takes just under 10 seconds.
## Enabling Caching for a Model
To enable caching, simply add `model_cache` to your `config.yml` with a valid `repo_id`. The `model_cache` has a few key configurations:
* `repo_id` (required): The endpoint for your cloud bucket. Currently, we support Hugging Face and Google Cloud Storage.
* `allow_patterns`: Only cache files that match specified patterns. Utilize Unix shell-style wildcards to denote these patterns.
* `ignore_patterns`: Conversely, you can also denote file patterns to ignore, hence streamlining the caching process.
We recently renamed `hf_cache` to `model_cache`, but don't worry! If you're using `hf_cache` in any of your projects, it will automatically be aliased to `model_cache`.
Here is an example of a well written `model_cache` for Stable Diffusion XL. Note how it only pulls the model weights that it needs using `allow_patterns`.
```yaml config.yml
...
model_cache:
- repo_id: madebyollin/sdxl-vae-fp16-fix
allow_patterns:
- config.json
- diffusion_pytorch_model.safetensors
- repo_id: stabilityai/stable-diffusion-xl-base-1.0
allow_patterns:
- "*.json"
- "*.fp16.safetensors"
- sd_xl_base_1.0.safetensors
- repo_id: stabilityai/stable-diffusion-xl-refiner-1.0
allow_patterns:
- "*.json"
- "*.fp16.safetensors"
- sd_xl_refiner_1.0.safetensors
...
```
Many Hugging Face repos have model weights in different formats (`.bin`, `.safetensors`, `.h5`, `.msgpack`, etc.). You only need one of these most of the time. To minimize cold starts, ensure that you only cache the weights you need.
### Cache invalidation
Cached model weights are bundled at build time. Thus, the only way to re-cache weights is to trigger a rebuild by creating a new deployment with `truss push`. There is not currently a mechanism for invalidating cached model weights on an existing model.
There are also some additional steps depending on the cloud bucket you want to query.
### Hugging Face 🤗
For any public Hugging Face repo, you don't need to do anything else. Adding the `model_cache` key with an appropriate `repo_id` should be enough.
However, if you want to deploy a model from a gated repo like [Llama 2](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) to Baseten, there's a few steps you need to take:
[Grab an API key](https://huggingface.co/settings/tokens) from Hugging Face with `read` access. Make sure you have access to the model you want to serve.
Paste your API key in your [secrets manager in Baseten](https://app.baseten.co/settings/secrets) under the key `hf_access_token`. You can read more about secrets [here](/truss/guides/secrets).
In your Truss's `config.yml`, add the following code:
```yaml config.yml
...
secrets:
hf_access_token: null
...
```
Make sure that the key `secrets` only shows up once in your `config.yml`.
If you run into any issues, run through all the steps above again and make sure you did not misspell the name of the repo or paste an incorrect API key.
Weights will be cached in the default Hugging Face cache directory, `~/.cache/huggingface/hub/models--{your_model_name}/`. You can change this directory by setting the `HF_HOME` or `HUGGINGFACE_HUB_CACHE` environment variable in your `config.yml`.
[Read more here](https://huggingface.co/docs/huggingface_hub/guides/manage-cache).
### Google Cloud Storage
Google Cloud Storage is a great alternative to Hugging Face when you have a custom model or fine-tune you want to gate, especially if you are already using GCP and care about security and compliance.
Your `model_cache` should look something like this:
```yaml config.yml
...
model_cache:
- repo_id: gs://path-to-my-bucket
...
```
If you are accessing a public GCS bucket, you can ignore the following steps, but make sure you set appropriate permissions on your bucket. Users should be able to list and view all files. Otherwise, the model build will fail.
For a private GCS bucket, first export your service account key. Rename it to be `service_account.json` and add it to the `data` directory of your Truss.
Your file structure should look something like this:
```
your-truss
|--model
| └── model.py
|--data
|. └── service_account.json
```
If you are using version control, like git, for your Truss, make sure to add `service_account.json` to your `.gitignore` file. You don't want to accidentally expose your service account key.
Weights will be cached at `/app/model_cache/{your_bucket_name}`.
### Amazon Web Services S3
Another popular cloud storage option for hosting model weights is AWS S3, especially if you're already using AWS services.
Your `model_cache` should look something like this:
```yaml config.yml
...
model_cache:
- repo_id: s3://path-to-my-bucket
...
```
If you are accessing a public S3 bucket, you can ignore the subsequent steps, but make sure you set an appropriate policy on your bucket. Users should be able to list and view all files. Otherwise, the model build will fail.
However, for a private S3 bucket, you need to first find your `aws_access_key_id`, `aws_secret_access_key`, and `aws_region` in your AWS dashboard. Create a file named `s3_credentials.json`. Inside this file, add the credentials that you identified earlier as shown below. Place this file into the `data` directory of your Truss.
The key `aws_session_token` can be included, but is optional.
Here is an example of how your `s3_credentials.json` file should look:
```json
{
"aws_access_key_id": "YOUR-ACCESS-KEY",
"aws_secret_access_key": "YOUR-SECRET-ACCESS-KEY",
"aws_region": "YOUR-REGION"
}
```
Your overall file structure should now look something like this:
```
your-truss
|--model
| └── model.py
|--data
|. └── s3_credentials.json
```
When you are generating credentials, make sure that the resulting keys have at minimum the following IAM policy:
```json
{
"Version": "2012-10-17",
"Statement": [
{
"Action": [
"s3:GetObject",
"s3:ListObjects",
],
"Effect": "Allow",
"Resource": ["arn:aws:s3:::S3_BUCKET/PATH_TO_MODEL/*"]
},
{
"Action": [
"s3:ListBucket",
],
"Effect": "Allow",
"Resource": ["arn:aws:s3:::S3_BUCKET"]
}
]
}
```
If you are using version control, like git, for your Truss, make sure to add `s3_credentials.json` to your `.gitignore` file. You don't want to accidentally expose your service account key.
Weights will be cached at `/app/model_cache/{your_bucket_name}`.
### Other Buckets
We can work with you to support additional bucket types if needed. If you have any suggestions, please [leave an issue](https://github.com/basetenlabs/truss/issues) on our GitHub repo.
# Pre/post-processing
Deploy a model that makes use of pre-process
Out of the box, Truss limits the amount of concurrent predicts that happen on
single container. This ensures that the CPU, and for many models the GPU, do not get
overloaded, and that the model can continue respond to requests in periods of high load
However, many models, in addition to having compute components, also have
IO requirements. For example, a model that classifies images may need to download
the image from a URL before it can classify it.
Truss provides a way to separate the IO component from the compute component, to
ensure that any IO does not prevent utilization of the compute on your pod.
To do this, you can use the pre/post process methods on a Truss. These methods
can be defined like this:
```python
class Model:
def __init__: ...
def load(self, **kwargs) -> None: ...
def preprocess(self, request):
# Include any IO logic that happens _before_ predict here
...
def predict(self, request):
# Include the actual predict here
...
def postprocess(self, response):
# Include any IO logic that happens _after_ predict here
...
```
What happens when the model is invoked is that any logic defined in the pre or post-process
methods happen on a separate thread, and are not subject to the same concurrency limits as
predict. So -- let's say you have a model that can handle 5 concurrent requests:
```config.yaml
...
runtime:
predict_concurrency: 5
...
```
If you hit it with 10 requests, they will *all* begin pre-processing, but then when the
the 6th request is ready to begin the predict method, it will have to wait for one of the
first 5 requests to finish. This ensures that the GPU is not overloaded, while also ensuring
that the compute logic does not get blocked by IO, thereby ensuring that you can achieve
maximum throughput.
When `predict` returns a generator (e.g. for streaming LLM outputs),
the model must not have a `postprocessing` method
It can only be used when the prediction result is instantly available as a
whole. In case of streaming, move any postprocessing logic into `predict` or
apply it client-side.
```python model/model.py
import requests
from typing import Dict
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
CHECKPOINT = "openai/clip-vit-base-patch32"
class Model:
"""
This is simple example of using CLIP to classify images.
It outputs the probability of the image being a cat or a dog.
"""
def __init__(self, **kwargs) -> None:
self._processor = None
self._model = None
def load(self):
"""
Loads the CLIP model and processor checkpoints.
"""
self._model = CLIPModel.from_pretrained(CHECKPOINT)
self._processor = CLIPProcessor.from_pretrained(CHECKPOINT)
def preprocess(self, request: Dict) -> Dict:
""""
This method downloads the image from the url and preprocesses it.
The preprocess method is used for any logic that involves IO, in this
case downloading the image. It is called before the predict method
in a separate thread and is not subject to the same concurrency
limits as the predict method, so can be called many times in parallel.
"""
image = Image.open(requests.get(request.pop("url"), stream=True).raw)
request["inputs"] = self._processor(
text=["a photo of a cat", "a photo of a dog"],
images=image,
return_tensors="pt",
padding=True
)
return request
def predict(self, request: Dict) -> Dict:
"""
This performs the actual classification. The predict method is subject to
the predict concurrency constraints.
"""
outputs = self._model(**request["inputs"])
logits_per_image = outputs.logits_per_image
return logits_per_image.softmax(dim=1).tolist()
```
```yaml config.yaml
model_name: clip-example
requirements:
- transformers==4.32.0
- pillow==10.0.0
- torch==2.0.1
resources:
cpu: "3"
memory: 14Gi
use_gpu: true
accelerator: A10G
```
# Private Hugging Face model
Load a model that requires authentication with Hugging Face
## Summary
To load a gated or private model from Hugging Face:
1. Create an [access token](https://huggingface.co/settings/tokens) on your Hugging Face account.
2. Add the `hf_access_token` key to your `config.yaml` secrets and value to your [Baseten account](https://app.baseten.co/settings/secrets).
3. Add `use_auth_token` to the appropriate line in `model.py`.
Example code:
```yaml config.yaml
secrets:
hf_access_token: null
```
```python model/model.py
self._model = pipeline(
"fill-mask",
model="baseten/docs-example-gated-model",
use_auth_token=self._secrets["hf_access_token"]
)
```
## Step-by-step example
[BERT base (uncased)](https://huggingface.co/bert-base-uncased) is a masked language model that can be used to infer missing words in a sentence.
While the model is publicly available on Hugging Face, we copied it into a gated model to use in this tutorial. The process is the same for using a gated model as it is for a private model.
You can see the code for the finished private model Truss on the right. Keep reading for step-by-step instructions on how to build it.
This example will cover:
1. Implementing a `transformers.pipeline` model in Truss
2. **Securely accessing secrets in your model server**
3. **Using a gated or private model with an access token**
### Step 0: Initialize Truss
Get started by creating a new Truss:
```sh
truss init private-bert
```
Give your model a name when prompted, like `Private Model Demo`. Then, navigate to the newly created directory:
```sh
cd private-bert
```
### Step 1: Implement the `Model` class
BERT base (uncased) is [a pipeline model](https://huggingface.co/docs/transformers/main_classes/pipelines), so it is straightforward to implement in Truss.
In `model/model.py`, we write the class `Model` with three member functions:
* `__init__`, which creates an instance of the object with a `_model` property
* `load`, which runs once when the model server is spun up and loads the `pipeline` model
* `predict`, which runs each time the model is invoked and handles the inference. It can use any JSON-serializable type as input and output.
[Read the quickstart guide](/quickstart) for more details on `Model` class implementation.
```python model/model.py
from transformers import pipeline
class Model:
def __init__(self, **kwargs) -> None:
self._secrets = kwargs["secrets"]
self._model = None
def load(self):
self._model = pipeline(
"fill-mask",
model="baseten/docs-example-gated-model"
)
def predict(self, model_input):
return self._model(model_input)
```
### Step 2: Set Python dependencies
Now, we can turn our attention to configuring the model server in `config.yaml`.
BERT base (uncased) has two dependencies:
```yaml config.yaml
requirements:
- torch==2.0.1
- transformers==4.30.2
```
Always pin exact versions for your Python dependencies. The ML/AI space moves fast, so you want to have an up-to-date version of each package while also being protected from breaking changes.
### Step 3: Set required secret
Now it's time to mix in access to the gated model:
1. Go to the [model page on Hugging Face](https://huggingface.co/baseten/docs-example-gated-model) and accept the terms to access the model.
2. Create an [access token](https://huggingface.co/settings/tokens) on your Hugging Face account.
3. Add the `hf_access_token` key and value to your [Baseten workspace secret manager](https://app.baseten.co/settings/secrets).
4. In your `config.yaml`, add the key `hf_access_token`:
```yaml config.yaml
secrets:
hf_access_token: null
```
Never set the actual value of a secret in the `config.yaml` file. Only put secret values in secure places, like the Baseten workspace secret manager.
### Step 4: Use access token in load
In `model/model.py`, you can give your model access to secrets in the init function:
```python model/model.py
def __init__(self, **kwargs) -> None:
self._secrets = kwargs["secrets"]
self._model = None
```
Then, update the load function with `use_auth_token`:
```python model/model.py
self._model = pipeline(
"fill-mask",
model="baseten/docs-example-gated-model",
use_auth_token=self._secrets["hf_access_token"]
)
```
This will allow the `pipeline` function to load the specified model from Hugging Face.
### Step 5: Deploy the model
You'll need a [Baseten API key](https://app.baseten.co/settings/account/api_keys) for this step.
We have successfully packaged a gated model as a Truss. Let's deploy!
Use `--trusted` with `truss push` to give the model server access to secrets stored on the remote host.
```sh
truss push --trusted
```
Wait for the model to finish deployment before invoking.
You can invoke the model with:
```sh
truss predict -d '"It is a [MASK] world"'
```
```yaml config.yaml
environment_variables: {}
model_name: private-model
python_version: py39
requirements:
- torch==2.0.1
- transformers==4.30.2
resources:
cpu: "1"
memory: 2Gi
use_gpu: false
accelerator: null
secrets:
hf_access_token: null
system_packages: []
```
```python model/model.py
from transformers import pipeline
class Model:
def __init__(self, **kwargs) -> None:
self._secrets = kwargs["secrets"]
self._model = None
def load(self):
self._model = pipeline(
"fill-mask",
model="baseten/docs-example-gated-model",
use_auth_token=self._secrets["hf_access_token"]
)
def predict(self, model_input):
return self._model(model_input)
```
# Using request objects / Cancellation
Get more control by directly using the request object.
Classically, the truss server accepts the client request and takes care of
extracting and optionally validating the payload (if you use `pydantic`
annotations for your input).
In advanced use case you might want to access the raw request object.
Examples are:
* Custom payload deserialization, e.g. binary protocol buffers.
* Handling of closed connections and cancelling long-running predictions.
There is likewise support for [returning response objects](/truss/guides/responses).
You can flexibly mix and match using requests and the "classic" input argument.
E.g.
```python
import fastapi
class Model:
...
def preprocess(self, inputs):
def preprocess(self, inputs, request: fastapi.Request):
def preprocess(self, request: fastapi.Request):
...
def predict(self, inputs):
def predict(self, inputs, request: fastapi.Request):
def predict(self, request: fastapi.Request):
...
def postprocess(self, inputs):
def postprocess(self, inputs, request: fastapi.Request):
...
```
The request argument can have any name, but it **must** be type-annotated as a
subclass of `starlette.requests.Request`, e.g. `fastapi.Request`.
If you *only* use requests, the truss server skips extracting the payload
completely to improve performance.
When mixing requests and "classic" inputs, the following rules apply:
* When using both arguments, request must be the second (and type-annotated).
* Argument names don't matter - but for clarity we recommend naming
"classic" inputs `inputs` (or `data`) and not `request`.
* When none of the three methods `preprocess`, `predict` and `postprocess` need
the classic inputs, the payload parsing by truss server is skipped.
* The request stays the same through the sequence of `preprocess`,
`predict` and `postprocess`. I.e. for example `preprocessing` only affects
the classic inputs, and `predict` will receive transformed inputs, but the
identical request object.
* `postprocess` cannot use only the request - because that would mean the
output of `predict` is discarded.
* Likewise, if `predict` only uses the request, you cannot have a
`preprocessing` method, because its outputs would be discarded.
## Cancelling predictions
If you make long-running predictions, such as LLM generation, and the client
drops the connection, it is useful to cancel the processing and free up the
server to handle other requests. You can use the `is_disconnected` method of
the request to check for dropped connections.
Here is a simple example:
```python
import fastapi, asyncio, logging
class Model:
async def predict(self, inputs, request: fastapi.Request):
await asyncio.sleep(1)
if await request.is_disconnected():
logging.warning("Cancelled (before gen).")
# Cancel the request on the model engine here.****
return
for i in range(5):
await asyncio.sleep(1.0)
logging.warning(i)
yield str(i)
if await request.is_disconnected():
logging.warning("Cancelled (during gen).")
# Cancel the request on the model engine here.
return
```
In a real example you must add cancelling the model engine, which
looks different depending on the framework used. Below are some examples.
The TRT-LLM example additionally showcases usage of an async polling
task for checking the connection, so you can cancel *between* yields from
the model and are not restricted by how long those yields take. The same logic
can also be used for other frameworks.
If you serve models with TRT LLM, you can use the `cancel` API of the response
generator.
```python
import asyncio
import json
import logging
from typing import AsyncGenerator, Awaitable, Callable
import tritonclient.grpc.aio as grpcclient
GRPC_SERVICE_PORT = 8001
logger = logging.getLogger(__name__)
class TritonClient:
def __init__(self, grpc_service_port: int = GRPC_SERVICE_PORT):
self.grpc_service_port = grpc_service_port
self._grpc_client = None
def start_grpc_stream(self) -> grpcclient.InferenceServerClient:
if self._grpc_client:
return self._grpc_client
self._grpc_client = grpcclient.InferenceServerClient(
url=f"localhost:{self.grpc_service_port}", verbose=False
)
return self._grpc_client
async def infer(
self,
model_input,
is_cancelled_fn: Callable[[], Awaitable[bool]],
model_name="ensemble"
) -> AsyncGenerator[str, None]:
grpc_client_instance = self.start_grpc_stream()
inputs = model_input.to_tensors()
async def input_generator():
yield {
"model_name": model_name,
"inputs" : inputs,
"request_id": model_input.request_id,
}
stats = await grpc_client_instance.get_inference_statistics()
logger.info(stats)
response_iterator = grpc_client_instance.stream_infer(
inputs_iterator=input_generator(),
)
if await is_cancelled_fn():
logging.info(
"Request cancelled before streaming. Cancelling Triton request.")
response_iterator.cancel()
return
# Below wraps the iteration of `response_iterator` into asyncio tasks, so that
# we can poll `is_cancelled_fn` 1/sec, even if the iterator is slower - i.e.
# we are not blocked / limited by the iterator when cancelling.
try:
gen_task = asyncio.ensure_future(response_iterator.__anext__())
while True:
done_task, _ = await asyncio.wait([gen_task], timeout=1)
if await is_cancelled_fn():
logging.info("Request cancelled. Cancelling Triton request.")
response_iterator.cancel()
gen_task.cancel()
return
if done_task:
try:
response = await gen_task
except StopAsyncIteration:
# response_iterator is exhausted, breaking `while True` loop.
return
result, error = response
if result:
result = result.as_numpy("text_output")
yield result[0].decode("utf-8")
else: # Error.
yield json.dumps(
{"status": "error", "message": error.message()})
return
gen_task = asyncio.ensure_future(response_iterator.__anext__())
except grpcclient.InferenceServerException as e:
logger.error(f"InferenceServerException: {e}")
```
If you serve models with vLLM, you can use the `abort` API of the engine.
Here is a minimal code snippet from
[their documentation](https://docs.vllm.ai/en/latest/dev/engine/async_llm_engine.html#vllm.AsyncLLMEngine.generate).
```python
# Please refer to entrypoints/api_server.py for
# the complete example.
# initialize the engine and the example input
engine = AsyncLLMEngine.from_engine_args(engine_args)
example_input = {
"prompt": "What is LLM?",
"stream": False, # assume the non-streaming case
"temperature": 0.0,
"request_id": 0,
}
# start the generation
results_generator = engine.generate(
example_input["prompt"],
SamplingParams(temperature=example_input["temperature"]),
example_input["request_id"])
# get the results
final_output = None
async for request_output in results_generator:
if await request.is_disconnected():
# Abort the request if the client disconnects.
await engine.abort(request_id)
# Return or raise an error
...
final_output = request_output
# Process and return the final output
...
```
The following features of requests are not supported:
* Streaming inputs ("file upload"). For large input data, provide URLs to
download source data instead of packing it into the request.
* Headers: most client-side headers are stripped from the requests reaching
the model. Add any information to control model behavior to the request payload.
# Returning response objects and SSEs
Get more control by directly creating the response object.
Classically, the truss server wraps the prediction results of your custom model
into a response object to be sent back via HTTP to the client.
In advanced use case you might want to create these response objects
yourself. Example use cases are:
* Control over the HTTP status codes.
* With streaming responses, you can use server-side-events (SSEs).
There is likewise support for
[using request objects](/truss/guides/requests).
```python
import fastapi
class Model:
def predict(self, inputs) -> fastapi.Response:
return fastapi.Response(...)
```
You can return a response from either `predict` or `postprocess` and
any subclasses from `starlette.responses.Response` are supported.
If you return a response from `predict`, you cannot use
`postprocessing`.
## SSE / Streaming example
```python
from starlette.responses import StreamingResponse
class Model:
def predict(self, model_input):
def event_stream():
while True:
time.sleep(1)
yield ("data: Server Time: "
f"{time.strftime('%Y-%m-%d %H:%M:%S')}\n\n"
return StreamingResponse(event_stream(), media_type="text/event-stream")
```
Response headers are not fully propagated. Include all information in the
response itself.
# Storing secrets in Baseten
A guide to using secrets securely in your ML models
Your model server may need to use access tokens, API keys, passwords, or other secret values. Truss gives you everything you need to use secrets securely.
## Setting secrets in `config.yaml`
If your model needs a secret, first add its name in `config.yaml` with a placeholder value:
```yaml config.yaml
secrets:
hf_access_token: null
```
Never set the actual value of a secret in the `config.yaml` file. Only put secret values in secure places, like the Baseten workspace secret manager.
## Using secrets in `model.py`
Secrets are passed to your `Model` class as a keyword argument in `init`. They can be accessed with:
```py model/model.py
def __init__(self, **kwargs):
self._secrets = kwargs["secrets"]
```
You can then use the `self._secrets` dictionary in the `load` and `predict` functions:
```py model/model.py
def load(self):
self._model = pipeline(
"fill-mask",
model="baseten/docs-example-gated-model",
use_auth_token=self._secrets["hf_access_token"]
)
```
## Storing secrets on your remote
On your remote host, such as your Baseten account, store both the secret name and value before deploying your model. On Baseten, you can add secrets to your workspace on the [secrets workspace settings page](https://app.baseten.co/settings/secrets).
Make sure to use the same name (case sensitive) as used in the Truss on the remote.
## Deploying with secrets
For additional security, models don't have access to secrets by default. To deploy a model and give it access to secrets, pass the `--trusted` flag during `truss push` as follows:
```sh
truss push --trusted
```
Your model will be deployed with access to secrets stored on your remote.
# Streaming output with an LLM
Deploy an LLM and stream the output
The worst part of using generative AI tools is the long wait time during model inference. For some types of generative models, including large language models (LLMs), you can start getting results 10X faster by streaming model output as it is generated.
LLMs have two properties that make streaming output particularly useful:
1. Generating a complete response takes time, easily 10 seconds or more for longer outputs
2. Partial outputs are often useful!
When you host your LLMs with Baseten, you can stream responses. Instead of having to wait for the entire output to be generated, you can immediately start returning results to users with a sub-one-second time-to-first-token.
In this example, we will show you how to deploy [Falcon 7B](https://huggingface.co/tiiuae/falcon-7b), an LLM, and stream the output as it is generated.
You can see the code for the finished Falcon 7B Truss on the right. Keep reading for step-by-step instructions on how to build it.
### Step 0: Initialize Truss
Get started by creating a new Truss:
```sh
truss init falcon-7b
```
Give your model a name when prompted, like `falcon-streaming`. Then, navigate to the newly created directory:
```sh
cd falcon-7b
```
### Step 1: Set up the `Model` class without streaming
As mentioned before, Falcon 7B is an LLM. We will use the Huggingface Transformers library to
load and run the model. In this first step, we will generate output normally and return it without streaming the output.
In `model/model.py`, we write the class `Model` with three member functions:
* `__init__`, which creates an instance of the object with a `_model` property
* `load`, which runs once when the model server is spun up and loads the `pipeline` model
* `predict`, which runs each time the model is invoked and handles the inference. It can use any JSON-serializable type as input and output for non-streaming outputs.
[Read the quickstart guide](/quickstart) for more details on `Model` class implementation.
```python model/model.py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, TextIteratorStreamer
from typing import Dict
from threading import Thread
CHECKPOINT = "tiiuae/falcon-7b-instruct"
DEFAULT_MAX_NEW_TOKENS = 150
DEFAULT_TOP_P = 0.95
class Model:
def __init__(self, **kwargs) -> None:
self.tokenizer = None
self.model = None
def load(self):
self.tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
self.tokenizer.pad_token = self.tokenizer.eos_token_id
self.model = AutoModelForCausalLM.from_pretrained(
CHECKPOINT,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto",
)
def predict(self, request: Dict) -> Dict:
prompt = request.pop("prompt")
# The steps in producing an output are to:
# 1. Tokenize the input
# 2. Set up generation parameters
# 3. Call the model.generate function
inputs = self.tokenizer(
prompt,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True
)
input_ids = inputs["input_ids"].to("cuda")
# These generation parameters can be tuned
# to better produce the output that you are looking for.
generation_config = GenerationConfig(
temperature=1,
top_p=DEFAULT_TOP_P,
top_k=40,
)
with torch.no_grad():
generation_kwargs = {
"input_ids": input_ids,
"generation_config": generation_config,
"return_dict_in_generate": True,
"output_scores": True,
"pad_token_id": self.tokenizer.eos_token_id,
"max_new_tokens": DEFAULT_MAX_NEW_TOKENS,
}
return self.model.generate(
**generation_kwargs
)
```
### Step 2: Add streaming support
Once we have a model that can produce the LLM outputs using the HuggingFace transformers library, we can adapt it to support streaming. The key change that needs to happen here is in the `predict` function.
While in the above example, the `predict` function returns a `Dict` containing the model output, to stream results, we need to return a Python `Generator` from the `predict` function instead. This will allow us to return partial results to the user as they are generated.
To produce outputs incrementally for the LLM, we will pass a `TextIteratorStreamer` object to the `generate` function. This object will return the model output as it is generated. We will then kick off the generation on a separate thread.
What we return from the `predict` function is a generator that will yield the model output from the streamer object as it is generated.
```python model/model.py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, TextIteratorStreamer
from typing import Dict
from threading import Thread
CHECKPOINT = "tiiuae/falcon-7b-instruct"
DEFAULT_MAX_NEW_TOKENS = 150
DEFAULT_TOP_P = 0.95
class Model:
def __init__(self, **kwargs) -> None:
self.tokenizer = None
self.model = None
def load(self):
self.tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
self.tokenizer.pad_token = self.tokenizer.eos_token_id
self.model = AutoModelForCausalLM.from_pretrained(
CHECKPOINT,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto",
)
def predict(self, request: Dict) -> Dict:
prompt = request.pop("prompt")
inputs = self.tokenizer(
prompt,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True
)
input_ids = inputs["input_ids"].to("cuda")
streamer = TextIteratorStreamer(self.tokenizer)
generation_config = GenerationConfig(
temperature=1,
top_p=DEFAULT_TOP_P,
top_k=40,
)
with torch.no_grad():
generation_kwargs = {
"input_ids": input_ids,
"generation_config": generation_config,
"return_dict_in_generate": True,
"output_scores": True,
"pad_token_id": self.tokenizer.eos_token_id,
"max_new_tokens": DEFAULT_MAX_NEW_TOKENS,
"streamer": streamer
}
thread = Thread(
target=self.model.generate,
kwargs=generation_kwargs
)
thread.start()
def inner():
for text in streamer:
yield text
thread.join()
return inner()
```
## Step 3: Add remainder of Truss configuration
Once we have the model code written -- the next thing we need to do before we deploy is
make sure that we have the rest of the Truss configuration in place.
The only things we need to add to the config.yaml are the Python and hardware requirements
for the model.
```yaml config.yaml
model_name: falcon-streaming
requirements:
- torch==2.0.1
- peft==0.4.0
- scipy==1.11.1
- sentencepiece==0.1.99
- accelerate==0.21.0
- bitsandbytes==0.41.1
- einops==0.6.1
- transformers==4.31.0
resources:
cpu: "3"
memory: 14Gi
use_gpu: true
accelerator: A10G
```
### Step 4: Deploy the model
You'll need a [Baseten API key](https://app.baseten.co/settings/account/api_keys) for this step.
We have successfully packaged Falcon as a Truss. Let's deploy! Run:
```sh
truss push
```
### Step 5: Invoke the model
You can invoke the model with:
```sh
truss predict -d '{"prompt": "Tell me about falcons", "do_sample": true}'
```
```python model/model.py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, TextIteratorStreamer
from typing import Dict
from threading import Thread
CHECKPOINT = "tiiuae/falcon-7b-instruct"
DEFAULT_MAX_NEW_TOKENS = 150
DEFAULT_TOP_P = 0.95
class Model:
def __init__(self, **kwargs) -> None:
self.tokenizer = None
self.model = None
def load(self):
self.tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
self.tokenizer.pad_token = self.tokenizer.eos_token_id
self.model = AutoModelForCausalLM.from_pretrained(
CHECKPOINT,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="auto",
)
def predict(self, request: Dict) -> Dict:
prompt = request.pop("prompt")
inputs = self.tokenizer(
prompt,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True
)
input_ids = inputs["input_ids"].to("cuda")
streamer = TextIteratorStreamer(self.tokenizer)
generation_config = GenerationConfig(
temperature=1,
top_p=DEFAULT_TOP_P,
top_k=40,
)
with torch.no_grad():
generation_kwargs = {
"input_ids": input_ids,
"generation_config": generation_config,
"return_dict_in_generate": True,
"output_scores": True,
"pad_token_id": self.tokenizer.eos_token_id,
"max_new_tokens": DEFAULT_MAX_NEW_TOKENS,
"streamer": streamer
}
# Kick off a new thread to execute the model generation.
# As the model generates outputs, they will be readable
# from the Streamer object.
thread = Thread(
target=self.model.generate,
kwargs=generation_kwargs
)
thread.start()
# We return a generator that iterates over content in the
# streamer object.
def inner():
for text in streamer:
yield text
thread.join()
return inner()
```
```yaml config.yaml
model_name: falcon-streaming
requirements:
- torch==2.0.1
- peft==0.4.0
- scipy==1.11.1
- sentencepiece==1.11.1
- accelerate==0.21.0
- bitsandbytes==0.41.1
- einops==0.6.1
- transformers==4.31.0
resources:
cpu: "3"
memory: 14Gi
use_gpu: true
accelerator: A10G
```
# Model with system packages
Deploy a model with both Python and system dependencies
## Summary
To add system packages to your model serving environment, open `config.yaml` and update the `system_packages` key with a list of apt-installable Debian packages.
Example code:
```yaml config.yaml
system_packages:
- tesseract-ocr
```
## Step-by-step example
[LayoutLM Document QA](https://huggingface.co/impira/layoutlm-document-qa) is a multimodal model that answers questions about provided invoice documents.
The model requires a system package, `tesseract-ocr`, which we need to include in the model serving environment.
You can see the code for the finished LayoutLM Document QA Truss on the right. Keep reading for step-by-step instructions on how to build it.
This example will cover:
1. Implementing a `transformers.pipeline` model in Truss
2. Adding Python requirements to the Truss config
3. **Adding system requirements to the Truss config**
4. Setting sufficient model resources for inference
### Step 0: Initialize Truss
Get started by creating a new Truss:
```sh
truss init layoutlm-document-qa
```
Give your model a name when prompted, like `LayoutLM Document QA`. Then, navigate to the newly created directory:
```sh
cd layoutlm-document-qa
```
### Step 1: Implement the `Model` class
LayoutLM Document QA is [a pipeline model](https://huggingface.co/docs/transformers/main_classes/pipelines), so it is straightforward to implement in Truss.
In `model/model.py`, we write the class `Model` with three member functions:
* `__init__`, which creates an instance of the object with a `_model` property
* `load`, which runs once when the model server is spun up and loads the `pipeline` model
* `predict`, which runs each time the model is invoked and handles the inference. It can use any JSON-serializable type as input and output.
[Read the quickstart guide](/quickstart) for more details on `Model` class implementation.
```python model/model.py
from transformers import pipeline
class Model:
def __init__(self, **kwargs) -> None:
self._model = None
def load(self):
# Load the model from Hugging Face
self._model = pipeline(
"document-question-answering",
model="impira/layoutlm-document-qa",
)
def predict(self, model_input):
# Invoke the model and return the results
return self._model(
model_input["url"],
model_input["prompt"]
)
```
### Step 2: Set Python dependencies
Now, we can turn our attention to configuring the model server in `config.yaml`.
In addition to `transformers`, LayoutLM Document QA has three other dependencies. We list them below as follows:
```yaml config.yaml
requirements:
- Pillow==10.0.0
- pytesseract==0.3.10
- torch==2.0.1
- transformers==4.30.2
```
Always pin exact versions for your Python dependencies. The ML/AI space moves fast, so you want to have an up-to-date version of each package while also being protected from breaking changes.
### Step 3: Install system packages
One of the Python dependencies, `pytesseract`, also requires a system package to operate.
Adding system packages works just like adding Python requirements. You can specify any package that's available via `apt` on Debian.
```yaml config.yaml
system_packages:
- tesseract-ocr
```
### Step 4: Configure model resources
LayoutLM Document QA doesn't require a GPU, but you'll need a midrange CPU instance if you want reasonably fast invocation times. 4 CPU cores and 16 GiB of RAM is sufficient for the model.
Model resources are also set in `config.yaml` and must be specified before you deploy the model.
```yaml config.yaml
resources:
cpu: "4"
memory: 16Gi
use_gpu: false
accelerator: null
```
### Step 5: Deploy the model
You'll need a [Baseten API key](https://app.baseten.co/settings/account/api_keys) for this step.
We have successfully packaged LayoutLM Document QA as a Truss. Let's deploy!
```sh
truss push
```
You can invoke the model with:
```sh
truss predict -d '{"url": "https://templates.invoicehome.com/invoice-template-us-neat-750px.png", "prompt": "What is the invoice number?"}'
```
```yaml config.yaml
environment_variables: {}
external_package_dirs: []
model_metadata: {}
model_name: LayoutLM Document QA
python_version: py39
requirements:
- Pillow==10.0.0
- pytesseract==0.3.10
- torch==2.0.1
- transformers==4.30.2
resources:
cpu: "4"
memory: 16Gi
use_gpu: false
accelerator: null
secrets: {}
system_packages:
- tesseract-ocr
```
```python model/model.py
from transformers import pipeline
class Model:
def __init__(self, **kwargs) -> None:
self._model = None
def load(self):
self._model = pipeline(
"document-question-answering",
model="impira/layoutlm-document-qa",
)
def predict(self, model_input):
return self._model(
model_input["url"],
model_input["prompt"]
)
```
# Overview
Truss: Package and deploy AI models on Baseten
At a high level, model deployment has three phases:
1. Get model weights for an open source, fine-tuned, or custom-built AI/ML model.
2. Implement a model server.
3. Run that model server in a container on the cloud behind a secure API endpoint.
It's easy to download weights for an open source model for the first step, and Baseten entirely automates the third step (cloud deployment). But that second step, implementing the model server, is more complex.
To make it easier for AI engineers to write model serving code, we built [Truss](https://github.com/basetenlabs/truss).
## Truss: a model server abstraction
[Truss](https://github.com/basetenlabs/truss) is an open source framework for writing model server code in Python.
Truss gives you:
* The ability to create a containerized model server without learning Docker.
* An enjoyable and productive dev loop where you can test changes live in a remote development environment that closely mirrors production.
* Compatability across model frameworks like `torch`, `transformers`, and `diffusors`; engines like `TensorRT`/`TensorRT-LLM`, `VLLM`, and `TGI`; serving technologies like `Triton`; and any package you can install with `pip` or `apt`.
We built Truss because containerization technologies like Docker are incredibly powerful, but their abstractions are too general for the problems faced in model serving by AI and ML engineers. Building model-specific optimizations at the infrastructure layer is a distinct skillset to developing AI models, so Truss brings familiar Python-based tooling to the problem of model packaging to empower all developers to build production-ready AI model servers.
## Using the Truss CLI
To get started with Truss, install the Truss CLI. We recommend always using the latest version:
```sh
pip install --upgrade truss
```
Start by creating a new Truss for your model:
```sh
truss init
```
After implementing your model server in `model.py` and `config.yaml`, you can push your model to Baseten:
```sh
truss push
```
This creates a development deployment, which you can patch by saving changes to your Truss code while running:
```sh
truss watch
```
When your model is ready for production, you can promote your deployment with:
```sh
truss push --publish
```
See the [Truss CLI reference](/truss-reference/cli) for more commands and options.
## Live reload developer workflow
Waiting for your model server to build and deploy every time you make a change would be a painful developer experience. Instead, work on your model as a development deployment and your changes will be live in seconds.
When you run `truss push`, your model is automatically deployed as a development deployment.
When you make a change to a development deployment, your code update is patched onto the running server. This patching process skips building and deploying an image and just runs the `load()` command to reload model weights after making any necessary environment updates.
Development deployments are great for rapid iteration, but aren't suitable for production use. When you're ready to use your model in your application, [promote your deployment to production](/deploy/lifecycle).
## Example model implementations
Step-by-step examples present core concepts for model packaging and deployment.
Source code for dozens of models with various engines, quantizations, and implementations.
Production-ready models with usage documentation, source code, and one-click deployments.
## Model deployment guides
With Truss, you get all of the power and flexibility of Python. You can completely customize your model server behavior and environment so suit your needs. To get you started, we've written guides to common steps in model server implementation:
Loading the model:
* [Using private models from HF](/truss/guides/private-model)
* [Bundling model weights in Truss](/truss/guides/data-directory)
* [Caching model weights](/truss/guides/cached-weights)
Running the model:
* [Working with secrets](/truss/guides/secrets)
* [Pre- and post-processing on CPU](/truss/guides/pre-process)
* [Streaming model output](/truss/guides/streaming)
* [Setting concurrency](/truss/guides/concurrency)
Setting the environment:
* [Adding system packages](/truss/guides/system-packages)
* [Bundling custom Python code as dependencies](/truss/guides/external-packages)
* [Using custom Docker base images](/truss/guides/base-images)
# Welcome to Baseten!
Fast, scalable inference in our cloud or yours
Baseten provides the infrastructure to deploy and serve AI models performantly, scalably, and cost-efficiently. With Baseten, you can:
* **Deploy** any open source, fine-tuned, or custom AI/ML model as an API endpoint with [Truss](/deploy)
* **Optimize** model performance with cutting-edge engines like [TensorRT-LLM](/performance/engine-builder-overview)
* **Orchestrate** model inference and build multi-model [Chains](chains/overview)
* **Scale** from zero to the peak with fast cold starts and [autoscaling](deploy/autoscaling)
* **Manage** your deployed models with API access, logs, and metrics
Go from model weights to API endpoint
One-click deploys for popular models