Requires b10cache enabled

Explore Torch Compile

Many users use torch.compile, which can decrease inference time by up to 40%. However, it increases your cold start time because you need to compile before your first inference. To decrease this time, caching previous compilation artifacts is a must-implement strategy, read more here. This new API exposes this caching functionality to our users.In practice, having the cache significantly reduces compilation latencies, by up to 5x.

Caching saving and loading

We expose two api calls.

1. save_compile_cache()

Save your model’s torch compilation cache for future use. This should be called after running prompts to warm up your model and trigger compilation. Returns True if successfully saved, else False.

2. load_compile_cache()

If you have previously saved compilation cache for this model, load it to speed up the compilation for the model on this pod. Returns True if successfully found a cache file and loaded it, else False.

Example

Here is an example of compile caching for Flux, an image generation model. Note how we save the result of load_compile_cache to inform on whether to save_compile_cache. In other implementations, you can fall back to skip compiling, in the off chance you fail to load the cache. There are two files to change.

1. config.yaml

Under requirements, add b10-tcache
config.yaml
requirements:
 - b10-tcache

2. model.py

Import the library and use the two functions to speed up torch compilation time.
model.py
from b10_tcache import load_compile_cache, save_compile_cache

class Model:
  def load(self):
    self.pipe = FluxPipeline.from_pretrained(
        self.model_name, torch_dtype=torch.bfloat16, token=self.hf_access_token
    ).to("cuda")

    # Try to load compile cache
    cache_loaded = load_compile_cache()

    logging.info("Compiling the model for performance optimization")
    self.pipe.transformer = torch.compile(
        self.pipe.transformer, mode="max-autotune-no-cudagraphs", dynamic=False
    )

    self.pipe.vae.decode = torch.compile(
        self.pipe.vae.decode, mode="max-autotune-no-cudagraphs", dynamic=False
    )

    seed = random.randint(0, MAX_SEED)
    generator = torch.Generator().manual_seed(seed)
    start_time = time.time()

    # Warmup the model with dummy prompts, also triggering compilation
    self.pipe(prompt="dummy prompt", prompt_2=None, guidance_scale=0.0, max_sequence_length=256, num_inference_steps=4, width=1024, height=1024, output_type="pil", generator=generator)
    self.pipe(prompt="extra dummy prompt", prompt_2=None, guidance_scale=0.0, max_sequence_length=256, num_inference_steps=4, width=1024, height=1024, output_type="pil", generator=generator)
    
    end_time = time.time()

    logging.info(
        f"Warmup completed in {(end_time - start_time)} seconds. "
        "This is expected to take a few minutes on the first run."
    )

    if not cache_loaded:
        # Save compile cache for future runs
        save_compile_cache()
See the full example in the truss-examples repo.