From 47986d8e0f159e307be314963a0f2e44f560a43c Mon Sep 17 00:00:00 2001 From: Pierre Krieger Date: Sun, 11 Dec 2016 11:23:50 +0100 Subject: [PATCH] Use strong typing for the headers --- examples/database.rs | 9 +- examples/websocket.rs | 4 +- src/assets.rs | 27 ++-- src/cgi.rs | 10 +- src/lib.rs | 25 ++-- src/log.rs | 12 +- src/proxy.rs | 16 +-- src/response.rs | 310 ++++++++++++++++++++++++++++++++++++++---- src/session.rs | 15 +- src/websocket/mod.rs | 32 +++-- 10 files changed, 366 insertions(+), 94 deletions(-) diff --git a/examples/database.rs b/examples/database.rs index 29c129aad..596e68a87 100644 --- a/examples/database.rs +++ b/examples/database.rs @@ -177,10 +177,11 @@ fn note_routes(request: &Request, db: &Transaction) -> Response { let id = id.unwrap(); - let mut response = Response::text("The note has been created"); - response.status_code = 201; - response.headers.push(("Location".to_owned(), format!("/note/{}", id))); - response + Response { + status_code: 201, + location: Some(format!("/note/{}", id).into()), + .. Response::text("The note has been created") + } }, (DELETE) (/note/{id: i32}) => { diff --git a/examples/websocket.rs b/examples/websocket.rs index a593b0df8..55a329f78 100644 --- a/examples/websocket.rs +++ b/examples/websocket.rs @@ -44,7 +44,7 @@ fn main() {

Received:

-

") +

").into() }, (GET) (/ws) => { @@ -73,7 +73,7 @@ fn main() { }, // Default 404 route as with all examples. - _ => rouille::Response::empty_404() + _ => rouille::Response::empty_404().into() ) }); } diff --git a/src/assets.rs b/src/assets.rs index aa6f74d86..14f3c6692 100644 --- a/src/assets.rs +++ b/src/assets.rs @@ -15,6 +15,7 @@ use time; use Request; use Response; +use ResponseCacheControl; use ResponseBody; /// Searches inside `path` for a file that matches the given request. If a file is found, @@ -128,24 +129,24 @@ pub fn match_assets(request: &Request, path: &P) -> Response if not_modified { return Response { status_code: 304, - headers: vec![ - ("Cache-Control".to_owned(), "public, max-age=3600".to_owned()), - ("ETag".to_owned(), etag.to_string()) - ], - data: ResponseBody::empty(), - upgrade: None, + cache_control: ResponseCacheControl::Public { + max_age: 3600, + must_revalidate: false, + }, + data: ResponseBody::from_file(file), + .. Response::empty_200() }; } Response { - status_code: 200, - headers: vec![ - ("Cache-Control".to_owned(), "public, max-age=3600".to_owned()), - ("Content-Type".to_owned(), extension_to_mime(extension).to_owned()), - ("ETag".to_owned(), etag.to_string()) - ], + content_type: Some(extension_to_mime(extension).into()), + cache_control: ResponseCacheControl::Public { + max_age: 3600, + must_revalidate: false, + }, + etag: Some(etag.to_owned().into()), data: ResponseBody::from_file(file), - upgrade: None, + .. Response::empty_200() } } diff --git a/src/cgi.rs b/src/cgi.rs index 8af48363c..c613fc10f 100644 --- a/src/cgi.rs +++ b/src/cgi.rs @@ -45,7 +45,7 @@ use std::process::Command; use std::process::Stdio; use Request; -use Response; +use RawResponse; use ResponseBody; /// Error that can happen when parsing the JSON input. @@ -75,11 +75,11 @@ pub trait CgiRun { /// The body of the returned `Response` will hold a handle to the child's stdout output. This /// means that the child can continue running in the background and send data to the client, /// even after you have finished handling the request. - fn start_cgi(self, request: &Request) -> Result; + fn start_cgi(self, request: &Request) -> Result; } impl CgiRun for Command { - fn start_cgi(mut self, request: &Request) -> Result { + fn start_cgi(mut self, request: &Request) -> Result { self.env("SERVER_SOFTWARE", "rouille") .env("SERVER_NAME", "localhost") // FIXME: .env("GATEWAY_INTERFACE", "CGI/1.1") @@ -125,11 +125,11 @@ impl CgiRun for Command { if header == "Status" { status = val[0..3].parse().expect("Status returned by CGI program is invalid"); } else { - headers.push((header.to_owned(), val.to_owned())); + headers.push((header.to_owned().into(), val.to_owned().into())); } } - Response { + RawResponse { status_code: status, headers: headers, data: ResponseBody::from_reader(stdout), diff --git a/src/lib.rs b/src/lib.rs index c4593740e..d4a9e52ab 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -64,7 +64,8 @@ extern crate url; pub use assets::match_assets; pub use log::log; -pub use response::{Response, ResponseBody}; +pub use response::{Response, ResponseBody, RawResponse}; +pub use response::{ResponseCookie, ResponseCacheControl}; pub use tiny_http::ReadWrite; use std::io::Cursor; @@ -119,7 +120,7 @@ macro_rules! try_or_400 { ($result:expr) => ( match $result { Ok(r) => r, - Err(_) => return $crate::Response::empty_400(), + Err(_) => return $crate::Response::empty_400().into(), } ); } @@ -131,7 +132,7 @@ macro_rules! try_or_404 { ($result:expr) => ( match $result { Ok(r) => r, - Err(_) => return $crate::Response::empty_404(), + Err(_) => return $crate::Response::empty_404().into(), } ); } @@ -162,7 +163,7 @@ macro_rules! try_or_404 { macro_rules! assert_or_400 { ($cond:expr) => ( if !$cond { - return $crate::Response::empty_400(); + return $crate::Response::empty_400().into(); } ); } @@ -202,7 +203,7 @@ macro_rules! assert_or_400 { /// *requests_counter.lock().unwrap() += 1; /// /// // rest of the handler -/// # panic!() +/// # rouille::Response::empty_404() /// }) /// ``` /// @@ -210,9 +211,10 @@ macro_rules! assert_or_400 { /// /// If your request handler panicks, a 500 error will automatically be sent to the client. /// -pub fn start_server(addr: A, handler: F) -> ! - where A: ToSocketAddrs, - F: Send + Sync + 'static + Fn(&Request) -> Response +pub fn start_server(addr: A, handler: F) -> ! + where A: ToSocketAddrs, + F: Send + Sync + 'static + Fn(&Request) -> R, + R: Into { let server = tiny_http::Server::http(addr).unwrap(); let handler = Arc::new(AssertUnwindSafe(handler)); // TODO: using AssertUnwindSafe here is wrong, but unwind safety has some usability problems in Rust in general @@ -259,7 +261,7 @@ pub fn start_server(addr: A, handler: F) -> ! let rouille_request = AssertUnwindSafe(rouille_request); let res = panic::catch_unwind(move || { let rouille_request = rouille_request; - handler(&rouille_request) + handler(&rouille_request).into() }); match res { @@ -268,6 +270,7 @@ pub fn start_server(addr: A, handler: F) -> ! Response::html("

Internal Server Error

\

An internal error has occurred on the server.

") .with_status_code(500) + .into() } } }; @@ -278,7 +281,7 @@ pub fn start_server(addr: A, handler: F) -> ! .with_data(res_data, res_len); for (key, value) in rouille_response.headers { - if let Ok(header) = tiny_http::Header::from_bytes(key, value) { + if let Ok(header) = tiny_http::Header::from_bytes(&*key, &*value) { response.add_header(header); } else { // TODO: ? @@ -430,7 +433,7 @@ impl Request { /// /// fn handle(request: &Request) -> Response { /// if !request.is_secure() { - /// return Response::redirect(&format!("https://example.com")); + /// return Response::redirect(format!("https://example.com")); /// } /// /// // ... diff --git a/src/log.rs b/src/log.rs index 4cb7b7c73..010c55ddc 100644 --- a/src/log.rs +++ b/src/log.rs @@ -15,7 +15,7 @@ use std::time::Instant; use chrono; use Request; -use Response; +use RawResponse; /// Adds a log entry to the given writer at each request. /// @@ -26,17 +26,18 @@ use Response; /// /// ``` /// use std::io; -/// use rouille::{Request, Response}; +/// use rouille::{Request, Response, RawResponse}; /// -/// fn handle(request: &Request) -> Response { +/// fn handle(request: &Request) -> RawResponse { /// rouille::log(request, io::stdout(), || { /// Response::text("hello world") /// }) /// } /// ``` -pub fn log(rq: &Request, mut output: W, f: F) -> Response +pub fn log(rq: &Request, mut output: W, f: F) -> RawResponse where W: Write, - F: FnOnce() -> Response + F: FnOnce() -> R, + R: Into { let start_instant = Instant::now(); let rq_line = format!("{} UTC - {} {}", chrono::UTC::now().format("%Y-%m-%d %H:%M:%S%.6f"), @@ -51,6 +52,7 @@ pub fn log(rq: &Request, mut output: W, f: F) -> Response match response { Ok(response) => { + let response: RawResponse = response.into(); let _ = writeln!(output, "{} - {} - {}", rq_line, elapsed_time, response.status_code); response }, diff --git a/src/proxy.rs b/src/proxy.rs index ccc2d6a53..f26bb6c3e 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -21,10 +21,10 @@ //! client. //! //! ``` -//! use rouille::{Request, Response}; +//! use rouille::{Request, Response, RawResponse}; //! use rouille::proxy; //! -//! fn handle_request(request: &Request) -> Response { +//! fn handle_request(request: &Request) -> RawResponse { //! let config = match request.header("Host") { //! Some(ref h) if h == "domain1.com" => { //! proxy::ProxyConfig { @@ -40,12 +40,12 @@ //! } //! }, //! -//! _ => return Response::empty_404() +//! _ => return Response::empty_404().into() //! }; //! //! match proxy::proxy(request, config) { //! Ok(r) => r, -//! Err(_) => Response::text("Bad gateway").with_status_code(500), +//! Err(_) => Response::text("Bad gateway").with_status_code(500).into(), //! } //! } //! ``` @@ -60,7 +60,7 @@ use std::net::TcpStream; use std::net::ToSocketAddrs; use Request; -use Response; +use RawResponse; use ResponseBody; /// Error that can happen when dispatching the request to another server. @@ -98,7 +98,7 @@ pub struct ProxyConfig { /// /// > **Note**: SSL is not supported. // TODO: ^ -pub fn proxy(request: &Request, config: ProxyConfig) -> Result +pub fn proxy(request: &Request, config: ProxyConfig) -> Result where A: ToSocketAddrs { let mut socket = try!(TcpStream::connect(config.addr)); @@ -164,11 +164,11 @@ pub fn proxy(request: &Request, config: ProxyConfig) -> Result, + pub headers: Vec<(Cow<'static, str>, Cow<'static, str>)>, /// An opaque type that contains the body of the response. pub data: ResponseBody, @@ -40,6 +46,189 @@ pub struct Response { pub upgrade: Option>, } +impl RawResponse { + /// Returns true if the status code of this `RawResponse` indicates success. + #[inline] + pub fn is_success(&self) -> bool { + self.status_code >= 200 && self.status_code < 400 + } + + /// Shortcut for `!response.is_success()`. + #[inline] + pub fn is_error(&self) -> bool { + !self.is_success() + } +} + +impl From for RawResponse { + fn from(mut response: Response) -> RawResponse { + // In order to allocate only what we need, we need to calculate the number of headers. + let num_headers = + if response.allow.is_some() || response.status_code == 405 { 1 } else { 0 } + + if response.content_type.is_some() { 1 } else { 0 } + + if response.location.is_some() { 1 } else { 0 } + + if response.content_language.is_some() { 1 } else { 0 } + + if response.etag.is_some() { 1 } else { 0 } + + response.www_authenticate.len() + + response.cookies.len(); + + let headers = { + let mut headers: Vec<(Cow<'static, str>, Cow<'static, str>)> = Vec::with_capacity(num_headers); + + if let Some(ref allow) = response.allow { + headers.push(("Allow".into(), allow.join(", ".into()).into())); + } else if response.status_code == 405 { + headers.push(("Allow".into(), "".into())); + } + + if let Some(content_type) = response.content_type { + headers.push(("Content-Type".into(), content_type)); + } + + if let Some(location) = response.location { + headers.push(("Location".into(), location)); + } + + for cookie in response.cookies.iter() { + // TODO: escape values + let mut value = format!("{}={}", cookie.name, cookie.value); + if cookie.http_only { value.push_str("; HttpOnly"); } + if cookie.secure { value.push_str("; Secure"); } + if let Some(ref path) = cookie.path { + value.push_str("; Path="); + value.push_str(path); + } + if let Some(ref domain) = cookie.domain { + value.push_str("; Domain="); + value.push_str(domain); + } + if let Some(max_age) = cookie.max_age { + value.push_str("; Max-Age="); + value.push_str(&max_age.to_string()); + } + + headers.push(("Set-Cookie".into(), value.into())); + } + + for challenge in response.www_authenticate.iter() { + let mut val: String = challenge.auth_scheme.to_string(); + let mut first = true; + for &(ref key, ref v) in challenge.params.iter() { + if !first { val.push_str(", "); } + else { val.push_str(" "); first = false; } + val.push_str(&key); + val.push_str("=\""); + val.push_str(&v); + val.push_str("\""); + } + headers.push(("WWW-Authenticate".into(), val.into())); + } + + if response.www_authenticate.is_empty() && response.status_code == 401 { + response.status_code = 403; + } + + if let Some(etag) = response.etag { + headers.push(("ETag".into(), etag)); + } + + headers + }; + + // Detects bugs with the number of headers calculated above. + debug_assert_eq!(headers.len(), headers.capacity()); + + RawResponse { + status_code: response.status_code, + headers: headers, + data: response.data, + upgrade: None, + } + } +} + +/// Contains a prototype of a response. Headers are strongly-typed. +/// +/// The response is only sent to the client when you return the `Response` object from your +/// request handler. This means that you are free to create as many `Response` objects as you want. +pub struct Response { + /// The status code to return to the user. + pub status_code: u16, + + /// Specifies the MIME type of the content. + /// + /// This corresponds to the `Content-Type` header. + /// + /// If you don't specify this, the browser may either interpret the data as + /// `application/octet-stream` or attempt to determine the type of data by analyzing the body. + /// When the body is not empty, it is strongly recommended that you always specify a + /// content-type. But in some situations it may not be possible to know what the content-type + /// is. + /// + /// Rouille doesn't check whether the MIME type is valid. + // TODO: ^ decide whether that's a good idea ; specs say it's strict, see https://tools.ietf.org/html/rfc2046 + pub content_type: Option>, + + /// List of cookies to send to the client. + pub cookies: Vec, + + pub etag: Option>, + + pub cache_control: ResponseCacheControl, + + /// If set, indicates that the same request may result in a different outcome if the request + /// supplies credentials or different credentials. + /// + /// This corresponds to the `WWW-Authenticate` header. + /// + /// When the status code is 401, it is mandatory for the server to return a `WWW-Authenticate` + /// header. In order to be compliant, rouille will automatically turn status code 401 into 403 + /// if you didn't supply this header. + pub www_authenticate: Vec, + + /// List of methods (`GET`, `POST`, etc.) supported for the target resource. + /// + /// A value of `None` indicates that no list will be returned to the client. A value of `Some` + /// with an empty `Vec` means that no method is allowed ; in other words, the resource is + /// disabled. + /// + /// With a 405 response code, the server must always return a `Allow` header. In this case, + /// rouille will return an empty list even if you put `None` here. + pub allow: Option>>, + + pub content_disposition: ResponseContentDisposition, + + /// + /// + /// This corresponds to the `Location` header. + pub location: Option>, + + /// Specifies the language of the content. + /// + /// The language must be a *language tag*, as defined by + /// [RFC 5646](https://tools.ietf.org/html/rfc5646). For example `en-US`. Rouille doesn't check + /// whether the language tag is valid. + /// + /// This corresponds to the `Content-Language` header. + pub content_language: Option>, + + /// Specifies whether any intermediate is allowed to transform the response in order to save + /// space or bandwidth. + /// + /// This corresponds to `Cache-Control: no-transform`. + /// + /// If the value is `true`, intermediate caches are not allowed to transform the response. + /// The default value is `false`. + /// + /// You are encouraged to only set this to `true` in very specific situations where the + /// response must match bit-by-bit. Do not set this to `true` just because you are worried + /// a cache may do something wrong. + pub no_transform: bool, + + /// An opaque type that contains the body of the response. + pub data: ResponseBody, +} + impl Response { /// Returns true if the status code of this `Response` indicates success. /// @@ -80,12 +269,13 @@ impl Response { /// let response = Response::redirect("/foo"); /// ``` #[inline] - pub fn redirect(target: &str) -> Response { + pub fn redirect(target: T) -> Response + where T: Into> + { Response { status_code: 303, - headers: vec![("Location".to_owned(), target.to_owned())], - data: ResponseBody::empty(), - upgrade: None, + location: Some(target.into()), + .. Response::empty_200() } } @@ -100,10 +290,9 @@ impl Response { #[inline] pub fn html(content: D) -> Response where D: Into> { Response { - status_code: 200, - headers: vec![("Content-Type".to_owned(), "text/html; charset=utf8".to_owned())], + content_type: Some("text/html; charset=utf8".into()), data: ResponseBody::from_data(content), - upgrade: None, + .. Response::empty_200() } } @@ -118,10 +307,9 @@ impl Response { #[inline] pub fn svg(content: D) -> Response where D: Into> { Response { - status_code: 200, - headers: vec![("Content-Type".to_owned(), "image/svg+xml; charset=utf8".to_owned())], + content_type: Some("image/svg+xml; charset=utf8".into()), data: ResponseBody::from_data(content), - upgrade: None, + .. Response::empty_200() } } @@ -136,10 +324,9 @@ impl Response { #[inline] pub fn text(text: S) -> Response where S: Into { Response { - status_code: 200, - headers: vec![("Content-Type".to_owned(), "text/plain; charset=utf8".to_owned())], + content_type: Some("text/plain; charset=utf8".into()), data: ResponseBody::from_string(text), - upgrade: None, + .. Response::empty_200() } } @@ -168,10 +355,9 @@ impl Response { let data = rustc_serialize::json::encode(content).unwrap(); Response { - status_code: 200, - headers: vec![("Content-Type".to_owned(), "application/json".to_owned())], + content_type: Some("application/json".into()), data: ResponseBody::from_data(data), - upgrade: None, + .. Response::empty_200() } } @@ -189,9 +375,39 @@ impl Response { // TODO: escape the realm Response { status_code: 401, - headers: vec![("WWW-Authenticate".to_owned(), format!("Basic realm=\"{}\"", realm))], + www_authenticate: vec![ + ResponseChallenge { + auth_scheme: "Basic".into(), + params: vec![("realm".into(), realm.to_owned().into())], // TODO: + } + ], + .. Response::empty_200() + } + } + + /// Builds an empty `Response` with a 200 status code. + /// + /// # Example + /// + /// ``` + /// use rouille::Response; + /// let response = Response::empty_200(); + /// ``` + #[inline] + pub fn empty_200() -> Response { + Response { + status_code: 200, data: ResponseBody::empty(), - upgrade: None, + no_transform: false, + content_language: None, + content_type: None, + location: None, + allow: None, + etag: None, + www_authenticate: Vec::new(), + cache_control: ResponseCacheControl::Unspecified, + content_disposition: ResponseContentDisposition::Inline, + cookies: Vec::new(), } } @@ -207,9 +423,7 @@ impl Response { pub fn empty_400() -> Response { Response { status_code: 400, - headers: vec![], - data: ResponseBody::empty(), - upgrade: None, + .. Response::empty_200() } } @@ -225,9 +439,7 @@ impl Response { pub fn empty_404() -> Response { Response { status_code: 404, - headers: vec![], - data: ResponseBody::empty(), - upgrade: None, + .. Response::empty_200() } } @@ -246,6 +458,44 @@ impl Response { } } +pub enum ResponseCacheControl { + Unspecified, + Public { + max_age: u64, + must_revalidate: bool, + }, + Private { + max_age: u64, + must_revalidate: bool, + }, + NoCache { + max_age: u64, + must_revalidate: bool, + }, + NoStore, +} + +pub struct ResponseCookie { + pub name: Cow<'static, str>, + pub value: Cow<'static, str>, + pub http_only: bool, + pub secure: bool, + pub path: Option>, + pub domain: Option>, + pub max_age: Option, + // TODO: pub expires: , +} + +pub struct ResponseChallenge { + pub auth_scheme: Cow<'static, str>, + pub params: Vec<(Cow<'static, str>, Cow<'static, str>)>, +} + +pub enum ResponseContentDisposition { + Inline, + Attachment, // FIXME: +} + /// An opaque type that represents the body of a response. /// /// You can't access the inside of this struct, but you can build one by using one of the provided diff --git a/src/session.rs b/src/session.rs index f439375e3..3095d21c9 100644 --- a/src/session.rs +++ b/src/session.rs @@ -40,6 +40,7 @@ use rand::Rng; use Request; use Response; +use ResponseCookie; use input; pub fn session(request: &Request, cookie_name: &str, timeout_s: u64, inner: F) -> Response @@ -66,11 +67,17 @@ pub fn session(request: &Request, cookie_name: &str, timeout_s: u64, inner: F let mut response = inner(&session); if session.key_was_retreived.load(Ordering::Relaxed) { // TODO: use `get_mut()` - // FIXME: correct interactions with existing headers + // FIXME: interaction with existing cookie // TODO: allow setting domain - let header_value = format!("{}={}; Max-Age={}; Path=/; HttpOnly", - cookie_name, session.key, timeout_s); - response.headers.push(("Set-Cookie".to_owned(), header_value)); + response.cookies.push(ResponseCookie { + name: cookie_name.to_owned().into(), // TODO: not zero-cost + value: session.key.into(), + http_only: true, + path: Some("/".into()), + domain: None, + max_age: Some(timeout_s), + secure: true, + }); } response diff --git a/src/websocket/mod.rs b/src/websocket/mod.rs index e9e91752f..31f59f30e 100644 --- a/src/websocket/mod.rs +++ b/src/websocket/mod.rs @@ -52,12 +52,12 @@ //! use std::sync::mpsc::Receiver; //! //! use rouille::Request; -//! use rouille::Response; +//! use rouille::RawResponse; //! use rouille::websocket; //! # fn main() {} //! //! fn handle_request(request: &Request, websockets: &Mutex>>) -//! -> Response +//! -> RawResponse //! { //! let (response, websocket) = try_or_400!(websocket::start(request, Some("my-subprotocol"))); //! websockets.lock().unwrap().push(websocket); @@ -80,7 +80,8 @@ use rustc_serialize::base64::ToBase64; use tiny_http::HTTPVersion; use Request; -use Response; +use ResponseBody; +use RawResponse; mod low_level; mod websocket; @@ -105,7 +106,7 @@ pub enum WebsocketError { /// Builds a `Response` that initiates the websocket protocol. pub fn start(request: &Request, subprotocol: Option<&str>) - -> Result<(Response, mpsc::Receiver), WebsocketError> + -> Result<(RawResponse, mpsc::Receiver), WebsocketError> { if request.method() != "GET" { return Err(WebsocketError::InvalidWebsocketRequest); @@ -150,14 +151,21 @@ pub fn start(request: &Request, subprotocol: Option<&str>) let (tx, rx) = mpsc::channel(); - let mut response = Response::text(""); - response.status_code = 101; - response.headers.push(("Upgrade".into(), "websocket".into())); - if let Some(sp) = subprotocol { - response.headers.push(("Sec-Websocket-Protocol".into(), sp.to_owned())); - } - response.headers.push(("Sec-Websocket-Accept".into(), key)); - response.upgrade = Some(Box::new(tx) as Box<_>); + let response = RawResponse { + status_code: 101, + headers: { + let mut headers = Vec::new(); + headers.push(("Upgrade".into(), "websocket".into())); + if let Some(sp) = subprotocol { + headers.push(("Sec-Websocket-Protocol".into(), sp.to_owned().into())); // TODO: meh alloc + } + headers.push(("Sec-Websocket-Accept".into(), key.into())); + headers + }, + data: ResponseBody::empty(), + upgrade: Some(Box::new(tx) as Box<_>), + }; + Ok((response, rx)) }