Skip to content
This repository has been archived by the owner on Dec 24, 2022. It is now read-only.

Add allow credentials support #12

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 67 additions & 14 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@
//! for a full usage example.

extern crate iron;
#[macro_use] extern crate log;
#[macro_use]
extern crate log;

use std::collections::HashSet;

Expand All @@ -58,16 +59,45 @@ use iron::method::Method;
use iron::status;
use iron::headers;

/// The struct that builds a CorsMiddleware
pub struct CorsMiddlewareBuilder {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use #[derive(Debug, Clone, PartialEq)].

allowed_hosts: Option<HashSet<String>>,
allow_credentials: bool
}

impl CorsMiddlewareBuilder {
pub fn new() -> Self {
CorsMiddlewareBuilder { allow_credentials: false, allowed_hosts: None }
}

/// Specify which origin hosts are allowed to access the resource.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add another line here?

/// If you don't specify any allowed hosts, then any host will be allowed to access the resource.

pub fn allowed_hosts(&mut self, allowed_hosts: HashSet<String>) -> &mut Self {
self.allowed_hosts = Some(allowed_hosts);
self
}
/// Specify Access-Control-Allow-Credentials
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you document the default value?

/// By default, the `AccessControlAllowCredentials` header will be set to `false`.

pub fn allow_credentials(&mut self, allow_credentials: bool) -> &mut Self {
self.allow_credentials = allow_credentials;
self
}

pub fn build(&self) -> CorsMiddleware {
CorsMiddleware { allowed_hosts: self.allowed_hosts.clone(), allow_credentials: self.allow_credentials }
}
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you change the build method to consume self? Then you can avoid the clone().

}

/// The struct that holds the CORS configuration.
pub struct CorsMiddleware {
allowed_hosts: Option<HashSet<String>>,
allow_credentials: bool,
}

impl CorsMiddleware {
/// Specify which origin hosts are allowed to access the resource.
pub fn with_whitelist(allowed_hosts: HashSet<String>) -> Self {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add the deprecation warning here?

#[deprecated(since="0.7.0", note="please use the `CorsMiddlewareBuilder` instead")]

CorsMiddleware {
allowed_hosts: Some(allowed_hosts),
allow_credentials: false,
}
}

Expand All @@ -77,6 +107,7 @@ impl CorsMiddleware {
pub fn with_allow_any() -> Self {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And also here.

#[deprecated(since="0.7.0", note="please use the `CorsMiddlewareBuilder` instead")]

CorsMiddleware {
allowed_hosts: None,
allow_credentials: false,
}
}
}
Expand All @@ -85,11 +116,13 @@ impl AroundMiddleware for CorsMiddleware {
fn around(self, handler: Box<Handler>) -> Box<Handler> {
match self.allowed_hosts {
Some(allowed_hosts) => Box::new(CorsHandlerWhitelist {
handler: handler,
allowed_hosts: allowed_hosts,
handler,
allowed_hosts,
allow_credentials: self.allow_credentials,
}),
None => Box::new(CorsHandlerAllowAny {
handler: handler,
handler,
allow_credentials: self.allow_credentials,
}),
}
}
Expand All @@ -99,25 +132,30 @@ impl AroundMiddleware for CorsMiddleware {
struct CorsHandlerWhitelist {
handler: Box<Handler>,
allowed_hosts: HashSet<String>,
allow_credentials: bool,
}

/// Handler if allowing any origin.
struct CorsHandlerAllowAny {
handler: Box<Handler>,
allow_credentials: bool,
}

impl CorsHandlerWhitelist {
fn add_cors_header(&self, headers: &mut headers::Headers, origin: &headers::Origin) {
let header = format_cors_origin(origin);
headers.set(headers::AccessControlAllowOrigin::Value(header));

if self.allow_credentials {
headers.set(headers::AccessControlAllowCredentials)
}
}

fn add_cors_preflight_headers(&self,
headers: &mut headers::Headers,
origin: &headers::Origin,
acrm: &headers::AccessControlRequestMethod,
acrh: Option<&headers::AccessControlRequestHeaders>) {

self.add_cors_header(headers, origin);

// Copy the method requested by the browser in the allowed methods header
Expand Down Expand Up @@ -155,7 +193,7 @@ impl CorsHandlerWhitelist {
}

// If we don't have an Access-Control-Request-Method header, treat as a possible OPTION CORS call
return self.process_possible_cors_request(req, origin)
return self.process_possible_cors_request(req, origin);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're at it, can you change this to self.process_possible_cors_request(req, origin) (without the return)?

}

fn process_possible_cors_request(&self, req: &mut Request, origin: headers::Origin) -> IronResult<Response> {
Expand All @@ -165,8 +203,14 @@ impl CorsHandlerWhitelist {
if may_process {
// Everything OK, process request and add CORS header to response
self.handler.handle(req)
.map(|mut res| { self.add_cors_header(&mut res.headers, &origin); res })
.map_err(|mut err| { self.add_cors_header(&mut err.response.headers, &origin); err })
.map(|mut res| {
self.add_cors_header(&mut res.headers, &origin);
res
})
.map_err(|mut err| {
self.add_cors_header(&mut err.response.headers, &origin);
err
})
} else {
// Not adding headers
warn!("Got disallowed CORS request from {}", &origin.host.hostname);
Expand Down Expand Up @@ -203,13 +247,16 @@ impl Handler for CorsHandlerWhitelist {
impl CorsHandlerAllowAny {
fn add_cors_header(&self, headers: &mut headers::Headers) {
headers.set(headers::AccessControlAllowOrigin::Any);

if self.allow_credentials {
headers.set(headers::AccessControlAllowCredentials)
}
}

fn add_cors_preflight_headers(&self,
headers: &mut headers::Headers,
acrm: &headers::AccessControlRequestMethod,
acrh: Option<&headers::AccessControlRequestHeaders>) {

self.add_cors_header(headers);

// Copy the method requested by the browser into the allowed methods header
Expand Down Expand Up @@ -239,13 +286,19 @@ impl CorsHandlerAllowAny {
}

// If we don't have an Access-Control-Request-Method header, treat as a possible OPTION CORS call
return self.process_possible_cors_request(req)
return self.process_possible_cors_request(req);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here too: Remove the return.

}

fn process_possible_cors_request(&self, req: &mut Request) -> IronResult<Response> {
self.handler.handle(req)
.map(|mut res| { self.add_cors_header(&mut res.headers); res })
.map_err(|mut err| { self.add_cors_header(&mut err.response.headers); err })
.map(|mut res| {
self.add_cors_header(&mut res.headers);
res
})
.map_err(|mut err| {
self.add_cors_header(&mut err.response.headers);
err
})
}
}

Expand All @@ -259,15 +312,15 @@ impl Handler for CorsHandlerAllowAny {
match req.headers.get::<headers::Origin>() {
None => {
self.handler.handle(req)
},
}
dbrgn marked this conversation as resolved.
Show resolved Hide resolved
Some(_) => {
match req.method {
//If is an OPTION request, check for preflight
Method::Options => self.process_possible_preflight(req),
// If is not an OPTION request, we assume a normal CORS (no preflight)
_ => self.process_possible_cors_request(req),
}
},
}
}
}
}
Expand Down