import requests
from typing import Dict
from PIL import Image
from transformers import CLIPProcessor, CLIPModel

CHECKPOINT = "openai/clip-vit-base-patch32"


class Model:
    """
    This is simple example of using CLIP to classify images.
    It outputs the probability of the image being a cat or a dog.
    """
    def __init__(self, **kwargs) -> None:
        self._processor = None
        self._model = None

    def load(self):
        """
        Loads the CLIP model and processor checkpoints.
        """
        self._model = CLIPModel.from_pretrained(CHECKPOINT)
        self._processor = CLIPProcessor.from_pretrained(CHECKPOINT)

    def preprocess(self, request: Dict) -> Dict:
        """"
        This method downloads the image from the url and preprocesses it.
        The preprocess method is used for any logic that involves IO, in this
        case downloading the image. It is called before the predict method
        in a separate thread and is not subject to the same concurrency
        limits as the predict method, so can be called many times in parallel.
        """
        image = Image.open(requests.get(request.pop("url"), stream=True).raw)
        request["inputs"] = self._processor(
            text=["a photo of a cat", "a photo of a dog"],
            images=image,
            return_tensors="pt",
            padding=True
        )
        return request

    def predict(self, request: Dict) -> Dict:
        """
        This performs the actual classification. The predict method is subject to
        the predict concurrency constraints.
        """
        outputs = self._model(**request["inputs"])
        logits_per_image = outputs.logits_per_image
        return logits_per_image.softmax(dim=1).tolist()

Out of the box, Truss limits the amount of concurrent predicts that happen on single container. This ensures that the CPU, and for many models the GPU, do not get overloaded, and that the model can continue respond to requests in periods of high load

However, many models, in addition to having compute components, also have IO requirements. For example, a model that classifies images may need to download the image from a URL before it can classify it.

Truss provides a way to separate the IO component from the compute component, to ensure that any IO does not prevent utilization of the compute on your pod.

To do this, you can use the pre/post process methods on a Truss. These methods can be defined like this:

class Model:
    def __init__: ...
    def load(self, **kwargs) -> None: ...
    def preprocess(self, request):
        # Include any IO logic that happens _before_ predict here
        ...

    def predict(self, request):
        # Include the actual predict here
        ...

    def postprocess(self, response):
        # Include any IO logic that happens _after_ predict here
        ...

What happens when the model is invoked is that any logic defined in the pre or post-process methods happen on a separate thread, and are not subject to the same concurrency limits as predict. So β€” let’s say you have a model that can handle 5 concurrent requests:

...
runtime:
    predict_concurrency: 5
...

If you hit it with 10 requests, they will all begin pre-processing, but then when the the 6th request is ready to begin the predict method, it will have to wait for one of the first 5 requests to finish. This ensures that the GPU is not overloaded, while also ensuring that the compute logic does not get blocked by IO, thereby ensuring that you can achieve maximum throughput.

When predict returns a generator (e.g. for streaming LLM outputs), the model must not have a postprocessing method It can only be used when the prediction result is instantly available as a whole. In case of streaming, move any postprocessing logic into predict or apply it client-side.
import requests
from typing import Dict
from PIL import Image
from transformers import CLIPProcessor, CLIPModel

CHECKPOINT = "openai/clip-vit-base-patch32"


class Model:
    """
    This is simple example of using CLIP to classify images.
    It outputs the probability of the image being a cat or a dog.
    """
    def __init__(self, **kwargs) -> None:
        self._processor = None
        self._model = None

    def load(self):
        """
        Loads the CLIP model and processor checkpoints.
        """
        self._model = CLIPModel.from_pretrained(CHECKPOINT)
        self._processor = CLIPProcessor.from_pretrained(CHECKPOINT)

    def preprocess(self, request: Dict) -> Dict:
        """"
        This method downloads the image from the url and preprocesses it.
        The preprocess method is used for any logic that involves IO, in this
        case downloading the image. It is called before the predict method
        in a separate thread and is not subject to the same concurrency
        limits as the predict method, so can be called many times in parallel.
        """
        image = Image.open(requests.get(request.pop("url"), stream=True).raw)
        request["inputs"] = self._processor(
            text=["a photo of a cat", "a photo of a dog"],
            images=image,
            return_tensors="pt",
            padding=True
        )
        return request

    def predict(self, request: Dict) -> Dict:
        """
        This performs the actual classification. The predict method is subject to
        the predict concurrency constraints.
        """
        outputs = self._model(**request["inputs"])
        logits_per_image = outputs.logits_per_image
        return logits_per_image.softmax(dim=1).tolist()