Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference streaming support #1750

Conversation

RobertSamoilescu
Copy link
Contributor

@RobertSamoilescu RobertSamoilescu commented May 9, 2024

This PR includes streaming support for MLServer by allowing the user to implement in the runtime the predict_stream method which expects as input a async generator of request an outputs a async generator of response.

class MyModel(MLModel):

    async def predict(self, payload: InferenceRequest) -> InferenceResponse:
	    pass

    async def predict_stream(
        self, payloads: AsyncIterator[InferenceRequest]
    ) -> AsyncIterator[InferenceResponse]:
	    pass

While the input-output types for the predict remain the same, for the predict_stream the implementation can handle a stream of inputs and a stream of outputs. This design choice is quite general and can cover many input-output scenarios:

  • unary input - unary output (handled by predict)
  • unary input - stream output (handled by predict_stream)
  • stream input - unary output (handled by predict_stream)
  • stream input - stream output (handled by predict_stream )

Although for REST, streamed input might not be a thing and currently not supported, for gRPC it is quite natural to have. In the case that a user will like to use streamed inputs, then they will have to use gRPC.

Exposed endpoints

We expose the following endpoints (+ the ones including the version) to the user:

  • /v2/models/{model_name}/infer
  • /v2/models/{model_name}/infer_stream
  • /v2/models/{model_name}/generate
  • /v2/models/{model_name}/generate_stream

The first two are general purpose endpoints while the later two are LLM specific (see open inference protocol here). Note that the infer and generate endpoints will point to the infer implementation while infer_stream and generate_stream will point to infer_stream implementation defined above.

Client calls

REST non-streaming

import os
import requests
from mlserver import types
from mlserver.codecs import StringCodec

TESTDATA_PATH = "../tests/testdata/"
payload_path = os.path.join(TESTDATA_PATH, "generate-request.json")
inference_request = types.InferenceRequest.parse_file(payload_path)


api_url = "http://localhost:8080/v2/models/text-model/generate"
response = requests.post(api_url, json=inference_request.dict())
response = types.InferenceResponse.parse_raw(response.text)
print(StringCodec.decode_output(response.outputs[0]))

REST streaming

import os
import httpx
from httpx_sse import connect_sse
from mlserver import types
from mlserver.codecs import StringCodec

TESTDATA_PATH = "../tests/testdata/"
payload_path = os.path.join(TESTDATA_PATH, "generate-request.json")
inference_request = types.InferenceRequest.parse_file(payload_path)

with httpx.Client() as client:
    with connect_sse(client, "POST", "http://localhost:8080/v2/models/text-model/generate_stream", json=inference_request.dict()) as event_source:
        for sse in event_source.iter_sse():
            response = types.InferenceResponse.parse_raw(sse.data)
            print(StringCodec.decode_output(response.outputs[0]))

gRPC non-streaming

import os
import grpc
import mlserver.grpc.converters as converters
import mlserver.grpc.dataplane_pb2_grpc as dataplane
import mlserver.types as types
from mlserver.codecs import StringCodec
from mlserver.grpc.converters import ModelInferResponseConverter

TESTDATA_PATH = "../tests/testdata/"
payload_path = os.path.join(TESTDATA_PATH, "generate-request.json")
inference_request = types.InferenceRequest.parse_file(payload_path)

# need to convert from string to bytes for grpc
inference_request.inputs[0] = StringCodec.encode_input("prompt", inference_request.inputs[0].data.__root__)
inference_request_g = converters.ModelInferRequestConverter.from_types(
    inference_request, model_name="text-model", model_version=None
)
grpc_channel = grpc.insecure_channel("localhost:8081")
grpc_stub = dataplane.GRPCInferenceServiceStub(grpc_channel)
response = grpc_stub.ModelInfer(inference_request_g)

response = ModelInferResponseConverter.to_types(response)
print(StringCodec.decode_output(response.outputs[0]))

gRPC streaming

import os
import grpc
import mlserver.grpc.converters as converters
import mlserver.grpc.dataplane_pb2_grpc as dataplane
import mlserver.types as types
from mlserver.codecs import StringCodec
from mlserver.grpc.converters import ModelInferResponseConverter


TESTDATA_PATH = "../tests/testdata/"
payload_path = os.path.join(TESTDATA_PATH, "generate-request.json")
inference_request = types.InferenceRequest.parse_file(payload_path)

# need to convert from string to bytes for grpc
inference_request.inputs[0] = StringCodec.encode_input("prompt", inference_request.inputs[0].data.__root__)
inference_request_g = converters.ModelInferRequestConverter.from_types(
    inference_request, model_name="text-model", model_version=None
)

async def get_inference_request_stream(inference_request):
    yield inference_request

async with grpc.aio.insecure_channel("localhost:8081") as grpc_channel:
    grpc_stub = dataplane.GRPCInferenceServiceStub(grpc_channel)
    inference_request_stream = get_inference_request_stream(inference_request_g)
    
    async for response in grpc_stub.ModelStreamInfer(inference_request_stream):
        response = ModelInferResponseConverter.to_types(response)
        print(StringCodec.decode_output(response.outputs[0]))

