Text-to-image
Building a text-to-image model with SDXL
View on Github
In this example, we go through a Truss that serves a text-to-image model. We use SDXL 1.0, which is one of the highest performing text-to-image models out there today.
Set up imports and torch settings
In this example, we use the Huggingface diffusers library to build our text-to-image model.
The following line is needed to enable TF32 on NVIDIA GPUs
Define the Model
class and load function
In the load
function of the Truss, we implement logic involved in
downloading and setting up the model. For this model, we use the
DiffusionPipeline
class in diffusers
to instantiate our SDXL pipeline,
and configure a number of relevant parameters.
See the diffusers docs for details on all of these parameters.
This is a utility function for converting PIL image to base64.
Define the predict function
The predict
function contains the actual inference logic. The steps here are:
- Setting up the generation params. We have defaults for these, and some, such
as the
scheduler
, are somewhat complicated - Running the Diffusion Pipeline
- If
use_refiner
is set toTrue
, we run the refiner model on the output - Convert the resulting image to base64 and return it
Set the scheduler based on the userβs input. See possible schedulers: https://huggingface.co/docs/diffusers/api/schedulers/overview for what the tradeoffs are.
Convert the results to base64, and return them.
Setting up the config yaml
Running SDXL requires a handful of Python libraries, including diffusers, transformers, and others.
Configuring resources for SDXL 1.0
Note that we need an A10G to run this model.
System Packages
Running diffusers requires ffmpeg
and a couple other system
packages.
Enabling Caching
SDXL is a very large model, and downloading it could take up to 10 minutes. This means that the cold start time for this model is long. We can solve that by using our build caching feature. This moves the model download to the build stage of your modelβ caching the model will take about 10 minutes initially but you will get ~9s cold starts subsequently.
To enable caching, add the following to the config:
Deploy the model
Deploy the model like you would other Trusses, with:
You can then invoke the model with: