Using request objects / Cancellation
Get more control by directly using the request object.
Classically, the truss server accepts the client request and takes care of
extracting and optionally validating the payload (if you use pydantic
annotations for your input).
In advanced use case you might want to access the raw request object. Examples are:
- Custom payload deserialization, e.g. binary protocol buffers.
- Handling of closed connections and cancelling long-running predictions.
You can flexibly mix and match using requests and the “classic” input argument. E.g.
import fastapi
class Model:
...
def preprocess(self, inputs):
def preprocess(self, inputs, request: fastapi.Request):
def preprocess(self, request: fastapi.Request):
...
def predict(self, inputs):
def predict(self, inputs, request: fastapi.Request):
def predict(self, request: fastapi.Request):
...
def postprocess(self, inputs):
def postprocess(self, inputs, request: fastapi.Request):
...
The request argument can have any name, but it must be type-annotated as a
subclass of starlette.requests.Request
, e.g. fastapi.Request
.
If you only use requests, the truss server skips extracting the payload completely to improve performance.
When mixing requests and “classic” inputs, the following rules apply:
- When using both arguments, request must be the second (and type-annotated).
- Argument names don’t matter - but for clarity we recommend naming
“classic” inputs
inputs
(ordata
) and notrequest
. - When none of the three methods
preprocess
,predict
andpostprocess
need the classic inputs, the payload parsing by truss server is skipped. - The request stays the same through the sequence of
preprocess
,predict
andpostprocess
. I.e. for examplepreprocessing
only affects the classic inputs, andpredict
will receive transformed inputs, but the identical request object. postprocess
cannot use only the request - because that would mean the output ofpredict
is discarded.- Likewise, if
predict
only uses the request, you cannot have apreprocessing
method, because its outputs would be discarded.
Cancelling predictions
If you make long-running predictions, such as LLM generation, and the client
drops the connection, it is useful to cancel the processing and free up the
server to handle other requests. You can use the is_disconnected
method of
the request to check for dropped connections.
Here is a simple example:
import fastapi, asyncio, logging
class Model:
async def predict(self, inputs, request: fastapi.Request):
await asyncio.sleep(1)
if await request.is_disconnected():
logging.warning("Cancelled (before gen).")
# Cancel the request on the model engine here.****
return
for i in range(5):
await asyncio.sleep(1.0)
logging.warning(i)
yield str(i)
if await request.is_disconnected():
logging.warning("Cancelled (during gen).")
# Cancel the request on the model engine here.
return
In a real example you must add cancelling the model engine, which looks different depending on the framework used. Below are some examples.
The TRT-LLM example additionally showcases usage of an async polling task for checking the connection, so you can cancel between yields from the model and are not restricted by how long those yields take. The same logic can also be used for other frameworks.
The following features of requests are not supported:
- Streaming inputs (“file upload”). For large input data, provide URLs to download source data instead of packing it into the request.
- Headers: most client-side headers are stripped from the requests reaching the model. Add any information to control model behavior to the request payload.