Text-to-image
Building a text-to-image model with SDXL
View on Github
In this example, we go through a Truss that serves a text-to-image model. We use SDXL 1.0, which is one of the highest performing text-to-image models out there today.
Set up imports and torch settings
In this example, we use the Huggingface diffusers library to build our text-to-image model.
import base64
import time
from io import BytesIO
from typing import Any
import torch
from diffusers import AutoencoderKL, DiffusionPipeline, DPMSolverMultistepScheduler
from PIL import Image
The following line is needed to enable TF32 on NVIDIA GPUs
torch.backends.cuda.matmul.allow_tf32 = True
Define the Model
class and load function
In the load
function of the Truss, we implement logic involved in
downloading and setting up the model. For this model, we use the
DiffusionPipeline
class in diffusers
to instantiate our SDXL pipeline,
and configure a number of relevant parameters.
See the diffusers docs for details on all of these parameters.
class Model:
def __init__(self, **kwargs):
self._model = None
def load(self):
vae = AutoencoderKL.from_pretrained(
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
)
self.pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
vae=vae,
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True,
)
self.pipe.unet.to(memory_format=torch.channels_last)
self.pipe.to("cuda")
self.pipe.enable_xformers_memory_efficient_attention()
self.refiner = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0",
text_encoder_2=self.pipe.text_encoder_2,
vae=self.pipe.vae,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
)
self.refiner.to("cuda")
self.refiner.enable_xformers_memory_efficient_attention()
This is a utility function for converting PIL image to base64.
def convert_to_b64(self, image: Image) -> str:
buffered = BytesIO()
image.save(buffered, format="JPEG")
img_b64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_b64
Define the predict function
The predict
function contains the actual inference logic. The steps here are:
- Setting up the generation params. We have defaults for these, and some, such
as the
scheduler
, are somewhat complicated - Running the Diffusion Pipeline
- If
use_refiner
is set toTrue
, we run the refiner model on the output - Convert the resulting image to base64 and return it
def predict(self, model_input: Any) -> Any:
prompt = model_input.pop("prompt")
negative_prompt = model_input.pop("negative_prompt", None)
use_refiner = model_input.pop("use_refiner", True)
num_inference_steps = model_input.pop("num_inference_steps", 30)
denoising_frac = model_input.pop("denoising_frac", 0.8)
end_cfg_frac = model_input.pop("end_cfg_frac", 0.4)
guidance_scale = model_input.pop("guidance_scale", 7.5)
seed = model_input.pop("seed", None)
scheduler = model_input.pop(
"scheduler", None
) # Default: EulerDiscreteScheduler (works pretty well)
Set the scheduler based on the userβs input. See possible schedulers: https://huggingface.co/docs/diffusers/api/schedulers/overview for what the tradeoffs are.
if scheduler == "DPM++ 2M":
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config
)
elif scheduler == "DPM++ 2M Karras":
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config, use_karras_sigmas=True
)
elif scheduler == "DPM++ 2M SDE Karras":
self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config,
algorithm_type="sde-dpmsolver++",
use_karras_sigmas=True,
)
generator = None
if seed is not None:
torch.manual_seed(seed)
generator = [torch.Generator(device="cuda").manual_seed(seed)]
if not use_refiner:
denoising_frac = 1.0
start_time = time.time()
image = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
generator=generator,
end_cfg=end_cfg_frac,
num_inference_steps=num_inference_steps,
denoising_end=denoising_frac,
guidance_scale=guidance_scale,
output_type="latent" if use_refiner else "pil",
).images[0]
scheduler = self.pipe.scheduler
if use_refiner:
self.refiner.scheduler = scheduler
image = self.refiner(
prompt=prompt,
negative_prompt=negative_prompt,
generator=generator,
end_cfg=end_cfg_frac,
num_inference_steps=num_inference_steps,
denoising_start=denoising_frac,
guidance_scale=guidance_scale,
image=image[None, :],
).images[0]
Convert the results to base64, and return them.
b64_results = self.convert_to_b64(image)
end_time = time.time() - start_time
print(f"Time: {end_time:.2f} seconds")
return {"status": "success", "data": b64_results, "time": end_time}
Setting up the config yaml
Running SDXL requires a handful of Python libraries, including diffusers, transformers, and others.
environment_variables: {}
external_package_dirs: []
model_metadata:
example_model_input: {"prompt": "A tree in a field under the night sky", "use_refiner": true}
model_name: Stable Diffusion XL
python_version: py39
requirements:
- transformers==4.34.0
- accelerate==0.23.0
- safetensors==0.4.0
- git+https://github.com/basetenlabs/diffusers.git@9a353290b1497023d4745a719ec02c50f680499a
- invisible-watermark>=0.2.0
- xformers==0.0.22
Configuring resources for SDXL 1.0
Note that we need an A10G to run this model.
resources:
accelerator: A10G
cpu: 3500m
memory: 20Gi
use_gpu: true
secrets: {}
System Packages
Running diffusers requires ffmpeg
and a couple other system
packages.
system_packages:
- ffmpeg
- libsm6
- libxext6
Enabling Caching
SDXL is a very large model, and downloading it could take up to 10 minutes. This means that the cold start time for this model is long. We can solve that by using our build caching feature. This moves the model download to the build stage of your modelβ caching the model will take about 10 minutes initially but you will get ~9s cold starts subsequently.
To enable caching, add the following to the config:
model_cache:
- repo_id: madebyollin/sdxl-vae-fp16-fix
allow_patterns:
- config.json
- diffusion_pytorch_model.safetensors
- repo_id: stabilityai/stable-diffusion-xl-base-1.0
allow_patterns:
- "*.json"
- "*.fp16.safetensors"
- sd_xl_base_1.0.safetensors
- repo_id: stabilityai/stable-diffusion-xl-refiner-1.0
allow_patterns:
- "*.json"
- "*.fp16.safetensors"
- sd_xl_refiner_1.0.safetensors
Deploy the model
Deploy the model like you would other Trusses, with:
$ truss push
You can then invoke the model with:
$ truss predict -d '{"prompt": "A tree in a field under the night sky", "use_refiner": true}'