Requires b10cache enabled

Explore Torch Compile

Many users use torch.compile, which can decrease inference time by up to 40%. However, it increases your cold start time because you need to compile before your first inference. To decrease this time, caching previous compilation artifacts is a must-implement strategy, read more here. This new API exposes this caching functionality to our users.In practice, having the cache significantly reduces compilation latencies, by up to 5x.

Caching saving and loading

We expose two API calls. Each call returns an OperationStatus object that helps you control the flow of the program based on the result.

1. load_compile_cache()

If you have previously saved compilation cache for this model, load it to speed up the compilation for the model on this pod. Returns:
  • OperationStatus.SUCCESS β†’ successful load
  • OperationStatus.SKIPPED β†’ if already exists in b10fs
  • OperationStatus.ERROR β†’ general catch-all errors
  • OperationStatus.DOES_NOT_EXIST if no cache file was found.

2. save_compile_cache()

Save your model’s torch compilation cache for future use. This should be called after running prompts to warm up your model and trigger compilation. Returns:
  • OperationStatus.SUCCESS β†’ successful save
  • OperationStatus.SKIPPED β†’ skipped because compile cache already exists in shared directory
  • OperationStatus.ERROR β†’ general catch-all errors

Example

Here is an example of compile caching for Flux, an image generation model. Note how we save the result of load_compile_cache to inform on whether to save_compile_cache. In other implementations, you can fall back to skip compiling, in the off chance you fail to load the cache. There are two files to change.

1. config.yaml

Under requirements, add b10-transfer:
requirements:
  - b10-transfer

2. model.py

Import the library and use the two functions to speed up torch compilation time:
from b10_transfer import load_compile_cache, save_compile_cache, OperationStatus

class Model:
    def load(self):
        self.pipe = FluxPipeline.from_pretrained(
            self.model_name, torch_dtype=torch.bfloat16, token=self.hf_access_token
        ).to("cuda")

        # Try to load compile cache
        cache_loaded: OperationStatus = load_compile_cache()

        if cache_loaded == OperationStatus.ERROR:
            logging.info("Run in eager mode, skipping torch compile")
        else:
            logging.info("Compiling the model for performance optimization")
            self.pipe.transformer = torch.compile(
                self.pipe.transformer, mode="max-autotune-no-cudagraphs", dynamic=False
            )

            self.pipe.vae.decode = torch.compile(
                self.pipe.vae.decode, mode="max-autotune-no-cudagraphs", dynamic=False
            )

            seed = random.randint(0, MAX_SEED)
            generator = torch.Generator().manual_seed(seed)
            start_time = time.time()
            # Warmup the model with dummy prompts, also triggering compilation
            self.pipe(
                prompt="dummy prompt",
                prompt_2=None,
                guidance_scale=0.0,
                max_sequence_length=256,
                num_inference_steps=4,
                width=1024,
                height=1024,
                output_type="pil",
                generator=generator
            )
            self.pipe(
                prompt="extra dummy prompt",
                prompt_2=None,
                guidance_scale=0.0,
                max_sequence_length=256,
                num_inference_steps=4,
                width=1024,
                height=1024,
                output_type="pil",
                generator=generator
            )

            end_time = time.time()

            logging.info(
                f"Warmup completed in {(end_time - start_time)} seconds. "
                "This is expected to take a few minutes on the first run."
            )

            if cache_loaded != OperationStatus.SUCCESS:
                # Save compile cache for future runs
                outcome: OperationStatus = save_compile_cache()
See the full example in the truss-examples repo.