Skip to content

Commit

Permalink
fix: typo
Browse files Browse the repository at this point in the history
Signed-off-by: hang lv <xlv20@fudan.edu.cn>
  • Loading branch information
n063h committed Jun 13, 2023
1 parent 9c421b8 commit 60f93ca
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 29 deletions.
14 changes: 7 additions & 7 deletions mosec/mixin/typed_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,18 @@ class TypedMsgPackMixin(Worker):
# pylint: disable=no-self-use

resp_mime_type = "application/msgpack"
input_typ: Optional[type] = None
_input_typ: Optional[type] = None

def deserialize(self, data: Any) -> Any:
"""Deserialize and validate request with msgspec."""
if not self.input_typ:
self.input_typ = parse_func_type(self.forward, ParseTarget.INPUT)
if not issubclass(self.input_typ, msgspec.Struct):
if not self._input_typ:
self._input_typ = parse_func_type(self.forward, ParseTarget.INPUT)
if not issubclass(self._input_typ, msgspec.Struct):
# skip other annotation type
return super().deserialize(data)

try:
return msgspec.msgpack.decode(data, type=self.input_typ)
return msgspec.msgpack.decode(data, type=self._input_typ)
except msgspec.ValidationError as err:
raise ValidationError(err) # pylint: disable=raise-missing-from

Expand All @@ -62,12 +62,12 @@ def get_forward_json_schema(
"""Get the JSON schema of the forward function."""
schema: Dict[str, Any]
comp_schema: Dict[str, Any]
(schema, comp_schema) = ({}, {})
schema, comp_schema = {}, {}
typ = parse_func_type(cls.forward, target)
try:
(schema,), comp_schema = msgspec.json.schema_components([typ], ref_template)
except TypeError as err:
logger.warning(
"Failed to generate JSON schema for %s: %s", cls.__name__, err
)
return (schema, comp_schema)
return schema, comp_schema
17 changes: 10 additions & 7 deletions mosec/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,17 +204,20 @@ def get_forward_json_schema(
Returns:
A tuple containing the schema and the component schemas.
The `get_forward_json_schema` method is a class method that returns the
JSON schema for the `forward` method of the given class `cls`. It takes
a `target` argument specifying the target to parse the schema for.
The :py:meth:`get_forward_json_schema` method is a class method that returns the
JSON schema for the :py:meth:`forward` method of the :py:class:`cls` class.
It takes a :py:obj:`target` param specifying the target to parse the schema for.
The returned value is a tuple containing the schema and the component schema.
Note:
.. note::
Developer must implement this function to retrieve the JSON schema
to enable openapi spec.
The `MOSEC_REF_TEMPLATE` constant should be used as a reference template
according to openapi standards.
.. note::
The :py:const:`MOSEC_REF_TEMPLATE` constant should be used as a reference
template according to openapi standards.
"""
return ({}, {})
return {}, {}
16 changes: 9 additions & 7 deletions src/apidoc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use utoipa::openapi::{
};

#[derive(Deserialize, Default)]
pub(crate) struct PythonApiDoc {
pub(crate) struct PythonAPIDoc {
#[serde(skip_serializing_if = "Option::is_none", default)]
request_body: Option<RequestBody>,
#[serde(skip_serializing_if = "Option::is_none", default)]
Expand All @@ -30,10 +30,10 @@ pub(crate) struct PythonApiDoc {
schemas: Option<BTreeMap<String, RefOr<Schema>>>,
}

impl FromStr for PythonApiDoc {
impl FromStr for PythonAPIDoc {
type Err = serde_json::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
serde_json::from_str::<PythonApiDoc>(s)
serde_json::from_str::<PythonAPIDoc>(s)
}
}

Expand Down Expand Up @@ -75,8 +75,8 @@ impl MosecApiDoc {
&mut op.responses
}

pub fn merge(&self, route: &str, python_api: PythonApiDoc) -> Self {
// merge PythonApiDoc of target route to mosec api
pub fn merge(&self, route: &str, python_api: PythonAPIDoc) -> Self {
// merge PythonAPIDoc of target route to mosec api
let mut api = self.api.clone();

if let Some(mut other_schemas) = python_api.schemas {
Expand Down Expand Up @@ -105,8 +105,10 @@ impl MosecApiDoc {
MosecApiDoc { api }
}

pub fn move_path(&self, from: &str, to: &str) -> Self {
// move one path to another
pub fn replace_path_item(&self, from: &str, to: &str) -> Self {
// replace path item from path `from` to path `to`
// e.g. /inference -> /v1/inference
// because utoipa_gen::proc_macro::OpenApi can't handle variable path
let mut api = self.api.clone();
if let Some(r) = api.paths.paths.remove(from) {
api.paths.paths.insert(to.to_owned(), r);
Expand Down
4 changes: 2 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ async fn metrics(_: Request<Body>) -> Response<Body> {
get,
path = "/openapi",
responses(
(status = StatusCode::OK, description = "Get Openapi Doc",body=String)
(status = StatusCode::OK, description = "Get OpenAPI doc",body=String)
)
)]
async fn openapi(_: Request<Body>, doc: openapi::OpenApi) -> Response<Body> {
Expand Down Expand Up @@ -269,7 +269,7 @@ async fn run(opts: &Opts) {
api: RustApiDoc::openapi(),
}
.merge("/inference", python_api.parse().unwrap_or_default())
.move_path("/inference", &opts.endpoint);
.replace_path_item("/inference", &opts.endpoint);

let state = AppState {
mime: opts.mime.clone(),
Expand Down
5 changes: 3 additions & 2 deletions tests/openapi_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def forward(self, data):
}

server = Server(endpoint="/v1/inference")
for w in sys.argv[1].split("/"):
server.append_worker(worker_mapping[w])
preprocess_worker, inference_worker = sys.argv[1].split("/")
server.append_worker(worker_mapping[preprocess_worker])
server.append_worker(worker_mapping[inference_worker], max_batch_size=16)
server.run()
11 changes: 7 additions & 4 deletions tests/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def mosec_service(request):
shlex.split(f"python -u tests/{name}.py {args} --port {TEST_PORT}"),
)
assert wait_for_port_open(port=TEST_PORT), "service failed to start"
yield (service, args)
yield service
service.terminate()
time.sleep(2) # wait for service to stop

Expand Down Expand Up @@ -181,34 +181,37 @@ def assert_empty_queue(http_client):


@pytest.mark.parametrize(
"mosec_service, http_client",
"mosec_service, http_client, args",
[
pytest.param(
"openapi_service TypedPreprocess/TypedInference",
"",
"TypedPreprocess/TypedInference",
id="TypedPreprocess/TypedInference",
),
pytest.param(
"openapi_service UntypedPreprocess/TypedInference",
"",
"UntypedPreprocess/TypedInference",
id="UntypedPreprocess/TypedInference",
),
pytest.param(
"openapi_service TypedPreprocess/UntypedInference",
"",
"TypedPreprocess/UntypedInference",
id="TypedPreprocess/UntypedInference",
),
pytest.param(
"openapi_service UntypedPreprocess/UntypedInference",
"",
"UntypedPreprocess/UntypedInference",
id="UntypedPreprocess/UntypedInference",
),
],
indirect=["mosec_service", "http_client"],
)
def test_openapi_service(mosec_service, http_client):
def test_openapi_service(mosec_service, http_client, args):
spec = http_client.get("/openapi").json()
(_, args) = mosec_service
input_cls, return_cls = args.split("/")
path_item = spec["paths"]["/v1/inference"]["post"]

Expand Down

0 comments on commit 60f93ca

Please sign in to comment.