Skip to content

Commit

Permalink
Implement rwf2#2871 by matching on outcome
Browse files Browse the repository at this point in the history
  • Loading branch information
jespersm committed Oct 2, 2024
1 parent 3bf9ef0 commit e70b79f
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 4 deletions.
7 changes: 4 additions & 3 deletions core/lib/src/request/from_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -521,12 +521,13 @@ impl<'r, T: FromRequest<'r>> FromRequest<'r> for Result<T, T::Error> {

#[crate::async_trait]
impl<'r, T: FromRequest<'r>> FromRequest<'r> for Option<T> {
type Error = Infallible;
type Error = T::Error;

async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Infallible> {
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
match T::from_request(request).await {
Success(val) => Success(Some(val)),
Error(_) | Forward(_) => Success(None),
Forward(_) => Success(None),
Error((status, error)) => Error((status, error)),
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion core/lib/src/response/flash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ impl<'r> FromRequest<'r> for FlashMessage<'r> {
Ok(i) if i <= kv.len() => Ok(Flash::named(&kv[..i], &kv[i..], req)),
_ => Err(())
}
}).or_error(Status::BadRequest)
}).or_forward(Status::BadRequest)
}
}

Expand Down
83 changes: 83 additions & 0 deletions core/lib/tests/refined-from-request.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#[macro_use]
extern crate rocket;

use std::num::ParseIntError;

use rocket::{outcome::IntoOutcome, request::{FromRequest, Outcome}, Request};
use rocket_http::{Header, Status};

pub struct SessionId {
session_id: u64,
}

#[rocket::async_trait]
impl<'r> FromRequest<'r> for SessionId {
type Error = ParseIntError;

async fn from_request(request: &'r Request<'_>) -> Outcome<Self, ParseIntError> {
let session_id_string = request.headers().get("Session-Id").next()
.or_forward(Status::BadRequest);
session_id_string.and_then(|v| v.parse()
.map(|id| SessionId { session_id: id })
.or_error(Status::BadRequest))
}
}

#[get("/mandatory")]
fn get_data_with_mandatory_header(header: SessionId) -> String {
format!("GET for session {:}", header.session_id)
}

#[get("/optional")]
fn get_data_with_opt_header(opt_header: Option<SessionId>) -> String {
if let Some(id) = opt_header {
format!("GET for session {:}", id.session_id)
} else {
format!("GET for new session")
}
}

#[test]
fn read_optional_header() {
let rocket = rocket::build().mount(
"/",
routes![get_data_with_opt_header, get_data_with_mandatory_header]);
let client = rocket::local::blocking::Client::debug(rocket).unwrap();

// If we supply the header, the handler sees it
let response = client.get("/optional")
.header(Header::new("session-id", "1234567")).dispatch();
assert_eq!(response.into_string().unwrap(), "GET for session 1234567".to_string());

// If no header, means that the handler sees a None
let response = client.get("/optional").dispatch();
assert_eq!(response.into_string().unwrap(), "GET for new session".to_string());

// If we supply a malformed header, the handler will not be called, but the request will fail
let response = client.get("/optional")
.header(Header::new("session-id", "Xw23")).dispatch();
assert_eq!(response.status(), Status::BadRequest);
}

#[test]
fn read_mandatory_header() {
let rocket = rocket::build().mount(
"/",
routes![get_data_with_opt_header, get_data_with_mandatory_header]);
let client = rocket::local::blocking::Client::debug(rocket).unwrap();

// If the header is missing, it's a bad request (extra info would be nice, though)
let response = client.get("/mandatory").dispatch();
assert_eq!(response.status(), Status::BadRequest);

// If the header is malformed, it's a bad request too (extra info would be nice, though)
let response = client.get("/mandatory")
.header(Header::new("session-id", "Xw23")).dispatch();
assert_eq!(response.status(), Status::BadRequest);

// If the header is fine, just do the stuff
let response = client.get("/mandatory")
.header(Header::new("session-id", "64535")).dispatch();
assert_eq!(response.status(), Status::Ok);
assert_eq!(response.into_string().unwrap(), "GET for session 64535".to_string());
}

0 comments on commit e70b79f

Please sign in to comment.