-
Notifications
You must be signed in to change notification settings - Fork 188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Inference streaming support #1750
Inference streaming support #1750
Conversation
32c3d7a
to
c2cf03c
Compare
fb3eed6
to
3637aef
Compare
a519933
to
63af613
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general it looks great, I left some comments though and I will look at tests next.
@@ -0,0 +1,45 @@ | |||
import asyncio |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we add an example infer.py
that uses this model? it is part of the PR description but probably better to have it here as well. Happy for it to be part of a follow up docs and examples PR.
mlserver/batching/hooks.py
Outdated
payload: AsyncIterator[InferenceRequest], | ||
) -> AsyncIterator[InferenceResponse]: | ||
model = _get_model(f) | ||
logger.warning( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this going to be logged on mlserver for every request? I think this might pollute the logs? I guess if the user doesnt set adaptive batching then this code path will not be hit anyway?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved it outside.
mlserver/handlers/dataplane.py
Outdated
break | ||
|
||
payload = self._prepare_payload(payload, model) | ||
payloads_decorated = self._payloads_decorator(payload, payloads, model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this really a decorator logic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I renamed it.
payload = self._prepare_payload(payload, model) | ||
payloads_decorated = self._payloads_decorator(payload, payloads, model) | ||
|
||
async for prediction in model.predict_stream(payloads_decorated): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens if one element in the stream fails? do we still keep going or should we break?
""" | ||
async for inference_response in infer_stream: | ||
# TODO: How should we send headers back? | ||
# response_headers = extract_headers(inference_response) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what are the kind of headers we usually send back in the response of infer
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See link here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comments on testing. I think we should add cases for:
- errors on infer_stream
- input streaming
@pytest.mark.parametrize( | ||
"sum_model_settings", [lazy_fixture("text_stream_model_settings")] | ||
) | ||
@pytest.mark.parametrize("sum_model", [lazy_fixture("text_stream_model")]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is sum_model
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
expected = pb.InferTensorContents(int64_contents=[6]) | ||
|
||
assert len(prediction.outputs) == 1 | ||
assert prediction.outputs[0].contents == expected | ||
|
||
|
||
@pytest.mark.parametrize("settings", [lazy_fixture("settings_stream")]) | ||
@pytest.mark.parametrize( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is sum_model_settings
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a fixture (see here)
@pytest.mark.parametrize( | ||
"model_name,model_version", [("text-model", "v1.2.3"), ("text-model", None)] | ||
) | ||
async def test_generate( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can generate
test be an extra parametrized item in infer
test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a bit tricky to parameterise the model loading through lazy_fixtures due to recursive dependency involving fixture. I will leave it like this since I don't want to refactor the tests.
@@ -147,15 +207,19 @@ async def test_infer_headers( | |||
) | |||
|
|||
|
|||
async def test_infer_error(rest_client, inference_request): | |||
async def test_infer_error( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we tests errors for the stream case as well?
yield generate_request | ||
|
||
|
||
async def test_predict_stream_fallback( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as explained earlier I am not sure if we should fallback to predict
or raise not implemented.
63af613
to
499e693
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm - great work! This should be followed by a docs PR to describe streaming and the current limitations more explicitly.
195bcde
to
0ec59c7
Compare
This PR includes streaming support for MLServer by allowing the user to implement in the runtime the
predict_stream
method which expects as input a async generator of request an outputs a async generator of response.While the input-output types for the
predict
remain the same, for thepredict_stream
the implementation can handle a stream of inputs and a stream of outputs. This design choice is quite general and can cover many input-output scenarios:predict
)predict_stream
)predict_stream
)predict_stream
)Although for REST, streamed input might not be a thing and currently not supported, for gRPC it is quite natural to have. In the case that a user will like to use streamed inputs, then they will have to use gRPC.
Exposed endpoints
We expose the following endpoints (+ the ones including the version) to the user:
/v2/models/{model_name}/infer
/v2/models/{model_name}/infer_stream
/v2/models/{model_name}/generate
/v2/models/{model_name}/generate_stream
The first two are general purpose endpoints while the later two are LLM specific (see open inference protocol here). Note that the
infer
andgenerate
endpoints will point to theinfer
implementation whileinfer_stream
andgenerate_stream
will point toinfer_stream
implementation defined above.Client calls
REST non-streaming
REST streaming
gRPC non-streaming
gRPC streaming
Limitations
"gzip_enabled": false
)"metrics_endpoint": null
)"parallel_workers": 0
)asyncio.exceptions.CancelledError type
.CancelledError
inherits fromBaseException
and thestarlette
middleware for error handling checks for the typeException
.