To use Chains, install a recent Truss version and ensure pydantic is v2:
pip install--upgrade truss 'pydantic>=2.0.0'
Truss requires python >=3.8,<3.13. To set up a fresh development environment,
you can use the following commands, creating a environment named chains_env
using pyenv:
To deploy Chains remotely, you also need a
Baseten account.
It is handy to export your API key to the current shell session or permanently in your .bashrc:
Exactly one Chainlet must be marked as the entrypoint with
the @chains.mark_entrypoint
decorator. This Chainlet is responsible for
handling public-facing input and output for the whole Chain in response to an
API call.
A Chainlet class has a single public method,
run_remote(), which is
the API
endpoint for the entrypoint Chainlet and the function that other Chainlets can
use as a dependency. The
run_remote()
method must be fully type-annotated
with
or .
Chainlets cannot be
instantiated. The only correct usages are:
Make one Chainlet depend on another one via the
chains.depends() directive
as an __init__-argument as shown above for the RandInt Chainlet.
Beyond that, you can structure your code as you like, with private methods,
imports from other files, and so forth.
Keep in mind that Chainlets are intended for distributed, replicated, remote
execution, so using global variables, global state, and certain Python features
like importing modules dynamically at runtime should be avoided as they may not
work as intended.
The main difference between this Chain and the previous one is that we now have
an LLM that needs a GPU and more complex dependencies.
Copy the following code into poems.py:
poems.py
import asynciofrom typing import Listimport pydanticimport truss_chains as chainsfrom truss import truss_configPHI_HF_MODEL ="microsoft/Phi-3-mini-4k-instruct"# This configures to cache model weights from the hunggingface repo# in the docker image that is used for deploying the Chainlet.PHI_CACHE = truss_config.ModelRepo( repo_id=PHI_HF_MODEL, allow_patterns=["*.json","*.safetensors",".model"])classMessages(pydantic.BaseModel): messages: List[dict[str,str]]classPhiLLM(chains.ChainletBase):# `remote_config` defines the resources required for this chainlet. remote_config = chains.RemoteConfig( docker_image=chains.DockerImage(# The phi model needs some extra python packages. pip_requirements=["accelerate==0.30.1","einops==0.8.0","transformers==4.41.2","torch==2.3.0",]),# The phi model needs a GPU and more CPUs. compute=chains.Compute(cpu_count=2, gpu="T4"),# Cache the model weights in the image assets=chains.Assets(cached=[PHI_CACHE]),)def__init__(self)->None:# Note the imports of the *specific* python requirements are# pushed down to here. This code will only be executed on the# remotely deployed Chainlet, not in the local environment,# so we don't need to install these packages in the local# dev environment.import torchimport transformers self._model = transformers.AutoModelForCausalLM.from_pretrained( PHI_HF_MODEL, torch_dtype=torch.float16, device_map="auto",) self._tokenizer = transformers.AutoTokenizer.from_pretrained( PHI_HF_MODEL,) self._generate_args ={"max_new_tokens":512,"temperature":1.0,"top_p":0.95,"top_k":50,"repetition_penalty":1.0,"no_repeat_ngram_size":0,"use_cache":True,"do_sample":True,"eos_token_id": self._tokenizer.eos_token_id,"pad_token_id": self._tokenizer.pad_token_id,}asyncdefrun_remote(self, messages: Messages)->str:import torch model_inputs = self._tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True) inputs = self._tokenizer(model_inputs, return_tensors="pt") input_ids = inputs["input_ids"].to("cuda")with torch.no_grad(): outputs = self._model.generate(input_ids=input_ids,**self._generate_args) output_text = self._tokenizer.decode(outputs[0], skip_special_tokens=True)return output_text
Now that we have an LLM, we can use it in a poem generator Chainlet. Add the
following code to poems.py:
poems.py
@chains.mark_entrypointclassPoemGenerator(chains.ChainletBase):def__init__(self, phi_llm: PhiLLM = chains.depends(PhiLLM))->None: self._phi_llm = phi_llmasyncdefrun_remote(self, words:list[str])->list[str]: tasks =[]for word in words: messages = Messages( messages=[{"role":"system","content":("You are poet who writes short, ""lighthearted, amusing poetry."),},{"role":"user","content":f"Write a poem about {word}"},]) tasks.append(asyncio.ensure_future(self._phi_llm.run_remote(messages)))returnlist(await asyncio.gather(*tasks))
Note that we use asyncio.ensure_future around each RPC to the LLM chainlet.
This makes the current python process start these remote calls concurrently,
i.e. the next call is started before the previous one has finished and we can
minimize our overall runtime. In order to await the results of all calls,
asyncio.gather is used which gives us back normal python objects.
If the LLM is hit with many concurrent requests, it can auto-scale up (if
autoscaling is configure). More advanced LLM models have batching capabilities,
so for those even a single instance can serve concurrent request.
Wait for the status to turn to ACTIVE and test invoking your Chain (replace
$INVOCATION_URL in below command):
curl-X POST $INVOCATION_URL\-H"Authorization: Api-Key $BASETEN_API_KEY"\-d'{"words": ["bird", "plane", "superman"]}'#[[#"<s> [INST] Generate a poem about: bird [/INST] In the quiet hush of...</s>",#"<s> [INST] Generate a poem about: plane [/INST] In the vast, boudl...</s>",#"<s> [INST] Generate a poem about: superman [/INST] In the realm where...</s>"#]]