Skip to content

Commit

Permalink
feat: provide openapi doc (#370)
Browse files Browse the repository at this point in the history
* feat: provide openapi doc

Signed-off-by: hang lv <xlv20@fudan.edu.cn>

* feat: change file default location

Signed-off-by: hang lv <xlv20@fudan.edu.cn>

* fix: sequence

Signed-off-by: hang lv <xlv20@fudan.edu.cn>

* fix: set default to pass test

Signed-off-by: hang lv <xlv20@fudan.edu.cn>

* refactor: typo & remove unnecessary design

Signed-off-by: hang lv <xlv20@fudan.edu.cn>

* feat: typo

Signed-off-by: hang lv <xlv20@fudan.edu.cn>

* feat: Decoupling the schema generation method

Signed-off-by: hang lv <xlv20@fudan.edu.cn>

* refactor: abstract the implementation of JSON schema

Signed-off-by: Xing Lv <xlv20@fudan.edu.cn>

* deps: add deps

Signed-off-by: Xing Lv <xlv20@fudan.edu.cn>

* fix: resolve conflicts

Signed-off-by: Xing Lv <xlv20@fudan.edu.cn>

* refactor: reduce complexity

Signed-off-by: Xing Lv <xlv20@fudan.edu.cn>

* fix: typo

Signed-off-by: hang lv <xlv20@fudan.edu.cn>

* fix: json typo

Signed-off-by: hang lv <xlv20@fudan.edu.cn>

* fix: typo

Signed-off-by: hang lv <xlv20@fudan.edu.cn>

* fix: rust docstring

Signed-off-by: hang lv <xlv20@fudan.edu.cn>

---------

Signed-off-by: hang lv <xlv20@fudan.edu.cn>
Signed-off-by: Xing Lv <xlv20@fudan.edu.cn>
  • Loading branch information
n063h authored Jun 13, 2023
1 parent 080f6e7 commit 3dcac43
Show file tree
Hide file tree
Showing 14 changed files with 651 additions and 52 deletions.
96 changes: 92 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,6 @@ once_cell = "1.18"
prometheus-client = "0.21.1"
argh = "0.1"
axum = "0.6.18"
utoipa = "3.3.0"
serde_json = "1.0.96"
serde = "1.0.163"
4 changes: 2 additions & 2 deletions examples/jax_single_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ def __init__(self):
else:
self.batch_forward = self._batch_forward

def _forward(self, x_single: jnp.ndarray) -> jnp.ndarray:
def _forward(self, x_single: jnp.ndarray) -> jnp.ndarray: # type: ignore
chex.assert_rank([x_single], [1])
h_1 = jnp.dot(self._layer1_w.T, x_single) + self._layer1_b
a_1 = jax.nn.relu(h_1)
h_2 = jnp.dot(self._layer2_w.T, a_1) + self._layer2_b
o_2 = jax.nn.softmax(h_2)
return jnp.argmax(o_2, axis=-1)

def _batch_forward(self, x_batch: jnp.ndarray) -> jnp.ndarray:
def _batch_forward(self, x_batch: jnp.ndarray) -> jnp.ndarray: # type: ignore
chex.assert_rank([x_batch], [2])
return jax.vmap(self._forward)(x_batch)

Expand Down
2 changes: 1 addition & 1 deletion examples/type_validation/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def forward(self, data: Request) -> Any:
class Inference(TypedMsgPackMixin, Worker):
"""Dummy batch inference."""

def forward(self, data: List[bytes]) -> Any:
def forward(self, data: List[bytes]) -> List[int]:
return [len(buf) for buf in data]


Expand Down
65 changes: 27 additions & 38 deletions mosec/mixin/typed_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,43 +14,20 @@

"""MOSEC type validation mixin."""

import inspect
import warnings
from typing import Any, List
from typing import Any, Dict, Optional, Tuple

from mosec import get_logger
from mosec.errors import ValidationError
from mosec.utils import ParseTarget, parse_func_type
from mosec.worker import Worker

try:
import msgspec # type: ignore
except ImportError:
warnings.warn("msgpack is required for TypedMsgPackMixin", ImportWarning)


def parse_forward_input_type(func):
"""Parse the input type of the forward function.
- single request: return the type
- batch request: return the list item type
"""
sig = inspect.signature(func)
params = list(sig.parameters.values())
if len(params) < 1:
raise TypeError("`forward` method doesn't have enough(1) parameters")

typ = params[0].annotation
origin = getattr(typ, "__origin__", None)
if origin is None:
return typ
# GenericAlias, `func` could be batch inference
if origin is list or origin is List:
if not hasattr(typ, "__args__") or len(typ.__args__) != 1:
raise TypeError(
"`forward` with dynamic batch should use "
"`List[Struct]` as the input annotation"
)
return typ.__args__[0]
raise TypeError(f"unsupported type {typ}")
logger = get_logger()


class TypedMsgPackMixin(Worker):
Expand All @@ -59,26 +36,38 @@ class TypedMsgPackMixin(Worker):
# pylint: disable=no-self-use

resp_mime_type = "application/msgpack"
_input_type = None
_input_typ: Optional[type] = None

def _get_input_type(self):
"""Get the input type from annotations."""
if self._input_type is None:
self._input_type = parse_forward_input_type(self.forward)
return self._input_type

def deserialize(self, data: Any) -> bytes:
def deserialize(self, data: Any) -> Any:
"""Deserialize and validate request with msgspec."""
schema = self._get_input_type()
if not issubclass(schema, msgspec.Struct):
if not self._input_typ:
self._input_typ = parse_func_type(self.forward, ParseTarget.INPUT)
if not issubclass(self._input_typ, msgspec.Struct):
# skip other annotation type
return super().deserialize(data)

try:
return msgspec.msgpack.decode(data, type=schema)
return msgspec.msgpack.decode(data, type=self._input_typ)
except msgspec.ValidationError as err:
raise ValidationError(err) # pylint: disable=raise-missing-from

def serialize(self, data: Any) -> bytes:
"""Serialize with `msgpack`."""
return msgspec.msgpack.encode(data)

@classmethod
def get_forward_json_schema(
cls, target: ParseTarget, ref_template: str
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Get the JSON schema of the forward function."""
schema: Dict[str, Any]
comp_schema: Dict[str, Any]
schema, comp_schema = {}, {}
typ = parse_func_type(cls.forward, target)
try:
(schema,), comp_schema = msgspec.json.schema_components([typ], ref_template)
except TypeError as err:
logger.warning(
"Failed to generate JSON schema for %s: %s", cls.__name__, err
)
return schema, comp_schema
48 changes: 47 additions & 1 deletion mosec/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@
:py:meth:`append_worker(num) <Server.append_worker>`.
"""

import json
import multiprocessing as mp
import os
import pathlib
import shutil
import signal
import subprocess
Expand All @@ -47,12 +50,14 @@
from mosec.ipc import IPCWrapper
from mosec.log import get_internal_logger
from mosec.manager import PyRuntimeManager, RsRuntimeManager, Runtime
from mosec.worker import Worker
from mosec.utils import ParseTarget
from mosec.worker import MOSEC_REF_TEMPLATE, Worker

logger = get_internal_logger()


GUARD_CHECK_INTERVAL = 1
MOSEC_OPENAPI_PATH = "mosec_openapi.json"


class Server:
Expand Down Expand Up @@ -210,6 +215,46 @@ def append_worker(
runtime.validate()
self.coordinator_manager.append(runtime)

def _generate_openapi(self):
"""Generate the OpenAPI specification."""
if self.coordinator_manager.worker_count <= 0:
return
workers = self.coordinator_manager.workers
request_worker_cls, response_worker_cls = workers[0], workers[-1]
input_schema, input_components = request_worker_cls.get_forward_json_schema(
ParseTarget.INPUT, MOSEC_REF_TEMPLATE
)
return_schema, return_components = response_worker_cls.get_forward_json_schema(
ParseTarget.RETURN, MOSEC_REF_TEMPLATE
)

def make_body(description, mime, schema):
if not schema:
return None
return {"description": description, "content": {mime: {"schema": schema}}}

schema = {
"request_body": make_body(
"Mosec Inference Request Body",
request_worker_cls.resp_mime_type,
input_schema,
),
"responses": None
if not return_schema
else {
200: make_body(
"Mosec Inference Result",
response_worker_cls.resp_mime_type,
return_schema,
)
},
"schemas": {**input_components, **return_components},
}
tmp_path = pathlib.Path(os.path.join(self._configs["path"], MOSEC_OPENAPI_PATH))
tmp_path.parent.mkdir(parents=True, exist_ok=True)
with open(tmp_path, "w", encoding="utf-8") as file:
json.dump(schema, file)

def run(self):
"""Start the mosec model server."""
self._validate_server()
Expand All @@ -218,6 +263,7 @@ def run(self):
return

self._handle_signal()
self._generate_openapi()
self._start_controller()
try:
self._manage_coordinators()
Expand Down
22 changes: 22 additions & 0 deletions mosec/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright 2023 MOSEC Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Provide common useful utils to develop MOSEC."""

from mosec.utils.types import ParseTarget, parse_func_type

__all__ = [
"parse_func_type",
"ParseTarget",
]
Loading

0 comments on commit 3dcac43

Please sign in to comment.