Skip to content

Commit

Permalink
Add a CORS middleware (#523)
Browse files Browse the repository at this point in the history
  • Loading branch information
fabricedesre authored Nov 18, 2022
1 parent 22a1426 commit 2f0f4f4
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 1 deletion.
2 changes: 1 addition & 1 deletion iroh-gateway/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ tokio = { version = "1", features = ["macros", "rt-multi-thread", "process", "fs
tokio-util = { version = "0.7", features = ["io"] }
toml = "0.5.9"
tower = { version = "0.4", features = ["util", "timeout", "load-shed", "limit"] }
tower-http = { version = "0.3", features = ["trace", "compression-full"] }
tower-http = { version = "0.3", features = ["trace", "compression-full", "cors"] }
tower-layer = { version = "0.3" }
tracing = "0.1.33"
tracing-opentelemetry = "0.18"
Expand Down
59 changes: 59 additions & 0 deletions iroh-gateway/src/cors.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
use http::header::{HeaderMap, HeaderName, HeaderValue};
use std::str::FromStr;
use tower_http::cors::CorsLayer;

/// Convert a header value formatted as a csv to a list of a given type.
fn from_header_value<T: FromStr>(source: &HeaderValue) -> Option<Vec<T>> {
if let Ok(names) = source.to_str() {
Some(
names
.split(',')
.filter_map(|s| T::from_str(s.trim()).ok())
.collect(),
)
} else {
None
}
}

/// Creates a CORS middleware from the config headers.
/// Used headers are:
/// - access-control-allow-headers
/// - access-control-expose-headers (set to allow-headers when not present)
/// - access-control-allow-methods
/// - access-control-allow-origin
pub(crate) fn cors_from_headers(headers: &HeaderMap) -> CorsLayer {
let mut layer = CorsLayer::new();

// access-control-allow-methods
if let Some(methods) = headers.get("access-control-allow-methods") {
if let Some(list) = from_header_value(methods) {
layer = layer.allow_methods(list);
}
}

// access-control-allow-origin
if let Some(origin) = headers.get("access-control-allow-origin") {
layer = layer.allow_origin(origin.clone());
}

// access-control-allow-headers
let mut allowed_header_names: Vec<HeaderName> = vec![];
if let Some(allowed_headers) = headers.get("access-control-allow-headers") {
if let Some(list) = from_header_value(allowed_headers) {
allowed_header_names = list.clone();
layer = layer.allow_headers(list);
}
}

// access-control-expose-headers
if let Some(exposed_headers) = headers.get("access-control-expose-headers") {
if let Some(list) = from_header_value(exposed_headers) {
layer = layer.expose_headers(list);
}
} else if !allowed_header_names.is_empty() {
layer = layer.expose_headers(allowed_header_names);
}

layer
}
3 changes: 3 additions & 0 deletions iroh-gateway/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,16 @@ pub trait StateConfig: std::fmt::Debug + Sync + Send {
}

pub fn get_app_routes<T: ContentLoader + std::marker::Unpin>(state: &Arc<State<T>>) -> Router {
let cors = crate::cors::cors_from_headers(state.config.user_headers());

// todo(arqu): ?uri=... https://github.com/ipfs/go-ipfs/pull/7802
Router::new()
.route("/:scheme/:cid", get(get_handler::<T>))
.route("/:scheme/:cid/*cpath", get(get_handler::<T>))
.route("/health", get(health_check))
.route("/icons.css", get(stylesheet_icons))
.route("/style.css", get(stylesheet_main))
.layer(cors)
.layer(Extension(Arc::clone(state)))
.layer(
ServiceBuilder::new()
Expand Down
1 change: 1 addition & 0 deletions iroh-gateway/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub mod client;
pub mod config;
pub mod constants;
pub mod core;
mod cors;
mod error;
pub mod handlers;
pub mod headers;
Expand Down

0 comments on commit 2f0f4f4

Please sign in to comment.