View on Github

In this example, we go through a Truss that serves the Qwen 7B Chat LLM, and streams the output to the client.

Why Streaming?

For certain ML models, generations can take a long time. Especially with LLMs, a long output could take 10-20 seconds to generate. However, because LLMs generate tokens in sequence, useful output can be made available to users sooner. To support this, in Truss, we support streaming output.

Set up the imports

In this example, we use the HuggingFace transformers library to build a text generation model.

model/model.py
from threading import Thread
from typing import Dict

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from transformers.generation import GenerationConfig

Define the load function

In the load function of the Truss, we implement logic involved in downloading the chat version of the Qwen 7B model and loading it into memory.

model/model.py
class Model:
    def __init__(self, **kwargs):
        self.model = None
        self.tokenizer = None

    def load(self):
        self.tokenizer = AutoTokenizer.from_pretrained(
            "Qwen/Qwen-7B-Chat", trust_remote_code=True
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            "Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True
        ).eval()

Define the preprocess function

In the preprocess function of the Truss, we set up a generate_args dictionary with some generation arguments from the inference request to be used in the predict function.

model/model.py
    def preprocess(self, request: dict) -> dict:
        generate_args = {
            "max_new_tokens": request.get("max_new_tokens", 512),
            "temperature": request.get("temperature", 0.5),
            "top_p": request.get("top_p", 0.95),
            "top_k": request.get("top_k", 40),
            "repetition_penalty": 1.0,
            "no_repeat_ngram_size": 0,
            "use_cache": True,
            "do_sample": True,
            "eos_token_id": self.tokenizer.eos_token_id,
            "pad_token_id": self.tokenizer.pad_token_id,
        }
        request["generate_args"] = generate_args
        return request

Define the predict function

In the predict function of the Truss, we implement the actual inference logic.

The two main steps are:

  • Tokenize the input
  • Call the model’s generate function if we’re not streaming the output, otherwise call the stream helper function
model/model.py
    def predict(self, request: Dict):
        stream = request.pop("stream", False)
        prompt = request.pop("prompt")
        generation_args = request.pop("generate_args")
        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.cuda()

        if stream:
            return self.stream(input_ids, generation_args)

        with torch.no_grad():
            output = self.model.generate(inputs=input_ids, **generation_args)
            return self.tokenizer.decode(output[0])

Define the stream helper function

In this helper function, we’ll instantiate the TextIteratorStreamer object, which we’ll later use for returning the LLM output to users.

model/model.py
    def stream(self, input_ids: list, generation_args: dict):
        streamer = TextIteratorStreamer(self.tokenizer)

When creating the generation parameters, ensure to pass the streamer object that we created previously.

model/model.py
        generation_config = GenerationConfig(**generation_args)
        generation_kwargs = {
            "input_ids": input_ids,
            "generation_config": generation_config,
            "return_dict_in_generate": True,
            "output_scores": True,
            "max_new_tokens": generation_args["max_new_tokens"],
            "streamer": streamer,
        }

Spawn a thread to run the generation, so that it does not block the main thread.

model/model.py
        with torch.no_grad():
            # Begin generation in a separate thread
            thread = Thread(target=self.model.generate, kwargs=generation_kwargs)
            thread.start()

In Truss, the way to achieve streaming output is to return a generator that yields content. In this example, we yield the output of the streamer, which produces output and yields it until the generation is complete.

We define this inner function to create our generator.

model/model.py
            # Yield generated text as it becomes available
            def inner():
                for text in streamer:
                    yield text
                thread.join()
        return inner()

Setting up the config.yaml

Running Qwen 7B requires torch, transformers, and a few other related libraries.

config.yaml
model_name: qwen-7b-chat
model_metadata:
  example_model_input:
    prompt: What is the meaning of life?
requirements:
  - accelerate==0.23.0
  - tiktoken==0.5.1
  - einops==0.6.1
  - scipy==1.11.3
  - transformers_stream_generator==0.0.4
  - peft==0.5.0
  - deepspeed==0.11.1
  - torch==2.0.1
  - transformers==4.32.0

Configure resources for Qwen

Note that we need an A10G to run this model.

config.yaml
resources:
  accelerator: A10G
  cpu: "3"
  memory: 14Gi
  use_gpu: true

Deploy Qwen 7B Chat

Deploy the model like you would other Trusses, with:

truss push qwen-7b-chat --publish