Chains is in beta mode. Read our launch blog post.

Chainlet

A Chainlet is the basic building block of Chains. A Chainlet is a Python class that specifies:

  • A set of compute resources.
  • A Python environment with software dependencies.
  • A typed interface run_remote() for other Chainlets to call.

This is the simplest possible Chainlet โ€” only the run_remote() method is required โ€” and we can layer in other concepts to create a more capable Chainlet.

import truss_chains as chains


class SayHello(chains.ChainletBase):

    def run_remote(self, name: str) -> str:
        return f"Hello, {name}"

Remote configuration

Chainlets are meant for deployment as remote services. Each Chainlet specifies its own requirements for compute hardware (CPU count, GPU type and count, etc) and software dependencies (Python libraries or system packages). This configuration is built into a Docker image automatically as part of the deployment process.

When no configuration is provided, the Chainlet will be deployed on a basic instance with one vCPU, 2GB of RAM, no GPU, and a standard set of Python and system packages.

Configuration is set using the remote_config class variable within the Chainlet:

import truss_chains as chains


class MyChainlet(chains.ChainletBase):
    remote_config = chains.RemoteConfig(
        docker_image=chains.DockerImage(
            pip_requirements=["torch==2.3.0", ...]
        ),
        compute=chains.Compute(gpu="H100", ...),
        assets=chains.Assets(secret_keys=["hf_access_token"], ...),
    )

See the remote configuration reference for a complete list of options.

Initialization

Chainlets are implemented as classes because we often want to set up expensive static resources once at startup and then re-use it with each invocation of the Chainlet. For example, we only want to initialize an AI model and download its weights once then re-use it every time we run inference.

We do this setup in __init__(), which is run exactly once when the Chainlet is deployed or scaled up.

import truss_chains as chains


class PhiLLM(chains.ChainletBase):
    def __init__(self) -> None:
        import torch
        import transformers

        self._model = transformers.AutoModelForCausalLM.from_pretrained(
            PHI_HF_MODEL,
            torch_dtype=torch.float16,
            device_map="auto",
        )

        self._tokenizer = transformers.AutoTokenizer.from_pretrained(
            PHI_HF_MODEL,
        )

Chainlet initialization also has two important features: context and dependency injection of other Chainlets, explained below.

Context (access information)

You can add DeploymentContext object as an optional argument to the __init__-method of a Chainlet. This allows you to use secrets within your Chainlet, such as using a hf_access_token to access a gated model on Hugging Face (note that when using secrets, they also need to be added to the assets).

import truss_chains as chains


class MistralLLM(chains.ChainletBase):
    remote_config = chains.RemoteConfig(
        ...
        assets=chains.Assets(secret_keys=["hf_access_token"], ...),
    )
  
    def __init__(
        self,
        # Adding the `context` argument, allows us to access secrets
        context: chains.DeploymentContext = chains.depends_context(),
    ) -> None:
        import transformers

        # Using the secret from context to access a gated model on HF
        self._model = transformers.AutoModelForCausalLM.from_pretrained(
            "mistralai/Mistral-7B-Instruct-v0.2",
            use_auth_token=context.secrets["hf_access_token"],
        )

Depends (call other Chainlets)

The Chains framework uses the chains.depends() function in Chainletsโ€™ __init__() method to track the dependency relationship between different Chainlets within a Chain.

This syntax, inspired by dependency injection, is used to translate local Python function calls into calls to the remote Chainlets in production.

Once a dependency Chainlet is added with chains.depends(), its run_remote() method can call this dependency Chainlet, e.g. below HelloAll we can make calls to SayHello:

import truss_chains as chains


class HelloAll(chains.ChainletBase):

    def __init__(self, say_hello_chainlet=chains.depends(SayHello)) -> None:
        self._say_hello = say_hello_chainlet

    def run_remote(self, names: list[str]) -> str:
        output = []
        for name in names:
            output.append(self._say_hello.run_remote(name))
        return "\n".join(output)

Run remote (chaining Chainlets)

The run_remote() method is run each time the Chainlet is called. It is the sole public interface for the Chainlet (though you can have as many private helper functions as you want) and its inputs and outputs must have type annotations.

In run_remote() you implement the actual work of the Chainlet, such as model inference or data chunking:

import truss_chains as chains


class PhiLLM(chains.ChainletBase):
    def run_remote(self, messages: Messages) -> str:
        import torch

        model_inputs = self._tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        inputs = self._tokenizer(model_inputs, return_tensors="pt")
        input_ids = inputs["input_ids"].to("cuda")
        with torch.no_grad():
            outputs = self._model.generate(input_ids=input_ids,
                                           **self._generate_args)
            output_text = self._tokenizer.decode(outputs[0],
                                                 skip_special_tokens=True)
        return output_text

If run_remote() makes calls to other Chainlets, e.g. invoking a dependency Chainlet for each element in a list, you can benefit from concurrent execution, by making the run_remote() an async method and starting the calls as concurrent tasks asyncio.ensure_future(self._dep_chainlet.run_remote(...)).

Entrypoint

The entrypoint is called directly from the deployed Chainโ€™s API endpoint and kicks off the entire chain. The entrypoint is also responsible for returning the final result back to the client.

Using the @chains.mark_entrypoint decorator, one Chainlet within a file is set as the entrypoint to the chain.

@chains.mark_entrypoint
class HelloAll(chains.ChainletBase):

Stub

Chains can be combined with existing Truss models using Stubs.

A Stub acts as a substitute (client-side proxy) for a remotely deployed dependency, either a Chainlet or a Truss model. The Stub performs the remote invocations as if it were local by taking care of the transport layer, authentication, data serialization and retries.

Stubs can be integrated into Chainlets by passing in a URL of the deployed model. They also require context to be initialized (for authentication).

import truss_chains as chains


class LLMClient(chains.StubBase):

    async def run_remote(
        self,
        prompt: str
    ) -> str:
        # Call the deployed model
        resp = await self._remote.predict_async(json_payload={
            "messages": [{"role": "user", "content": prompt}],
            "stream"  : False
        })
        # Return a string with the model output
        return resp["output"]


LLM_PREDICT_URL = ...
    
    
class MyChainlet(chains.ChainletBase):

    def __init__(
        self,
        context: chains.DeploymentContext = chains.depends_context(),
    ):
        self._llm = LLMClient.from_url(LLM_PREDICT_URL, context)

See the StubBase reference for details on the StubBase implementation.

Pydantic data types

To make orchestrating multiple remotely deployed services possible, Chains relies heavily on typed inputs and outputs. Values must be serialized to a safe exchange format to be sent over the network.

The Chains framework uses the type annotations to infer how data should be serialized and currently is restricted to types that are JSON compatible. Types can be:

  • Direct type annotations for simple types such as int, float, or list[str].
  • Pydantic models to define a schema for nested data structures or multiple arguments.

An example of pydantic input and output types for a Chainlet is given below:

import enum
import pydantic

class Modes(enum.Enum):
    MODE_0 = "MODE_0"
    MODE_1 = "MODE_1"


class SplitTextInput(pydantic.BaseModel):
    data: str
    num_partitions: int
    mode: Modes


class SplitTextOutput(pydantic.BaseModel):
    parts: list[str]
    part_lens: list[int]

Refer to the pydantic docs for more details on how to define custom pydantic data models.

We are working on more efficient support for numeric data and bytes, for the time being a workaround for dealing with these types is to use base64-encoding and add them as a string-valued field to a pydantic model.

Chains compared to Truss