Skip to content

Commit

Permalink
Enhance HTTP handling (#59)
Browse files Browse the repository at this point in the history
* Type hinting + invalid method handling

* Fix wrong import

* Add to http server error handling

* Fix likely bug

* Requirement tweaks

* Add strict http flag

* server logging tweaks

* Fix tests with new expected behavior

* Increment version number
  • Loading branch information
JoshCap20 authored Sep 26, 2024
1 parent 988d6dd commit e090933
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 95 deletions.
1 change: 1 addition & 0 deletions areion/core/response.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import orjson


HTTP_STATUS_CODES: dict[int, str] = {
100: "Continue",
101: "Switching Protocols",
Expand Down
47 changes: 30 additions & 17 deletions areion/core/server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio

from .exceptions import HttpError, NotFoundError
from .response import HttpResponse
from .response import HttpResponse, HTTP_STATUS_CODES
from .request import HttpRequest


Expand Down Expand Up @@ -57,47 +57,59 @@ async def _handle_client(self, reader, writer):
async def _process_request(self, reader, writer):
try:
await self._handle_request_logic(reader, writer)
except asyncio.CancelledError:
self.log("debug", "Client connection cancelled.")
except ConnectionResetError:
self.log("debug", "Connection reset by peer.")
except Exception as e:
if isinstance(e, ConnectionResetError):
self.log("debug", f"Connection reset by peer: {e}")
else:
self.log("error", f"Error processing request: {e}")
response = HttpResponse(status_code=500, body="Internal Server Error")
await self._send_response(writer, response)
self.log("warning", f"Error processing request: {e}")

async def _handle_request_logic(self, reader, writer):
# HttpErrors are NOT handled outside of this method
while True:
# Handle request reading
try:
data = await asyncio.wait_for(
reader.readuntil(b'\r\n\r\n'), timeout=self.keep_alive_timeout
reader.readuntil(b"\r\n\r\n"), timeout=self.keep_alive_timeout
)
except asyncio.TimeoutError:
response = HttpResponse(status_code=408, body=HTTP_STATUS_CODES[408])
await self._send_response(writer, response)
break
except asyncio.IncompleteReadError:
response = HttpResponse(status_code=400, body=HTTP_STATUS_CODES[400])
await self._send_response(writer, response)
break
except asyncio.LimitOverrunError:
response = HttpResponse(status_code=413, body="Payload Too Large")
response = HttpResponse(status_code=413, body=HTTP_STATUS_CODES[413])
await self._send_response(writer, response)
break
except Exception as e:
response = HttpResponse(status_code=500, body=HTTP_STATUS_CODES[500])
await self._send_response(writer, response)
self.log("error", f"Error reading request: {e}")
break

if not data:
break

try:
headers_end = data.find(b'\r\n\r\n')
header_data = data[:headers_end].decode('utf-8')
lines = header_data.split('\r\n')
headers_end = data.find(b"\r\n\r\n")
header_data = data[:headers_end].decode("utf-8")
lines = header_data.split("\r\n")
request_line = lines[0]
header_lines = lines[1:]

method, path, _ = request_line.strip().split(" ")
headers = {}
for line in header_lines:
if ': ' in line:
if ": " in line:
header_name, header_value = line.strip().split(": ", 1)
headers[header_name] = header_value

request = self.request_factory.create(method, path, headers)
request: HttpRequest = self.request_factory.create(
method, path, headers
)

handler, path_params, is_async = self.router.get_handler(method, path)

Expand All @@ -111,12 +123,13 @@ async def _handle_request_logic(self, reader, writer):
except HttpError as e:
# Handles web exceptions raised by the handler
response = HttpResponse(status_code=e.status_code, body=str(e))
self.log("warning", f"[RESPONSE][HTTP-ERROR] {e}")
except Exception as e:
# Handles all other exceptions
response = HttpResponse(status_code=500, body="Internal Server Error")
self.log("error", f"Exception in request handling: {e}")
response = HttpResponse(status_code=500, body=HTTP_STATUS_CODES[500])
self.log("error", f"[RESPONSE][ERROR] {e}")

await self._send_response(writer, response)
await self._send_response(writer=writer, response=response)

if (
"Connection" in request.headers
Expand Down
143 changes: 84 additions & 59 deletions areion/default/router.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from asyncio import iscoroutinefunction
from ..core.exceptions import MethodNotAllowedError, NotFoundError


class Router:
Expand Down Expand Up @@ -45,8 +46,15 @@ def __init__(self):
self.global_middlewares = []
self.route_info = []
self.logger = None

def add_route(self, path, handler, methods=["GET"], middlewares=None):
self.strict_http: bool = False

def add_route(
self,
path: str,
handler: callable,
methods: list[str] = ["GET"],
middlewares: list[callable] = None,
) -> None:
"""
Adds a route to the router.
Expand All @@ -65,24 +73,37 @@ def my_handler(request):
router.add_route("/hello", my_handler, methods=["GET"])
"""
segments = self._split_path(path)
current_node = self.root
# Does not hurt perfomance since performed at startup
if self.strict_http and not all(
method in self.allowed_methods for method in methods
):
raise ValueError("Invalid HTTP method specified.")
# TODO: Investigate impact on route path
if self.strict_http and not path.startswith("/"):
raise ValueError("Path must start with a forward slash.")
if self._check_if_route_and_methods_exists(path, methods):
raise ValueError("A route already exists with one of these methods.")
if not callable(handler):
raise TypeError("Handler must be a callable function.")

segments: list = self._split_path(path)
current_node: TrieNode = self.root
for segment in segments:
if segment.startswith(":"): # Dynamic path segment
if segment.startswith(":"):
if current_node.dynamic_child is None:
current_node.dynamic_child = TrieNode()
current_node.dynamic_child.param_name = segment[1:]
current_node = current_node.dynamic_child
else: # Static path segment
else:
if segment not in current_node.children:
current_node.children[segment] = TrieNode()
current_node = current_node.children[segment]

for method in methods:
combined_middlewares = self.global_middlewares + (middlewares or [])
wrapped_handler = handler
combined_middlewares: list = self.global_middlewares + (middlewares or [])
wrapped_handler: callable = handler
for middleware in reversed(combined_middlewares):
wrapped_handler = middleware(wrapped_handler)
wrapped_handler: callable = middleware(wrapped_handler)

current_node.handler[method] = {
"handler": wrapped_handler,
Expand All @@ -91,6 +112,7 @@ def my_handler(request):
"doc": handler.__doc__,
}

# For generating openapi documentation
self.route_info.append(
{
"path": path,
Expand All @@ -101,7 +123,7 @@ def my_handler(request):
}
)

def group(self, base_path, middlewares=None) -> "Router":
def group(self, base_path: str, middlewares: list[callable] = None) -> "Router":
"""
Creates a sub-router (group) with a base path and optional group-specific middlewares.
Expand All @@ -113,7 +135,7 @@ def group(self, base_path, middlewares=None) -> "Router":
Router: A sub-router instance with the specified base path.
"""
sub_router = Router()
group_middlewares = middlewares or []
group_middlewares: list = middlewares or []

def add_sub_route(sub_path, handler, methods=["GET"], middlewares=None):
full_path = f"{base_path.rstrip('/')}/{sub_path.lstrip('/')}"
Expand All @@ -125,7 +147,9 @@ def add_sub_route(sub_path, handler, methods=["GET"], middlewares=None):
sub_router.add_route = add_sub_route
return sub_router

def route(self, path, methods=["GET"], middlewares=[]):
def route(
self, path: str, methods: list[str] = ["GET"], middlewares: list[callable] = []
):
"""
A decorator to define a route with optional middlewares.
Expand All @@ -144,12 +168,14 @@ def hello(request):
"""

def decorator(func):
self.add_route(path, func, methods=methods, middlewares=middlewares)
self.add_route(
path=path, handler=func, methods=methods, middlewares=middlewares
)
return func

return decorator

def get_handler(self, method, path):
def get_handler(self, method: str, path: str) -> tuple:
"""
Retrieve the handler for a given HTTP method and path.
Expand All @@ -171,73 +197,72 @@ def get_handler(self, method, path):
current_node = self.root
path_params = {}

# TODO: [DESIGN] PASS MORE INFO TO THESE EXCEPTIONS TO GLOBALLY LOG
for segment in segments:
if segment in current_node.children:
current_node = current_node.children[segment]
elif current_node.dynamic_child: # Match dynamic segments
elif current_node.dynamic_child:
param_node = current_node.dynamic_child
path_params[param_node.param_name] = segment # Store dynamic param
path_params[param_node.param_name] = segment
current_node = param_node
else:
return None, None, None
raise NotFoundError()

if not current_node.handler:
raise NotFoundError()

if method in current_node.handler:
handler_info = current_node.handler[method]
is_async = handler_info["is_async"]
return handler_info["handler"], path_params, is_async
else:
raise MethodNotAllowedError()

return None, None, None

def _split_path(self, path):
"""Splits a path into segments and normalizes it."""
return [segment for segment in path.strip("/").split("/") if segment]

### Middleware Handling ###

def add_global_middleware(self, middleware) -> None:
"""Adds a middleware that will be applied globally to all routes."""
self.global_middlewares.append(middleware)

def _apply_middlewares(self, handler_info, method, path) -> callable:
def add_global_middleware(self, middleware: callable) -> None:
"""
Applies global and route-specific middlewares to the given handler.
Adds a middleware that will be applied globally to all routes.
This method takes a handler and wraps it with the middlewares specified
both globally and for the specific route. If the handler is asynchronous,
it ensures that the returned handler is also asynchronous.
Args:
handler_info (dict): A dictionary containing handler information.
- "handler" (callable): The original handler function.
- "is_async" (bool): A flag indicating if the handler is asynchronous.
- "middlewares" (list, optional): A list of middlewares specific to the route.
method (str): The HTTP method for the route.
path (str): The path for the route.
Returns:
callable: The handler wrapped with the applied middlewares.
Parameters:
middleware (callable): A callable that represents the middleware to be added.
"""
handler = handler_info["handler"]
is_async = handler_info["is_async"]

middlewares = self.global_middlewares[:]
route_middlewares = handler_info.get("middlewares", [])
middlewares.extend(route_middlewares)
self.global_middlewares.append(middleware)

for middleware in reversed(middlewares):
handler = middleware(handler)
### Utility Methods ###

if is_async:
def _split_path(self, path: str) -> list:
"""
Splits a path into segments and normalizes it.
"""
return [segment for segment in path.strip("/").split("/") if segment]

async def async_wrapper(*args, **kwargs):
return await handler(*args, **kwargs)
def _check_if_route_and_methods_exists(self, path: str, methods: list[str]) -> bool:
"""
Checks if a route exists in the router.
"""

return async_wrapper
def _check_if_method_exists(path: str, method: str) -> bool:
"""
Checks if a method exists for a given path.
"""
segments = self._split_path(path)
current_node = self.root
for segment in segments:
if segment in current_node.children:
current_node = current_node.children[segment]
elif current_node.dynamic_child:
current_node = current_node.dynamic_child
else:
return False
return method in current_node.handler

return handler
for method in methods:
return _check_if_method_exists(path, method)

def log(self, level: str, message: str) -> None:
# Safe logging method (bug fix for scheduled tasks before server is ran)
"""
Safe logging method.
(Bug fix for scheduled tasks before server is ran)
"""
if self.logger:
log_method = getattr(self.logger, level, None)
if log_method:
Expand Down
20 changes: 7 additions & 13 deletions areion/tests/components/test_router.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import unittest
import asyncio
from unittest.mock import MagicMock, AsyncMock
from ... import DefaultRouter as Router
from ... import DefaultRouter as Router, NotFoundError, MethodNotAllowedError


class TestRouter(unittest.TestCase):
Expand Down Expand Up @@ -30,19 +30,15 @@ def test_get_handler_invalid_route(self):
# Test invalid route
handler = MagicMock()
self.router.add_route("/valid", handler)
route_handler, path_params, is_async = self.router.get_handler("GET", "/invalid")
self.assertIsNone(route_handler)
self.assertIsNone(path_params)
self.assertIsNone(is_async)
with self.assertRaises(NotFoundError):
self.router.get_handler("GET", "/invalid")

def test_get_handler_invalid_method(self):
# Test invalid method for a valid route
handler = MagicMock()
self.router.add_route("/test", handler, methods=["POST"])
route_handler, path_params, is_async = self.router.get_handler("GET", "/test")
self.assertIsNone(route_handler)
self.assertIsNone(path_params)
self.assertIsNone(is_async)
with self.assertRaises(MethodNotAllowedError):
self.router.get_handler("GET", "/test")

def test_add_global_middleware(self):
# Test global middleware application
Expand Down Expand Up @@ -151,10 +147,8 @@ def test_dynamic_route_no_param(self):
# Test dynamic route with missing parameter
handler = MagicMock()
self.router.add_route("/user/:id", handler)
route_handler, path_params, is_async = self.router.get_handler("GET", "/user/")
self.assertIsNone(route_handler)
self.assertIsNone(path_params)
self.assertIsNone(is_async)
with self.assertRaises(NotFoundError):
self.router.get_handler("GET", "/user/")

def test_no_middlewares(self):
# Test route without middleware
Expand Down
Loading

0 comments on commit e090933

Please sign in to comment.