Caching model weights
Accelerate cold starts by caching your weights
Truss natively supports automatic caching for model weights. This is a simple yet effective strategy to enhance deployment speed and operational efficiency when it comes to cold starts and scaling beyond a single replica.
What is a “cold start”?
“Cold start” is a term used to refer to the time taken by a model to boot up after being idle. This process can become a critical factor in serverless environments, as it can significantly influence the model response time, customer satisfaction, and cost.
Without caching our model’s weights, we would need to download weights every time we scale up. Caching model weights circumvents this download process. When our new instance boots up, the server automatically finds the cached weights and can proceed with starting up the endpoint.
In practice, this reduces the cold start for large models to just a few seconds. For example, Stable Diffusion XL can take a few minutes to boot up without caching. With caching, it takes just under 10 seconds.
Enabling Caching for a Model
To enable caching, simply add model_cache
to your config.yml
with a valid repo_id
. The model_cache
has a few key configurations:
repo_id
(required): The endpoint for your cloud bucket. Currently, we support Hugging Face and Google Cloud Storage.allow_patterns
: Only cache files that match specified patterns. Utilize Unix shell-style wildcards to denote these patterns.ignore_patterns
: Conversely, you can also denote file patterns to ignore, hence streamlining the caching process.
hf_cache
to model_cache
, but don’t worry! If you’re using hf_cache
in any of your projects, it will automatically be aliased to model_cache
.Here is an example of a well written model_cache
for Stable Diffusion XL. Note how it only pulls the model weights that it needs using allow_patterns
.
Many Hugging Face repos have model weights in different formats (.bin
, .safetensors
, .h5
, .msgpack
, etc.). You only need one of these most of the time. To minimize cold starts, ensure that you only cache the weights you need.
Cache invalidation
Cached model weights are bundled at build time. Thus, the only way to re-cache weights is to trigger a rebuild by creating a new deployment with truss push
. There is not currently a mechanism for invalidating cached model weights on an existing model.
There are also some additional steps depending on the cloud bucket you want to query.
Hugging Face 🤗
For any public Hugging Face repo, you don’t need to do anything else. Adding the model_cache
key with an appropriate repo_id
should be enough.
However, if you want to deploy a model from a gated repo like Llama 2 to Baseten, there’s a few steps you need to take:
Get Hugging Face API Key
Grab an API key from Hugging Face with read
access. Make sure you have access to the model you want to serve.
Add it to Baseten Secrets Manager
Paste your API key in your secrets manager in Baseten under the key hf_access_token
. You can read more about secrets here.
Update Config
In your Truss’s config.yml
, add the following code:
Make sure that the key secrets
only shows up once in your config.yml
.
If you run into any issues, run through all the steps above again and make sure you did not misspell the name of the repo or paste an incorrect API key.
Weights will be cached in the default Hugging Face cache directory, ~/.cache/huggingface/hub/models--{your_model_name}/
. You can change this directory by setting the HF_HOME
or HUGGINGFACE_HUB_CACHE
environment variable in your config.yml
.
Google Cloud Storage
Google Cloud Storage is a great alternative to Hugging Face when you have a custom model or fine-tune you want to gate, especially if you are already using GCP and care about security and compliance.
Your model_cache
should look something like this:
If you are accessing a public GCS bucket, you can ignore the following steps, but make sure you set appropriate permissions on your bucket. Users should be able to list and view all files. Otherwise, the model build will fail.
For a private GCS bucket, first export your service account key. Rename it to be service_account.json
and add it to the data
directory of your Truss.
Your file structure should look something like this:
If you are using version control, like git, for your Truss, make sure to add service_account.json
to your .gitignore
file. You don’t want to accidentally expose your service account key.
Weights will be cached at /app/model_cache/{your_bucket_name}
.
Amazon Web Services S3
Another popular cloud storage option for hosting model weights is AWS S3, especially if you’re already using AWS services.
Your model_cache
should look something like this:
If you are accessing a public S3 bucket, you can ignore the subsequent steps, but make sure you set an appropriate policy on your bucket. Users should be able to list and view all files. Otherwise, the model build will fail.
However, for a private S3 bucket, you need to first find your aws_access_key_id
, aws_secret_access_key
, and aws_region
in your AWS dashboard. Create a file named s3_credentials.json
. Inside this file, add the credentials that you identified earlier as shown below. Place this file into the data
directory of your Truss.
The key aws_session_token
can be included, but is optional.
Here is an example of how your s3_credentials.json
file should look:
Your overall file structure should now look something like this:
When you are generating credentials, make sure that the resulting keys have at minimum the following IAM policy:
If you are using version control, like git, for your Truss, make sure to add s3_credentials.json
to your .gitignore
file. You don’t want to accidentally expose your service account key.
Weights will be cached at /app/model_cache/{your_bucket_name}
.
Other Buckets
We can work with you to support additional bucket types if needed. If you have any suggestions, please leave an issue on our GitHub repo.