diff --git a/mosec/mixin/typed_worker.py b/mosec/mixin/typed_worker.py index 70ece89a..6e7fe649 100644 --- a/mosec/mixin/typed_worker.py +++ b/mosec/mixin/typed_worker.py @@ -36,18 +36,18 @@ class TypedMsgPackMixin(Worker): # pylint: disable=no-self-use resp_mime_type = "application/msgpack" - input_typ: Optional[type] = None + _input_typ: Optional[type] = None def deserialize(self, data: Any) -> Any: """Deserialize and validate request with msgspec.""" - if not self.input_typ: - self.input_typ = parse_func_type(self.forward, ParseTarget.INPUT) - if not issubclass(self.input_typ, msgspec.Struct): + if not self._input_typ: + self._input_typ = parse_func_type(self.forward, ParseTarget.INPUT) + if not issubclass(self._input_typ, msgspec.Struct): # skip other annotation type return super().deserialize(data) try: - return msgspec.msgpack.decode(data, type=self.input_typ) + return msgspec.msgpack.decode(data, type=self._input_typ) except msgspec.ValidationError as err: raise ValidationError(err) # pylint: disable=raise-missing-from @@ -62,7 +62,7 @@ def get_forward_json_schema( """Get the JSON schema of the forward function.""" schema: Dict[str, Any] comp_schema: Dict[str, Any] - (schema, comp_schema) = ({}, {}) + schema, comp_schema = {}, {} typ = parse_func_type(cls.forward, target) try: (schema,), comp_schema = msgspec.json.schema_components([typ], ref_template) @@ -70,4 +70,4 @@ def get_forward_json_schema( logger.warning( "Failed to generate JSON schema for %s: %s", cls.__name__, err ) - return (schema, comp_schema) + return schema, comp_schema diff --git a/mosec/worker.py b/mosec/worker.py index 19b7e8af..c71d652e 100644 --- a/mosec/worker.py +++ b/mosec/worker.py @@ -204,17 +204,20 @@ def get_forward_json_schema( Returns: A tuple containing the schema and the component schemas. - The `get_forward_json_schema` method is a class method that returns the - JSON schema for the `forward` method of the given class `cls`. It takes - a `target` argument specifying the target to parse the schema for. + The :py:meth:`get_forward_json_schema` method is a class method that returns the + JSON schema for the :py:meth:`forward` method of the :py:class:`cls` class. + It takes a :py:obj:`target` param specifying the target to parse the schema for. The returned value is a tuple containing the schema and the component schema. - Note: + .. note:: + Developer must implement this function to retrieve the JSON schema to enable openapi spec. - The `MOSEC_REF_TEMPLATE` constant should be used as a reference template - according to openapi standards. + .. note:: + + The :py:const:`MOSEC_REF_TEMPLATE` constant should be used as a reference + template according to openapi standards. """ - return ({}, {}) + return {}, {} diff --git a/src/apidoc.rs b/src/apidoc.rs index 16b02369..9f72da25 100644 --- a/src/apidoc.rs +++ b/src/apidoc.rs @@ -21,7 +21,7 @@ use utoipa::openapi::{ }; #[derive(Deserialize, Default)] -pub(crate) struct PythonApiDoc { +pub(crate) struct PythonAPIDoc { #[serde(skip_serializing_if = "Option::is_none", default)] request_body: Option, #[serde(skip_serializing_if = "Option::is_none", default)] @@ -30,10 +30,10 @@ pub(crate) struct PythonApiDoc { schemas: Option>>, } -impl FromStr for PythonApiDoc { +impl FromStr for PythonAPIDoc { type Err = serde_json::Error; fn from_str(s: &str) -> Result { - serde_json::from_str::(s) + serde_json::from_str::(s) } } @@ -75,8 +75,8 @@ impl MosecApiDoc { &mut op.responses } - pub fn merge(&self, route: &str, python_api: PythonApiDoc) -> Self { - // merge PythonApiDoc of target route to mosec api + pub fn merge(&self, route: &str, python_api: PythonAPIDoc) -> Self { + // merge PythonAPIDoc of target route to mosec api let mut api = self.api.clone(); if let Some(mut other_schemas) = python_api.schemas { @@ -105,8 +105,10 @@ impl MosecApiDoc { MosecApiDoc { api } } - pub fn move_path(&self, from: &str, to: &str) -> Self { - // move one path to another + pub fn replace_path_item(&self, from: &str, to: &str) -> Self { + // replace path item from path `from` to path `to` + // e.g. /inference -> /v1/inference + // because utoipa_gen::proc_macro::OpenApi can't handle variable path let mut api = self.api.clone(); if let Some(r) = api.paths.paths.remove(from) { api.paths.paths.insert(to.to_owned(), r); diff --git a/src/main.rs b/src/main.rs index 72250ec2..0b11874e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -108,7 +108,7 @@ async fn metrics(_: Request) -> Response { get, path = "/openapi", responses( - (status = StatusCode::OK, description = "Get Openapi Doc",body=String) + (status = StatusCode::OK, description = "Get OpenAPI doc",body=String) ) )] async fn openapi(_: Request, doc: openapi::OpenApi) -> Response { @@ -269,7 +269,7 @@ async fn run(opts: &Opts) { api: RustApiDoc::openapi(), } .merge("/inference", python_api.parse().unwrap_or_default()) - .move_path("/inference", &opts.endpoint); + .replace_path_item("/inference", &opts.endpoint); let state = AppState { mime: opts.mime.clone(), diff --git a/tests/openapi_service.py b/tests/openapi_service.py index c33fd401..5fc30f2b 100644 --- a/tests/openapi_service.py +++ b/tests/openapi_service.py @@ -76,6 +76,7 @@ def forward(self, data): } server = Server(endpoint="/v1/inference") - for w in sys.argv[1].split("/"): - server.append_worker(worker_mapping[w]) + preprocess_worker, inference_worker = sys.argv[1].split("/") + server.append_worker(worker_mapping[preprocess_worker]) + server.append_worker(worker_mapping[inference_worker], max_batch_size=16) server.run() diff --git a/tests/test_service.py b/tests/test_service.py index 03d1e1e4..de72c10f 100644 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -45,7 +45,7 @@ def mosec_service(request): shlex.split(f"python -u tests/{name}.py {args} --port {TEST_PORT}"), ) assert wait_for_port_open(port=TEST_PORT), "service failed to start" - yield (service, args) + yield service service.terminate() time.sleep(2) # wait for service to stop @@ -181,34 +181,37 @@ def assert_empty_queue(http_client): @pytest.mark.parametrize( - "mosec_service, http_client", + "mosec_service, http_client, args", [ pytest.param( "openapi_service TypedPreprocess/TypedInference", "", + "TypedPreprocess/TypedInference", id="TypedPreprocess/TypedInference", ), pytest.param( "openapi_service UntypedPreprocess/TypedInference", "", + "UntypedPreprocess/TypedInference", id="UntypedPreprocess/TypedInference", ), pytest.param( "openapi_service TypedPreprocess/UntypedInference", "", + "TypedPreprocess/UntypedInference", id="TypedPreprocess/UntypedInference", ), pytest.param( "openapi_service UntypedPreprocess/UntypedInference", "", + "UntypedPreprocess/UntypedInference", id="UntypedPreprocess/UntypedInference", ), ], indirect=["mosec_service", "http_client"], ) -def test_openapi_service(mosec_service, http_client): +def test_openapi_service(mosec_service, http_client, args): spec = http_client.get("/openapi").json() - (_, args) = mosec_service input_cls, return_cls = args.split("/") path_item = spec["paths"]["/v1/inference"]["post"]