Limitations

  • GZipMiddleware must be disabled since it is not compatible with starlette streaming ("gzip_enabled": false)
  • GRPC metrics endpoints must be disabled - further investigation in a following PR ("metrics_endpoint": null)
  • Parallel workers are not supported ("parallel_workers": 0)
  • Error handling for REST is not supported - this is because when the error raised is of asyncio.exceptions.CancelledError type. CancelledError inherits from BaseException and the starlette middleware for error handling checks for the type Exception.

@RobertSamoilescu RobertSamoilescu force-pushed the feature/inference-streaming-poc branch from 32c3d7a to c2cf03c Compare May 10, 2024 09:28
@RobertSamoilescu RobertSamoilescu requested review from sakoush and lc525 May 10, 2024 09:58
@RobertSamoilescu RobertSamoilescu force-pushed the feature/inference-streaming-poc branch from fb3eed6 to 3637aef Compare May 10, 2024 10:32
@RobertSamoilescu RobertSamoilescu force-pushed the feature/inference-streaming-poc branch from a519933 to 63af613 Compare May 15, 2024 13:00
Copy link
Member

@sakoush sakoush left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general it looks great, I left some comments though and I will look at tests next.

@@ -0,0 +1,45 @@
import asyncio
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add an example infer.py that uses this model? it is part of the PR description but probably better to have it here as well. Happy for it to be part of a follow up docs and examples PR.

mlserver/batching/hooks.py Outdated Show resolved Hide resolved
payload: AsyncIterator[InferenceRequest],
) -> AsyncIterator[InferenceResponse]:
model = _get_model(f)
logger.warning(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this going to be logged on mlserver for every request? I think this might pollute the logs? I guess if the user doesnt set adaptive batching then this code path will not be hit anyway?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved it outside.

mlserver/grpc/dataplane_pb2.pyi Outdated Show resolved Hide resolved
break

payload = self._prepare_payload(payload, model)
payloads_decorated = self._payloads_decorator(payload, payloads, model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this really a decorator logic?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed it.

payload = self._prepare_payload(payload, model)
payloads_decorated = self._payloads_decorator(payload, payloads, model)

async for prediction in model.predict_stream(payloads_decorated):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if one element in the stream fails? do we still keep going or should we break?

mlserver/model.py Show resolved Hide resolved
"""
async for inference_response in infer_stream:
# TODO: How should we send headers back?
# response_headers = extract_headers(inference_response)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are the kind of headers we usually send back in the response of infer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See link here

proto/dataplane.proto Outdated Show resolved Hide resolved
Copy link
Member

@sakoush sakoush left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comments on testing. I think we should add cases for:

  • errors on infer_stream
  • input streaming

tests/fixtures.py Show resolved Hide resolved
@pytest.mark.parametrize(
"sum_model_settings", [lazy_fixture("text_stream_model_settings")]
)
@pytest.mark.parametrize("sum_model", [lazy_fixture("text_stream_model")])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is sum_model?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a fixture (see here). Also the definition is here.

expected = pb.InferTensorContents(int64_contents=[6])

assert len(prediction.outputs) == 1
assert prediction.outputs[0].contents == expected


@pytest.mark.parametrize("settings", [lazy_fixture("settings_stream")])
@pytest.mark.parametrize(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is sum_model_settings?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a fixture (see here)

tests/rest/test_endpoints.py Outdated Show resolved Hide resolved
tests/rest/test_endpoints.py Outdated Show resolved Hide resolved
tests/rest/test_endpoints.py Show resolved Hide resolved
@pytest.mark.parametrize(
"model_name,model_version", [("text-model", "v1.2.3"), ("text-model", None)]
)
async def test_generate(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can generate test be an extra parametrized item in infer test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a bit tricky to parameterise the model loading through lazy_fixtures due to recursive dependency involving fixture. I will leave it like this since I don't want to refactor the tests.

@@ -147,15 +207,19 @@ async def test_infer_headers(
)


async def test_infer_error(rest_client, inference_request):
async def test_infer_error(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we tests errors for the stream case as well?

yield generate_request


async def test_predict_stream_fallback(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as explained earlier I am not sure if we should fallback to predict or raise not implemented.

@RobertSamoilescu RobertSamoilescu force-pushed the feature/inference-streaming-poc branch from 63af613 to 499e693 Compare May 20, 2024 13:53
Copy link
Member

@sakoush sakoush left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm - great work! This should be followed by a docs PR to describe streaming and the current limitations more explicitly.

@RobertSamoilescu RobertSamoilescu force-pushed the feature/inference-streaming-poc branch from 195bcde to 0ec59c7 Compare May 22, 2024 09:02
@RobertSamoilescu RobertSamoilescu merged commit 54cd47e into SeldonIO:master May 22, 2024
25 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants