Skip to content

Commit

Permalink
Various refactors to model wrapper + truss server (#1379)
Browse files Browse the repository at this point in the history
  • Loading branch information
nnarayen authored Feb 10, 2025
1 parent a253680 commit 35e8d3e
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 100 deletions.
9 changes: 9 additions & 0 deletions truss/templates/server/common/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class UserCodeError(Exception):
pass


class ModelMethodNotImplemented(Exception):
pass


class ModelDefinitionError(TypeError):
"""When the user-defined truss model does not meet the contract."""

Expand Down Expand Up @@ -96,6 +100,10 @@ async def exception_handler(_: fastapi.Request, exc: Exception) -> fastapi.Respo
"Internal Server Error",
_BASETEN_DOWNSTREAM_ERROR_CODE,
)
if isinstance(exc, ModelMethodNotImplemented):
return _make_baseten_response(
HTTPStatus.NOT_FOUND.value, exc, _BASETEN_CLIENT_ERROR_CODE
)
if isinstance(exc, fastapi.HTTPException):
# This is a pass through, but additionally adds our custom error headers.
return _make_baseten_response(
Expand All @@ -117,6 +125,7 @@ async def exception_handler(_: fastapi.Request, exc: Exception) -> fastapi.Respo
UserCodeError,
ModelDefinitionError,
fastapi.HTTPException,
ModelMethodNotImplemented,
}


Expand Down
84 changes: 47 additions & 37 deletions truss/templates/server/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def defer() -> Callable[[], None]:


