From e98cf3cc12c81d1d4d7f8450941e8c7adbf8bce2 Mon Sep 17 00:00:00 2001 From: Josh Caponigro <97563979+JoshCap20@users.noreply.github.com> Date: Sat, 12 Oct 2024 15:12:10 -0500 Subject: [PATCH] Update request body handling and OpenAPI spec parsing (#115) * Improve request body handling * Improve openAPI parsing * Increment version number * Update body tests --- .gitignore | 1 + README.md | 12 +- areion/__init__.py | 2 +- areion/core/request.py | 34 +++-- areion/dev/swagger.py | 180 ++++++++++++++++++++----- areion/main.py | 2 +- areion/tests/components/test_router.py | 6 +- areion/tests/core/test_request.py | 28 ++-- pyproject.toml | 2 +- 9 files changed, 203 insertions(+), 64 deletions(-) diff --git a/.gitignore b/.gitignore index 89a732c..ba4f9f0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ *.pyc +.venv Areion.egg-info/ build/ diff --git a/README.md b/README.md index ad47504..4457d31 100644 --- a/README.md +++ b/README.md @@ -202,14 +202,14 @@ Below is a simple example to get you started with Areion. ```python -from areion import AreionServerBuilder, DefaultRouter +from areion import AreionServerBuilder, DefaultRouter, HttpRequest # Initialize the router router = DefaultRouter() # Define a simple route @router.route("/hello") -def hello_world(request): +def hello_world(request: HttpRequest): return "Hello, World!" # Build and run the server @@ -358,6 +358,9 @@ def get_all_users(request): @users_router.route("/:user_id", methods=["GET"]) def get_user(request, user_id): + body = request.get_parsed_body() + if not body.get("token"): + return HttpResponse(status_code=401, body="Unauthorized", content_type="text/plain") return HttpResponse(status_code=200, body={"user_id": user_id}, content_type="application/json") ``` @@ -645,8 +648,11 @@ Represents an HTTP request. These are injected into each route handler as the fi - `add_header(key, value)`: Adds a header. - `get_header(key)`: Retrieves a header value. -- `get_body()`: Retrieves the request body. +- `get_parsed_body()`: Retrieves the request body as a dictionary. +- `get_raw_body()`: Retrieves the raw request body. - `get_query_param(key)`: Retrieves a query parameter. +- `get_raw_query_params()`: Retrieves the raw query parameters. +- `get_parsed_query_params()`: Retrieves the query parameters as a dictionary. - `add_metadata(key, value)`: Adds metadata. - `get_metadata(key)`: Retrieves metadata. - `render_template(template_name, context)`: Renders a template. diff --git a/areion/__init__.py b/areion/__init__.py index 347fc84..eba2fa9 100644 --- a/areion/__init__.py +++ b/areion/__init__.py @@ -35,7 +35,7 @@ create_xml_response, ) -__version__ = "v1.1.10" +__version__ = "v1.1.11" __all__ = [ # Main classes diff --git a/areion/core/request.py b/areion/core/request.py index 48bf851..5eff2d6 100644 --- a/areion/core/request.py +++ b/areion/core/request.py @@ -3,7 +3,8 @@ """ from .response import HttpResponse -from urllib.parse import urlparse, parse_qs +from urllib.parse import urlparse, parse_qsl +import orjson class HttpRequest: @@ -25,8 +26,10 @@ class HttpRequest: Adds a header to the request. get_header(key: str) -> str | None: Retrieve the value of a specified header. - get_body() -> str | None: - Retrieve the body of the request, if available. + get_raw_body() -> str | None: + Retrieve the raw request body as a str if available. + get_parsed_body() -> dict: + Retrieve the parsed request body as a dictionary. add_metadata(key: str, value: any) -> None: Adds a metadata entry to the request. get_metadata(key: str) -> any: @@ -86,14 +89,27 @@ def get_header(self, key) -> str | None: str or None: The value of the specified header if it exists, otherwise None. """ return self.headers.get(key) + + def get_parsed_body(self) -> dict | str | None: + """ + Parse the body of the request and return it as a dictionary. + Returns: + dict or str or None: The parsed body of the request if it exists, otherwise None. + """ + if not self.body: + return None + try: + return orjson.loads(self.body) + except orjson.JSONDecodeError: + return self.body.decode("utf-8") - def get_body(self) -> str | None: + def get_raw_body(self) -> str | None: """ Retrieve the body of the request. Returns: str or None: The body of the request if it exists, otherwise None. """ - return self.body + return self.body if self.body else None def add_metadata(self, key: str, value: any) -> None: """ @@ -138,7 +154,7 @@ def get_parsed_query_params(self) -> dict: Returns: dict: A dictionary containing the parsed query parameters. """ - return parse_qs(self.query_params) + return dict(parse_qsl(self.query_params)) def render_template(self, template_name: str, context: dict = None) -> str: """ @@ -199,7 +215,7 @@ def as_dict(self, show_components: bool = False): "query_params": self.get_parsed_query_params(), "headers": self.headers, "metadata": self.metadata, - "body": self.body, + "body": self.get_parsed_body(), "logger": self.logger, "engine": self.engine, "orchestrator": self.orchestrator, @@ -210,7 +226,7 @@ def as_dict(self, show_components: bool = False): "query_params": self.get_parsed_query_params(), "headers": self.headers, "metadata": self.metadata, - "body": self.body, + "body": self.get_parsed_body(), } def __repr__(self) -> str: @@ -230,7 +246,7 @@ def __init__(self, logger=None, engine=None, orchestrator=None): self.engine = engine self.orchestrator = orchestrator - def create(self, method, path, headers, body=None): + def create(self, method, path, headers, body: bytes = b"") -> HttpRequest: """ Creates an HttpRequest with injected logger, engine, and orchestrator. """ diff --git a/areion/dev/swagger.py b/areion/dev/swagger.py index 1b77d3b..e12887f 100644 --- a/areion/dev/swagger.py +++ b/areion/dev/swagger.py @@ -1,7 +1,8 @@ import json import os import inspect -from ..core import HttpResponse +import re +from ..core import HttpResponse, HttpRequest ENV = os.getenv("ENV", "development") @@ -58,7 +59,7 @@ def swagger_ui(request): def generate_openapi_spec(self): openapi_spec = { "openapi": "3.0.0", - "info": {"title": "Areion API", "version": "1.0.0"}, + "info": {"title": "Areion Swagger UI", "version": "v1.1.11"}, "paths": {}, } @@ -66,66 +67,168 @@ def generate_openapi_spec(self): path = route["path"] method = route["method"].lower() handler = route["handler"] - doc = route["doc"] or "" + doc = ( + inspect.getdoc(handler) + or "No documentation available for this endpoint." + ) - # Parse docstring for summary and description - doc_lines = doc.strip().split("\n") - summary = doc_lines[0] if doc_lines else "" - description = "\n".join(doc_lines[1:]).strip() if len(doc_lines) > 1 else "" + # Parse docstring for summary, description, parameters, and responses + summary, description, doc_params, response_description = ( + self._parse_docstring(doc) + ) parameters = [] + request_body_content = None # Get dynamic segments from path - path_segments = self._split_path(path) + path_segments = self.split_path(path) path_params = [ segment[1:] for segment in path_segments if segment.startswith(":") ] - # Get handler signature - sig = inspect.signature(handler) - params = sig.parameters - - for param_name, param in params.items(): - if param_name == "request": - continue - - # Determine if parameter is in path or query + for param_name, param_info in doc_params.items(): + # Determine if parameter is in path, query, or request body if param_name in path_params: param_in = "path" required = True + elif method in ["post", "put", "patch"]: + # For request body parameters + param_in = "body" + required = param_info.get("required", True) else: param_in = "query" - required = param.default == inspect.Parameter.empty - - # Get parameter type from annotation - annotation = param.annotation - if annotation != inspect.Parameter.empty: - openapi_type = SwaggerHandler.map_python_type_to_openapi(annotation) + required = param_info.get("required", False) + + openapi_type = param_info.get("type", "string") + description = param_info.get("description", "") + + if param_in == "body": + if request_body_content is None: + request_body_content = { + "application/json": { + "schema": { + "type": "object", + "properties": {}, + "required": [], + } + } + } + request_body_content["application/json"]["schema"]["properties"][ + param_name + ] = {"type": openapi_type, "description": description} + if required: + request_body_content["application/json"]["schema"][ + "required" + ].append(param_name) else: - openapi_type = "string" - - parameter_spec = { - "name": param_name, - "in": param_in, - "required": required, - "schema": {"type": openapi_type}, - } - parameters.append(parameter_spec) + parameter_spec = { + "name": param_name, + "in": param_in, + "required": required, + "schema": {"type": openapi_type}, + "description": description, + } + parameters.append(parameter_spec) # Build the path item if path not in openapi_spec["paths"]: openapi_spec["paths"][path] = {} - openapi_spec["paths"][path][method] = { + operation = { "summary": summary, "description": description, "parameters": parameters, - "responses": {"200": {"description": "Successful Response"}}, + "responses": { + "200": { + "description": response_description or "Successful Response" + } + }, } + if request_body_content: + operation["requestBody"] = { + "content": request_body_content, + "required": True, + } + + openapi_spec["paths"][path][method] = operation + return openapi_spec - def _split_path(self, path): + def _parse_docstring(self, doc): + """ + Parse the docstring to extract summary, description, parameters, and response details. + + Args: + doc (str): The docstring to parse. + + Returns: + tuple: summary, description, parameters, response_description + """ + lines = doc.strip().split("\n") + summary = lines[0].strip() if lines else "" + description_lines = [] + parameters = {} + response_description = "" + + current_section = None + param_pattern = re.compile(r"^(\s*)([\w_]+)\s*\(([\w\[\]]+)\):\s*(.+)") + for line in lines[1:]: + stripped_line = line.strip() + if not stripped_line: + continue # skip empty lines + + if stripped_line.lower() == "parameters:": + current_section = "parameters" + continue + elif stripped_line.lower() == "returns:": + current_section = "returns" + continue + elif current_section == "parameters": + # Try to match a parameter definition + match = param_pattern.match(line) + if match: + indent, param_name, param_type, param_desc = match.groups() + required = True + default_match = re.search( + r"\(default\s*is\s*([^)]+)\)", param_desc, re.IGNORECASE + ) + if default_match: + required = False + param_desc = param_desc.replace( + default_match.group(0), "" + ).strip() + parameters[param_name] = { + "type": SwaggerHandler.map_python_type_to_openapi(param_type), + "description": param_desc, + "required": required, + } + else: + # Maybe a continuation of the previous parameter's description + if parameters and line.startswith(" " * 4): + last_param = list(parameters.keys())[-1] + parameters[last_param]["description"] += " " + stripped_line + else: + # Unindented line, exit parameters section + current_section = None + elif current_section == "returns": + # Collect return description + if not response_description: + response_description = stripped_line + else: + response_description += " " + stripped_line + else: + # Accumulate description + description_lines.append(line) + + description = ( + "\n".join(description_lines).strip() + if description_lines + else "No detailed description provided." + ) + return summary, description, parameters, response_description + + def split_path(self, path): """Splits a path into segments and normalizes it.""" return [segment for segment in path.strip("/").split("/") if segment] @@ -133,11 +236,18 @@ def _split_path(self, path): def map_python_type_to_openapi(python_type): type_mapping = { int: "integer", + "int": "integer", str: "string", + "str": "string", bool: "boolean", + "bool": "boolean", float: "number", + "float": "number", dict: "object", + "dict": "object", list: "array", + "list": "array", type(None): "null", + "None": "null", } return type_mapping.get(python_type, "any") diff --git a/areion/main.py b/areion/main.py index fca955f..9dc8238 100644 --- a/areion/main.py +++ b/areion/main.py @@ -18,7 +18,7 @@ / / /-----| | | | / / / | | | | __/___/__/_______|_|__\\_\\___ - // v1.1.10 \\ + // v1.1.11 \\ // A R E I O N \\ // joshcap20/areion \\ //________________________________\\ diff --git a/areion/tests/components/test_router.py b/areion/tests/components/test_router.py index 1d2bf83..d51c6b0 100644 --- a/areion/tests/components/test_router.py +++ b/areion/tests/components/test_router.py @@ -209,7 +209,7 @@ def test__remove_query_params(self): self.assertEqual(request.headers, {}) self.assertEqual( request.get_parsed_query_params(), - {"param1": ["value1"], "param2": ["value2"]}, + {"param1": "value1", "param2": "value2"}, ) def test__remove_query_params_parsed_no_query(self): @@ -261,7 +261,7 @@ def test__remove_query_params_no_path_or_query(self): self.assertEqual(request.headers, {}) self.assertEqual( request.get_parsed_query_params(), - {"param1": ["value1"], "param2": ["value2"]}, + {"param1": "value1", "param2": "value2"}, ) def test__remove_query_params_no_query_params(self): @@ -365,7 +365,7 @@ def test__split_path_and_query_params(self): self.assertEqual(request.path, "/test") self.assertEqual( request.get_parsed_query_params(), - {"param1": ["value1"], "param2": ["value2"]}, + {"param1": "value1", "param2": "value2"}, ) self.assertEqual(request.get_raw_query_params(), "param1=value1¶m2=value2") diff --git a/areion/tests/core/test_request.py b/areion/tests/core/test_request.py index 874384e..8752d95 100644 --- a/areion/tests/core/test_request.py +++ b/areion/tests/core/test_request.py @@ -1,6 +1,7 @@ import unittest from unittest.mock import Mock from ... import HttpRequest, HttpRequestFactory, HttpResponse +import orjson class TestHttpRequest(unittest.TestCase): @@ -9,7 +10,7 @@ def setUp(self): self.method = "GET" self.path = "/test" self.headers = {"Content-Type": "application/json"} - self.body = "Request Body" + self.body = b"Request Body" self.query_params = {} self.metadata = {} self.request = HttpRequest(self.method, self.path, self.headers, self.body) @@ -42,7 +43,7 @@ def test_get_query_params(self): request = HttpRequest("GET", "/test?param1=value1¶m2=value2", {}) self.assertEqual( request.get_parsed_query_params(), - {"param1": ["value1"], "param2": ["value2"]}, + {"param1": "value1", "param2": "value2"}, ) self.assertEqual(request.get_raw_query_params(), "param1=value1¶m2=value2") self.assertEqual(request.query_params, "param1=value1¶m2=value2") @@ -58,26 +59,30 @@ def test_get_query_params_multiple_values(self): self.assertEqual(self.request.query_params, "") request = HttpRequest("GET", "/test?param1=value1¶m1=value2", {}) self.assertEqual( - request.get_parsed_query_params(), {"param1": ["value1", "value2"]} + request.get_parsed_query_params(), {"param1": "value2"} ) self.assertEqual(request.get_raw_query_params(), "param1=value1¶m1=value2") self.assertEqual(request.query_params, "param1=value1¶m1=value2") def test_get_body(self): - self.request.body = "New Body" - self.assertEqual(self.request.get_body(), "New Body") + self.request.body = orjson.dumps({"message": "New Body"}) + self.assertEqual(self.request.get_parsed_body(), {"message": "New Body"}) def test_get_body_none(self): self.request.body = None - self.assertIsNone(self.request.get_body()) + self.assertEqual(self.request.get_parsed_body(), None) def test_get_body_empty(self): self.request.body = "" - self.assertEqual(self.request.get_body(), "") + self.assertEqual(self.request.get_parsed_body(), None) def test_get_body_bytes(self): self.request.body = b"Binary Body" - self.assertEqual(self.request.get_body(), b"Binary Body") + self.assertEqual(self.request.get_raw_body(), b"Binary Body") + + def test_get_body_json(self): + self.request.body = orjson.dumps({"message": "New Body"}) + self.assertEqual(self.request.get_parsed_body(), {"message": "New Body"}) def test_repr(self): expected_repr = f"" @@ -88,13 +93,14 @@ def test_str(self): self.assertEqual(str(self.request), expected_str) def test_as_dict_default(self): + self.request.body = b"{\"message\": \"New Body\"}" expected_dict = { "method": self.method, "path": self.path, "query_params": {}, "headers": self.headers, "metadata": {}, - "body": self.body, + "body": {"message": "New Body"}, } self.assertEqual(self.request.as_dict(), expected_dict) @@ -108,7 +114,7 @@ def test_as_dict_with_components(self): "query_params": {}, "headers": self.headers, "metadata": self.request.metadata, - "body": self.body, + "body": self.body.decode("utf-8"), "logger": self.request.logger, "engine": self.request.engine, "orchestrator": self.request.orchestrator, @@ -122,7 +128,7 @@ def test_as_dict_without_components(self): "query_params": {}, "headers": self.headers, "metadata": {}, - "body": self.body, + "body": self.body.decode("utf-8"), } self.assertEqual(self.request.as_dict(show_components=False), expected_dict) diff --git a/pyproject.toml b/pyproject.toml index 5657880..f7a361d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "pdm.backend" [project] name = "areion" -version = "1.1.10" +version = "1.1.11" authors = [ { name="Josh Caponigro", email="joshcaponigro@gmail.com" }, ]