From 30261fef4b4e65c0b2e746777985a56dd240e181 Mon Sep 17 00:00:00 2001 From: Keming Date: Thu, 3 Aug 2023 20:11:40 +0800 Subject: [PATCH] feat: support multi-route with shared workers (#423) * fix dockerfile path Signed-off-by: Keming * refactor py args to rs Signed-off-by: Keming * finish the basic multiroute rust part Signed-off-by: Keming * fix the openapi Signed-off-by: Keming * add protocol state Signed-off-by: Keming * add runtime register Signed-off-by: Keming * Apply suggestions from code review Co-authored-by: zclzc <38581401+lkevinzc@users.noreply.github.com> Signed-off-by: Keming * fix a deadlock Signed-off-by: Keming * add test Signed-off-by: Keming * add doc Signed-off-by: Keming * bump version Signed-off-by: Keming * combine ingress & egress to state enum Signed-off-by: Keming --------- Signed-off-by: Keming Co-authored-by: zclzc <38581401+lkevinzc@users.noreply.github.com> --- .github/workflows/check.yml | 2 +- Cargo.lock | 59 ++---- Cargo.toml | 15 +- Dockerfile | 9 +- Makefile | 3 + README.md | 5 +- docs/source/examples/index.md | 1 + docs/source/examples/multi_route.md | 33 ++++ docs/source/reference/concept.md | 4 +- docs/source/reference/interface.md | 20 +-- examples/multi_route/client.py | 40 +++++ examples/multi_route/server.py | 77 ++++++++ examples/server_side_event/client.py | 4 +- mosec/__init__.py | 2 + mosec/args.py | 13 +- mosec/coordinator.py | 148 ++++++--------- mosec/dry_run.py | 117 ++++++------ mosec/ipc.py | 53 ------ mosec/plugins/__init__.py | 19 -- mosec/plugins/plasma_shm.py | 86 --------- mosec/protocol.py | 24 ++- mosec/runtime.py | 155 ++++++---------- mosec/server.py | 200 ++++++++++++--------- mosec/worker.py | 1 + src/apidoc.rs | 65 +++---- src/args.rs | 72 -------- src/config.rs | 98 ++++++++++ src/coordinator.rs | 121 ------------- src/main.rs | 82 +++++---- src/protocol.rs | 83 +++++---- src/routes.rs | 28 +-- src/tasks.rs | 248 +++++++++++++++++++------- tests/services/__init__.py | 13 ++ tests/services/multi_route_service.py | 79 ++++++++ tests/services/openapi_service.py | 8 +- tests/services/square_service.py | 2 +- tests/test_coordinator.py | 39 ++-- tests/test_protocol.py | 16 +- tests/test_service.py | 39 +++- tests/utils.py | 24 ++- 40 files changed, 1110 insertions(+), 997 deletions(-) create mode 100644 docs/source/examples/multi_route.md create mode 100644 examples/multi_route/client.py create mode 100644 examples/multi_route/server.py delete mode 100644 mosec/ipc.py delete mode 100644 mosec/plugins/__init__.py delete mode 100644 mosec/plugins/plasma_shm.py delete mode 100644 src/args.rs create mode 100644 src/config.rs delete mode 100644 src/coordinator.rs create mode 100644 tests/services/multi_route_service.py diff --git a/.github/workflows/check.yml b/.github/workflows/check.yml index e80f86e4..6b71d621 100644 --- a/.github/workflows/check.yml +++ b/.github/workflows/check.yml @@ -79,7 +79,7 @@ jobs: - name: Test run: | make semantic_lint test - - name: Test pyarrow in Linux + - name: Test shm in Linux if: ${{ startsWith(matrix.os, 'ubuntu') }} run: | sudo apt update && sudo apt install redis diff --git a/Cargo.lock b/Cargo.lock index 6b8361f0..1125bbd1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,39 +26,11 @@ dependencies = [ "memchr", ] -[[package]] -name = "argh" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab257697eb9496bf75526f0217b5ed64636a9cfafa78b8365c71bd283fcef93e" -dependencies = [ - "argh_derive", - "argh_shared", -] - -[[package]] -name = "argh_derive" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b382dbd3288e053331f03399e1db106c9fb0d8562ad62cb04859ae926f324fa6" -dependencies = [ - "argh_shared", - "proc-macro2", - "quote", - "syn 1.0.109", -] - -[[package]] -name = "argh_shared" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64cb94155d965e3d37ffbbe7cc5b82c3dd79dd33bd48e536f73d2cfb8d85506f" - [[package]] name = "async-channel" -version = "1.8.0" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf46fee83e5ccffc220104713af3292ff9bc7c64c7de289f66dae8e38d826833" +checksum = "81953c529336010edd6d8e358f886d9581267795c61b19475b71314bffa46d35" dependencies = [ "concurrent-queue", "event-listener", @@ -308,6 +280,12 @@ version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "65d09067bfacaa79114679b279d7f5885b53295b1e2cfb4e79c8e4bd3d633169" +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + [[package]] name = "event-listener" version = "2.5.3" @@ -401,9 +379,9 @@ checksum = "b6c80984affa11d98d1b88b66ac8853f143217b399d3c74116778ff8fdb4ed2e" [[package]] name = "hashbrown" -version = "0.12.3" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" +checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" [[package]] name = "hermit-abi" @@ -470,11 +448,11 @@ dependencies = [ [[package]] name = "indexmap" -version = "1.9.3" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" +checksum = "d5477fe2230a79769d8dc68e0eabf5437907c0457a5614a9e8dddb67f65eb65d" dependencies = [ - "autocfg", + "equivalent", "hashbrown", "serde", ] @@ -563,9 +541,8 @@ dependencies = [ [[package]] name = "mosec" -version = "0.7.2" +version = "0.8.0" dependencies = [ - "argh", "async-channel", "async-stream", "axum", @@ -1241,9 +1218,9 @@ checksum = "b15811caf2415fb889178633e7724bad2509101cde276048e013b9def5e51fa0" [[package]] name = "utoipa" -version = "3.3.0" +version = "3.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68ae74ef183fae36d650f063ae7bde1cacbe1cd7e72b617cbe1e985551878b98" +checksum = "8c624186f22e625eb8faa777cb33d34cd595aa16d1742aa1d8b6cf35d3e4dda9" dependencies = [ "indexmap", "serde", @@ -1253,9 +1230,9 @@ dependencies = [ [[package]] name = "utoipa-gen" -version = "3.3.0" +version = "3.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ea8ac818da7e746a63285594cce8a96f5e00ee31994e655bd827569cb8b137b" +checksum = "b9ce5f21ca77e010f5283fa791c6ab892c68b3668a1bdc6b7ac6cf978f5d5b30" dependencies = [ "proc-macro-error", "proc-macro2", diff --git a/Cargo.toml b/Cargo.toml index 1894e13c..fcfa59f7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mosec" -version = "0.7.2" +version = "0.8.0" authors = ["Keming ", "Zichen "] edition = "2021" license = "Apache-2.0" @@ -19,13 +19,12 @@ tracing-subscriber = { version = "0.3", features = ["local-time", "json"] } tokio = { version = "1", features = ["rt", "rt-multi-thread", "time", "macros", "sync", "signal", "io-util"] } derive_more = { version = "0.99", features = ["display", "error", "from"] } # MPMS that only one consumer sees each message & async -async-channel = { version = "1" } +async-channel = "1.9" once_cell = "1.18" -prometheus-client = "0.21.1" -argh = "0.1" -axum = "0.6.18" +prometheus-client = "0.21" +axum = "0.6" async-stream = "0.3.5" -utoipa = "3.3.0" -serde_json = "1.0.96" -serde = "1.0.163" +utoipa = "3.4" +serde_json = "1.0" +serde = "1.0" utoipa-swagger-ui = { version = "3", features = ["axum"] } diff --git a/Dockerfile b/Dockerfile index fb7af4b5..901a7638 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,9 +3,7 @@ ARG base=nvidia/cuda:11.6.2-cudnn8-runtime-ubuntu20.04 FROM ${base} ENV DEBIAN_FRONTEND=noninteractive LANG=en_US.UTF-8 LC_ALL=en_US.UTF-8 - -ARG MOSEC_PORT=8000 -ENV MOSEC_PORT=${MOSEC_PORT} +ENV PATH /opt/conda/bin:$PATH ARG CONDA_VERSION=py310_23.3.1-0 @@ -13,7 +11,6 @@ RUN apt update && \ apt install -y --no-install-recommends \ wget \ git \ - build-essential \ ca-certificates && \ rm -rf /var/lib/apt/lists/* @@ -45,9 +42,7 @@ RUN set -x && \ find /opt/conda/ -follow -type f -name '*.js.map' -delete && \ /opt/conda/bin/conda clean -afy -RUN /opt/conda/bin/conda create -n mosec python=3.10 - -ENV PYTHON_PREFIX=/opt/conda/envs/mosec/bin +ENV PYTHON_PREFIX=/opt/conda/bin RUN update-alternatives --install /usr/bin/python python ${PYTHON_PREFIX}/python 1 && \ update-alternatives --install /usr/bin/python3 python3 ${PYTHON_PREFIX}/python3 1 && \ diff --git a/Makefile b/Makefile index 8d0a8489..59ab1de4 100644 --- a/Makefile +++ b/Makefile @@ -32,6 +32,9 @@ test_all: dev pytest tests -vv -s RUST_BACKTRACE=1 cargo test -vv +test_chaos: dev + @python -m tests.bad_req + doc: @cd docs && make html && cd ../ @python -m http.server -d docs/build/html 7291 -b 127.0.0.1 diff --git a/README.md b/README.md index 1b582b89..02babe5a 100644 --- a/README.md +++ b/README.md @@ -152,6 +152,8 @@ Then let's start the server with debug logs: python examples/stable_diffusion/server.py --log-level debug --timeout 30000 ``` +Open `http://127.0.0.1:8000/openapi/swagger/` in your browser to get the OpenAPI doc. + And in another terminal, test it: ```shell @@ -173,8 +175,9 @@ That's it! You have just hosted your **_stable-diffusion model_** as a service! More ready-to-use examples can be found in the [Example](https://mosecorg.github.io/mosec/examples/index.html) section. It includes: -- [Multi-stage workflow demo](https://mosecorg.github.io/mosec/examples/echo.html): a simple echo demo even without any ML model. +- [Pipeline](https://mosecorg.github.io/mosec/examples/echo.html): a simple echo demo even without any ML model. - [Request validation](https://mosecorg.github.io/mosec/examples/validate.html): validate the request with type annotation. +- [Multiple route](https://mosecorg.github.io/mosec/examples/multi_route.html): serve multiple models in one service - [Shared memory IPC](https://mosecorg.github.io/mosec/examples/ipc.html): inter-process communication with shared memory. - [Customized GPU allocation](https://mosecorg.github.io/mosec/examples/env.html): deploy multiple replicas, each using different GPUs. - [Customized metrics](https://mosecorg.github.io/mosec/examples/metric.html): record your own metrics for monitoring. diff --git a/docs/source/examples/index.md b/docs/source/examples/index.md index 828a6814..b2e18475 100644 --- a/docs/source/examples/index.md +++ b/docs/source/examples/index.md @@ -13,6 +13,7 @@ metric pytorch stable_diffusion validate +multi_route ``` We provide examples across different ML frameworks and for various tasks in this section. diff --git a/docs/source/examples/multi_route.md b/docs/source/examples/multi_route.md new file mode 100644 index 00000000..11e17efc --- /dev/null +++ b/docs/source/examples/multi_route.md @@ -0,0 +1,33 @@ +# Multi-Route + +This example shows how to use the multi-route feature. + +You will need this feature if you want to: + +- Serve multiple models in one service on different endpoints. + - i.e. register `/embedding` & `/classify` with different models +- Serve one model to multiple different endpoints in one service. + - i.e. register LLaMA with `/inference` and `/v1/chat/completions` to make it compatible with the OpenAI API +- Share a worker in different routes + - The shared worker will collect the dynamic batch from multiple previous stages. + - If you want to have multiple runtimes with sharing, you can declare multiple runtime instances with the same worker class. + +The worker definition part is the same as for a single route. The only difference is how you register the worker with the server. + +Here we expose a new [concept](../reference/concept.md) called [`Runtime`](mosec.runtime.Runtime). + +You can create the `Runtime` and register on the server with a `{endpoint: [Runtime]}` dictionary. + +See the complete demo code below. + +## Server + +```{include} ../../../examples/multi_route/server.py +:code: python +``` + +## Client + +```{include} ../../../examples/multi_route/client.py +:code: python +``` diff --git a/docs/source/reference/concept.md b/docs/source/reference/concept.md index 6c461437..cb622902 100644 --- a/docs/source/reference/concept.md +++ b/docs/source/reference/concept.md @@ -4,13 +4,15 @@ There are a few terms used in `mosec`. - `worker`: a Python process that executes the `forward` method (inherit from [`mosec.Worker`](mosec.worker.Worker)) - `stage`: one processing unit in the pipeline, each stage contains several `worker` replicas + - also known as [`Runtime`](mosec.runtime.Runtime) in the code - each stage retrieves the data from the previous stage and passes the result to the next stage - retrieved data will be deserialized by the [`Worker.deserialize_ipc`](mosec.worker.Worker.deserialize_ipc) method - data to be passed will be serialized by the [`Worker.serialize_ipc`](mosec.worker.Worker.serialize_ipc) method - `ingress/egress`: the first/last stage in the pipeline - ingress gets data from the client, while egress sends data to the client - data will be deserialized by the ingress [`Worker.serialize`](mosec.worker.Worker.serialize) method and serialized by the egress [`Worker.deserialize`](mosec.worker.Worker.deserialize) method -- `pipeline`: a chain of processing stages +- `pipeline`: a chain of processing stages, will be registered to an endpoint (default: `/inference`) + - a server can have multiple pipelines, check the [multi-route](../examples/multi_route.md) example - `dynamic batching`: batch requests until either the max batch size or the max wait time is reached - `controller`: a Rust tokio thread that works on: - read from the previous queue to get new tasks diff --git a/docs/source/reference/interface.md b/docs/source/reference/interface.md index d1535452..bc797cfa 100644 --- a/docs/source/reference/interface.md +++ b/docs/source/reference/interface.md @@ -14,31 +14,25 @@ :members: ``` -## Errors +## Runtime ```{eval-rst} -.. automodule:: mosec.errors - :members: - :show-inheritance: +.. automodule:: mosec.runtime + :members: Runtime ``` -## Mixins +## Errors ```{eval-rst} -.. automodule:: mosec.mixin +.. automodule:: mosec.errors :members: :show-inheritance: ``` -## Plugins - -```{eval-rst} -.. automodule:: mosec.ipc - :members: -``` +## Mixins ```{eval-rst} -.. automodule:: mosec.plugins +.. automodule:: mosec.mixin :members: :show-inheritance: ``` diff --git a/examples/multi_route/client.py b/examples/multi_route/client.py new file mode 100644 index 00000000..261d3a7c --- /dev/null +++ b/examples/multi_route/client.py @@ -0,0 +1,40 @@ +# Copyright 2023 MOSEC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from http import HTTPStatus + +import httpx +import msgpack # type: ignore + +typed_req = { + "bin": b"hello mosec with type check", + "name": "type check", +} + +print(">> requesting for the typed route with msgpack serde") +resp = httpx.post( + "http://127.0.0.1:8000/v1/inference", content=msgpack.packb(typed_req) +) +if resp.status_code == HTTPStatus.OK: + print(f"OK: {msgpack.unpackb(resp.content)}") +else: + print(f"err[{resp.status_code}] {resp.text}") + +print(">> requesting for the untyped route with json serde") +resp = httpx.post("http://127.0.0.1:8000/inference", content=b"hello mosec") +if resp.status_code == HTTPStatus.OK: + print(f"OK: {json.loads(resp.content)}") +else: + print(f"err[{resp.status_code}] {resp.text}") diff --git a/examples/multi_route/server.py b/examples/multi_route/server.py new file mode 100644 index 00000000..e91445b5 --- /dev/null +++ b/examples/multi_route/server.py @@ -0,0 +1,77 @@ +# Copyright 2023 MOSEC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +from msgspec import Struct + +from mosec import Runtime, Server, Worker +from mosec.mixin import TypedMsgPackMixin + + +class Request(Struct): + """User request struct.""" + + # pylint: disable=too-few-public-methods + + bin: bytes + name: str = "test" + + +class TypedPreprocess(TypedMsgPackMixin, Worker): + """Dummy preprocess to exit early if the validation failed.""" + + def forward(self, data: Request) -> Any: + """Input will be parse as the `Request`.""" + print(f"received from {data.name} with {data.bin!r}") + return data.bin + + +class Preprocess(Worker): + """Dummy preprocess worker.""" + + def deserialize(self, data: bytes) -> Any: + return data + + def forward(self, data: Any) -> Any: + return data + + +class Inference(Worker): + """Dummy inference worker.""" + + def forward(self, data: Any) -> Any: + return [{"length": len(datum)} for datum in data] + + +class TypedPostprocess(TypedMsgPackMixin, Worker): + """Dummy postprocess with msgpack.""" + + def forward(self, data: Any) -> Any: + return data + + +if __name__ == "__main__": + server = Server() + typed_pre = Runtime(TypedPreprocess) + pre = Runtime(Preprocess) + inf = Runtime(Inference, max_batch_size=16) + typed_post = Runtime(TypedPostprocess) + server.register_runtime( + { + "/v1/inference": [typed_pre, inf, typed_post], + "/inference": [pre, inf], + } + ) + server.run() diff --git a/examples/server_side_event/client.py b/examples/server_side_event/client.py index ee7ce41b..0b2ab186 100644 --- a/examples/server_side_event/client.py +++ b/examples/server_side_event/client.py @@ -17,7 +17,7 @@ with httpx.Client() as client: with connect_sse( - client, "POST", "http://127.0.0.1:8000/sse_inference", json={"text": "mosec"} + client, "POST", "http://127.0.0.1:8000/inference", json={"text": "mosec"} ) as event_source: for sse in event_source.iter_sse(): print(f"Event({sse.event}): {sse.data}") @@ -25,7 +25,7 @@ # error handling with httpx.Client() as client: with connect_sse( - client, "POST", "http://127.0.0.1:8000/sse_inference", json={"error": "mosec"} + client, "POST", "http://127.0.0.1:8000/inference", json={"error": "mosec"} ) as event_source: for sse in event_source.iter_sse(): print(f"Event({sse.event}): {sse.data}") diff --git a/mosec/__init__.py b/mosec/__init__.py index cf91d78e..4ace9b9f 100644 --- a/mosec/__init__.py +++ b/mosec/__init__.py @@ -22,6 +22,7 @@ ValidationError, ) from mosec.log import get_logger +from mosec.runtime import Runtime from mosec.server import Server from mosec.worker import SSEWorker, Worker @@ -36,6 +37,7 @@ "Server", "Worker", "SSEWorker", + "Runtime", "ServerError", "ClientError", "ValidationError", diff --git a/mosec/args.py b/mosec/args.py index 35ca7c38..9a092afe 100644 --- a/mosec/args.py +++ b/mosec/args.py @@ -62,7 +62,10 @@ def build_arguments_parser() -> argparse.ArgumentParser: parser.add_argument( "--path", - help="Unix Domain Socket address for internal Inter-Process Communication", + help=( + "Unix Domain Socket address for internal Inter-Process Communication." + "If not set, a random path will be created under the temporary dir." + ), type=str, default=os.path.join( tempfile.gettempdir(), f"mosec_{random.randrange(2**32):x}" @@ -146,6 +149,14 @@ def parse_arguments() -> argparse.Namespace: DeprecationWarning, ) + if args.debug: + args.log_level = "debug" + warnings.warn( + "`--debug` is deprecated and will be removed in v1, please configure" + "`--log_level=debug`", + DeprecationWarning, + ) + if not is_port_available(args.address, args.port): raise RuntimeError( f"{args.address}:{args.port} is in use. " diff --git a/mosec/coordinator.py b/mosec/coordinator.py index 690be6f8..fbc3bbe9 100644 --- a/mosec/coordinator.py +++ b/mosec/coordinator.py @@ -14,6 +14,7 @@ """The Coordinator is used to control the data flow between `Worker` and `Server`.""" +import enum import os import queue import signal @@ -24,10 +25,9 @@ import traceback from contextlib import contextmanager from multiprocessing.synchronize import Event -from typing import Any, Callable, Optional, Sequence, Tuple, Type +from typing import Any, Optional, Sequence, Type from mosec.errors import MosecError, MosecTimeoutError -from mosec.ipc import IPCWrapper from mosec.log import get_internal_logger from mosec.protocol import HTTPStautsCode, Protocol from mosec.worker import SSEWorker, Worker @@ -38,8 +38,13 @@ CONN_MAX_RETRY = 10 CONN_CHECK_INTERVAL = 1 -STAGE_INGRESS = "ingress" -STAGE_EGRESS = "egress" + +class State(enum.IntFlag): + """Task state.""" + + INGRESS = 0b1 + EGRESS = 0b10 + PROTOCOL_TIMEOUT = 2.0 @@ -97,57 +102,42 @@ def __init__( self, worker: Type[Worker], max_batch_size: int, - stage: str, shutdown: Event, shutdown_notify: Event, socket_prefix: str, - stage_id: int, + stage_name: str, worker_id: int, - ipc_wrapper: Optional[Callable[..., IPCWrapper]], timeout: int, ): """Initialize the mosec coordinator. Args: - worker (Worker): subclass of `mosec.Worker` implemented by users. - max_batch_size (int): maximum batch size for this worker. - stage (str): identifier to distinguish the first and last stages. - shutdown (Event): `multiprocessing.synchronize.Event` object for shutdown + worker: subclass of `mosec.Worker` implemented by users. + max_batch_size: maximum batch size for this worker. + shutdown: `multiprocessing.synchronize.Event` object for shutdown IPC. - socket_prefix (str): prefix for the socket addresses. - stage_id (int): identification number for worker stages. - worker_id (int): identification number for worker processes at the same + socket_prefix: prefix for the socket addresses. + stage_name: identification name for this worker stage. + worker_id: identification number for worker processes at the same stage. - ipc_wrapper (IPCWrapper): IPC wrapper class to be initialized. - timeout (int): timeout for the `forward` function. - - Raises: - TypeError: ipc_wrapper should inherit from `IPCWrapper` + timeout: timeout for the `forward` function. """ - self.worker = worker() - self.worker.worker_id = worker_id - self.worker.max_batch_size = max_batch_size - self.worker.stage = stage + self.name = f"<{stage_name}|{worker_id}>" self.timeout = timeout - self.name = f"<{stage_id}|{worker.__name__}|{worker_id}>" self.current_ids: Sequence[bytes] = [] self.semaphore = FakeSemaphore() # type: ignore + self.worker = worker() + self.worker.worker_id = worker_id + self.worker.max_batch_size = max_batch_size + self.worker.stage = stage_name + self.protocol = Protocol( name=self.name, - addr=os.path.join(socket_prefix, f"ipc_{stage_id}.socket"), + addr=os.path.join(socket_prefix, f"ipc_{stage_name}.socket"), timeout=PROTOCOL_TIMEOUT, ) - # optional plugin features - ipc wrapper - self.ipc_wrapper: Optional[IPCWrapper] = None - if ipc_wrapper is not None: - self.ipc_wrapper = ipc_wrapper() - if not issubclass(type(self.ipc_wrapper), IPCWrapper): - raise TypeError( - "ipc_wrapper must be the subclass of mosec.plugins.IPCWrapper" - ) - # ignore termination & interruption signal signal.signal(signal.SIGTERM, signal.SIG_IGN) signal.signal(signal.SIGINT, signal.SIG_IGN) @@ -197,7 +187,9 @@ def streaming(self): # encode the text with UTF-8 payloads = (text.encode(),) ids = (self.current_ids[index],) - self.protocol.send(HTTPStautsCode.STREAM_EVENT, ids, payloads) + self.protocol.send( + HTTPStautsCode.STREAM_EVENT, ids, (0,) * len(ids), payloads + ) self.semaphore.release() except queue.Empty: continue @@ -243,71 +235,28 @@ def run(self): self.coordinate() - def get_decoder(self) -> Callable[[bytes], Any]: - """Get the decoder function for this stage. - - The first stage will use the worker's deserialize function. - """ - if STAGE_INGRESS in self.worker.stage: - return self.worker.deserialize - return self.worker.deserialize_ipc - - def get_encoder(self) -> Callable[[Any], bytes]: - """Get the encoder function for this stage. - - The last stage will use the worker's serialize function. - """ - if STAGE_EGRESS in self.worker.stage: - return self.worker.serialize - return self.worker.serialize_ipc - - def get_protocol_recv( - self, - ) -> Callable[[], Tuple[bytes, Sequence[bytes], Sequence[bytes]]]: - """Get the protocol receive function for this stage. - - IPC wrapper will be used if it's provided and the stage is not the first one. - """ - if STAGE_INGRESS in self.worker.stage or self.ipc_wrapper is None: - return self.protocol.receive - - # TODO(kemingy) find a better way - def wrapped_recv() -> Tuple[bytes, Sequence[bytes], Sequence[bytes]]: - flag, ids, payloads = self.protocol.receive() - payloads = self.ipc_wrapper.get( # type: ignore - [bytes(x) for x in payloads] - ) - return flag, ids, payloads - - return wrapped_recv - - def get_protocol_send( - self, - ) -> Callable[[int, Sequence[bytes], Sequence[bytes]], None]: - """Get the protocol send function for this stage. - - IPC wrapper will be used if it's provided and the stage is not the last one. - """ - if STAGE_EGRESS in self.worker.stage or self.ipc_wrapper is None: - return self.protocol.send - - # TODO(kemingy) find a better way - def wrapped_send(flag: int, ids: Sequence[bytes], payloads: Sequence[bytes]): - if flag == HTTPStautsCode.OK: - payloads = self.ipc_wrapper.put(payloads) # type: ignore - return self.protocol.send(flag, ids, payloads) + def decode(self, payload: bytes, state: int) -> Any: + """Decode the payload with the state.""" + return ( + self.worker.deserialize(payload) + if state & State.INGRESS + else self.worker.deserialize_ipc(payload) + ) - return wrapped_send + def encode(self, data: Any, state: int) -> bytes: + """Encode the data with the state.""" + return ( + self.worker.serialize(data) + if state & State.EGRESS + else self.worker.serialize_ipc(data) + ) def coordinate(self): """Start coordinating the protocol's communication and worker's forward pass.""" - decoder = self.get_decoder() - encoder = self.get_encoder() - protocol_recv = self.get_protocol_recv() - protocol_send = self.get_protocol_send() while not self.shutdown.is_set(): try: - _, ids, payloads = protocol_recv() + # flag received from the server is not used + _, ids, states, payloads = self.protocol.receive() # expose here to be used by stream event self.current_ids = ids except socket.timeout: @@ -320,7 +269,10 @@ def coordinate(self): # pylint: disable=broad-except length = len(payloads) try: - data = [decoder(item) for item in payloads] + data = [ + self.decode(payload, state) + for (payload, state) in zip(payloads, states) + ] with set_mosec_timeout(self.timeout): data = ( self.worker.forward(data) @@ -333,7 +285,9 @@ def coordinate(self): f"input({length})!=output({len(data)})" ) status = HTTPStautsCode.OK - payloads = [encoder(item) for item in data] + payloads = [ + self.encode(datum, state) for (datum, state) in zip(data, states) + ] except (MosecError, MosecTimeoutError) as err: err_msg = str(err).replace("\n", " - ") err_msg = err_msg if err_msg else err.msg @@ -348,7 +302,7 @@ def coordinate(self): try: # pylint: disable=consider-using-with self.semaphore.acquire(timeout=self.timeout) - protocol_send(status, ids, payloads) + self.protocol.send(status, ids, states, payloads) except OSError as err: logger.error("%s failed to send to socket: %s", self.name, err) break diff --git a/mosec/dry_run.py b/mosec/dry_run.py index d80b729a..e1334fb5 100644 --- a/mosec/dry_run.py +++ b/mosec/dry_run.py @@ -21,11 +21,11 @@ import sys import time from multiprocessing.context import SpawnContext, SpawnProcess -from typing import TYPE_CHECKING, List, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Tuple, Union from mosec.env import env_var_context from mosec.log import get_internal_logger -from mosec.runtime import PyRuntimeManager, Runtime +from mosec.runtime import Runtime from mosec.worker import Worker if TYPE_CHECKING: @@ -75,37 +75,37 @@ def __init__(self, process_context: SpawnContext, shutdown_notify: Event): process_context: server context of spawn process shutdown_notify: event of server will shutdown """ - self._process_context = process_context - self._shutdown_notify = shutdown_notify + self.process_context = process_context + self.shutdown_notify = shutdown_notify - self._pool: List[SpawnProcess] = [] - self._sender_pipes: List[PipeConnection] = [] - self._receiver_pipes: List[PipeConnection] = [] + self.processes: List[SpawnProcess] = [] + self.sender_pipes: List[PipeConnection] = [] + self.receiver_pipes: List[PipeConnection] = [] def new_pipe(self): """Create new pipe for dry run workers to communicate.""" - receiver, sender = self._process_context.Pipe(duplex=False) - self._sender_pipes.append(sender) - self._receiver_pipes.append(receiver) + receiver, sender = self.process_context.Pipe(duplex=False) + self.sender_pipes.append(sender) + self.receiver_pipes.append(receiver) - def start_worker(self, worker_runtime: Runtime, first: bool): + def start_worker(self, worker_runtime: Runtime, init: bool): """Start the worker process for dry run. Args: worker_runtime: worker runtime to start - first: whether the worker is tried to start at first time + init: whether the worker is tried to start at the first time """ self.new_pipe() - coordinator = self._process_context.Process( + coordinator = self.process_context.Process( target=dry_run_func, args=( worker_runtime.worker, worker_runtime.max_batch_size, - self._receiver_pipes[-2], - self._sender_pipes[-1], - first, - self._shutdown_notify, + self.receiver_pipes[-2], + self.sender_pipes[-1], + init, + self.shutdown_notify, ), daemon=True, ) @@ -113,7 +113,7 @@ def start_worker(self, worker_runtime: Runtime, first: bool): with env_var_context(worker_runtime.env, 0): coordinator.start() - self._pool.append(coordinator) + self.processes.append(coordinator) def probe_worker_liveness(self) -> Tuple[Union[int, None], Union[int, None]]: """Check every worker is running/alive. @@ -123,7 +123,7 @@ def probe_worker_liveness(self) -> Tuple[Union[int, None], Union[int, None]]: exitcode: exitcode of the first failed worker """ - for i, process in enumerate(self._pool): + for i, process in enumerate(self.processes): if process.exitcode is not None: return i, process.exitcode return None, None @@ -136,7 +136,7 @@ def wait_all(self) -> Tuple[Union[int, None], Union[int, None]]: exitcode: exitcode of the first failed worker """ - for i, process in enumerate(self._pool): + for i, process in enumerate(self.processes): process.join() if process.exitcode != 0: return i, process.exitcode @@ -144,7 +144,7 @@ def wait_all(self) -> Tuple[Union[int, None], Union[int, None]]: def first_last_pipe(self): """Get first sender and last receiver pipes.""" - return self._sender_pipes[0], self._receiver_pipes[-1] + return self.sender_pipes[0], self.receiver_pipes[-1] class DryRunner: @@ -158,14 +158,11 @@ class DryRunner: will be used. """ - def __init__(self, manager: PyRuntimeManager): + def __init__(self, router: Dict[str, List[Runtime]]): """Init dry runner.""" - logger.info("init dry runner for %s", manager.workers) - - self._manager = manager - self._process_context: SpawnContext = SpawnContext() - self._shutdown_notify: Event = self._process_context.Event() - self._pool = Pool(self._process_context, self._shutdown_notify) + self.router = router + self.process_context: SpawnContext = SpawnContext() + self.shutdown_notify: Event = self.process_context.Event() signal.signal(signal.SIGTERM, self.terminate) signal.signal(signal.SIGINT, self.terminate) @@ -173,38 +170,48 @@ def __init__(self, manager: PyRuntimeManager): def terminate(self, signum, framestack): """Terminate the dry run.""" logger.info("received terminate signal [%s] %s", signum, framestack) - self._shutdown_notify.set() + self.shutdown_notify.set() def run(self): """Execute thr dry run process.""" - self._pool.new_pipe() - for i, worker_runtime in enumerate(self._manager): - self._pool.start_worker(worker_runtime, i == 0) - - logger.info("dry run init successful") - self.warmup() - - logger.info("wait for worker init done") - if not self._shutdown_notify.is_set(): - self._shutdown_notify.set() - - failed, exitcode = self._pool.wait_all() - if failed is not None: - logger.warning( - "detect %s with abnormal exit code %d", - self._manager.workers[failed], - exitcode, + for endpoint, runtimes in self.router.items(): + logger.info( + "init dry run for endpoint %s with %s", + endpoint, + [runtime.name for runtime in runtimes], ) - sys.exit(exitcode) + + pool = Pool(self.process_context, self.shutdown_notify) + pool.new_pipe() + for i, worker_runtime in enumerate(runtimes): + pool.start_worker(worker_runtime, i == 0) + + logger.info("dry run init successful") + self.warmup(runtimes, pool) + + logger.info("wait for worker init done") + if not self.shutdown_notify.is_set(): + self.shutdown_notify.set() + + failed, exitcode = pool.wait_all() + if failed is not None: + logger.warning( + "detect %s with abnormal exit code %d", + runtimes[failed].name, + exitcode, + ) + sys.exit(exitcode) + + self.shutdown_notify.clear() logger.info("dry run exit") - def warmup(self): + def warmup(self, runtimes: List[Runtime], pool: Pool): """Warmup the service. If neither `example` nor `multi_examples` is provided, it will only init the worker class. """ - ingress = self._manager.workers[0] + ingress = runtimes[0].worker example = None if ingress.example: example = ingress.example @@ -218,25 +225,25 @@ def warmup(self): logger.info("cannot find the example in the 1st stage worker, skip warmup") return - sender, receiver = self._pool.first_last_pipe() + sender, receiver = pool.first_last_pipe() start_time = time.perf_counter() sender.send(example) - while not self._shutdown_notify.is_set(): + while not self.shutdown_notify.is_set(): if receiver.poll(0.1): break # liveness probe - failed, exitcode = self._pool.probe_worker_liveness() + failed, exitcode = pool.probe_worker_liveness() if failed is not None: logger.warning( "worker %s exit with code %d", - self._manager.workers[failed], + runtimes[failed].name, exitcode, ) - self._shutdown_notify.set() + self.shutdown_notify.set() break - if self._shutdown_notify.is_set(): + if self.shutdown_notify.is_set(): sys.exit(1) res = receiver.recv_bytes() diff --git a/mosec/ipc.py b/mosec/ipc.py deleted file mode 100644 index 9343fc81..00000000 --- a/mosec/ipc.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright 2022 MOSEC Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Wrapper layer for IPC between workers. - -This will be called before sending data or after receiving data through the Protocol. - - -.. warning:: - This implementation is deprecated. Please use - :py:mod:`PlasmaShmIPCMixin ` -""" - -import abc -from typing import List - - -class IPCWrapper(abc.ABC): - """This public class defines the mosec IPC wrapper plugin interface. - - The wrapper has to implement at least ``put`` and ``get`` methods. - """ - - @abc.abstractmethod - def put(self, data: List[bytes]) -> List[bytes]: - """Put bytes to somewhere to get ids, which are sent via protocol. - - Args: - data: List of bytes data. - - Returns: List of bytes ID. - """ - - @abc.abstractmethod - def get(self, ids: List[bytes]) -> List[bytes]: - """Get bytes from somewhere by ids, which are received via protocol. - - Args: - ids: List of bytes ID. - - Returns: List of bytes data. - """ diff --git a/mosec/plugins/__init__.py b/mosec/plugins/__init__.py deleted file mode 100644 index 78fb3187..00000000 --- a/mosec/plugins/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright 2022 MOSEC Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Provide useful tools to extend MOSEC.""" - -from mosec.plugins.plasma_shm import PlasmaShmWrapper - -__all__ = ["PlasmaShmWrapper"] diff --git a/mosec/plugins/plasma_shm.py b/mosec/plugins/plasma_shm.py deleted file mode 100644 index e93bf8c4..00000000 --- a/mosec/plugins/plasma_shm.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2022 MOSEC Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Provide another data transfer way between workers. - -The data will be stored in plasma shared memory, while the object ID will be -sent via the original way. - - use case: large image tensors - benefits: more stable P99 latency - -```{warning} -This implementation is deprecated. Please use -:py:mod:`PlasmaShmIPCMixin ` -``` -""" - -import warnings -from typing import TYPE_CHECKING, List - -from mosec.ipc import IPCWrapper - -# We do not enforce the installation of third party libraries for -# plugins, because users may not enable them. -try: - from pyarrow import plasma # type: ignore -except ImportError: - warnings.warn( - "pyarrow is not installed. PlasmaShmWrapper is not available.", ImportWarning - ) - -if TYPE_CHECKING: - from pyarrow import plasma - - -class PlasmaShmWrapper(IPCWrapper): - """Shared memory wrapper using ``pyarrow`` Plasma. - - This public class is an example implementation of the - :py:mod:`IPCWrapper `. - It utilizes ``pyarrow.plasma`` as the in-memory object store for - potentially more efficient data transfer. - """ - - def __init__(self, shm_path: str) -> None: - """Initialize the IPC Wrapper as a plasma client. - - Args: - shm_path (str): path of the plasma server. - """ - warnings.warn( - "PlasmaShmWrapper is deprecated, please use RedisShmIPCMixin.", - DeprecationWarning, - ) - self.client = plasma.connect(shm_path) - - def _put_plasma(self, data: List[bytes]) -> List[plasma.ObjectID]: - """Batch put into plasma memory store.""" - return [self.client.put(x) for x in data] - - def _get_plasma(self, object_ids: List[plasma.ObjectID]) -> List[bytes]: - """Batch get from plasma memory store.""" - objects = self.client.get(object_ids) - self.client.delete(object_ids) - return objects - - def put(self, data: List[bytes]) -> List[bytes]: - """Save data to the plasma memory store and return the ID.""" - object_ids = self._put_plasma(data) - return [id.binary() for id in object_ids] - - def get(self, ids: List[bytes]) -> List[bytes]: - """Get data from the plasma memory store by ID.""" - object_ids = [plasma.ObjectID(id) for id in ids] - return self._get_plasma(object_ids) diff --git a/mosec/protocol.py b/mosec/protocol.py index 21411808..bd695dcc 100644 --- a/mosec/protocol.py +++ b/mosec/protocol.py @@ -54,11 +54,13 @@ class Protocol: FORMAT_BATCH = "!H" FORMAT_ID = "!I" FORMAT_LENGTH = "!I" + FORMAT_STATE = "!H" # lengths LENGTH_TASK_FLAG = 2 LENGTH_TASK_BATCH = 2 LENGTH_TASK_ID = 4 + LENGTH_TASK_STATE = 2 LENGTH_TASK_BODY_LEN = 4 def __init__( @@ -82,21 +84,23 @@ def __init__( self.name = name self.addr = addr - def receive(self) -> Tuple[bytes, Sequence[bytes], Sequence[bytes]]: + def receive(self) -> Tuple[bytes, Sequence[bytes], Sequence[int], Sequence[bytes]]: """Receive tasks from the server.""" flag = self.socket.recv(self.LENGTH_TASK_FLAG) batch_size_bytes = self.socket.recv(self.LENGTH_TASK_BATCH) batch_size = struct.unpack(self.FORMAT_BATCH, batch_size_bytes)[0] - ids, payloads = [], [] + ids, states, payloads = [], [], [] total_bytes = 0 while batch_size > 0: batch_size -= 1 id_bytes = self.socket.recv(self.LENGTH_TASK_ID) + state_bytes = self.socket.recv(self.LENGTH_TASK_STATE) length_bytes = self.socket.recv(self.LENGTH_TASK_BODY_LEN) length = struct.unpack(self.FORMAT_LENGTH, length_bytes)[0] payload = _recv_all(self.socket, length) ids.append(id_bytes) + states.append(struct.unpack(self.FORMAT_STATE, state_bytes)[0]) payloads.append(payload) total_bytes += length @@ -114,9 +118,15 @@ def receive(self) -> Tuple[bytes, Sequence[bytes], Sequence[bytes]]: "which may affect performance", RuntimeWarning, ) - return flag, ids, payloads + return flag, ids, states, payloads - def send(self, flag: int, ids: Sequence[bytes], payloads: Sequence[bytes]): + def send( + self, + flag: int, + ids: Sequence[bytes], + states: Sequence[int], + payloads: Sequence[bytes], + ): """Send results to the server.""" data = BytesIO() data.write(struct.pack(self.FORMAT_FLAG, flag)) @@ -125,10 +135,10 @@ def send(self, flag: int, ids: Sequence[bytes], payloads: Sequence[bytes]): batch_size = len(ids) data.write(struct.pack(self.FORMAT_BATCH, batch_size)) if batch_size > 0: - for task_id, payload in zip(ids, payloads): - length = struct.pack(self.FORMAT_LENGTH, len(payload)) + for task_id, state, payload in zip(ids, states, payloads): data.write(task_id) - data.write(length) + data.write(struct.pack(self.FORMAT_STATE, state)) + data.write(struct.pack(self.FORMAT_LENGTH, len(payload))) data.write(payload) self.socket.sendall(data.getbuffer()) if logger.isEnabledFor(logging.DEBUG): diff --git a/mosec/runtime.py b/mosec/runtime.py index 0e620010..ac4f87f4 100644 --- a/mosec/runtime.py +++ b/mosec/runtime.py @@ -16,19 +16,17 @@ import multiprocessing as mp import subprocess -from functools import partial from multiprocessing.context import ForkContext, SpawnContext from multiprocessing.process import BaseProcess from multiprocessing.synchronize import Event from pathlib import Path from time import monotonic, sleep -from typing import Any, Callable, Dict, Iterable, List, Optional, Type, Union, cast +from typing import Callable, Dict, Iterable, List, Optional, Type, Union, cast import pkg_resources -from mosec.coordinator import STAGE_EGRESS, STAGE_INGRESS, Coordinator +from mosec.coordinator import Coordinator from mosec.env import env_var_context, validate_env, validate_int_ge -from mosec.ipc import IPCWrapper from mosec.log import get_internal_logger from mosec.worker import Worker @@ -39,20 +37,22 @@ # pylint: disable=too-many-instance-attributes # pylint: disable=too-many-arguments +# pylint: disable=too-few-public-methods class Runtime: """The wrapper with one worker and its arguments.""" + # count how many runtime instances have been created + _stage_id: int = 0 + def __init__( self, worker: Type[Worker], - num: int, - max_batch_size: int, - max_wait_time: int, - stage_id: int, - timeout: int, - start_method: str, - env: Union[None, List[Dict[str, str]]], - ipc_wrapper: Optional[Union[Type[IPCWrapper], partial]], + num: int = 1, + max_batch_size: int = 1, + max_wait_time: int = 10, + timeout: int = 3, + start_method: str = "spawn", + env: Union[None, List[Dict[str, str]]] = None, ): """Initialize the mosec coordinator. @@ -65,27 +65,25 @@ def __init__( for dynamic batching, needs to be used with `max_batch_size` to enable the feature. If not configure, will use the CLI argument `--wait` (default=10ms) - stage_id (int): identification number for worker stages. - timeout (int): timeout for the `forward` function. + timeout (int): timeout (second) for the `forward` function. start_method: the process starting method ("spawn" or "fork") env: the environment variables to set before starting the process - ipc_wrapper (IPCWrapper): IPC wrapper class to be initialized. - - Raises: - TypeError: ipc_wrapper should inherit from `IPCWrapper` """ self.worker = worker self.num = num self.max_batch_size = max_batch_size self.max_wait_time = max_wait_time - self.stage_id = stage_id self.timeout = timeout self.start_method = start_method self.env = env - self.ipc_wrapper = ipc_wrapper + Runtime._stage_id += 1 + # adding the stage id in case the worker class is added to multiple stages + self.name = f"{self.worker.__name__}_{self._stage_id}" self._pool: List[Union[BaseProcess, None]] = [None for _ in range(self.num)] + self._validate() + @staticmethod def _process_healthy(process: Union[BaseProcess, None]) -> bool: """Check if the child process is healthy. @@ -97,16 +95,17 @@ def _process_healthy(process: Union[BaseProcess, None]) -> bool: return process is not None and process.exitcode is None def _healthy(self, method: Callable[[Iterable[object]], bool]) -> bool: + """Check if all/any of the child processes are healthy.""" return method(self._pool) def _start_process( self, worker_id: int, - stage_label: str, work_path: str, shutdown: Event, shutdown_notify: Event, ): + """Start a worker process in the context.""" context = mp.get_context(self.start_method) context = cast(Union[SpawnContext, ForkContext], context) coordinator_process = context.Process( @@ -114,13 +113,11 @@ def _start_process( args=( self.worker, self.max_batch_size, - stage_label, shutdown, shutdown_notify, work_path, - self.stage_id, + self.name, worker_id + 1, - self.ipc_wrapper, self.timeout, ), daemon=True, @@ -131,22 +128,20 @@ def _start_process( self._pool[worker_id] = coordinator_process - def check( + def _check( self, - first: bool, - stage_label: str, work_path: str, shutdown: Event, shutdown_notify: Event, + init: bool, ) -> bool: """Check and start the worker process if it has not started yet. Args: - first: whether the worker is tried to start at first time - stage_label: label of worker ingress and egress work_path: path of working directory shutdown: Event of server shutdown shutdown_notify: Event of server will shutdown + init: whether the worker is tried to start at the first time Returns: Whether the worker is started successfully @@ -156,7 +151,7 @@ def check( if self._healthy(all): # this stage is healthy return True - if not first and not self._healthy(any): + if not init and not self._healthy(any): # this stage might contain bugs return False @@ -165,12 +160,10 @@ def check( ] for worker_id in need_start_id: # for every worker in each stage - self._start_process( - worker_id, stage_label, work_path, shutdown, shutdown_notify - ) + self._start_process(worker_id, work_path, shutdown, shutdown_notify) return True - def validate(self): + def _validate(self): """Validate arguments of worker runtime.""" validate_env(self.env, self.num) assert issubclass( @@ -196,54 +189,36 @@ def __init__(self, work_path: str, shutdown: Event, shutdown_notify: Event): shutdown: Event of server shutdown shutdown_notify: Event of server will shutdown """ - self._runtimes: List[Runtime] = [] + self.runtimes: List[Runtime] = [] self._work_path = work_path - self._shutdown = shutdown - self._shutdown_notify = shutdown_notify - - def __iter__(self): - """Iterate workers of manager.""" - return self._runtimes.__iter__() + self.shutdown = shutdown + self.shutdown_notify = shutdown_notify @property def worker_count(self) -> int: """Get number of workers.""" - return len(self._runtimes) + return len(self.runtimes) @property def workers(self) -> List[Type[Worker]]: """Get List of workers.""" - return [r.worker for r in self._runtimes] - - def egress_mime(self) -> str: - """Return mime of egress worker.""" - return self._runtimes[-1].worker.resp_mime_type + return [r.worker for r in self.runtimes] def append(self, runtime: Runtime): """Sequentially appends workers to the workflow pipeline.""" - self._runtimes.append(runtime) - - def _label_stage(self, stage_id: int) -> str: - stage = "" - if stage_id == 0: - stage += STAGE_INGRESS - if stage_id == self.worker_count - 1: - stage += STAGE_EGRESS - return stage + self.runtimes.append(runtime) - def check_and_start(self, first: bool) -> Union[Runtime, None]: + def check_and_start(self, init: bool) -> Union[Runtime, None]: """Check all worker processes and try to start failed ones. Args: - first: whether the worker is tried to start at first time + init: whether the worker is tried to start at the first time """ - for stage_id, worker_runtime in enumerate(self._runtimes): - label = self._label_stage(stage_id) - success = worker_runtime.check( - first, label, self._work_path, self._shutdown, self._shutdown_notify - ) - if not success: + for worker_runtime in self.runtimes: + if not worker_runtime._check( # pylint: disable=protected-access + self._work_path, self.shutdown, self.shutdown_notify, init + ): return worker_runtime return None @@ -251,34 +226,28 @@ def check_and_start(self, first: bool) -> Union[Runtime, None]: class RsRuntimeManager: """The manager to control Mosec process.""" - def __init__( - self, py_manager: PyRuntimeManager, endpoint: str, configs: Dict[str, Any] - ): + def __init__(self, timeout: int): """Initialize a Mosec manager. Args: - manager: manager of python coordinator - endpoint: event of server shutdown - configs: config + timeout: service timeout in milliseconds """ - self._process: Optional[subprocess.Popen] = None + self.process: Optional[subprocess.Popen] = None - self._py_manager: PyRuntimeManager = py_manager - self._endpoint = endpoint - self._server_path = Path( + self.server_path = Path( pkg_resources.resource_filename("mosec", "bin"), "mosec" ) - self._configs: Dict[str, Any] = configs + self.timeout = timeout def halt(self): """Graceful shutdown.""" # terminate controller first and wait for a graceful period - if self._process is None: + if self.process is None: return - self._process.terminate() - graceful_period = monotonic() + self._configs["timeout"] / 1000 + self.process.terminate() + graceful_period = monotonic() + self.timeout / 1000 while monotonic() < graceful_period: - ctr_exitcode = self._process.poll() + ctr_exitcode = self.process.poll() if ctr_exitcode is None: sleep(0.1) continue @@ -291,24 +260,14 @@ def halt(self): if monotonic() > graceful_period: logger.error("failed to terminate mosec service, will try to kill it") - self._process.kill() + self.process.kill() - def start(self) -> subprocess.Popen: - """Start the Mosec process.""" + def start(self, config_path: Path) -> subprocess.Popen: + """Start the Mosec process. + + Args: + config_path: configuration path of mosec + """ # pylint: disable=consider-using-with - self._process = subprocess.Popen([self._server_path] + self._controller_args()) - return self._process - - def _controller_args(self): - args = [] - self._configs.pop("dry_run") - for key, value in self._configs.items(): - args.extend([f"--{key.replace('_', '-')}", str(value).lower()]) - for worker_runtime in self._py_manager: - args.extend(["--batches", str(worker_runtime.max_batch_size)]) - args.extend(["--waits", str(worker_runtime.max_wait_time)]) - mime_type = self._py_manager.egress_mime() - args.extend(["--mime", mime_type]) - args.extend(["--endpoint", self._endpoint]) - logger.info("mosec received arguments: %s", args) - return args + self.process = subprocess.Popen([self.server_path, config_path]) + return self.process diff --git a/mosec/server.py b/mosec/server.py index 6dae3d93..63a8b005 100644 --- a/mosec/server.py +++ b/mosec/server.py @@ -24,8 +24,8 @@ corresponding worker is appended, by setting the :py:meth:`append_worker(max_batch_size) `. -Multiprocess ------------- +Multiprocessing +--------------- The user may spawn multiple processes for any stage when the corresponding worker is appended, by setting the @@ -34,30 +34,28 @@ import json import multiprocessing as mp -import os import pathlib import shutil import signal import subprocess import traceback -from functools import partial +from collections import defaultdict from multiprocessing.synchronize import Event from time import sleep -from typing import Dict, List, Optional, Type, Union +from typing import Dict, List, Type, Union from mosec.args import parse_arguments from mosec.dry_run import DryRunner -from mosec.ipc import IPCWrapper from mosec.log import get_internal_logger from mosec.runtime import PyRuntimeManager, RsRuntimeManager, Runtime from mosec.utils import ParseTarget -from mosec.worker import MOSEC_REF_TEMPLATE, Worker +from mosec.worker import MOSEC_REF_TEMPLATE, SSEWorker, Worker logger = get_internal_logger() GUARD_CHECK_INTERVAL = 1 -MOSEC_OPENAPI_PATH = "mosec_openapi.json" +MOSEC_RESERVED_ENDPOINTS = {"/", "/metrics", "/openapi"} class Server: @@ -68,30 +66,17 @@ class Server: """ # pylint: disable=too-many-instance-attributes - def __init__( - self, - ipc_wrapper: Optional[Union[IPCWrapper, partial]] = None, - endpoint: str = "/inference", - ): - """Initialize a MOSEC Server. - - Args: - ipc_wrapper: (deprecated) wrapper function (before and after) IPC - endpoint: path to route inference - """ - self.ipc_wrapper = ipc_wrapper - self.endpoint = endpoint - + def __init__(self): + """Initialize a MOSEC Server.""" self._shutdown: Event = mp.get_context("spawn").Event() self._shutdown_notify: Event = mp.get_context("spawn").Event() self._configs: dict = vars(parse_arguments()) - self.py_runtime_manager: PyRuntimeManager = PyRuntimeManager( + self._py_runtime_manager: PyRuntimeManager = PyRuntimeManager( self._configs["path"], self._shutdown, self._shutdown_notify ) - self.rs_runtime_manager = RsRuntimeManager( - self.py_runtime_manager, endpoint, self._configs - ) + self._rs_runtime_manager = RsRuntimeManager(self._configs["timeout"]) + self._router: Dict[str, List[Runtime]] = defaultdict(list) self._daemon: Dict[str, Union[subprocess.Popen, mp.Process]] = {} @@ -102,7 +87,7 @@ def _handle_signal(self): signal.signal(signal.SIGINT, self._terminate) def _validate_server(self): - assert self.py_runtime_manager.worker_count > 0, ( + assert self._py_runtime_manager.worker_count > 0, ( "no worker registered\n" "help: use `.append_worker(...)` to register at least one worker" ) @@ -127,7 +112,36 @@ def _start_rs_runtime(self): """Subprocess to start the rust runtime manager program.""" if self._server_shutdown: return - process = self.rs_runtime_manager.start() + + # dump the config to a JSON file + config_path = pathlib.Path(self._configs["path"]) / "config.json" + configs = {"runtimes": [], "routes": []} + for key, value in self._configs.items(): + if key in ("dry_run", "debug", "wait"): + continue + configs[key] = value + for runtime in self._py_runtime_manager.runtimes: + configs["runtimes"].append( + { + "worker": runtime.name, + "max_batch_size": runtime.max_batch_size, + "max_wait_time": runtime.max_wait_time, + } + ) + for endpoint, pipeline in self._router.items(): + configs["routes"].append( + { + "endpoint": endpoint, + "workers": [runtime.name for runtime in pipeline], + "is_sse": issubclass(pipeline[-1].worker, SSEWorker), + **generate_openapi([runtime.worker for runtime in pipeline]), + } + ) + config_path.parent.mkdir(parents=True, exist_ok=True) + with open(config_path, "w", encoding="utf-8") as file: + json.dump(configs, file, indent=2) + + process = self._rs_runtime_manager.start(config_path) self.register_daemon("rs_runtime", process) def _terminate(self, signum, framestack): @@ -135,17 +149,16 @@ def _terminate(self, signum, framestack): self._server_shutdown = True def _manage_py_runtime(self): - first = True + init = True while not self._server_shutdown: - failed_worker = self.py_runtime_manager.check_and_start(first) - if failed_worker is not None: + failed_runtime = self._py_runtime_manager.check_and_start(init) + if failed_runtime is not None: self._terminate( 1, - f"all the {failed_worker.worker.__name__} workers" - f" at stage {failed_worker.stage_id} exited;" + f"all the {failed_runtime.name} workers exited;" " please check for bugs or socket connection issues", ) - first = False + init = False self._check_daemon() sleep(GUARD_CHECK_INTERVAL) @@ -153,7 +166,7 @@ def _halt(self): """Graceful shutdown.""" # notify the rs runtime to shutdown self._shutdown_notify.set() - self.rs_runtime_manager.halt() + self._rs_runtime_manager.halt() # shutdown py runtime manager self._shutdown.set() shutil.rmtree(self._configs["path"], ignore_errors=True) @@ -182,6 +195,7 @@ def append_worker( start_method: str = "spawn", env: Union[None, List[Dict[str, str]]] = None, timeout: int = 0, + route: Union[str, List[str]] = "/inference", ): """Sequentially appends workers to the workflow pipeline. @@ -198,73 +212,58 @@ def append_worker( change this unless you understand the difference between them) env: the environment variables to set before starting the process timeout: the timeout (second) for each worker forward processing (>=1) + route: the route path for this worker. If not configured, will use the + default route path `/inference`. If a list is provided, different + route paths will share the same worker. """ timeout = timeout if timeout >= 1 else self._configs["timeout"] // 1000 max_wait_time = max_wait_time if max_wait_time >= 1 else self._configs["wait"] - stage_id = self.py_runtime_manager.worker_count runtime = Runtime( worker, num, max_batch_size, max_wait_time, - stage_id + 1, timeout, start_method, env, - None, ) - runtime.validate() - self.py_runtime_manager.append(runtime) - - def _generate_openapi(self): - """Generate the OpenAPI specification.""" - if self.py_runtime_manager.worker_count <= 0: - return - workers = self.py_runtime_manager.workers - request_worker_cls, response_worker_cls = workers[0], workers[-1] - input_schema, input_components = request_worker_cls.get_forward_json_schema( - ParseTarget.INPUT, MOSEC_REF_TEMPLATE - ) - return_schema, return_components = response_worker_cls.get_forward_json_schema( - ParseTarget.RETURN, MOSEC_REF_TEMPLATE - ) - - def make_body(description, mime, schema): - if not schema: - return None - return {"description": description, "content": {mime: {"schema": schema}}} - - schema = { - "request_body": make_body( - "Mosec Inference Request Body", - request_worker_cls.resp_mime_type, - input_schema, - ), - "responses": None - if not return_schema - else { - 200: make_body( - "Mosec Inference Result", - response_worker_cls.resp_mime_type, - return_schema, - ) - }, - "schemas": {**input_components, **return_components}, - } - tmp_path = pathlib.Path(os.path.join(self._configs["path"], MOSEC_OPENAPI_PATH)) - tmp_path.parent.mkdir(parents=True, exist_ok=True) - with open(tmp_path, "w", encoding="utf-8") as file: - json.dump(schema, file) + self._register_route(runtime, route) + self._py_runtime_manager.append(runtime) + + def register_runtime(self, routes: Dict[str, List[Runtime]]): + """Register the runtime to the routes.""" + if self._py_runtime_manager.worker_count > 0: + raise RuntimeError( + "`register_runtime` can only be registered to an empty mosec server" + ) + unique_runtimes = set() + for endpoint, runtimes in routes.items(): + for runtime in runtimes: + self._register_route(runtime, endpoint) + unique_runtimes.add(runtime) + for runtime in unique_runtimes: + self._py_runtime_manager.append(runtime) + + def _register_route(self, runtime: Runtime, route: Union[str, List[str]]): + """Register the route path for the worker.""" + if isinstance(route, str): + if route in MOSEC_RESERVED_ENDPOINTS: + raise ValueError(f"'{route}' is reserved, try another one") + self._router[route].append(runtime) + elif isinstance(route, list): + for endpoint in route: + if endpoint in MOSEC_RESERVED_ENDPOINTS: + raise ValueError(f"'{endpoint}' is reserved, try another one") + self._router[endpoint].append(runtime) def run(self): """Start the mosec model server.""" self._validate_server() if self._configs["dry_run"]: - DryRunner(self.py_runtime_manager).run() + DryRunner(self._router).run() return self._handle_signal() - self._generate_openapi() self._start_rs_runtime() try: self._manage_py_runtime() @@ -272,3 +271,40 @@ def run(self): except Exception: logger.error(traceback.format_exc().replace("\n", " ")) self._halt() + + +def generate_openapi(workers: List[Type[Worker]]): + """Generate the OpenAPI specification for one pipeline.""" + if not workers: + return {} + request_worker_cls, response_worker_cls = workers[0], workers[-1] + input_schema, input_components = request_worker_cls.get_forward_json_schema( + ParseTarget.INPUT, MOSEC_REF_TEMPLATE + ) + return_schema, return_components = response_worker_cls.get_forward_json_schema( + ParseTarget.RETURN, MOSEC_REF_TEMPLATE + ) + + def make_body(description, mime, schema): + if not schema: + return None + return {"description": description, "content": {mime: {"schema": schema}}} + + return { + "request_body": make_body( + "Mosec Inference Request Body", + request_worker_cls.resp_mime_type, + input_schema, + ), + "responses": None + if not return_schema + else { + "200": make_body( + "Mosec Inference Result", + response_worker_cls.resp_mime_type, + return_schema, + ) + }, + "schemas": {**input_components, **return_components}, + "mime": response_worker_cls.resp_mime_type, + } diff --git a/mosec/worker.py b/mosec/worker.py index aa37a8af..a3a89909 100644 --- a/mosec/worker.py +++ b/mosec/worker.py @@ -234,6 +234,7 @@ class SSEWorker(Worker): _stream_queue: SimpleQueue _stream_semaphore: Semaphore + resp_mime_type = "text/event-stream" def send_stream_event(self, text: str, index: int = 0): """Send a stream event to the client. diff --git a/src/apidoc.rs b/src/apidoc.rs index 1ff97a92..8428d4a3 100644 --- a/src/apidoc.rs +++ b/src/apidoc.rs @@ -12,32 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::BTreeMap; -use std::str::FromStr; +use utoipa::openapi::{Components, OpenApi, PathItemType}; -use serde::Deserialize; -use utoipa::openapi::request_body::RequestBody; -use utoipa::openapi::{Components, OpenApi, PathItemType, RefOr, Response, Schema}; - -#[derive(Deserialize, Default)] -pub(crate) struct PythonAPIDoc { - #[serde(skip_serializing_if = "Option::is_none", default)] - request_body: Option, - - #[serde(skip_serializing_if = "Option::is_none", default)] - responses: Option>>, - - #[serde(skip_serializing_if = "Option::is_none", default)] - schemas: Option>>, -} - -impl FromStr for PythonAPIDoc { - type Err = serde_json::Error; - - fn from_str(s: &str) -> Result { - serde_json::from_str::(s) - } -} +use crate::config::Route; #[derive(Default, Clone)] pub(crate) struct MosecOpenAPI { @@ -45,12 +22,16 @@ pub(crate) struct MosecOpenAPI { } impl MosecOpenAPI { - /// merge PythonAPIDoc of target route to mosec api - pub fn merge(&mut self, route: &str, python_api: PythonAPIDoc) -> &mut Self { - let path = self.api.paths.paths.get_mut(route).unwrap(); + /// Merge the route request_body/response/schemas into the OpenAPI. + pub fn merge_route(&mut self, route: &Route) -> &mut Self { + let reserved = match route.is_sse { + true => "/openapi/reserved/inference", + false => "/openapi/reserved/inference_sse", + }; + let mut path = self.api.paths.paths.get(reserved).unwrap().clone(); let op = path.operations.get_mut(&PathItemType::Post).unwrap(); - if let Some(mut other_schemas) = python_api.schemas { + if let Some(mut user_schemas) = route.schemas.clone() { if self.api.components.is_none() { self.api.components = Some(Components::default()); } @@ -59,31 +40,27 @@ impl MosecOpenAPI { .as_mut() .unwrap() .schemas - .append(&mut other_schemas); + .append(&mut user_schemas); }; - if let Some(req) = python_api.request_body { + if let Some(req) = route.request_body.clone() { op.request_body = Some(req); }; - if let Some(mut responses) = python_api.responses { + if let Some(mut responses) = route.responses.clone() { op.responses.responses.append(&mut responses); }; + self.api.paths.paths.insert(route.endpoint.clone(), path); self } - /// This function replaces a [OpenAPI Path Item Object][path_item] from path `from` to path `to`. - /// - /// e.g. /inference -> /v1/inference. - /// - /// It is used to handle cases where variable paths are not supported by the [utoipa-gen][utoipa-gen] library. - /// - /// [path_item]: https://spec.openapis.org/oas/latest.html#path-item-object - /// [utoipa-gen]: https://crates.io/crates/utoipa-gen - pub fn replace_path_item(&mut self, from: &str, to: &str) -> &mut Self { - if let Some(r) = self.api.paths.paths.remove(from) { - self.api.paths.paths.insert(to.to_owned(), r); - } + /// Removes the reserved paths from the OpenAPI spec. + pub fn clean(&mut self) -> &mut Self { + self.api.paths.paths.remove("/openapi/reserved/inference"); + self.api + .paths + .paths + .remove("/openapi/reserved/inference_sse"); self } } diff --git a/src/args.rs b/src/args.rs deleted file mode 100644 index 9fe13c4e..00000000 --- a/src/args.rs +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2022 MOSEC Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use argh::FromArgs; - -#[derive(FromArgs, Debug, PartialEq)] -/// MOSEC arguments -pub(crate) struct Opts { - /// the Endpoint of inference - #[argh(option, default = "String::from(\"/inference\")")] - pub(crate) endpoint: String, - - /// the Unix domain socket directory path - #[argh(option, default = "String::from(\"\")")] - pub(crate) path: String, - - /// max batch size for each stage - #[argh(option)] - pub(crate) batches: Vec, - - /// capacity for the channel - /// (when the channel is full, the new requests will be dropped with 429 Too Many Requests) - #[argh(option, short = 'c', default = "1024")] - pub(crate) capacity: usize, - - /// timeout for one request (milliseconds) - #[argh(option, short = 't', default = "3000")] - pub(crate) timeout: u64, - - /// wait time for each batch (milliseconds), use `waits` instead [deprecated] - #[argh(option, short = 'w', default = "10")] - pub(crate) wait: u64, - - /// max wait time for each stage - #[argh(option)] - pub(crate) waits: Vec, - - /// service host - #[argh(option, short = 'a', default = "String::from(\"0.0.0.0\")")] - pub(crate) address: String, - - /// service port - #[argh(option, short = 'p', default = "8000")] - pub(crate) port: u16, - - /// metrics namespace - #[argh(option, short = 'n', default = "String::from(\"mosec_service\")")] - pub(crate) namespace: String, - - /// enable debug log - #[argh(option, short = 'd', default = "false")] - pub(crate) debug: bool, - - /// set the log level - #[argh(option, default = "String::from(\"info\")")] - pub(crate) log_level: String, - - /// response mime type - #[argh(option, default = "String::from(\"application/json\")")] - pub(crate) mime: String, -} diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 00000000..1a1cc739 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,98 @@ +// Copyright 2023 MOSEC Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::BTreeMap; +use std::fmt; + +use serde::Deserialize; +use utoipa::openapi::request_body::RequestBody; +use utoipa::openapi::{RefOr, Response, Schema}; + +#[derive(Deserialize, Debug)] +pub(crate) struct Runtime { + pub max_batch_size: usize, + pub max_wait_time: u64, + pub worker: String, +} + +#[derive(Deserialize)] +pub(crate) struct Route { + pub endpoint: String, + pub workers: Vec, + pub mime: String, + pub is_sse: bool, + pub request_body: Option, + pub responses: Option>>, + pub schemas: Option>>, +} + +impl fmt::Debug for Route { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "({}: [{}], resp({}))", + self.endpoint, + self.workers.join(", "), + self.mime + ) + } +} + +#[derive(Deserialize, Debug)] +pub(crate) struct Config { + // socket dir + pub path: String, + // channel capacity + pub capacity: usize, + // service timeout (ms) + pub timeout: u64, + // service address + pub address: String, + // service port + pub port: u16, + // metrics namespace + pub namespace: String, + // log level: (debug, info, warning, error) + pub log_level: String, + pub runtimes: Vec, + pub routes: Vec, +} + +impl Default for Config { + fn default() -> Self { + Self { + path: String::from("/tmp/mosec"), + capacity: 1024, + timeout: 3000, + address: String::from("0.0.0.0"), + port: 8000, + namespace: String::from("mosec_service"), + log_level: String::from("info"), + runtimes: vec![Runtime { + max_batch_size: 64, + max_wait_time: 3000, + worker: String::from("Inference_1"), + }], + routes: vec![Route { + endpoint: String::from("/inference"), + workers: vec![String::from("Inference_1")], + mime: String::from("application/json"), + is_sse: false, + request_body: None, + responses: None, + schemas: None, + }], + } + } +} diff --git a/src/coordinator.rs b/src/coordinator.rs deleted file mode 100644 index 26b963ba..00000000 --- a/src/coordinator.rs +++ /dev/null @@ -1,121 +0,0 @@ -// Copyright 2022 MOSEC Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::fs; -use std::path::Path; -use std::sync::Arc; -use std::time::Duration; - -use async_channel::{bounded, Receiver, Sender}; -use tokio::sync::Barrier; -use tracing::{error, info}; - -use crate::args::Opts; -use crate::metrics::{Metrics, METRICS}; -use crate::protocol::communicate; -use crate::tasks::{TaskManager, TASK_MANAGER}; - -#[derive(Debug)] -pub(crate) struct Coordinator { - capacity: usize, - path: String, - batches: Vec, - wait_time: Vec, - receiver: Receiver, - sender: Sender, -} - -impl Coordinator { - pub(crate) fn init_from_opts(opts: &Opts) -> Self { - // init the global task manager - let (sender, receiver) = bounded(opts.capacity); - let timeout = Duration::from_millis(opts.timeout); - let wait_time = opts - .waits - .iter() - .map(|x| Duration::from_millis(*x)) - .collect(); - let path = if !opts.path.is_empty() { - opts.path.to_string() - } else { - // default IPC path - std::env::temp_dir() - .join(env!("CARGO_PKG_NAME")) - .into_os_string() - .into_string() - .unwrap() - }; - let task_manager = TaskManager::new(timeout, sender.clone()); - TASK_MANAGER.set(task_manager).unwrap(); - let metrics = Metrics::init_with_namespace(&opts.namespace, opts.timeout); - METRICS.set(metrics).unwrap(); - - Self { - capacity: opts.capacity, - path, - batches: opts.batches.clone(), - wait_time, - receiver, - sender, - } - } - - pub(crate) fn run(&self) -> Arc { - let barrier = Arc::new(Barrier::new(self.batches.len() + 1)); - let mut last_receiver = self.receiver.clone(); - let mut last_sender = self.sender.clone(); - let folder = Path::new(&self.path); - if folder.is_dir() { - info!(path=?folder, "socket path already exist, try to remove it"); - fs::remove_dir_all(folder).unwrap(); - } - fs::create_dir(folder).unwrap(); - - for i in 0..self.batches.len() { - let (sender, receiver) = bounded::(self.capacity); - let path = folder.join(format!("ipc_{:?}.socket", i + 1)); - - let batch_size = self.batches[i]; - let wait = self.wait_time[i]; - tokio::spawn(communicate( - path, - batch_size as usize, - wait, - (i + 1).to_string(), - last_receiver.clone(), - sender.clone(), - last_sender.clone(), - barrier.clone(), - )); - last_receiver = receiver; - last_sender = sender; - } - tokio::spawn(finish_task(last_receiver)); - barrier - } -} - -async fn finish_task(receiver: Receiver) { - let task_manager = TaskManager::global(); - loop { - match receiver.recv().await { - Ok(id) => { - task_manager.notify_task_done(id); - } - Err(err) => { - error!(%err, "failed to get the task id when trying to mark it as done"); - } - } - } -} diff --git a/src/main.rs b/src/main.rs index 432ed902..1921d5f0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,22 +13,21 @@ // limitations under the License. mod apidoc; -mod args; -mod coordinator; +mod config; mod errors; mod metrics; mod protocol; mod routes; mod tasks; +use std::env; use std::fs::read_to_string; use std::net::SocketAddr; -use std::path::Path; use axum::routing::{get, post}; use axum::Router; use tokio::signal::unix::{signal, SignalKind}; -use tracing::info; +use tracing::{debug, info}; use tracing_subscriber::fmt::time::OffsetTime; use tracing_subscriber::prelude::*; use tracing_subscriber::{filter, Layer}; @@ -36,12 +35,10 @@ use utoipa::OpenApi; use utoipa_swagger_ui::SwaggerUi; use crate::apidoc::MosecOpenAPI; -use crate::args::Opts; -use crate::coordinator::Coordinator; -use crate::routes::{index, inference, metrics, sse_inference, AppState, RustAPIDoc}; -use crate::tasks::TaskManager; - -const MOSEC_OPENAPI_PATH: &str = "mosec_openapi.json"; +use crate::config::Config; +use crate::metrics::{Metrics, METRICS}; +use crate::routes::{index, inference, metrics, sse_inference, RustAPIDoc}; +use crate::tasks::{TaskManager, TASK_MANAGER}; async fn shutdown_signal() { let mut interrupt = signal(SignalKind::interrupt()).unwrap(); @@ -65,44 +62,61 @@ async fn shutdown_signal() { } #[tokio::main] -async fn run(opts: &Opts) { - let python_api = - read_to_string(Path::new(&opts.path).join(MOSEC_OPENAPI_PATH)).unwrap_or_default(); +async fn run(conf: &Config) { let mut doc = MosecOpenAPI { api: RustAPIDoc::openapi(), }; - doc.merge("/inference", python_api.parse().unwrap_or_default()) - .replace_path_item("/inference", &opts.endpoint); + for route in &conf.routes { + doc.merge_route(route); + } + doc.clean(); - let state = AppState { - mime: opts.mime.clone(), - }; - let coordinator = Coordinator::init_from_opts(opts); - let barrier = coordinator.run(); - barrier.wait().await; - let app = Router::new() - .merge(SwaggerUi::new("/api/swagger").url("/api/openapi.json", doc.api)) + let metrics_instance = Metrics::init_with_namespace(&conf.namespace, conf.timeout); + METRICS.set(metrics_instance).unwrap(); + let mut task_manager = TaskManager::new(conf.timeout); + let barrier = task_manager.init_from_config(conf); + TASK_MANAGER.set(task_manager).unwrap(); + + let mut router = Router::new() + .merge(SwaggerUi::new("/openapi/swagger").url("/openapi/metadata.json", doc.api)) .route("/", get(index)) - .route("/metrics", get(metrics)) - .route(&opts.endpoint, post(inference)) - .route("/sse_inference", post(sse_inference)) - .with_state(state); + .route("/metrics", get(metrics)); - let addr: SocketAddr = format!("{}:{}", opts.address, opts.port).parse().unwrap(); + for route in &conf.routes { + if route.is_sse { + router = router.route(&route.endpoint, post(sse_inference)); + } else { + router = router.route(&route.endpoint, post(inference)); + } + } + + // wait until each stage has at least one worker alive + barrier.wait().await; + let addr: SocketAddr = format!("{}:{}", conf.address, conf.port).parse().unwrap(); info!(?addr, "http service is running"); axum::Server::bind(&addr) - .serve(app.into_make_service()) + .serve(router.into_make_service()) .with_graceful_shutdown(shutdown_signal()) .await .unwrap(); } fn main() { - let opts: Opts = argh::from_env(); + // let opts: Opts = argh::from_env(); + let cmd_args: Vec = env::args().collect(); + if cmd_args.len() != 2 { + println!( + "expect one argument as the config path but got {:?}", + cmd_args + ); + return; + } + let config_str = read_to_string(&cmd_args[1]).expect("read config file failure"); + let conf: Config = serde_json::from_str(&config_str).expect("parse config failure"); // this has to be defined before tokio multi-threads let timer = OffsetTime::local_rfc_3339().expect("local time offset"); - if opts.debug || opts.log_level == "debug" { + if conf.log_level == "debug" { // use colorful log for debug let output = tracing_subscriber::fmt::layer().compact().with_timer(timer); tracing_subscriber::registry() @@ -116,7 +130,7 @@ fn main() { .init(); } else { // use JSON format for production - let level = match opts.log_level.as_str() { + let level = match conf.log_level.as_str() { "error" => tracing::Level::ERROR, "warning" => tracing::Level::WARN, _ => tracing::Level::INFO, @@ -128,6 +142,6 @@ fn main() { .init(); } - info!(?opts, "parse service arguments"); - run(&opts); + debug!(?conf, "parse service arguments"); + run(&conf); } diff --git a/src/protocol.rs b/src/protocol.rs index f5e68ab4..8cdca3b0 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -16,7 +16,7 @@ use std::path::PathBuf; use std::sync::Arc; use std::time::{Duration, Instant}; -use async_channel::{Receiver, Sender}; +use async_channel::Receiver; use bytes::{BufMut, Bytes, BytesMut}; use tokio::io::{self, AsyncReadExt, AsyncWriteExt}; use tokio::net::{UnixListener, UnixStream}; @@ -28,6 +28,7 @@ use crate::tasks::{TaskCode, TaskManager}; const FLAG_U8_SIZE: usize = 2; const NUM_U8_SIZE: usize = 2; +const STATE_U8_SIZE: usize = 2; const TASK_ID_U8_SIZE: usize = 4; const LENGTH_U8_SIZE: usize = 4; @@ -43,20 +44,16 @@ pub(crate) async fn communicate( path: PathBuf, batch_size: usize, wait_time: Duration, - stage_id: String, + stage_name: String, receiver: Receiver, - sender: Sender, - last_sender: Sender, barrier: Arc, ) { let listener = UnixListener::bind(&path).expect("failed to bind to the socket"); let mut connection_id: u32 = 0; loop { connection_id += 1; - let sender_clone = sender.clone(); - let last_sender_clone = last_sender.clone(); let receiver_clone = receiver.clone(); - let stage_id_label = stage_id.clone(); + let stage_name_label = stage_name.clone(); let connection_id_label = connection_id.to_string(); info!(?path, "begin listening to socket"); match listener.accept().await { @@ -66,15 +63,17 @@ pub(crate) async fn communicate( let mut code: TaskCode = TaskCode::InternalError; let mut ids: Vec = Vec::with_capacity(batch_size); let mut data: Vec = Vec::with_capacity(batch_size); + let mut states: Vec = Vec::with_capacity(batch_size); let task_manager = TaskManager::global(); let metrics = Metrics::global(); let metric_label = StageConnectionLabel { - stage: stage_id_label.clone(), + stage: stage_name_label.clone(), connection: connection_id_label, }; loop { ids.clear(); data.clear(); + states.clear(); let batch_timer = get_batch(&receiver_clone, batch_size, &mut ids, wait_time).await; if let Some(timer) = batch_timer { @@ -86,7 +85,7 @@ pub(crate) async fn communicate( // start record the duration metrics here because receiving the first task // depends on when the request comes in. let start_timer = Instant::now(); - task_manager.get_multi_tasks_data(&mut ids, &mut data); + task_manager.get_multi_tasks_data(&mut ids, &mut data, &mut states); if data.is_empty() { continue; } @@ -97,48 +96,54 @@ pub(crate) async fn communicate( .get_or_create(&metric_label) .observe(data.len() as f64); } - if let Err(err) = send_message(&mut stream, &ids, &data).await { - error!(%err, %stage_id_label, %connection_id, "socket send message error"); + if let Err(err) = send_message(&mut stream, &ids, &data, &states).await { + error!(%err, %stage_name_label, %connection_id, "socket send message error"); info!( "service failed to write data to stream, will try to send task \ back to see if other thread can handle it" ); for id in &ids { - last_sender_clone.send(*id).await.expect("sender is closed"); + task_manager.send_task(id).await; } break; } - debug!(%stage_id_label, %connection_id, "socket finished to send message"); + debug!(%stage_name_label, %connection_id, "socket finished to send message"); ids.clear(); data.clear(); + states.clear(); if let Err(err) = - read_message(&mut stream, &mut code, &mut ids, &mut data).await + read_message(&mut stream, &mut code, &mut ids, &mut data, &mut states) + .await { - error!(%err, %stage_id_label, %connection_id, "socket receive message error"); + error!(%err, %stage_name_label, %connection_id, "socket receive message error"); break; } - debug!(%stage_id_label, %connection_id, "socket finished to read message"); + debug!(%stage_name_label, %connection_id, "socket finished to read message"); while code == TaskCode::StreamEvent { send_stream_event(&ids, &data).await; ids.clear(); data.clear(); - if let Err(err) = - read_message(&mut stream, &mut code, &mut ids, &mut data).await + states.clear(); + if let Err(err) = read_message( + &mut stream, + &mut code, + &mut ids, + &mut data, + &mut states, + ) + .await { - error!(%err, %stage_id_label, %connection_id, "socket receive message error"); + error!(%err, %stage_name_label, %connection_id, "socket receive message error"); break; } - debug!(%stage_id_label, %connection_id, "socket finished to read message"); + debug!(%stage_name_label, %connection_id, "socket finished to read message"); } task_manager.update_multi_tasks(code, &ids, &data).await; match code { TaskCode::Normal => { for id in &ids { - sender_clone - .send(*id) - .await - .expect("next channel is closed"); + task_manager.send_task(id).await; } // only the normal tasks will be recorded metrics @@ -150,7 +155,7 @@ pub(crate) async fn communicate( warn!( ?ids, ?code, - ?stage_id_label, + ?stage_name_label, ?connection_id, "abnormal tasks, check Python log for more details" ); @@ -164,7 +169,7 @@ pub(crate) async fn communicate( } } Err(err) => { - error!(%err, %stage_id, %connection_id, "socket failed to accept the connection"); + error!(%err, %stage_name, %connection_id, "socket failed to accept the connection"); break; } } @@ -193,6 +198,7 @@ async fn read_message( code: &mut TaskCode, ids: &mut Vec, data: &mut Vec, + states: &mut Vec, ) -> Result<(), io::Error> { stream.readable().await?; let mut flag_buf = [0u8; FLAG_U8_SIZE]; @@ -219,14 +225,18 @@ async fn read_message( let mut id_buf = [0u8; TASK_ID_U8_SIZE]; let mut length_buf = [0u8; LENGTH_U8_SIZE]; + let mut state_buf = [0u8; STATE_U8_SIZE]; for _ in 0..num { stream.read_exact(&mut id_buf).await?; + stream.read_exact(&mut state_buf).await?; stream.read_exact(&mut length_buf).await?; let id = u32::from_be_bytes(id_buf); + let state = u16::from_be_bytes(state_buf); let length = u32::from_be_bytes(length_buf); let mut data_buf = vec![0u8; length as usize]; stream.read_exact(&mut data_buf).await?; ids.push(id); + states.push(state); data.push(data_buf.into()); } let byte_size = data.iter().fold(0, |acc, x| acc + x.len()); @@ -279,6 +289,7 @@ async fn send_message( stream: &mut UnixStream, ids: &[u32], data: &[Bytes], + states: &[u16], ) -> Result<(), io::Error> { stream.writable().await?; let mut buffer = BytesMut::new(); @@ -286,6 +297,7 @@ async fn send_message( buffer.put_u16(ids.len() as u16); for i in 0..ids.len() { buffer.put_u32(ids[i]); + buffer.put_u16(states[i]); buffer.put_u32(data[i].len() as u32); buffer.put(data[i].clone()); } @@ -345,13 +357,15 @@ mod tests { let listener = UnixListener::bind(&path).expect("bind error"); let ids = vec![0u32, 1]; let data = vec![Bytes::from_static(b"hello"), Bytes::from_static(b"world")]; + let states = vec![1u16, 2]; // setup the server in another tokio thread let ids_clone = ids.clone(); let data_clone = data.clone(); + let states_clone = states.clone(); tokio::spawn(async move { let (mut stream, _addr) = listener.accept().await.unwrap(); - send_message(&mut stream, &ids_clone, &data_clone) + send_message(&mut stream, &ids_clone, &data_clone, &states_clone) .await .expect("send message error"); tokio::time::sleep(Duration::from_millis(1)).await; @@ -359,13 +373,22 @@ mod tests { let mut stream = UnixStream::connect(&path).await.unwrap(); let mut recv_ids = Vec::new(); + let mut recv_states = Vec::new(); let mut recv_data = Vec::new(); let mut code = TaskCode::InternalError; - read_message(&mut stream, &mut code, &mut recv_ids, &mut recv_data) - .await - .expect("read message error"); + read_message( + &mut stream, + &mut code, + &mut recv_ids, + &mut recv_data, + &mut recv_states, + ) + .await + .expect("read message error"); assert_eq!(recv_ids, ids); assert_eq!(recv_data, data); + assert_eq!(recv_states, states); + std::fs::remove_file(&path).expect("failed to remove the test socket file"); } } diff --git a/src/routes.rs b/src/routes.rs index 694fe8af..465b4394 100644 --- a/src/routes.rs +++ b/src/routes.rs @@ -15,7 +15,7 @@ use std::time::Duration; use axum::body::BoxBody; -use axum::extract::State; +use axum::http::Uri; use axum::response::sse::{Event, KeepAlive, Sse}; use axum::response::IntoResponse; use bytes::Bytes; @@ -34,11 +34,7 @@ const SERVER_INFO: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_V const RESPONSE_DEFAULT: &[u8] = b"MOSEC service"; const RESPONSE_EMPTY: &[u8] = b"no data provided"; const RESPONSE_SHUTDOWN: &[u8] = b"gracefully shutting down"; - -#[derive(Clone)] -pub(crate) struct AppState { - pub mime: String, -} +const DEFAULT_RESPONSE_MIME: &str = "application/json"; fn build_response(status: StatusCode, content: Bytes) -> Response { Response::builder() @@ -92,7 +88,7 @@ pub(crate) async fn metrics(_: Request) -> Response { #[utoipa::path( post, - path = "/inference", + path = "/openapi/reserved/inference", responses( (status = StatusCode::OK, description = "Inference"), (status = StatusCode::BAD_REQUEST, description = "BAD_REQUEST"), @@ -103,8 +99,13 @@ pub(crate) async fn metrics(_: Request) -> Response { (status = StatusCode::TOO_MANY_REQUESTS, description = "TOO_MANY_REQUESTS"), ), )] -pub(crate) async fn inference(State(state): State, req: Request) -> Response { +pub(crate) async fn inference(uri: Uri, req: Request) -> Response { let task_manager = TaskManager::global(); + let endpoint = uri.path(); + let mime = match task_manager.get_mime_type(endpoint) { + Some(mime) => mime.as_str(), + None => DEFAULT_RESPONSE_MIME, + }; let data = to_bytes(req.into_body()).await.unwrap(); if task_manager.is_shutdown() { @@ -121,7 +122,7 @@ pub(crate) async fn inference(State(state): State, req: Request) let (status, content); let metrics = Metrics::global(); metrics.remaining_task.inc(); - match task_manager.submit_task(data).await { + match task_manager.submit_task(data, endpoint).await { Ok(task) => { content = task.data; status = match task.code { @@ -164,14 +165,14 @@ pub(crate) async fn inference(State(state): State, req: Request) let mut resp = build_response(status, content); if status == StatusCode::OK { resp.headers_mut() - .insert(CONTENT_TYPE, HeaderValue::from_str(&state.mime).unwrap()); + .insert(CONTENT_TYPE, HeaderValue::from_str(mime).unwrap()); } resp } #[utoipa::path( post, - path = "/sse_inference", + path = "/openapi/reserved/inference_sse", responses( (status = StatusCode::OK, description = "Inference"), (status = StatusCode::BAD_REQUEST, description = "BAD_REQUEST"), @@ -182,8 +183,9 @@ pub(crate) async fn inference(State(state): State, req: Request) (status = StatusCode::TOO_MANY_REQUESTS, description = "TOO_MANY_REQUESTS"), ), )] -pub(crate) async fn sse_inference(req: Request) -> Response { +pub(crate) async fn sse_inference(uri: Uri, req: Request) -> Response { let task_manager = TaskManager::global(); + let endpoint = uri.path(); let data = to_bytes(req.into_body()).await.unwrap(); if task_manager.is_shutdown() { @@ -199,7 +201,7 @@ pub(crate) async fn sse_inference(req: Request) -> Response { } let metrics = Metrics::global(); - match task_manager.submit_sse_task(data).await { + match task_manager.submit_sse_task(data, endpoint).await { Ok(mut rx) => { let stream = async_stream::stream! { while let Some((msg, code)) = rx.recv().await { diff --git a/src/tasks.rs b/src/tasks.rs index 9f4fa58f..afa1ae01 100644 --- a/src/tasks.rs +++ b/src/tasks.rs @@ -13,19 +13,22 @@ // limitations under the License. use std::collections::HashMap; +use std::path::Path; use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::{Mutex, RwLock}; +use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; use bytes::Bytes; use hyper::StatusCode; use once_cell::sync::OnceCell; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::{mpsc, oneshot, Barrier}; use tokio::time; use tracing::{debug, error, info, warn}; +use crate::config::Config; use crate::errors::ServiceError; use crate::metrics::{CodeLabel, Metrics, DURATION_LABEL}; +use crate::protocol::communicate; #[derive(Debug, Clone, Copy, PartialEq, Eq, derive_more::Display, derive_more::Error)] pub(crate) enum TaskCode { @@ -40,7 +43,7 @@ pub(crate) enum TaskCode { #[display(fmt = "500: Internal Server Error")] InternalError, // special case - #[display(fmt = "500: Internal Server Error")] + #[display(fmt = "200: Server Sent Event")] StreamEvent, } @@ -48,14 +51,18 @@ pub(crate) enum TaskCode { pub(crate) struct Task { pub(crate) code: TaskCode, pub(crate) data: Bytes, + pub(crate) stage: usize, + pub(crate) route: String, pub(crate) create_at: Instant, } impl Task { - fn new(data: Bytes) -> Self { + fn new(data: Bytes, route: String) -> Self { Self { code: TaskCode::InternalError, data, + stage: 0, + route, create_at: Instant::now(), } } @@ -63,17 +70,30 @@ impl Task { fn update(&mut self, code: TaskCode, data: &Bytes) { self.code = code; self.data = data.clone(); + self.stage += 1; + } + + /// Encode the current state of the task into a 16-bit integer. + /// 0000 0000 0000 00yx + /// x: is ingress + /// y: is egress + fn encode_state(&self, total: usize) -> u16 { + let mut state = 0; + state |= (self.stage == 0) as u16; + state |= ((total - 1 == self.stage) as u16) << 1; + state } } #[derive(Debug)] pub(crate) struct TaskManager { - table: RwLock>, + table: Mutex>, notifiers: Mutex>>, stream_senders: Mutex>>, timeout: Duration, current_id: Mutex, - channel: async_channel::Sender, + senders: HashMap>>, + mime_types: HashMap, shutdown: AtomicBool, } @@ -84,18 +104,84 @@ impl TaskManager { TASK_MANAGER.get().expect("task manager is not initialized") } - pub(crate) fn new(timeout: Duration, channel: async_channel::Sender) -> Self { + pub(crate) fn new(timeout: u64) -> Self { Self { - table: RwLock::new(HashMap::new()), + table: Mutex::new(HashMap::new()), notifiers: Mutex::new(HashMap::new()), stream_senders: Mutex::new(HashMap::new()), - timeout, + timeout: Duration::from_millis(timeout), current_id: Mutex::new(0), - channel, + senders: HashMap::new(), + mime_types: HashMap::new(), shutdown: AtomicBool::new(false), } } + pub(crate) fn init_from_config(&mut self, conf: &Config) -> Arc { + let barrier = Arc::new(Barrier::new(conf.runtimes.len() + 1)); + + let mut worker_channel = + HashMap::, async_channel::Sender)>::new(); + let dir = Path::new(&conf.path); + + // run the coordinator in different threads + for runtime in &conf.runtimes { + let (tx, rx) = async_channel::bounded::(conf.capacity); + worker_channel.insert(runtime.worker.clone(), (rx.clone(), tx)); + let path = dir.join(format!("ipc_{}.socket", runtime.worker)); + tokio::spawn(communicate( + path, + runtime.max_batch_size, + Duration::from_millis(runtime.max_wait_time), + runtime.worker.clone(), + rx, + barrier.clone(), + )); + } + + for route in &conf.routes { + self.mime_types + .insert(route.endpoint.clone(), route.mime.clone()); + let worker_senders = route + .workers + .iter() + .map(|w| worker_channel[w].1.clone()) + .collect(); + self.senders.insert(route.endpoint.clone(), worker_senders); + } + + barrier + } + + pub(crate) fn get_mime_type(&self, endpoint: &str) -> Option<&String> { + self.mime_types.get(endpoint) + } + + pub(crate) async fn send_task(&self, id: &u32) { + let stage: usize; + let route: &Vec>; + { + let table = self.table.lock().unwrap(); + match table.get(id) { + Some(task) => { + stage = task.stage; + route = &self.senders[&task.route]; + } + None => { + warn!(%id, "failed to get the task when trying to send it"); + return; + } + }; + } + if stage >= route.len() { + self.notify_task_done(id); + return; + } + if route[stage].send(*id).await.is_err() { + warn!(%id, "failed to send this task, the sender might be closed"); + } + } + pub(crate) async fn shutdown(&self) { self.shutdown.store(true, Ordering::Release); let fut = time::timeout(self.timeout, async { @@ -103,7 +189,7 @@ impl TaskManager { let mut retry = 0; loop { interval.tick().await; - let remaining_task_num = self.table.read().unwrap().len(); + let remaining_task_num = self.table.lock().unwrap().len(); if remaining_task_num == 0 { break; } @@ -118,15 +204,15 @@ impl TaskManager { } } - pub(crate) async fn submit_task(&self, data: Bytes) -> Result { - let (id, rx) = self.add_new_task(data)?; + pub(crate) async fn submit_task(&self, data: Bytes, key: &str) -> Result { + let (id, rx) = self.add_new_task(data, key)?; if let Err(err) = time::timeout(self.timeout, rx).await { warn!(%id, %err, "task was not completed in the expected time, if this happens a lot, \ you might want to increase the service timeout"); self.delete_task(id, false); return Err(ServiceError::Timeout); } - let mut table = self.table.write().unwrap(); + let mut table = self.table.lock().unwrap(); match table.remove(&id) { Some(task) => Ok(task), None => { @@ -139,8 +225,9 @@ impl TaskManager { pub(crate) async fn submit_sse_task( &self, data: Bytes, + key: &str, ) -> Result, ServiceError> { - let (id, rx) = self.add_new_task(data)?; + let (id, rx) = self.add_new_task(data, key)?; let (sender, receiver) = mpsc::channel(16); { @@ -166,7 +253,7 @@ impl TaskManager { notifiers.remove(&id); } { - let mut table = self.table.write().unwrap(); + let mut table = self.table.lock().unwrap(); task = table.remove(&id); } if has_stream { @@ -180,7 +267,11 @@ impl TaskManager { self.shutdown.load(Ordering::Acquire) } - fn add_new_task(&self, data: Bytes) -> Result<(u32, oneshot::Receiver<()>), ServiceError> { + fn add_new_task( + &self, + data: Bytes, + key: &str, + ) -> Result<(u32, oneshot::Receiver<()>), ServiceError> { let (tx, rx) = oneshot::channel(); let id: u32; { @@ -193,12 +284,12 @@ impl TaskManager { notifiers.insert(id, tx); } { - let mut table = self.table.write().unwrap(); - table.insert(id, Task::new(data)); + let mut table = self.table.lock().unwrap(); + table.insert(id, Task::new(data, key.to_string())); } debug!(%id, "add a new task"); - if self.channel.try_send(id).is_err() { + if self.senders[key][0].try_send(id).is_err() { warn!(%id, "reach the capacity limit, will delete this task"); self.delete_task(id, false); return Err(ServiceError::TooManyRequests); @@ -206,11 +297,11 @@ impl TaskManager { Ok((id, rx)) } - pub(crate) fn notify_task_done(&self, id: u32) { + pub(crate) fn notify_task_done(&self, id: &u32) { let res; { let mut notifiers = self.notifiers.lock().unwrap(); - res = notifiers.remove(&id); + res = notifiers.remove(id); } if let Some(sender) = res { if !sender.is_closed() { @@ -219,8 +310,8 @@ impl TaskManager { warn!(%id, "the task notifier is already closed, will delete it \ (this is usually because the client side has closed the connection)"); { - let mut table = self.table.write().unwrap(); - table.remove(&id); + let mut table = self.table.lock().unwrap(); + table.remove(id); } let metrics = Metrics::global(); metrics.remaining_task.dec(); @@ -231,12 +322,18 @@ impl TaskManager { } } - pub(crate) fn get_multi_tasks_data(&self, ids: &mut Vec, data: &mut Vec) { - let table = self.table.read().unwrap(); + pub(crate) fn get_multi_tasks_data( + &self, + ids: &mut Vec, + data: &mut Vec, + states: &mut Vec, + ) { + let table = self.table.lock().unwrap(); // delete the task_id if the task_id doesn't exist in the table ids.retain(|&id| match table.get(&id) { Some(task) => { data.push(task.data.clone()); + states.push(task.encode_state(self.senders[&task.route].len())); true } None => false, @@ -248,17 +345,14 @@ impl TaskManager { { // make sure the table lock is released since the next func call may need to acquire // the notifiers lock, we'd better only hold one lock at a time - let mut table = self.table.write().unwrap(); + let mut table = self.table.lock().unwrap(); for i in 0..ids.len() { let task = table.get_mut(&ids[i]); match task { Some(task) => { task.update(code, &data[i]); - match code { - TaskCode::Normal => {} - _ => { - abnormal_tasks.push(ids[i]); - } + if code != TaskCode::Normal { + abnormal_tasks.push(ids[i]); } } None => { @@ -279,7 +373,7 @@ impl TaskManager { } } } - for task_id in abnormal_tasks { + for task_id in &abnormal_tasks { self.notify_task_done(task_id); } } @@ -312,10 +406,12 @@ async fn wait_sse_finish(id: u32, timeout: Duration, notifier: oneshot::Receiver mod tests { use super::*; + const DEFAULT_ENDPOINT: &str = "/inference"; + #[test] fn create_and_update_task() { let now = Instant::now(); - let mut task = Task::new(Bytes::from_static(b"hello")); + let mut task = Task::new(Bytes::from_static(b"hello"), "".to_string()); assert!(task.create_at > now); assert!(task.create_at < Instant::now()); assert!(matches!(task.code, TaskCode::InternalError)); @@ -328,70 +424,82 @@ mod tests { #[tokio::test] async fn task_manager_add_new_task() { - let (tx, rx) = async_channel::bounded(1); - let task_manager = TaskManager::new(Duration::from_secs(1), tx); + let mut task_manager = TaskManager::new(1000); + task_manager.init_from_config(&Config::default()); let (id, _rx) = task_manager - .add_new_task(Bytes::from_static(b"hello")) + .add_new_task(Bytes::from_static(b"hello"), DEFAULT_ENDPOINT) .unwrap(); assert_eq!(id, 0); { - let table = task_manager.table.read().unwrap(); + let table = task_manager.table.lock().unwrap(); let task = table.get(&id).unwrap(); assert_eq!(task.data, Bytes::from_static(b"hello")); } - let recv_id = rx.recv().await.unwrap(); - assert_eq!(recv_id, id); // add a new task let (id, _rx) = task_manager - .add_new_task(Bytes::from_static(b"world")) + .add_new_task(Bytes::from_static(b"world"), DEFAULT_ENDPOINT) .unwrap(); assert_eq!(id, 1); { - let table = task_manager.table.read().unwrap(); + let table = task_manager.table.lock().unwrap(); let task = table.get(&id).unwrap(); assert_eq!(task.data, Bytes::from_static(b"world")); } - let recv_id = rx.recv().await.unwrap(); - assert_eq!(recv_id, id); } #[tokio::test] async fn task_manager_timeout() { - let (tx, _rx) = async_channel::bounded(1); - let task_manager = TaskManager::new(Duration::from_millis(1), tx); + let mut task_manager = TaskManager::new(1); + task_manager.init_from_config(&Config::default()); // wait until this task timeout - let res = task_manager.submit_task(Bytes::from_static(b"hello")).await; + let res = task_manager + .submit_task(Bytes::from_static(b"hello"), DEFAULT_ENDPOINT) + .await; assert!(matches!(res.unwrap_err(), ServiceError::Timeout)); } #[tokio::test] async fn task_manager_too_many_request() { - let (tx, _rx) = async_channel::bounded(1); - // push one task into the channel to make the channel full - let _ = tx.send(0u32).await; - let task_manager = TaskManager::new(Duration::from_millis(1), tx); + let mut task_manager = TaskManager::new(1); + let mut config = Config::default(); + // capacity > 0 + config.capacity = 1; + task_manager.init_from_config(&config); + // send one task id to block the channel + task_manager.senders[DEFAULT_ENDPOINT][0] + .send(0) + .await + .unwrap(); - // trigger too many request since the capacity is 0 - let res = task_manager.submit_task(Bytes::from_static(b"hello")).await; + // trigger too many request since the capacity is 1 + let res = task_manager + .submit_task(Bytes::from_static(b"hello"), DEFAULT_ENDPOINT) + .await; assert!(matches!(res.unwrap_err(), ServiceError::TooManyRequests)); } #[tokio::test] async fn task_manager_graceful_shutdown() { - let (tx, _rx) = async_channel::bounded(1); - let task_manager = TaskManager::new(Duration::from_millis(1), tx); + let mut task_manager = TaskManager::new(1); + task_manager.init_from_config(&Config::default()); assert!(!task_manager.is_shutdown()); task_manager.shutdown().await; assert!(task_manager.is_shutdown()); + } - let (tx, _rx) = async_channel::bounded(1); - let task_manager = TaskManager::new(Duration::from_millis(10), tx); + #[tokio::test] + async fn task_manager_graceful_shutdown_after_timeout() { + let mut task_manager = TaskManager::new(10); + task_manager.init_from_config(&Config::default()); { // block with one task in the channel - let mut table = task_manager.table.write().unwrap(); - table.insert(0u32, Task::new(Bytes::from_static(b"hello"))); + let mut table = task_manager.table.lock().unwrap(); + table.insert( + 0u32, + Task::new(Bytes::from_static(b"hello"), DEFAULT_ENDPOINT.to_string()), + ); } assert!(!task_manager.is_shutdown()); let now = Instant::now(); @@ -403,24 +511,32 @@ mod tests { #[tokio::test] async fn task_manager_get_and_update_task() { - let (tx, _rx) = async_channel::bounded(1); - let task_manager = TaskManager::new(Duration::from_millis(1), tx); + let mut task_manager = TaskManager::new(1); + task_manager.init_from_config(&Config::default()); // add some tasks to the table { - let mut table = task_manager.table.write().unwrap(); - table.insert(0, Task::new(Bytes::from_static(b"hello"))); - table.insert(1, Task::new(Bytes::from_static(b"world"))); + let mut table = task_manager.table.lock().unwrap(); + table.insert( + 0, + Task::new(Bytes::from_static(b"hello"), DEFAULT_ENDPOINT.to_string()), + ); + table.insert( + 1, + Task::new(Bytes::from_static(b"world"), DEFAULT_ENDPOINT.to_string()), + ); } let mut task_ids = vec![0, 1, 2]; let mut data = Vec::new(); - task_manager.get_multi_tasks_data(&mut task_ids, &mut data); + let mut states = Vec::new(); + task_manager.get_multi_tasks_data(&mut task_ids, &mut data, &mut states); assert_eq!(task_ids, vec![0, 1]); assert_eq!( data, vec![Bytes::from_static(b"hello"), Bytes::from_static(b"world")] ); + assert_eq!(states, vec![3 as u16, 3 as u16]); // update tasks data = vec![Bytes::from_static(b"rust"), Bytes::from_static(b"tokio")]; @@ -428,7 +544,7 @@ mod tests { .update_multi_tasks(TaskCode::Normal, &task_ids, &data) .await; let mut new_data = Vec::new(); - task_manager.get_multi_tasks_data(&mut task_ids, &mut new_data); + task_manager.get_multi_tasks_data(&mut task_ids, &mut new_data, &mut states); assert_eq!(task_ids, vec![0, 1]); assert_eq!( new_data, diff --git a/tests/services/__init__.py b/tests/services/__init__.py index e69de29b..c5e6244f 100644 --- a/tests/services/__init__.py +++ b/tests/services/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 MOSEC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/services/multi_route_service.py b/tests/services/multi_route_service.py new file mode 100644 index 00000000..ff6f7723 --- /dev/null +++ b/tests/services/multi_route_service.py @@ -0,0 +1,79 @@ +# Copyright 2023 MOSEC Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test multi-route service.""" + +from typing import Any + +from msgspec import Struct + +from mosec import Runtime, Server, Worker +from mosec.mixin import TypedMsgPackMixin + + +class Request(Struct): + """User request struct.""" + + # pylint: disable=too-few-public-methods + + bin: bytes + name: str = "test" + + +class TypedPreprocess(TypedMsgPackMixin, Worker): + """Dummy preprocess to exit early if the validation failed.""" + + def forward(self, data: Request) -> Any: + """Input will be parse as the `Request`.""" + print(f"received from {data.name} with {data.bin!r}") + return data.bin + + +class Preprocess(Worker): + """Dummy preprocess worker.""" + + def deserialize(self, data: bytes) -> Any: + return data + + def forward(self, data: Any) -> Any: + return data + + +class Inference(Worker): + """Dummy inference worker.""" + + def forward(self, data: Any) -> Any: + return [{"length": len(datum)} for datum in data] + + +class TypedPostprocess(TypedMsgPackMixin, Worker): + """Dummy postprocess with msgpack.""" + + def forward(self, data: Any) -> Any: + return data + + +if __name__ == "__main__": + server = Server() + typed_pre = Runtime(TypedPreprocess) + pre = Runtime(Preprocess) + inf = Runtime(Inference, max_batch_size=16) + typed_post = Runtime(TypedPostprocess) + server.register_runtime( + { + "/v1/inference": [typed_pre, inf, typed_post], + "/inference": [pre, inf], + } + ) + server.run() diff --git a/tests/services/openapi_service.py b/tests/services/openapi_service.py index 164e9ba0..7480840e 100644 --- a/tests/services/openapi_service.py +++ b/tests/services/openapi_service.py @@ -76,8 +76,10 @@ def forward(self, data): "UntypedInference": UntypedInference, } - server = Server(endpoint="/v1/inference") + server = Server() preprocess_worker, inference_worker = sys.argv[1].split("/") - server.append_worker(worker_mapping[preprocess_worker]) - server.append_worker(worker_mapping[inference_worker], max_batch_size=16) + server.append_worker(worker_mapping[preprocess_worker], route="/v1/inference") + server.append_worker( + worker_mapping[inference_worker], max_batch_size=16, route="/v1/inference" + ) server.run() diff --git a/tests/services/square_service.py b/tests/services/square_service.py index 6f044d79..d1992eaf 100644 --- a/tests/services/square_service.py +++ b/tests/services/square_service.py @@ -30,6 +30,6 @@ def forward(self, data: List[dict]) -> List[dict]: if __name__ == "__main__": - server = Server(endpoint="/v1/inference") + server = Server() server.append_worker(SquareService, max_batch_size=8) server.run() diff --git a/tests/test_coordinator.py b/tests/test_coordinator.py index e05e7bf6..50d99e48 100644 --- a/tests/test_coordinator.py +++ b/tests/test_coordinator.py @@ -32,14 +32,13 @@ import msgpack # type: ignore import pytest -from mosec.coordinator import PROTOCOL_TIMEOUT, STAGE_EGRESS, STAGE_INGRESS, Coordinator +from mosec.coordinator import PROTOCOL_TIMEOUT, Coordinator, State from mosec.mixin import MsgpackMixin from mosec.protocol import HTTPStautsCode, _recv_all from mosec.worker import Worker from tests.utils import imitate_controller_send SOCKET_PREFIX = join(tempfile.gettempdir(), "test-mosec") -STAGE = STAGE_INGRESS + STAGE_EGRESS logger = logging.getLogger() @@ -81,21 +80,21 @@ def base_test_config(): } -def test_coordinator_worker_property(): +def test_coordinator_worker_property(mocker): + mocker.patch("mosec.coordinator.CONN_MAX_RETRY", 5) + mocker.patch("mosec.coordinator.CONN_CHECK_INTERVAL", 0.01) ctx = "spawn" coordinator = Coordinator( EchoWorkerJSON, max_batch_size=16, - stage=STAGE_EGRESS, shutdown=mp.get_context(ctx).Event(), shutdown_notify=mp.get_context(ctx).Event(), - socket_prefix="", - stage_id=2, + socket_prefix=SOCKET_PREFIX, + stage_name=EchoWorkerJSON.__name__, worker_id=3, - ipc_wrapper=None, timeout=3, ) - assert coordinator.worker.stage == STAGE_EGRESS + assert coordinator.worker.stage == EchoWorkerJSON.__name__ assert coordinator.worker.worker_id == 3 assert coordinator.worker.max_batch_size == 16 @@ -103,15 +102,13 @@ def test_coordinator_worker_property(): def make_coordinator(w_cls, shutdown, shutdown_notify, config): return Coordinator( w_cls, - config["max_batch_size"], - STAGE, - shutdown, - shutdown_notify, - SOCKET_PREFIX, - config["stage_id"], - config["worker_id"], - None, - config["timeout"], + max_batch_size=config["max_batch_size"], + shutdown=shutdown, + shutdown_notify=shutdown_notify, + socket_prefix=SOCKET_PREFIX, + stage_name=f"{w_cls.__name__}_{config['stage_id']}", + worker_id=config["worker_id"], + timeout=config["timeout"], ) @@ -146,7 +143,7 @@ def test_incorrect_socket_file(mocker, base_test_config, caplog): mocker.patch("mosec.coordinator.CONN_MAX_RETRY", 5) mocker.patch("mosec.coordinator.CONN_CHECK_INTERVAL", 0.01) - sock_addr = join(SOCKET_PREFIX, f"ipc_{base_test_config.get('stage_id')}.socket") + sock_addr = join(SOCKET_PREFIX, f"ipc_{EchoWorkerJSON.__name__}_1.socket") c_ctx = base_test_config.pop("c_ctx") shutdown = mp.get_context(c_ctx).Event() shutdown_notify = mp.get_context(c_ctx).Event() @@ -170,7 +167,6 @@ def test_incorrect_socket_file(mocker, base_test_config, caplog): sock.bind(sock_addr) with caplog.at_level(logging.ERROR): - # with pytest.raises(RuntimeError, match=r".*Connection refused.*"): _ = make_coordinator( EchoWorkerJSON, shutdown, shutdown_notify, base_test_config ) @@ -216,7 +212,7 @@ def test_echo_batch(base_test_config, test_data, worker, deserializer): # knows this stage enables batching base_test_config["max_batch_size"] = 8 - sock_addr = join(SOCKET_PREFIX, f"ipc_{base_test_config.get('stage_id')}.socket") + sock_addr = join(SOCKET_PREFIX, f"ipc_{worker.__name__}_1.socket") shutdown = mp.get_context(c_ctx).Event() shutdown_notify = mp.get_context(c_ctx).Event() @@ -246,14 +242,17 @@ def test_echo_batch(base_test_config, test_data, worker, deserializer): got_flag = struct.unpack("!H", conn.recv(2))[0] got_batch_size = struct.unpack("!H", conn.recv(2))[0] got_ids = [] + got_states = [] got_payloads = [] while got_batch_size > 0: got_batch_size -= 1 got_ids.append(conn.recv(4)) + got_states.append(struct.unpack("!H", conn.recv(2))[0]) got_length = struct.unpack("!I", conn.recv(4))[0] got_payloads.append(_recv_all(conn, got_length)) assert got_flag == HTTPStautsCode.OK assert got_ids == sent_ids + assert got_states == [State.INGRESS | State.EGRESS] * len(sent_ids) assert all( deserializer(x) == deserializer(y) for x, y in zip(got_payloads, sent_payloads) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 87a9358e..844c4c16 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -22,17 +22,18 @@ import pytest +from mosec.coordinator import State from mosec.protocol import Protocol from tests.mock_socket import Socket from tests.utils import imitate_controller_send -def echo(protocol: Protocol, datum: List[bytes]): - sent_status = random.choice([1, 2, 4, 8]) +def echo(protocol: Protocol, data: List[bytes]): + sent_flag = random.choice([1, 2, 4, 8]) - sent_ids, sent_payloads = imitate_controller_send(protocol.socket, datum) + sent_ids, sent_payloads = imitate_controller_send(protocol.socket, data) - _, got_ids, got_payloads = protocol.receive() # client recv + _, got_ids, got_states, got_payloads = protocol.receive() # client recv assert len(protocol.socket.buffer) == 0 # type: ignore assert got_ids == sent_ids assert all( @@ -40,12 +41,13 @@ def echo(protocol: Protocol, datum: List[bytes]): ) got_payload_bytes = [bytes(x) for x in got_payloads] # client echo - protocol.send(sent_status, got_ids, got_payload_bytes) + protocol.send(sent_flag, got_ids, got_states, got_payload_bytes) # server recv (symmetric protocol) - got_status, got_ids, got_payloads = protocol.receive() + got_flag, got_ids, got_states, got_payloads = protocol.receive() assert len(protocol.socket.buffer) == 0 # type: ignore - assert struct.unpack("!H", got_status)[0] == sent_status + assert struct.unpack("!H", got_flag)[0] == sent_flag + assert got_states == [State.INGRESS | State.EGRESS] * len(sent_ids) assert got_ids == sent_ids assert all( bytes(got_payloads[i]) == sent_payloads[i] for i in range(len(sent_payloads)) diff --git a/tests/test_service.py b/tests/test_service.py index d06fe145..2f6169f8 100644 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -71,11 +71,11 @@ def test_square_service(mosec_service, http_client): resp = http_client.get("/metrics") assert resp.status_code == HTTPStatus.OK - resp = http_client.post("/v1/inference", json={"msg": 2}) + resp = http_client.post("/inference", json={"msg": 2}) assert resp.status_code == HTTPStatus.UNPROCESSABLE_ENTITY assert resp.text == "request validation error: 'x'" - resp = http_client.post("/v1/inference", content=b"bad-binary-request") + resp = http_client.post("/inference", content=b"bad-binary-request") assert resp.status_code == HTTPStatus.BAD_REQUEST validate_square_service(http_client, 2) @@ -184,7 +184,7 @@ def test_mixin_typed_service(mosec_service, http_client): def test_sse_service(mosec_service, http_client): count = 0 with connect_sse( - http_client, "POST", "/sse_inference", json={"text": "mosec"} + http_client, "POST", "/inference", json={"text": "mosec"} ) as event_source: for sse in event_source.iter_sse(): count += 1 @@ -194,7 +194,7 @@ def test_sse_service(mosec_service, http_client): count = 0 with connect_sse( - http_client, "POST", "/sse_inference", json={"bad": "req"} + http_client, "POST", "/inference", json={"bad": "req"} ) as event_source: for sse in event_source.iter_sse(): count += 1 @@ -229,7 +229,7 @@ def test_square_service_mp(mosec_service, http_client): def validate_square_service(http_client, num): - resp = http_client.post("/v1/inference", json={"x": num}) + resp = http_client.post("/inference", json={"x": num}) assert resp.status_code == HTTPStatus.OK assert resp.json()["x"] == num**2 @@ -280,7 +280,7 @@ def assert_empty_queue(http_client): indirect=["mosec_service", "http_client"], ) def test_openapi_service(mosec_service, http_client, args): - spec = http_client.get("/api/openapi.json").json() + spec = http_client.get("/openapi/metadata.json").json() input_cls, return_cls = args.split("/") path_item = spec["paths"]["/v1/inference"]["post"] @@ -298,3 +298,30 @@ def test_openapi_service(mosec_service, http_client, args): assert path_item["responses"]["200"]["content"] == want else: assert "content" not in path_item["responses"]["200"] + + +@pytest.mark.parametrize( + "mosec_service, http_client", + [ + pytest.param("multi_route_service", "", id="multi-route"), + ], + indirect=["mosec_service", "http_client"], +) +def test_multi_route_service(mosec_service, http_client): + data = b"mosec" + req = { + "name": "mosec-test", + "bin": data, + } + + # test /inference + resp = http_client.post("/inference", content=data) + assert resp.status_code == HTTPStatus.OK, resp + assert resp.headers["content-type"] == "application/json" + assert resp.json() == {"length": len(data)} + + # test /v1/inference + resp = http_client.post("/v1/inference", content=msgpack.packb(req)) + assert resp.status_code == HTTPStatus.OK, resp + assert resp.headers["content-type"] == "application/msgpack" + assert msgpack.unpackb(resp.content) == {"length": len(data)} diff --git a/tests/utils.py b/tests/utils.py index 59ab558b..134c563a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -22,29 +22,37 @@ import socket import struct import time +from http import HTTPStatus +from io import BytesIO from typing import TYPE_CHECKING, List, Tuple, Union +from mosec.coordinator import State + if TYPE_CHECKING: from tests.mock_socket import Socket as mock_socket def imitate_controller_send( - sock: Union[mock_socket, socket.socket], l_data: List[bytes] + sock: Union[mock_socket, socket.socket], data: List[bytes] ) -> Tuple[List[bytes], List[bytes]]: # explicit byte format here for sanity check # placeholder flag, should be discarded by receiver - header = struct.pack("!H", 0) + struct.pack("!H", len(l_data)) - body = b"" + header = struct.pack("!HH", HTTPStatus.OK, len(data)) + buf = BytesIO() + buf.write(header) sent_ids = [] sent_payloads = [] - for data in l_data: + for datum in data: tid = struct.pack("!I", random.randint(1, 100)) sent_ids.append(tid) - sent_payloads.append(data) - length = struct.pack("!I", len(data)) - body += tid + length + data + sent_payloads.append(datum) + length = struct.pack("!I", len(datum)) + buf.write(tid) + buf.write(struct.pack("!H", State.INGRESS | State.EGRESS)) # task state + buf.write(length) + buf.write(datum) - sock.sendall(header + body) # type: ignore + sock.sendall(buf.getbuffer()) # type: ignore return sent_ids, sent_payloads