Get more control by directly using the request object.
Classically, the truss server accepts the client request and takes care of
extracting and optionally validating the payload (if you use pydantic
annotations for your input).
In advanced use case you might want to access the raw request object.
Examples are:
Custom payload deserialization, e.g. binary protocol buffers.
Handling of closed connections and cancelling long-running predictions.
The request argument can have any name, but it must be type-annotated as a
subclass of starlette.requests.Request, e.g. fastapi.Request.
If you only use requests, the truss server skips extracting the payload
completely to improve performance.
When mixing requests and “classic” inputs, the following rules apply:
When using both arguments, request must be the second (and type-annotated).
Argument names don’t matter - but for clarity we recommend naming
“classic” inputs inputs (or data) and not request.
When none of the three methods preprocess, predict and postprocess need
the classic inputs, the payload parsing by truss server is skipped.
The request stays the same through the sequence of preprocess,
predict and postprocess. I.e. for example preprocessing only affects
the classic inputs, and predict will receive transformed inputs, but the
identical request object.
postprocess cannot use only the request - because that would mean the
output of predict is discarded.
Likewise, if predict only uses the request, you cannot have a
preprocessing method, because its outputs would be discarded.
If you make long-running predictions, such as LLM generation, and the client
drops the connection, it is useful to cancel the processing and free up the
server to handle other requests. You can use the is_disconnected method of
the request to check for dropped connections.
Here is a simple example:
import fastapi, asyncio, loggingclassModel:asyncdefpredict(self, inputs, request: fastapi.Request):await asyncio.sleep(1)ifawait request.is_disconnected(): logging.warning("Cancelled (before gen).")# Cancel the request on the model engine here.****returnfor i inrange(5):await asyncio.sleep(1.0) logging.warning(i)yieldstr(i)ifawait request.is_disconnected(): logging.warning("Cancelled (during gen).")# Cancel the request on the model engine here.return
In a real example you must add cancelling the model engine, which
looks different depending on the framework used. Below are some examples.
The TRT-LLM example additionally showcases usage of an async polling
task for checking the connection, so you can cancel between yields from
the model and are not restricted by how long those yields take. The same logic
can also be used for other frameworks.
If you serve models with TRT LLM, you can use the cancel API of the response
generator.
import asyncioimport jsonimport loggingfrom typing import AsyncGenerator, Awaitable, Callableimport tritonclient.grpc.aio as grpcclientGRPC_SERVICE_PORT =8001logger = logging.getLogger(__name__)classTritonClient:def__init__(self, grpc_service_port:int= GRPC_SERVICE_PORT): self.grpc_service_port = grpc_service_port self._grpc_client =Nonedefstart_grpc_stream(self)-> grpcclient.InferenceServerClient:if self._grpc_client:return self._grpc_client self._grpc_client = grpcclient.InferenceServerClient( url=f"localhost:{self.grpc_service_port}", verbose=False)return self._grpc_clientasyncdefinfer( self, model_input, is_cancelled_fn: Callable[[], Awaitable[bool]], model_name="ensemble")-> AsyncGenerator[str,None]: grpc_client_instance = self.start_grpc_stream() inputs = model_input.to_tensors()asyncdefinput_generator():yield{"model_name": model_name,"inputs": inputs,"request_id": model_input.request_id,} stats =await grpc_client_instance.get_inference_statistics() logger.info(stats) response_iterator = grpc_client_instance.stream_infer( inputs_iterator=input_generator(),)ifawait is_cancelled_fn(): logging.info("Request cancelled before streaming. Cancelling Triton request.") response_iterator.cancel()return# Below wraps the iteration of `response_iterator` into asyncio tasks, so that# we can poll `is_cancelled_fn` 1/sec, even if the iterator is slower - i.e.# we are not blocked / limited by the iterator when cancelling.try: gen_task = asyncio.ensure_future(response_iterator.__anext__())whileTrue: done_task, _ =await asyncio.wait([gen_task], timeout=1)ifawait is_cancelled_fn(): logging.info("Request cancelled. Cancelling Triton request.") response_iterator.cancel() gen_task.cancel()returnif done_task:try: response =await gen_taskexcept StopAsyncIteration:# response_iterator is exhausted, breaking `while True` loop.return result, error = responseif result: result = result.as_numpy("text_output")yield result[0].decode("utf-8")else:# Error.yield json.dumps({"status":"error","message": error.message()})return gen_task = asyncio.ensure_future(response_iterator.__anext__())except grpcclient.InferenceServerException as e: logger.error(f"InferenceServerException: {e}")
If you serve models with vLLM, you can use the abort API of the engine.
Here is a minimal code snippet from
their documentation.
# Please refer to entrypoints/api_server.py for# the complete example.# initialize the engine and the example inputengine = AsyncLLMEngine.from_engine_args(engine_args)example_input ={"prompt":"What is LLM?","stream":False,# assume the non-streaming case"temperature":0.0,"request_id":0,}# start the generationresults_generator = engine.generate( example_input["prompt"], SamplingParams(temperature=example_input["temperature"]), example_input["request_id"])# get the resultsfinal_output =Noneasyncfor request_output in results_generator:ifawait request.is_disconnected():# Abort the request if the client disconnects.await engine.abort(request_id)# Return or raise an error... final_output = request_output# Process and return the final output...
The following features of requests are not supported:
Streaming inputs (“file upload”). For large input data, provide URLs to
download source data instead of packing it into the request.
Headers: most client-side headers are stripped from the requests reaching
the model. Add any information to control model behavior to the request payload.