Skip to content

Commit

Permalink
Fix routing issues when loading a Router via a dynamic library
Browse files Browse the repository at this point in the history
  • Loading branch information
davidpdrsn committed Mar 3, 2023
1 parent 67befbc commit 5087278
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 15 deletions.
3 changes: 3 additions & 0 deletions axum/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
# Unreleased

- **fixed:** Add `#[must_use]` to `WebSocketUpgrade::on_upgrade` ([#1801])
- **fixed:** Fix routing issues when loading a `Router` via a dynamic library ([#1806])

[#1806]: https://github.com/tokio-rs/axum/pull/1806

[#1801]: https://github.com/tokio-rs/axum/pull/1801

Expand Down
35 changes: 20 additions & 15 deletions axum/src/routing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,24 +47,12 @@ pub use self::method_routing::{
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub(crate) struct RouteId(u32);

impl RouteId {
fn next() -> Self {
use std::sync::atomic::{AtomicU32, Ordering};
// `AtomicU64` isn't supported on all platforms
static ID: AtomicU32 = AtomicU32::new(0);
let id = ID.fetch_add(1, Ordering::Relaxed);
if id == u32::MAX {
panic!("Over `u32::MAX` routes created. If you need this, please file an issue.");
}
Self(id)
}
}

/// The router type for composing handlers and services.
pub struct Router<S = (), B = Body> {
routes: HashMap<RouteId, Endpoint<S, B>>,
node: Arc<Node>,
fallback: Fallback<S, B>,
prev_route_id: Option<RouteId>,
}

impl<S, B> Clone for Router<S, B> {
Expand All @@ -73,6 +61,7 @@ impl<S, B> Clone for Router<S, B> {
routes: self.routes.clone(),
node: Arc::clone(&self.node),
fallback: self.fallback.clone(),
prev_route_id: self.prev_route_id,
}
}
}
Expand Down Expand Up @@ -117,6 +106,7 @@ where
routes: Default::default(),
node: Default::default(),
fallback: Fallback::Default(Route::new(NotFound)),
prev_route_id: Default::default(),
}
}

Expand All @@ -134,7 +124,7 @@ where

validate_path(path);

let id = RouteId::next();
let id = self.next_route_id();

let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self
.node
Expand Down Expand Up @@ -189,7 +179,7 @@ where
panic!("Paths must start with a `/`");
}

let id = RouteId::next();
let id = self.next_route_id();
self.set_node(path, id);
self.routes.insert(id, endpoint);
self
Expand Down Expand Up @@ -286,6 +276,7 @@ where
routes,
node,
fallback,
prev_route_id: _,
} = other.into();

for (id, route) in routes {
Expand Down Expand Up @@ -335,6 +326,7 @@ where
routes,
node: self.node,
fallback,
prev_route_id: self.prev_route_id,
}
}

Expand Down Expand Up @@ -368,6 +360,7 @@ where
routes,
node: self.node,
fallback: self.fallback,
prev_route_id: self.prev_route_id,
}
}

Expand Down Expand Up @@ -419,6 +412,7 @@ where
routes,
node: self.node,
fallback,
prev_route_id: self.prev_route_id,
}
}

Expand Down Expand Up @@ -506,6 +500,17 @@ where
Endpoint::NestedRouter(router) => router.call_with_state(req, state),
}
}

fn next_route_id(&mut self) -> RouteId {
let next_id = self.prev_route_id.map_or(RouteId(0), |RouteId(id)| {
let next = id
.checked_add(1)
.expect("Over `u32::MAX` routes created. If you need this, please file an issue.");
RouteId(next)
});
self.prev_route_id = Some(next_id);
next_id
}
}

impl<B> Router<(), B>
Expand Down

0 comments on commit 5087278

Please sign in to comment.