Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Misc rust fixes/changes #312

Merged
merged 8 commits into from
Jul 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion extensions/rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions extensions/rust/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
[package]
name = "route_map_rs"
name = "route_map"
version = "0.1.0"
edition = "2021"

[lib]
name = "route_map_rs"
name = "route_map"
crate-type = ["cdylib"]

[dependencies]
Expand Down
2 changes: 1 addition & 1 deletion extensions/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::route_map::RouteMap;
use pyo3::prelude::*;

#[pymodule]
fn route_map_rs(_py: Python, m: &PyModule) -> PyResult<()> {
fn route_map(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<RouteMap>()?;
Ok(())
}
65 changes: 33 additions & 32 deletions extensions/rust/src/route_map.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::util::{build_route_middleware_stack, get_base_components, path_parameters_eq};

use std::collections::{hash_map, HashMap, HashSet};
use std::collections::{HashMap, HashSet};

use pyo3::{
prelude::*,
Expand Down Expand Up @@ -103,28 +103,30 @@ pub struct RouteMap {
impl RouteMap {
/// Creates an empty `RouteMap`
#[new]
#[args(debug = false)]
pub fn new(py: Python, debug: bool) -> PyResult<Self> {
macro_rules! get_attr_and_downcast {
($module:ident, $attr:expr, $downcast_ty:ty) => {{
$module.getattr($attr)?.downcast::<$downcast_ty>()?.into()
}};
fn get_attr_and_downcast<T>(module: &PyAny, attr: &str) -> PyResult<Py<T>>
where
for<'py> T: PyTryFrom<'py>,
for<'py> &'py T: Into<Py<T>>,
{
Ok(module.getattr(attr)?.downcast::<T>()?.into())
}

let parsers = py.import("starlite.parsers")?;
let parse_path_params = get_attr_and_downcast!(parsers, "parse_path_params", PyFunction);
let parse_path_params = get_attr_and_downcast(parsers, "parse_path_params")?;

let routes = py.import("starlite.routes")?;
let http_route = get_attr_and_downcast!(routes, "HTTPRoute", PyType);
let web_socket_route = get_attr_and_downcast!(routes, "WebSocketRoute", PyType);
let asgi_route = get_attr_and_downcast!(routes, "ASGIRoute", PyType);
let http_route = get_attr_and_downcast(routes, "HTTPRoute")?;
let web_socket_route = get_attr_and_downcast(routes, "WebSocketRoute")?;
let asgi_route = get_attr_and_downcast(routes, "ASGIRoute")?;

let middleware = py.import("starlite.middleware")?;
let exception_handler_middleware =
get_attr_and_downcast!(middleware, "ExceptionHandlerMiddleware", PyType);
get_attr_and_downcast(middleware, "ExceptionHandlerMiddleware")?;

let starlette_middleware = py.import("starlette.middleware")?;
let starlette_middleware =
get_attr_and_downcast!(starlette_middleware, "Middleware", PyType);
let starlette_middleware = get_attr_and_downcast(starlette_middleware, "Middleware")?;

Ok(RouteMap {
map: Node::new(),
Expand Down Expand Up @@ -180,7 +182,7 @@ impl RouteMap {
Ok(())
}

// Given a scope, retrieves the correct ASGI App for the route
/// Given a scope, retrieves the correct ASGI App for the route
pub fn resolve_asgi_app(&self, scope: &PyAny) -> PyResult<Py<PyAny>> {
let (asgi_handlers, is_asgi) = self.parse_scope_to_route(scope)?;

Expand Down Expand Up @@ -290,14 +292,12 @@ impl<'rm> ConfigureNodeView<'rm> {

let asgi_handlers = cur_node.asgi_handlers.as_mut().unwrap();

macro_rules! generate_single_route_handler_stack {
($handler_type:expr) => {
let route_handler = route.getattr("route_handler")?;
let middleware_stack =
build_route_middleware_stack(py, &ctx, route, route_handler)?;
asgi_handlers.insert($handler_type.to_string(), middleware_stack.to_object(py));
};
}
let mut generate_single_route_handler_stack = |handler_type: &str| -> PyResult<()> {
let route_handler = route.getattr("route_handler")?;
let middleware_stack = build_route_middleware_stack(py, &ctx, route, route_handler)?;
asgi_handlers.insert(handler_type.to_string(), middleware_stack);
Ok(())
};

if route.is_instance(http_route.as_ref(py))? {
let route_handler_map: HashMap<String, &PyAny> =
Expand All @@ -308,12 +308,12 @@ impl<'rm> ConfigureNodeView<'rm> {
let route_handler = handler_mapping.get_item(0)?;
let middleware_stack =
build_route_middleware_stack(py, &ctx, route, route_handler)?;
asgi_handlers.insert(method, middleware_stack.to_object(py));
asgi_handlers.insert(method, middleware_stack);
}
} else if route.is_instance(web_socket_route.as_ref(py))? {
generate_single_route_handler_stack!("websocket");
generate_single_route_handler_stack("websocket")?;
} else if route.is_instance(asgi_route.as_ref(py))? {
generate_single_route_handler_stack!("asgi");
generate_single_route_handler_stack("asgi")?;
cur_node.is_asgi = true;
}

Expand Down Expand Up @@ -354,17 +354,18 @@ impl RouteMap {
let component_set = &mut cur_node.components;
component_set.insert(component.to_string());

if let hash_map::Entry::Vacant(e) = cur_node.children.entry(component.to_string()) {
e.insert(Node::new());
}
cur_node = cur_node.children.get_mut(component).unwrap();
cur_node = cur_node
.children
.entry(component.to_string())
.or_insert_with(Node::new);
}
} else {
if let hash_map::Entry::Vacant(e) = self.map.children.entry(path.clone()) {
e.insert(Node::new());
}
self.add_plain_route(&path[..]);
cur_node = self.map.children.get_mut(&path[..]).unwrap();
cur_node = self
.map
.children
.entry(path.clone())
.or_insert_with(Node::new);
}

ConfigureNodeView {
Expand Down
2 changes: 1 addition & 1 deletion poetry_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def build(setup_kwargs: Dict[str, Any]) -> None:
Add rust_extensions to the setup dict
"""
setup_kwargs["rust_extensions"] = [
RustExtension("starlite.route_map_rs", path="extensions/rust/Cargo.toml", binding=Binding.PyO3)
RustExtension("starlite.route_map", path="extensions/rust/Cargo.toml", binding=Binding.PyO3)
]
setup_kwargs["zip_safe"] = False

Expand Down
5 changes: 2 additions & 3 deletions starlite/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from starlite.plugins.base import PluginProtocol
from starlite.provide import Provide
from starlite.response import Response
from starlite.route_map_rs import RouteMap as RouteMapInit
from starlite.route_map import RouteMap
from starlite.router import Router
from starlite.routes import ASGIRoute, BaseRoute, HTTPRoute, WebSocketRoute
from starlite.signature import SignatureModelFactory
Expand All @@ -47,7 +47,6 @@
from starlette.types import ASGIApp, Receive, Scope, Send

from starlite.handlers.base import BaseRouteHandler
from starlite.route_map import RouteMap

DEFAULT_OPENAPI_CONFIG = OpenAPIConfig(title="Starlite API", version="1.0.0")
DEFAULT_CACHE_CONFIG = CacheConfig()
Expand Down Expand Up @@ -105,7 +104,7 @@ def __init__(
self.plain_routes: Set[str] = set()
self.plugins = plugins or []
self.routes: List[BaseRoute] = []
self.route_map: "RouteMap" = RouteMapInit(self.debug)
self.route_map = RouteMap(self.debug)
self.state = State()

super().__init__(
Expand Down
Loading