import asyncio
import truss_chains as chains
class Preprocess2x(chains.ChainletBase):
async def run_remote(self, number: int) -> int:
return 2 * number
class MyBaseChainlet(chains.ChainletBase):
remote_config = chains.RemoteConfig(
compute=chains.Compute(cpu_count=1, memory="100Mi"),
options=chains.ChainletOptions(enable_b10_tracing=True),
)
def __init__(self, preprocess=chains.depends(Preprocess2x)):
self._preprocess = preprocess
async def run_remote(self, number: int) -> float:
return 1.0 / await self._preprocess.run_remote(number)
# Assert base behavior.
with chains.run_local():
chainlet = MyBaseChainlet()
result = asyncio.get_event_loop().run_until_complete(chainlet.run_remote(4))
assert result == 1 / (4 * 2)