Streaming output significantly reduces wait time for generative AI models by returning results as they are generated instead of waiting for the full response.

Why Streaming?

  • Faster response time – Get initial results in under 1 second instead of waiting 10+ seconds.
  • Improved user experience – Partial outputs are immediately usable.

This guide walks through deploying Falcon 7B with streaming enabled.

1. Initialize Truss

truss init falcon-7b && cd falcon-7b

2: Implement Model (Non-Streaming)

This first version loads the Falcon 7B model without streaming:

model/model.py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from typing import Dict

CHECKPOINT = "tiiuae/falcon-7b-instruct"
DEFAULT_MAX_NEW_TOKENS = 150
DEFAULT_TOP_P = 0.95

class Model:
    def __init__(self, **kwargs) -> None:
        self.tokenizer = None
        self.model = None

    def load(self):
        self.tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
        self.model = AutoModelForCausalLM.from_pretrained(
            CHECKPOINT, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto"
        )

    def predict(self, request: Dict) -> Dict:
        prompt = request["prompt"]
        inputs = self.tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True, padding=True)
        input_ids = inputs["input_ids"].to("cuda")
        generation_config = GenerationConfig(temperature=1, top_p=DEFAULT_TOP_P, top_k=40)

        with torch.no_grad():
            return self.model.generate(
                input_ids=input_ids,
                generation_config=generation_config,
                return_dict_in_generate=True,
                output_scores=True,
                pad_token_id=self.tokenizer.eos_token_id,
                max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
            )

3. Add Streaming Support

To enable streaming, we:

  • Use TextIteratorStreamer to stream tokens as they are generated.
  • Run generate() in a separate thread to prevent blocking.
  • Return a generator that streams results.
model/model.py
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, TextIteratorStreamer
from threading import Thread
from typing import Dict

CHECKPOINT = "tiiuae/falcon-7b-instruct"

class Model:
    def __init__(self, **kwargs) -> None:
        self.tokenizer = None
        self.model = None

    def load(self):
        self.tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
        self.model = AutoModelForCausalLM.from_pretrained(
            CHECKPOINT, torch_dtype=torch.bfloat16, trust_remote_code=True, device_map="auto"
        )

    def predict(self, request: Dict):
        prompt = request["prompt"]
        inputs = self.tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True, padding=True)
        input_ids = inputs["input_ids"].to("cuda")

        streamer = TextIteratorStreamer(self.tokenizer)
        generation_config = GenerationConfig(temperature=1, top_p=0.95, top_k=40)

        def generate():
            self.model.generate(
                input_ids=input_ids,
                generation_config=generation_config,
                return_dict_in_generate=True,
                output_scores=True,
                pad_token_id=self.tokenizer.eos_token_id,
                max_new_tokens=150,
                streamer=streamer,
            )

        thread = Thread(target=generate)
        thread.start()

        def stream_output():
            for text in streamer:
                yield text
            thread.join()

        return stream_output()

4. Configure config.yaml

config.yaml
model_name: falcon-streaming
requirements:
  - torch==2.0.1
  - peft==0.4.0
  - scipy==1.11.1
  - sentencepiece==0.1.99
  - accelerate==0.21.0
  - bitsandbytes==0.41.1
  - einops==0.6.1
  - transformers==4.31.0
resources:
  cpu: "3"
  memory: 14Gi
  use_gpu: true
  accelerator: A10G

5. Deploy & Invoke

Deploy the model:

truss push

Invoke with:

truss predict -d '{"prompt": "Tell me about falcons", "do_sample": true}'