import requestsfrom typing import Dictfrom PIL import Imagefrom transformers import CLIPProcessor, CLIPModelCHECKPOINT ="openai/clip-vit-base-patch32"classModel:""" This is simple example of using CLIP to classify images. It outputs the probability of the image being a cat or a dog."""def__init__(self,**kwargs)->None: self._processor =None self._model =Nonedefload(self):""" Loads the CLIP model and processor checkpoints.""" self._model = CLIPModel.from_pretrained(CHECKPOINT) self._processor = CLIPProcessor.from_pretrained(CHECKPOINT)defpreprocess(self, request: Dict)-> Dict:"""" This method downloads the image from the url and preprocesses it. The preprocess method is used forany logic that involves IO,in thiscase downloading the image. It is called before the predict methodin a separate thread andisnot subject to the same concurrency limits as the predict method, so can be called many times in parallel.""" image = Image.open(requests.get(request.pop("url"), stream=True).raw) request["inputs"]= self._processor( text=["a photo of a cat","a photo of a dog"], images=image, return_tensors="pt", padding=True)return requestdefpredict(self, request: Dict)-> Dict:""" This performs the actual classification. The predict method is subject to the predict concurrency constraints.""" outputs = self._model(**request["inputs"]) logits_per_image = outputs.logits_per_imagereturn logits_per_image.softmax(dim=1).tolist()
Out of the box, Truss limits the amount of concurrent predicts that happen on
single container. This ensures that the CPU, and for many models the GPU, do not get
overloaded, and that the model can continue respond to requests in periods of high load
However, many models, in addition to having compute components, also have
IO requirements. For example, a model that classifies images may need to download
the image from a URL before it can classify it.
Truss provides a way to separate the IO component from the compute component, to
ensure that any IO does not prevent utilization of the compute on your pod.
To do this, you can use the pre/post process methods on a Truss. These methods
can be defined like this:
classModel:def __init__:...defload(self,**kwargs)->None:...defpreprocess(self, request):# Include any IO logic that happens _before_ predict here...defpredict(self, request):# Include the actual predict here...defpostprocess(self, response):# Include any IO logic that happens _after_ predict here...
What happens when the model is invoked is that any logic defined in the pre or post-process
methods happen on a separate thread, and are not subject to the same concurrency limits as
predict. So β letβs say you have a model that can handle 5 concurrent requests:
...runtime: predict_concurrency: 5...
If you hit it with 10 requests, they will all begin pre-processing, but then when the
the 6th request is ready to begin the predict method, it will have to wait for one of the
first 5 requests to finish. This ensures that the GPU is not overloaded, while also ensuring
that the compute logic does not get blocked by IO, thereby ensuring that you can achieve
maximum throughput.
When predict returns a generator (e.g. for streaming LLM outputs),
the model must not have a postprocessing method
It can only be used when the prediction result is instantly available as a
whole. In case of streaming, move any postprocessing logic into predict or
apply it client-side.
import requestsfrom typing import Dictfrom PIL import Imagefrom transformers import CLIPProcessor, CLIPModelCHECKPOINT ="openai/clip-vit-base-patch32"classModel:""" This is simple example of using CLIP to classify images. It outputs the probability of the image being a cat or a dog."""def__init__(self,**kwargs)->None: self._processor =None self._model =Nonedefload(self):""" Loads the CLIP model and processor checkpoints.""" self._model = CLIPModel.from_pretrained(CHECKPOINT) self._processor = CLIPProcessor.from_pretrained(CHECKPOINT)defpreprocess(self, request: Dict)-> Dict:"""" This method downloads the image from the url and preprocesses it. The preprocess method is used forany logic that involves IO,in thiscase downloading the image. It is called before the predict methodin a separate thread andisnot subject to the same concurrency limits as the predict method, so can be called many times in parallel.""" image = Image.open(requests.get(request.pop("url"), stream=True).raw) request["inputs"]= self._processor( text=["a photo of a cat","a photo of a dog"], images=image, return_tensors="pt", padding=True)return requestdefpredict(self, request: Dict)-> Dict:""" This performs the actual classification. The predict method is subject to the predict concurrency constraints.""" outputs = self._model(**request["inputs"]) logits_per_image = outputs.logits_per_imagereturn logits_per_image.softmax(dim=1).tolist()
import requestsfrom typing import Dictfrom PIL import Imagefrom transformers import CLIPProcessor, CLIPModelCHECKPOINT ="openai/clip-vit-base-patch32"classModel:""" This is simple example of using CLIP to classify images. It outputs the probability of the image being a cat or a dog."""def__init__(self,**kwargs)->None: self._processor =None self._model =Nonedefload(self):""" Loads the CLIP model and processor checkpoints.""" self._model = CLIPModel.from_pretrained(CHECKPOINT) self._processor = CLIPProcessor.from_pretrained(CHECKPOINT)defpreprocess(self, request: Dict)-> Dict:"""" This method downloads the image from the url and preprocesses it. The preprocess method is used forany logic that involves IO,in thiscase downloading the image. It is called before the predict methodin a separate thread andisnot subject to the same concurrency limits as the predict method, so can be called many times in parallel.""" image = Image.open(requests.get(request.pop("url"), stream=True).raw) request["inputs"]= self._processor( text=["a photo of a cat","a photo of a dog"], images=image, return_tensors="pt", padding=True)return requestdefpredict(self, request: Dict)-> Dict:""" This performs the actual classification. The predict method is subject to the predict concurrency constraints.""" outputs = self._model(**request["inputs"]) logits_per_image = outputs.logits_per_imagereturn logits_per_image.softmax(dim=1).tolist()