Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce truss server passthrough for OpenAI methods #1364

Merged
merged 8 commits into from
Feb 6, 2025

Conversation

nnarayen
Copy link
Contributor

@nnarayen nnarayen commented Feb 3, 2025

🚀 What

This PR adds truss server compatibility for additional OpenAI endpoints - we support /v1/completions and /v1/chat/completions for now, but it should be straightforward to add more in the future if needed.

Notes / possible next steps:

  • Determine how we can expose this functionality easily via code_gen models / chains
  • Unfortunate that the client library has validation on model existing, even though we ignore on BE

💻 How

🔬 Testing

  • Additional e2e tests
  • Deployed model to staging with context builder 0.9.60rc005 and confirmed with a sample OpenAI client script

return None

@classmethod
def from_model(cls, model) -> "ModelDescriptor":
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes to this method are intended to be no-ops except for the addition of completions / chats on L271/272. Everything else is a refactor

truss/templates/server/model_wrapper.py Show resolved Hide resolved
return await model_fn(*args)
return await to_thread.run_sync(model_fn, *args)

async def _trace_and_process_model_fn(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All future OpenAI compatible endpoints (and honestly anything else we want to expose via Truss Server) will likely look very similar to these endpoints. predict is more opinionated since it orchestrates the pre -> predict -> post flow, and I didn't want to make changes there for this PR.

A dedicated followup might be able to reuse this type of helper in that flow, but it'll be easier to review in isolation

@@ -617,7 +624,7 @@ async def _stream_with_background_task(
generator: Union[Generator[bytes, None, None], AsyncGenerator[bytes, None]],
span: trace.Span,
trace_ctx: trace.Context,
release_and_end: Callable[[], None],
release_and_end: Callable[[], None] = lambda: None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We intentionally don't want to reuse the predict semaphore for these endpoints, and we want to encourage users to build concurrency controls in their model code if they need it

A great future goal could be a @truss.concurrency(max_requests = 2) or something

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explicitly pass the argument where you want deviating behavior instead of setting a default?

I don't see a different/changed call-site of _write_response_to_queue anyway in this PR - is this something to come?

return await model.chat_completions(inputs, request)

return await self._execute_request(
model_name=MODEL_BASENAME,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that the current flow hardcodes model in beefeater and truss server. I poked around and it seems to have roots in kserve which expects a different serving model, but for now I figured it was ok to hardcode.

OpenAI clients will be opinionated about the URL they hit, so we'd have to add some mapping logic in beefeater if we wanted to preserve this URL param.

Copy link
Contributor

@marius-baseten marius-baseten Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the whole ability to "name" a model is a relic from ancient times. Right now there is a 1:1 relation between a truss server deployment and a model (and the model "name" is completely irrelevant).

Can we nuke this whole "feature" and simplify code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sure we could! I think it might actually make this PR more complicated though, but I'm happy to take on as a pure refactor followup if that works? For now, I think hardcoding this name so that the other code paths work as is seems like the simplest path forward

@nnarayen nnarayen changed the title Nikhil/introduce openai methods Introduce truss server passthrough for OpenAI methods Feb 3, 2025
@nnarayen nnarayen force-pushed the nikhil/introduce-openai-methods branch 4 times, most recently from 255b9fa to 649806d Compare February 3, 2025 23:59
@nnarayen nnarayen force-pushed the nikhil/introduce-openai-methods branch 2 times, most recently from 0f6d468 to 21a88eb Compare February 4, 2025 00:04
pass

def chat_completions(self, input: Dict) -> str:
return "chat_completions"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you e.g. use non-streaming openai completions and test open ai compatibility with their client directly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this in truss-chains?

https://github.com/basetenlabs/truss/blob/main/truss/tests/test_model_inference.py

Also, you can use with _temp_truss there so you don't need add a lot of minifiles.

Copy link
Contributor Author

@nnarayen nnarayen Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion! Makes more sense there, hadn't seen this file before

@@ -617,7 +624,7 @@ async def _stream_with_background_task(
generator: Union[Generator[bytes, None, None], AsyncGenerator[bytes, None]],
span: trace.Span,
trace_ctx: trace.Context,
release_and_end: Callable[[], None],
release_and_end: Callable[[], None] = lambda: None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explicitly pass the argument where you want deviating behavior instead of setting a default?

I don't see a different/changed call-site of _write_response_to_queue anyway in this PR - is this something to come?

truss/templates/server/model_wrapper.py Show resolved Hide resolved

@classmethod
def _safe_extract_descriptor(
cls, model: Any, method: str
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: method -> method_name.


def _gen_truss_schema(
cls,
model: Any,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: it was already like this, but it might be better to make clear this is a class not instance so model -> model_cls or so.

@@ -49,6 +49,17 @@
TRT_LLM_EXTENSION_NAME = "trt_llm"
POLL_FOR_ENVIRONMENT_UPDATES_TIMEOUT_SECS = 30


class ModelMethod(Enum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better name ModelMethodKind.

descriptor, inputs, request, self._model.completions
)

return await self._trace_and_process_model_fn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we wrap self._model.completions into a nested function def exec_fn which calls _execute_async_model_fn and then we allso call _trace_and_process_model_fn on that.

I haven't fully wrapped my head around all nuances here, but is this really necessary? Is there are more concise way? What would the stack traces look like if an exception is raised in self._model.completions - do we need to update any logging/stack filtering logic for that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mentioned in #1364 (comment), but reduced one layer of wrapping and I think you're right it's cleaner!

async def _execute_request(
self,
model_name: str,
method: ModelMethod,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: update naming to be exact.

self, model_name: str, request: Request, body_raw: bytes = Depends(parse_body)
async def _execute_request(
self,
model_name: str,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider nuking this whole model_name parameterization, it seems not needed for anything anymore. But ask in truss/core-product channel if someone knows a reason to keep it.

async def chat_completions(
self, request: Request, body_raw: bytes = Depends(parse_body)
) -> Response:
async def execution_fn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have all this method wrapping in ModelWrapper (see comment there) - I'm confused why there is even more here now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good question, I was also trying to trace the different responsibilities. TrussServer has lots of common logic to deal with things at the HTTP layer - parse request body, start root span, execute model code, and then generate HTTP response. Everything except for (3) is the same across predict/completions/chat completions, which is why I introduced the thin wrapping layer for execution_fn.

Very similarly, ModelWrapper has lots of similarities for how it executes things - starts spans, decides how to run code in threads, and then deals with return values / generators.

Inside model wrapper we can reduce one layer of wrapping by passing in more parameters, let me know if you like that more.

@nnarayen nnarayen force-pushed the nikhil/introduce-openai-methods branch 3 times, most recently from d70b857 to 719d760 Compare February 4, 2025 21:09
truss/templates/server/model_wrapper.py Show resolved Hide resolved
Comment on lines 584 to 588
descriptor,
inputs,
request,
self._model.preprocess,
supports_generators=False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that you newly introduced this generic _execute_async_model_fn it would make sense to evolve MethodDescriptor accordingly: (descriptor, preprocess supports_generators) all conceptually refer to the same thing.

This function signature here is overly long, the static properties of the model and the dynamic inputs (inputs, request) are interleaved without an order and the most important: there is no builtin coherence for the static properties.

My suggestion is to update MethodDescriptor so that it works harmonically with _execute_async_model_fn with a concise API, by bundling (descriptor, preprocess supports_generators) into it.

And you could also move the (repetitive) assertion into _execute_async_model_fn.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a pretty general pattern for better interfaces and higher cohesion, try to apply this wherever possible :)

Comment on lines +720 to +728
if inspect.isgenerator(result) or inspect.isasyncgen(result):
if request.headers.get("accept") == "application/json":
return await _gather_generator(result)
else:
return await self._stream_with_background_task(
result,
fn_span,
detached_ctx,
# No semaphores needed for non-predict model functions.
release_and_end=lambda: None,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we ok with _stream_with_background_task being entered for any model_fn?
In particular, what if all pre- predict and post-process are generators (or is this not allowed anyway)?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's not clear that _trace_and_process_model_fn is only intended for some methods - what does "model_fn" mean? Can you constrain it with assertions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All really good questions, I'l leave extra comments in the code but will try to clarify here as well. With this PR, I propose 2 ways of invoking code in model wrapper: (1) __call__ - extremely specific to predict orchestration, has lots of validation on specific combinations (2) _trace_and_process_model_fn - ideally intended for any other standalone model function that doesn't need orchestration.

Therefore, I think _trace_and_process_model_fn should be as generic as possible for now, and support the cross product of (async, sync) x (generator, non-generator). For our use case now, both completions / chat completions should be allowed to be any of the above. If we find a use case where one of those combinations is invalid, we should add more metadata to MethodDescriptor and check that here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok understoo. Maybe add a one-liner docstring for the intended usage of _trace_and_process_model_fn (or _process_model_fn if you change the name).

In that case we could maybe also rename __call__ to a more sepcific name.

request: starlette.requests.Request,
method_name: MethodName,
descriptor: MethodDescriptor,
model_fn: Any,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you create a suitable type alias for the callable?

Comment on lines 21 to 26
MODEL_BASENAME,
InputType,
MethodDescriptor,
MethodName,
ModelWrapper,
OutputType,
Copy link
Contributor

@marius-baseten marius-baseten Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to above: the fact that you need to depend on these "internal" of model wrapper means that the abstraction is not really good (yet) - can be improved.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed on MODEL_BASENAME, but still think it's worthwhile to rip that dependency out in a different PR. MethodDescriptor already has a couple dependencies in truss server, but now it's more explicit since we need it for type hints.

Overall I agree w the sentiment here so let's see what we can do!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a symbol is only used for type checking, you can make the import conditional.

Comment on lines 234 to 240
async def execution_fn(
model: ModelWrapper, inputs: InputType, request: Request
) -> OutputType:
self._raise_if_not_supported(
MethodName.COMPLETIONS, model.model_descriptor.completions
)
return await model.completions(inputs, request)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you move that extra code into model.completions instead of wrapping a function?

Besides creating convoluted code and stack traces, this is also confusing, because you shadow inputs and request and it might look like you capture them in the wrapped function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could, but it made sense to me to have truss_server deal with the HTTP layer and throw the 404 in the helper. It likely doesn't make sense for ModelWrapper code to throw an opinionated status code.

I'd vote we either throw a different exception that gets translated into a status code via truss server, or find a different way to have shared code here.

Copy link
Contributor

@marius-baseten marius-baseten Feb 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand your desire to separate the HTTP layer. Would throwing a non-HTTP exception and then adjusting the exception handler in errors.py work for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After the most recent refactor I feel less strongly here, we just have a couple lines that are slightly duplicated. In the future we can consider pushing some of the error handling to model wrapper, and then doing the status code translation as you suggested

@@ -68,6 +68,7 @@ class MethodName(str, enum.Enum):
"starlette.responses.Response",
pydantic.BaseModel,
]
ModelFn = Callable[..., Union[OutputType, Awaitable[OutputType]]]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From this comment, we can introduce a slight type constraint, but I'm not sure how much benefit it brings:

  • Seems non trivial to have discriminated union functionality to tie the ArgConfig and the underlying Callable together (i.e. if ArgConfig.NONE, then the type system knows it's a Callable[[], OutputType])
  • Similar to above, tricky to have the same discrimination on async vs sync

To avoid all that, I made it have variable arguments for now, given that ArgConfig will do validation on parse. We have to do a couple explicit casts to make the type system happy as a result though.

@nnarayen nnarayen force-pushed the nikhil/introduce-openai-methods branch from abc11c1 to f5a22da Compare February 5, 2025 22:10
)

@classmethod
def _is_async(cls, method: Any):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed that await on a function that returns an AsyncGenerator is a type error, so I wonder if we should actually remove the inspect.isasyncgenfunction(method) clause from is_async.

Technically doesn't matter, since we agreed to explicitly check for generators first which will avoid the await. However, it's likely confusing for future readers why we bucketed Coroutine and AsyncGenerators together.

Copy link
Contributor

@marius-baseten marius-baseten Feb 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That reasoning make sense to me. But it might need some comment why we do not check isasyncgenfunction.



@pytest.mark.integration
def test_postprocess_async_generator_streaming():
Copy link
Contributor Author

@nnarayen nnarayen Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels very unlikely that a real user would ever be interested in either of the below test cases, but wanted to show that we technically support it

@nnarayen nnarayen force-pushed the nikhil/introduce-openai-methods branch 3 times, most recently from a220749 to 66e9da5 Compare February 6, 2025 14:53
Comment on lines +734 to +739
descriptor = self.model_descriptor.completions
assert descriptor, (
f"`{MethodName.COMPLETIONS}` must only be called if model has it."
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are 5 lines to be repeated for each endpoint. You could move the assertion for descriptor not none into _process_model_fn

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm considering a future refactor where the descriptor is always present with the name of the function, but has a different indicator that the underlying function wasn't implemented by the user. As written now, I'd have to explicitly pass the MethodName as an additional argument purely for the error message, which didn't seem worthwhile. I'll explore this in a followup!

fn_span = self._tracer.start_span(f"call-{method.value}")
with tracing.section_as_event(
fn_span, method.value
), tracing.detach_context() as detached_ctx:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you don't do this in this PR, could you tag all ocurrences with a TODO note, please?

Comment on lines 201 to 205
model = self._safe_lookup_model(MODEL_BASENAME)
self._raise_if_not_supported(
MethodName.CHAT_COMPLETIONS, model.model_descriptor.chat_completions
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all of this could be parameterized and moved into _execute_request.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to pass a constant as an argument self._safe_lookup_model(MODEL_BASENAME) here?

This could be a property self._safe_model

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're likely right, but as of now would result in another parameter being passed to _execute_request. Let me explore adding more metadata to method_descriptor in a followup (as well as some of the other smaller improvements mentioned on this PR).

@nnarayen nnarayen force-pushed the nikhil/introduce-openai-methods branch from 97bab5e to a12f649 Compare February 6, 2025 19:02
@nnarayen nnarayen merged commit 87c83d9 into main Feb 6, 2025
5 checks passed
@nnarayen nnarayen deleted the nikhil/introduce-openai-methods branch February 6, 2025 19:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants