In this example, we go through a Truss that serves a text-to-image model. We
use Flux Schnell, which is one of the highest performing text-to-image models out
there today.
Set up imports and torch settings
In this example, we use the Hugging Face diffusers library to build our text-to-image model.
import base64
import random
import logging
from io import BytesIO
import numpy as np
import torch
from diffusers import FluxPipeline
from PIL import Image
logging.basicConfig(level=logging.INFO)
MAX_SEED = np.iinfo(np.int32).max
Define the Model
class and load function
In the load
function of the Truss, we implement logic involved in
downloading and setting up the model. For this model, we use the
FluxPipeline
class in diffusers
to instantiate our Flux pipeline,
and configure a number of relevant parameters.
See the diffusers docs for details
on all of these parameters.
class Model:
def __init__(self, **kwargs):
self.pipe = None
self.repo_id = "black-forest-labs/FLUX.1-schnell"
def load(self):
self.pipe = FluxPipeline.from_pretrained(self.repo_id, torch_dtype=torch.bfloat16).to("cuda")
This is a utility function for converting a PIL image to base64.
def convert_to_b64(self, image: Image) -> str:
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_b64
Define the predict function
The predict
function contains the actual inference logic. The steps here are:
- Setting up the generation params. These include things like the prompt, image width, image height, number of inference steps, etc.
- Running the Diffusion Pipeline
- Convert the resulting image to base64 and return it
def predict(self, model_input):
seed = model_input.get("seed")
prompt = model_input.get("prompt")
prompt2 = model_input.get("prompt2")
max_sequence_length = model_input.get(
"max_sequence_length", 256
) # 256 is max for FLUX.1-schnell
guidance_scale = model_input.get(
"guidance_scale", 0.0
) # 0.0 is the only value for FLUX.1-schnell
num_inference_steps = model_input.get(
"num_inference_steps", 4
) # schnell is timestep-distilled
width = model_input.get("width", 1024)
height = model_input.get("height", 1024)
if not math.isclose(guidance_scale, 0.0):
logging.warning(
"FLUX.1-schnell does not support guidance_scale other than 0.0"
)
guidance_scale = 0.0
if not seed:
seed = random.randint(0, MAX_SEED)
if len(prompt.split()) > max_sequence_length:
logging.warning(
"FLUX.1-schnell does not support prompts longer than 256 tokens, truncating"
)
tokens = prompt.split()
prompt = " ".join(tokens[: min(len(tokens), max_sequence_length)])
generator = torch.Generator().manual_seed(seed)
image = self.pipe(
prompt=prompt,
guidance_scale=guidance_scale,
max_sequence_length=max_sequence_length,
num_inference_steps=num_inference_steps,
width=width,
height=height,
output_type="pil",
generator=generator,
).images[0]
b64_results = self.convert_to_b64(image)
return {"data": b64_results}
Setting up the config.yaml
Running Flux Schnell requires a handful of Python libraries, including
diffusers
, transformers
, and others.
external_package_dirs: []
model_cache:
- repo_id: black-forest-labs/FLUX.1-schnell
allow_patterns:
- "*.json"
- "*.safetensors"
ignore_patterns:
- "flux1-schnell.safetensors"
model_metadata:
example_model_input: {"prompt": 'black forest gateau cake spelling out the words "FLUX SCHNELL", tasty, food photography, dynamic shot'}
model_name: Flux.1-schnell
python_version: py311
requirements:
- git+https://github.com/huggingface/diffusers.git@v0.32.2
- transformers
- accelerate
- sentencepiece
- protobuf
resources:
accelerator: H100_40GB
use_gpu: true
secrets: {}
system_packages:
- ffmpeg
- libsm6
- libxext6
Configuring resources for Flux Schnell
Note that we need an H100 40GB GPU to run this model.
resources:
accelerator: H100_40GB
use_gpu: true
secrets: {}
System Packages
Running diffusers requires ffmpeg
and a couple other system
packages.
system_packages:
- ffmpeg
- libsm6
- libxext6
Enabling Caching
Flux Schnell is a large model, and downloading it could take several minutes. This means
that the cold start time for this model is long. We can solve that by using our build
caching feature. This moves the model download to the build stage of your model—
caching the model will take about 15 minutes initially but you will get ~20s cold starts
subsequently.
To enable caching, add the following to the config:
model_cache:
- repo_id: black-forest-labs/FLUX.1-schnell
allow_patterns:
- "*.json"
- "*.safetensors"
ignore_patterns:
- "flux1-schnell.safetensors"
Deploy the model
Deploy the model like you would other Trusses, with:
truss push flux/schnell --publish
Run an inference
Use a Python script to call the model once its deployed and parse its response. We parse the resulting base64-encoded string output into an actual image file: output_image.jpg
.
import httpx
import os
import base64
from PIL import Image
from io import BytesIO
# Replace the empty string with your model id below
model_id = ""
baseten_api_key = os.environ["BASETEN_API_KEY"]
# Function used to convert a base64 string to a PIL image
def b64_to_pil(b64_str):
return Image.open(BytesIO(base64.b64decode(b64_str)))
data = {
"prompt": 'red velvet cake spelling out the words "FLUX SCHNELL", tasty, food photography, dynamic shot'
}
# Call model endpoint
res = httpx.post(
f"https://model-{model_id}.api.baseten.co/production/predict",
headers={"Authorization": f"Api-Key {baseten_api_key}"},
json=data
)
# Get output image
res = res.json()
output = res.get("data")
# Convert the base64 model output to an image
img = b64_to_pil(output)
img.save("output_image.jpg")