LLM
Building an LLM
View on Github
In this example, we go through a Truss that serves an LLM. We use the model Mistral-7B, which is a general-purpose LLM that can used for a variety of tasks, like summarization, question-answering, translation, and others.
Set up the imports and key constants
In this example, we use the Huggingface transformers library to build a text generation model.
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
We use the 7B version of the Mistral model.
CHECKPOINT = "mistralai/Mistral-7B-v0.1"
Define the Model
class and load function
In the load
function of the Truss, we implement logic involved in
downloading and setting up the model. For this LLM, we use the Auto
classes in transformers
to instantiate our Mistral model.
class Model:
def __init__(self, **kwargs) -> None:
self.tokenizer = None
self.model = None
def load(self):
self.model = AutoModelForCausalLM.from_pretrained(
CHECKPOINT, torch_dtype=torch.float16, device_map="auto"
)
self.tokenizer = AutoTokenizer.from_pretrained(
CHECKPOINT,
)
Define the predict
function
In the predict function, we implement the actual inference logic. The steps here are:
- Set up the generation params. We have defaults for both of these, but adjusting the values will have an impact on the model output
- Tokenize the input
- Generate the output
- Use tokenizer to decode the output
def predict(self, request: dict):
prompt = request.pop("prompt")
generate_args = {
"max_new_tokens": request.get("max_new_tokens", 128),
"temperature": request.get("temperature", 1.0),
"top_p": request.get("top_p", 0.95),
"top_k": request.get("top_p", 50),
"repetition_penalty": 1.0,
"no_repeat_ngram_size": 0,
"use_cache": True,
"do_sample": True,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
}
input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.cuda()
with torch.no_grad():
output = self.model.generate(inputs=input_ids, **generate_args)
return self.tokenizer.decode(output[0])
Setting up the config.yaml
Running Mistral 7B requires a few libraries, such as
torch
, transformers
and a couple others.
environment_variables: {}
external_package_dirs: []
model_metadata:
example_model_input: {"prompt": "What is the meaning of life?"}
model_name: Mistral 7B
python_version: py311
requirements:
- transformers==4.34.0
- sentencepiece==0.1.99
- accelerate==0.23.0
- torch==2.0.1
Configure resources for Mistral
Note that we need an A10G to run this model.
resources:
accelerator: A10G
use_gpu: true
secrets: {}
system_packages: []
Deploy the model
Deploy the model like you would other Trusses, with:
$ truss push
You can then invoke the model with:
$ truss predict -d '{"inputs": "What is a large language model?"}'