-
Notifications
You must be signed in to change notification settings - Fork 78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce truss server passthrough for OpenAI methods #1364
Conversation
return None | ||
|
||
@classmethod | ||
def from_model(cls, model) -> "ModelDescriptor": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes to this method are intended to be no-ops except for the addition of completions / chats on L271/272. Everything else is a refactor
return await model_fn(*args) | ||
return await to_thread.run_sync(model_fn, *args) | ||
|
||
async def _trace_and_process_model_fn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All future OpenAI compatible endpoints (and honestly anything else we want to expose via Truss Server) will likely look very similar to these endpoints. predict
is more opinionated since it orchestrates the pre -> predict -> post flow, and I didn't want to make changes there for this PR.
A dedicated followup might be able to reuse this type of helper in that flow, but it'll be easier to review in isolation
@@ -617,7 +624,7 @@ async def _stream_with_background_task( | |||
generator: Union[Generator[bytes, None, None], AsyncGenerator[bytes, None]], | |||
span: trace.Span, | |||
trace_ctx: trace.Context, | |||
release_and_end: Callable[[], None], | |||
release_and_end: Callable[[], None] = lambda: None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We intentionally don't want to reuse the predict
semaphore for these endpoints, and we want to encourage users to build concurrency controls in their model code if they need it
A great future goal could be a @truss.concurrency(max_requests = 2)
or something
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explicitly pass the argument where you want deviating behavior instead of setting a default?
I don't see a different/changed call-site of _write_response_to_queue
anyway in this PR - is this something to come?
return await model.chat_completions(inputs, request) | ||
|
||
return await self._execute_request( | ||
model_name=MODEL_BASENAME, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed that the current flow hardcodes model
in beefeater and truss server. I poked around and it seems to have roots in kserve
which expects a different serving model, but for now I figured it was ok to hardcode.
OpenAI clients will be opinionated about the URL they hit, so we'd have to add some mapping logic in beefeater if we wanted to preserve this URL param.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the whole ability to "name" a model is a relic from ancient times. Right now there is a 1:1 relation between a truss server deployment and a model (and the model "name" is completely irrelevant).
Can we nuke this whole "feature" and simplify code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm sure we could! I think it might actually make this PR more complicated though, but I'm happy to take on as a pure refactor followup if that works? For now, I think hardcoding this name
so that the other code paths work as is seems like the simplest path forward
255b9fa
to
649806d
Compare
0f6d468
to
21a88eb
Compare
pass | ||
|
||
def chat_completions(self, input: Dict) -> str: | ||
return "chat_completions" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you e.g. use non-streaming openai completions and test open ai compatibility with their client directly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this in truss-chains
?
https://github.com/basetenlabs/truss/blob/main/truss/tests/test_model_inference.py
Also, you can use with _temp_truss
there so you don't need add a lot of minifiles.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good suggestion! Makes more sense there, hadn't seen this file before
@@ -617,7 +624,7 @@ async def _stream_with_background_task( | |||
generator: Union[Generator[bytes, None, None], AsyncGenerator[bytes, None]], | |||
span: trace.Span, | |||
trace_ctx: trace.Context, | |||
release_and_end: Callable[[], None], | |||
release_and_end: Callable[[], None] = lambda: None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explicitly pass the argument where you want deviating behavior instead of setting a default?
I don't see a different/changed call-site of _write_response_to_queue
anyway in this PR - is this something to come?
|
||
@classmethod | ||
def _safe_extract_descriptor( | ||
cls, model: Any, method: str |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: method
-> method_name
.
|
||
def _gen_truss_schema( | ||
cls, | ||
model: Any, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: it was already like this, but it might be better to make clear this is a class not instance so model
-> model_cls
or so.
@@ -49,6 +49,17 @@ | |||
TRT_LLM_EXTENSION_NAME = "trt_llm" | |||
POLL_FOR_ENVIRONMENT_UPDATES_TIMEOUT_SECS = 30 | |||
|
|||
|
|||
class ModelMethod(Enum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better name ModelMethodKind
.
descriptor, inputs, request, self._model.completions | ||
) | ||
|
||
return await self._trace_and_process_model_fn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we wrap self._model.completions
into a nested function def exec_fn
which calls _execute_async_model_fn
and then we allso call _trace_and_process_model_fn
on that.
I haven't fully wrapped my head around all nuances here, but is this really necessary? Is there are more concise way? What would the stack traces look like if an exception is raised in self._model.completions
- do we need to update any logging/stack filtering logic for that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mentioned in #1364 (comment), but reduced one layer of wrapping and I think you're right it's cleaner!
async def _execute_request( | ||
self, | ||
model_name: str, | ||
method: ModelMethod, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: update naming to be exact.
self, model_name: str, request: Request, body_raw: bytes = Depends(parse_body) | ||
async def _execute_request( | ||
self, | ||
model_name: str, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider nuking this whole model_name
parameterization, it seems not needed for anything anymore. But ask in truss/core-product channel if someone knows a reason to keep it.
async def chat_completions( | ||
self, request: Request, body_raw: bytes = Depends(parse_body) | ||
) -> Response: | ||
async def execution_fn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have all this method wrapping in ModelWrapper
(see comment there) - I'm confused why there is even more here now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah good question, I was also trying to trace the different responsibilities. TrussServer has lots of common logic to deal with things at the HTTP layer - parse request body, start root span, execute model code, and then generate HTTP response. Everything except for (3) is the same across predict/completions/chat completions, which is why I introduced the thin wrapping layer for execution_fn
.
Very similarly, ModelWrapper
has lots of similarities for how it executes things - starts spans, decides how to run code in threads, and then deals with return values / generators.
Inside model wrapper we can reduce one layer of wrapping by passing in more parameters, let me know if you like that more.
d70b857
to
719d760
Compare
descriptor, | ||
inputs, | ||
request, | ||
self._model.preprocess, | ||
supports_generators=False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now that you newly introduced this generic _execute_async_model_fn
it would make sense to evolve MethodDescriptor
accordingly: (descriptor, preprocess supports_generators)
all conceptually refer to the same thing.
This function signature here is overly long, the static properties of the model and the dynamic inputs (inputs
, request
) are interleaved without an order and the most important: there is no builtin coherence for the static properties.
My suggestion is to update MethodDescriptor
so that it works harmonically with _execute_async_model_fn
with a concise API, by bundling (descriptor, preprocess supports_generators)
into it.
And you could also move the (repetitive) assertion into _execute_async_model_fn
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a pretty general pattern for better interfaces and higher cohesion, try to apply this wherever possible :)
if inspect.isgenerator(result) or inspect.isasyncgen(result): | ||
if request.headers.get("accept") == "application/json": | ||
return await _gather_generator(result) | ||
else: | ||
return await self._stream_with_background_task( | ||
result, | ||
fn_span, | ||
detached_ctx, | ||
# No semaphores needed for non-predict model functions. | ||
release_and_end=lambda: None, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we ok with _stream_with_background_task
being entered for any model_fn
?
In particular, what if all pre- predict and post-process are generators (or is this not allowed anyway)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe it's not clear that _trace_and_process_model_fn
is only intended for some methods - what does "model_fn" mean? Can you constrain it with assertions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All really good questions, I'l leave extra comments in the code but will try to clarify here as well. With this PR, I propose 2 ways of invoking code in model wrapper: (1) __call__
- extremely specific to predict
orchestration, has lots of validation on specific combinations (2) _trace_and_process_model_fn
- ideally intended for any other standalone model function that doesn't need orchestration.
Therefore, I think _trace_and_process_model_fn
should be as generic as possible for now, and support the cross product of (async, sync) x (generator, non-generator). For our use case now, both completions / chat completions should be allowed to be any of the above. If we find a use case where one of those combinations is invalid, we should add more metadata to MethodDescriptor
and check that here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok understoo. Maybe add a one-liner docstring for the intended usage of _trace_and_process_model_fn
(or _process_model_fn
if you change the name).
In that case we could maybe also rename __call__
to a more sepcific name.
request: starlette.requests.Request, | ||
method_name: MethodName, | ||
descriptor: MethodDescriptor, | ||
model_fn: Any, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you create a suitable type alias for the callable?
MODEL_BASENAME, | ||
InputType, | ||
MethodDescriptor, | ||
MethodName, | ||
ModelWrapper, | ||
OutputType, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Related to above: the fact that you need to depend on these "internal" of model wrapper means that the abstraction is not really good (yet) - can be improved.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed on MODEL_BASENAME
, but still think it's worthwhile to rip that dependency out in a different PR. MethodDescriptor
already has a couple dependencies in truss server, but now it's more explicit since we need it for type hints.
Overall I agree w the sentiment here so let's see what we can do!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If a symbol is only used for type checking, you can make the import conditional.
async def execution_fn( | ||
model: ModelWrapper, inputs: InputType, request: Request | ||
) -> OutputType: | ||
self._raise_if_not_supported( | ||
MethodName.COMPLETIONS, model.model_descriptor.completions | ||
) | ||
return await model.completions(inputs, request) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you move that extra code into model.completions
instead of wrapping a function?
Besides creating convoluted code and stack traces, this is also confusing, because you shadow inputs
and request
and it might look like you capture them in the wrapped function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could, but it made sense to me to have truss_server
deal with the HTTP layer and throw the 404 in the helper. It likely doesn't make sense for ModelWrapper
code to throw an opinionated status code.
I'd vote we either throw a different exception that gets translated into a status code via truss server, or find a different way to have shared code here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand your desire to separate the HTTP layer. Would throwing a non-HTTP exception and then adjusting the exception handler in errors.py
work for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After the most recent refactor I feel less strongly here, we just have a couple lines that are slightly duplicated. In the future we can consider pushing some of the error handling to model wrapper, and then doing the status code translation as you suggested
@@ -68,6 +68,7 @@ class MethodName(str, enum.Enum): | |||
"starlette.responses.Response", | |||
pydantic.BaseModel, | |||
] | |||
ModelFn = Callable[..., Union[OutputType, Awaitable[OutputType]]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From this comment, we can introduce a slight type constraint, but I'm not sure how much benefit it brings:
- Seems non trivial to have discriminated union functionality to tie the
ArgConfig
and the underlyingCallable
together (i.e. ifArgConfig.NONE
, then the type system knows it's aCallable[[], OutputType]
) - Similar to above, tricky to have the same discrimination on async vs sync
To avoid all that, I made it have variable arguments for now, given that ArgConfig
will do validation on parse. We have to do a couple explicit casts to make the type system happy as a result though.
abc11c1
to
f5a22da
Compare
) | ||
|
||
@classmethod | ||
def _is_async(cls, method: Any): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Confirmed that await
on a function that returns an AsyncGenerator
is a type error, so I wonder if we should actually remove the inspect.isasyncgenfunction(method)
clause from is_async
.
Technically doesn't matter, since we agreed to explicitly check for generators first which will avoid the await
. However, it's likely confusing for future readers why we bucketed Coroutine and AsyncGenerators together.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That reasoning make sense to me. But it might need some comment why we do not check isasyncgenfunction
.
|
||
|
||
@pytest.mark.integration | ||
def test_postprocess_async_generator_streaming(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It feels very unlikely that a real user would ever be interested in either of the below test cases, but wanted to show that we technically support it
a220749
to
66e9da5
Compare
descriptor = self.model_descriptor.completions | ||
assert descriptor, ( | ||
f"`{MethodName.COMPLETIONS}` must only be called if model has it." | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are 5 lines to be repeated for each endpoint. You could move the assertion for descriptor not none into _process_model_fn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm considering a future refactor where the descriptor is always present with the name of the function, but has a different indicator that the underlying function wasn't implemented by the user. As written now, I'd have to explicitly pass the MethodName
as an additional argument purely for the error message, which didn't seem worthwhile. I'll explore this in a followup!
fn_span = self._tracer.start_span(f"call-{method.value}") | ||
with tracing.section_as_event( | ||
fn_span, method.value | ||
), tracing.detach_context() as detached_ctx: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you don't do this in this PR, could you tag all ocurrences with a TODO note, please?
model = self._safe_lookup_model(MODEL_BASENAME) | ||
self._raise_if_not_supported( | ||
MethodName.CHAT_COMPLETIONS, model.model_descriptor.chat_completions | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think all of this could be parameterized and moved into _execute_request
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to pass a constant as an argument self._safe_lookup_model(MODEL_BASENAME)
here?
This could be a property self._safe_model
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're likely right, but as of now would result in another parameter being passed to _execute_request
. Let me explore adding more metadata to method_descriptor
in a followup (as well as some of the other smaller improvements mentioned on this PR).
97bab5e
to
a12f649
Compare
🚀 What
This PR adds
truss
server compatibility for additional OpenAI endpoints - we support/v1/completions
and/v1/chat/completions
for now, but it should be straightforward to add more in the future if needed.Notes / possible next steps:
code_gen
models / chainsmodel
existing, even though we ignore on BE💻 How
🔬 Testing
0.9.60rc005
and confirmed with a sample OpenAI client script