Deploy Llama 2 with Caching
Enable fast cold starts for a model with private Hugging Face weights
In this example, we will cover how you can use the model_cache
key in your Truss’s config.yml
to automatically bundle model weights from a private Hugging Face repo.
Bundling model weights can significantly reduce cold start times because your instance won’t waste time downloading the model weights from Hugging Face’s servers.
We use Llama-2-7b
, a popular open-source large language model, as an example. In order to follow along with us, you need to request access to Llama 2.
- First, sign up for a Hugging Face account if you don’t already have one.
- Request access to Llama 2 from Meta’s website.
- Next, request access to Llama 2 on Hugging Face by clicking the “Request access” button on the model page.
If you want to deploy on Baseten, you also need to create a Hugging Face API token and add it to your organizations’s secrets.
- Create a Hugging Face API token and copy it to your clipboard.
- Add the token with the key
hf_access_token
to your organization’s secrets on Baseten.
Step 0: Initialize Truss
Get started by creating a new Truss:
truss init llama-2-7b-chat
Select the TrussServer
option then hit y
to confirm Truss creation. Then navigate to the newly created directory:
cd llama-2-7b-chat
Step 1: Implement Llama 2 7B in Truss
Next, we’ll fill out the model.py
file to implement Llama 2 7B in Truss.
In model/model.py
, we write the class Model
with three member functions:
__init__
, which creates an instance of the object with a_model
propertyload
, which runs once when the model server is spun up and loads thepipeline
modelpredict
, which runs each time the model is invoked and handles the inference. It can use any JSON-serializable type as input and output.
We will also create a helper function format_prompt
outside of the Model
class to appropriately format the incoming text according to the Llama 2 specification.
Read the quickstart guide for more details on Model
class implementation.
from typing import Dict, List
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
DEFAULT_SYSTEM_PROMPT = "You are a helpful, respectful and honest assistant."
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
class Model:
def __init__(self, **kwargs) -> None:
self._data_dir = kwargs["data_dir"]
self._config = kwargs["config"]
self._secrets = kwargs["secrets"]
self.model = None
self.tokenizer = None
def load(self):
self.model = LlamaForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-chat-hf",
use_auth_token=self._secrets["hf_access_token"],
torch_dtype=torch.float16,
device_map="auto"
)
self.tokenizer = LlamaTokenizer.from_pretrained(
"meta-llama/Llama-2-7b-chat-hf",
use_auth_token=self._secrets["hf_access_token"]
)
def predict(self, request: Dict) -> Dict[str, List]:
prompt = request.pop("prompt")
prompt = format_prompt(prompt)
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(**inputs, do_sample=True, num_beams=1, max_new_tokens=100)
response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
return {"response": response}
def format_prompt(prompt: str, system_prompt: str = DEFAULT_SYSTEM_PROMPT) -> str:
return f"{B_INST} {B_SYS} {system_prompt} {E_SYS} {prompt} {E_INST}"
Step 2: Set Python dependencies
Now, we can turn our attention to configuring the model server in config.yaml
.
In addition to transformers
, Llama 2 has three other dependencies. We list them below as follows:
requirements:
- accelerate==0.21.0
- safetensors==0.3.2
- torch==2.0.1
- transformers==4.30.2
Always pin exact versions for your Python dependencies. The ML/AI space moves fast, so you want to have an up-to-date version of each package while also being protected from breaking changes.
Step 3: Configure Hugging Face caching
Finally, we can configure Hugging Face caching in config.yaml
by adding the model_cache
key. When building the image for your Llama 2 deployment, the Llama 2 model weights will be downloaded and cached for future use.
model_cache:
- repo_id: "meta-llama/Llama-2-7b-chat-hf"
ignore_patterns:
- "*.bin"
In this configuration:
meta-llama/Llama-2-7b-chat-hf
is therepo_id
, pointing to the exact model to cache.- We use a wild card to ignore all
.bin
files in the model directory by providing a pattern underignore_patterns
. This is because the model weights are stored in.bin
andsafetensors
format, and we only want to cache thesafetensors
files.
Step 4: Deploy the model
You’ll need a Baseten API key for this step. Make sure you added your HUGGING_FACE_HUB_TOKEN
to your organization’s secrets.
We have successfully packaged Llama 2 as a Truss. Let’s deploy!
truss push --trusted
Step 5: Invoke the model
You can invoke the model with:
truss predict -d '{"prompt": "What is a large language model?"}'