_ArgsType = Union[
Tuple[()],
Tuple[Any],
Tuple[Any, starlette.requests.Request],
Tuple[starlette.requests.Request],
Expand Down Expand Up @@ -132,11 +133,9 @@ class ArgConfig(enum.Enum):
INPUTS_AND_REQUEST = enum.auto()

@classmethod
def from_signature(
cls, signature: inspect.Signature, method_name: str
) -> "ArgConfig":
def from_method(cls, method: Any, method_name: MethodName) -> "ArgConfig":
signature = inspect.signature(method)
parameters = list(signature.parameters.values())

if len(parameters) == 0:
return cls.NONE
elif len(parameters) == 1:
Expand Down Expand Up @@ -173,6 +172,8 @@ def prepare_args(
descriptor: "MethodDescriptor",
) -> _ArgsType:
args: _ArgsType
if descriptor.arg_config == ArgConfig.NONE:
args = ()
if descriptor.arg_config == ArgConfig.INPUTS_ONLY:
args = (inputs,)
elif descriptor.arg_config == ArgConfig.REQUEST_ONLY:
Expand All @@ -197,7 +198,7 @@ def from_method(cls, method: Any, method_name: MethodName) -> "MethodDescriptor"
return cls(
is_async=cls._is_async(method),
is_generator=cls._is_generator(method),
arg_config=ArgConfig.from_signature(inspect.signature(method), method_name),
arg_config=ArgConfig.from_method(method, method_name),
method_name=method_name,
# ArgConfig ensures that the Callable has an appropriate signature.
method=cast(ModelFn, method),
Expand Down Expand Up @@ -234,22 +235,19 @@ def skip_input_parsing(self) -> bool:
@classmethod
def _gen_truss_schema(
cls,
model_cls: Any,
predict: MethodDescriptor,
preprocess: Optional[MethodDescriptor],
postprocess: Optional[MethodDescriptor],
) -> TrussSchema:
if preprocess:
parameters = inspect.signature(model_cls.preprocess).parameters
parameters = inspect.signature(preprocess.method).parameters
else:
parameters = inspect.signature(model_cls.predict).parameters
parameters = inspect.signature(predict.method).parameters

if postprocess:
return_annotation = inspect.signature(
model_cls.postprocess
).return_annotation
return_annotation = inspect.signature(postprocess.method).return_annotation
else:
return_annotation = inspect.signature(model_cls.predict).return_annotation
return_annotation = inspect.signature(predict.method).return_annotation

return TrussSchema.from_signature(parameters, return_annotation)

Expand All @@ -267,7 +265,7 @@ def _safe_extract_descriptor(
def from_model(cls, model_cls) -> "ModelDescriptor":
preprocess = cls._safe_extract_descriptor(model_cls, MethodName.PREPROCESS)
predict = cls._safe_extract_descriptor(model_cls, MethodName.PREDICT)
if predict is None:
if not predict:
raise errors.ModelDefinitionError(
f"Truss model must have a `{MethodName.PREDICT}` method."
)
Expand All @@ -294,11 +292,9 @@ def from_model(cls, model_cls) -> "ModelDescriptor":
)

truss_schema = cls._gen_truss_schema(
model_cls=model_cls,
predict=predict,
preprocess=preprocess,
postprocess=postprocess,
predict=predict, preprocess=preprocess, postprocess=postprocess
)

return cls(
preprocess=preprocess,
predict=predict,
Expand Down Expand Up @@ -377,6 +373,14 @@ def ready(self) -> bool:
def _model_file_name(self) -> str:
return self._config["model_class_filename"]

@property
def skip_input_parsing(self) -> bool:
return self.model_descriptor.skip_input_parsing

@property
def truss_schema(self) -> Optional[TrussSchema]:
return self.model_descriptor.truss_schema

def start_load_thread(self):
# Don't retry failed loads.
if self._status == ModelWrapper.Status.NOT_READY:
Expand Down Expand Up @@ -568,7 +572,7 @@ async def is_healthy(self) -> Optional[bool]:
is_healthy: Optional[bool] = None
if not descriptor or self.load_failed:
# return early with None if model does not have is_healthy method or load failed
return is_healthy
return None
try:
if descriptor.is_async:
is_healthy = await self._model.is_healthy()
Expand All @@ -594,16 +598,16 @@ async def preprocess(
assert descriptor, (
f"`{MethodName.PREPROCESS}` must only be called if model has it."
)
return await self._execute_async_model_fn(inputs, request, descriptor)
return await self._execute_user_model_fn(inputs, request, descriptor)

async def predict(
async def _predict(
self, inputs: Any, request: starlette.requests.Request
) -> Union[OutputType, Any]:
# The result can be a serializable data structure, byte-generator, a request,
# or, if `postprocessing` is used, anything. In the last case postprocessing
# must convert the result to something serializable.
descriptor = self.model_descriptor.predict
return await self._execute_async_model_fn(inputs, request, descriptor)
return await self._execute_user_model_fn(inputs, request, descriptor)

async def postprocess(
self, result: Union[InputType, Any], request: starlette.requests.Request
Expand All @@ -616,7 +620,7 @@ async def postprocess(
assert descriptor, (
f"`{MethodName.POSTPROCESS}` must only be called if model has it."
)
return await self._execute_async_model_fn(result, request, descriptor)
return await self._execute_user_model_fn(result, request, descriptor)

async def _write_response_to_queue(
self,
Expand Down Expand Up @@ -684,7 +688,7 @@ async def _buffered_response_generator() -> AsyncGenerator[bytes, None]:

return _buffered_response_generator()

async def _execute_async_model_fn(
async def _execute_user_model_fn(
self,
inputs: Union[InputType, Any],
request: starlette.requests.Request,
Expand All @@ -699,7 +703,7 @@ async def _execute_async_model_fn(
return await cast(Awaitable[OutputType], descriptor.method(*args))
return await to_thread.run_sync(descriptor.method, *args)

async def _process_model_fn(
async def _execute_model_endpoint(
self,
inputs: InputType,
request: starlette.requests.Request,
Expand All @@ -713,7 +717,7 @@ async def _process_model_fn(
with tracing.section_as_event(
fn_span, descriptor.method_name
), tracing.detach_context() as detached_ctx:
result = await self._execute_async_model_fn(inputs, request, descriptor)
result = await self._execute_user_model_fn(inputs, request, descriptor)

if inspect.isgenerator(result) or inspect.isasyncgen(result):
return await self._handle_generator_response(
Expand Down Expand Up @@ -747,27 +751,33 @@ async def _handle_generator_response(
generator, span, trace_ctx, cleanup_fn=get_cleanup_fn()
)

def _get_descriptor_or_raise(
self, descriptor: Optional[MethodDescriptor], method_name: MethodName
) -> MethodDescriptor:
if not descriptor:
raise errors.ModelMethodNotImplemented(
f"`{method_name}` must only be called if model has it."
)

return descriptor

async def completions(
self, inputs: InputType, request: starlette.requests.Request
) -> OutputType:
descriptor = self.model_descriptor.completions
assert descriptor, (
f"`{MethodName.COMPLETIONS}` must only be called if model has it."
descriptor = self._get_descriptor_or_raise(
self.model_descriptor.completions, MethodName.COMPLETIONS
)

return await self._process_model_fn(inputs, request, descriptor)
return await self._execute_model_endpoint(inputs, request, descriptor)

async def chat_completions(
self, inputs: InputType, request: starlette.requests.Request
) -> OutputType:
descriptor = self.model_descriptor.chat_completions
assert descriptor, (
f"`{MethodName.CHAT_COMPLETIONS}` must only be called if model has it."
descriptor = self._get_descriptor_or_raise(
self.model_descriptor.chat_completions, MethodName.CHAT_COMPLETIONS
)
return await self._execute_model_endpoint(inputs, request, descriptor)

return await self._process_model_fn(inputs, request, descriptor)

async def __call__(
async def predict(
self, inputs: Optional[InputType], request: starlette.requests.Request
) -> OutputType:
"""
Expand Down Expand Up @@ -804,7 +814,7 @@ async def __call__(
# exactly handle that case we would need to apply `detach_context`
# around each `next`-invocation that consumes the generator, which is
# prohibitive.
predict_result = await self.predict(preprocess_result, request)
predict_result = await self._predict(preprocess_result, request)

if inspect.isgenerator(predict_result) or inspect.isasyncgen(
predict_result
Expand Down
Loading

0 comments on commit 35e8d3e

Please sign in to comment.