diff --git a/src/lib.rs b/src/lib.rs index 01af379..a091e31 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,10 +3,3 @@ mod middleware; pub use middleware::CompressMiddleware; - -#[derive(PartialEq)] -pub enum Encoding { - BROTLI, - GZIP, - DEFLATE, -} diff --git a/src/middleware.rs b/src/middleware.rs index 5f88cc3..b35cf9d 100644 --- a/src/middleware.rs +++ b/src/middleware.rs @@ -6,11 +6,10 @@ use async_compression::futures::bufread::DeflateEncoder; use async_compression::futures::bufread::GzipEncoder; use futures_util::io::BufReader; use regex::Regex; +use tide::http::content::{AcceptEncoding, ContentEncoding, Encoding}; use tide::http::{headers, Body, Method}; use tide::{Middleware, Next, Request, Response}; -use crate::Encoding; - const THRESHOLD: usize = 1024; /// A middleware for compressing response body data. @@ -51,16 +50,17 @@ impl Middleware for CompressMiddlew // Incoming Request data // Need to grab these things before the request is consumed by `next.run()`. let is_head = req.method() == Method::Head; - let accepts_encoding = accepts_encoding(&req); + let accepts = AcceptEncoding::from_headers(&req)?; // Propagate to route let mut res: Response = next.run(req).await; // Head requests should have no body to compress. // Can't tell if we can compress if there is no Accepts-Encoding header. - if is_head || accepts_encoding.is_none() { + if is_head || accepts.is_none() { return Ok(res); } + let mut accepts = accepts.unwrap(); // Should we transform? if let Some(cache_control) = res.header(headers::CACHE_CONTROL) { @@ -74,8 +74,8 @@ impl Middleware for CompressMiddlew // Check if an encoding may already exist. // Can't tell if we should compress if an encoding set. - if let Some(previous_encoding) = res.header(headers::CONTENT_ENCODING) { - if previous_encoding.iter().any(|v| v.as_str() != "identity") { + if let Some(previous_encoding) = ContentEncoding::from_headers(&res)? { + if previous_encoding != Encoding::Identity { return Ok(res); } } @@ -88,11 +88,18 @@ impl Middleware for CompressMiddlew } let body = res.take_body(); - let encoding = accepts_encoding.unwrap(); + let encoding = accepts.negotiate(&[ + #[cfg(feature = "brotli")] + Encoding::Brotli, + #[cfg(feature = "gzip")] + Encoding::Gzip, + #[cfg(feature = "deflate")] + Encoding::Deflate, + ])?; // Get a new Body backed by an appropriate encoder, if one is available. res.set_body(get_encoder(body, &encoding)); - res.insert_header(headers::CONTENT_ENCODING, get_encoding_name(&encoding)); + encoding.apply(&mut res); // End size no longer matches body size, so any existing Content-Length is useless. res.remove_header(headers::CONTENT_LENGTH); @@ -101,66 +108,28 @@ impl Middleware for CompressMiddlew } } -/// Gets an `Encoding` that matches up to the Accept-Encoding value. -fn accepts_encoding(req: &Request) -> Option { - let header = req.header(headers::ACCEPT_ENCODING)?; - - #[cfg(feature = "brotli")] - { - if header.iter().any(|v| v.as_str().contains("br")) { - return Some(Encoding::BROTLI); - } - } - - #[cfg(feature = "gzip")] - { - if header.iter().any(|v| v.as_str().contains("gzip")) { - return Some(Encoding::GZIP); - } - } - - #[cfg(feature = "deflate")] - { - if header.iter().any(|v| v.as_str().contains("deflate")) { - return Some(Encoding::DEFLATE); - } - } - - None -} - /// Returns a `Body` made from an encoder chosen from the `Encoding`. -fn get_encoder(body: Body, encoding: &Encoding) -> Body { +fn get_encoder(body: Body, encoding: &ContentEncoding) -> Body { #[cfg(feature = "brotli")] { - if *encoding == Encoding::BROTLI { + if *encoding == Encoding::Brotli { return Body::from_reader(BufReader::new(BrotliEncoder::new(body)), None); } } #[cfg(feature = "gzip")] { - if *encoding == Encoding::GZIP { + if *encoding == Encoding::Gzip { return Body::from_reader(BufReader::new(GzipEncoder::new(body)), None); } } #[cfg(feature = "deflate")] { - if *encoding == Encoding::DEFLATE { + if *encoding == Encoding::Deflate { return Body::from_reader(BufReader::new(DeflateEncoder::new(body)), None); } } body } - -/// Maps an `Encoding` to a Content-Encoding string. -fn get_encoding_name(encoding: &Encoding) -> String { - (match *encoding { - Encoding::BROTLI => "br", - Encoding::GZIP => "gzip", - Encoding::DEFLATE => "deflate", - }) - .to_string() -} diff --git a/tests/existing-encoding.rs b/tests/existing-encoding.rs index 6402071..d162622 100644 --- a/tests/existing-encoding.rs +++ b/tests/existing-encoding.rs @@ -19,7 +19,7 @@ async fn existing_encoding() { app.at("/").get(|_| async { let mut res = Response::new(StatusCode::Ok); res.set_body(TEXT.to_owned()); - res.insert_header(headers::CONTENT_ENCODING, "some-format"); + res.insert_header(headers::CONTENT_ENCODING, "deflate"); Ok(res) }); @@ -29,7 +29,7 @@ async fn existing_encoding() { assert_eq!(res.status(), 200); assert!(res.header(headers::CONTENT_LENGTH).is_none()); - assert_eq!(res[headers::CONTENT_ENCODING], "some-format"); + assert_eq!(res[headers::CONTENT_ENCODING], "deflate"); assert_eq!(res.body_string().await.unwrap(), TEXT); } @@ -40,8 +40,8 @@ async fn multi_existing_encoding() { app.at("/").get(|_| async { let mut res = Response::new(StatusCode::Ok); res.set_body(TEXT.to_owned()); - res.append_header(headers::CONTENT_ENCODING, "gzip"); res.append_header(headers::CONTENT_ENCODING, "identity"); + res.append_header(headers::CONTENT_ENCODING, "gzip"); Ok(res) }); @@ -51,7 +51,7 @@ async fn multi_existing_encoding() { assert_eq!(res.status(), 200); assert!(res.header(headers::CONTENT_LENGTH).is_none()); - assert_eq!(res[headers::CONTENT_ENCODING][0].as_str(), "gzip"); - assert_eq!(res[headers::CONTENT_ENCODING][1].as_str(), "identity"); + assert_eq!(res[headers::CONTENT_ENCODING][0].as_str(), "identity"); + assert_eq!(res[headers::CONTENT_ENCODING][1].as_str(), "gzip"); assert_eq!(res.body_string().await.unwrap(), TEXT); } diff --git a/tests/unencoded.rs b/tests/unencoded.rs index 76de580..a97cb6a 100644 --- a/tests/unencoded.rs +++ b/tests/unencoded.rs @@ -43,12 +43,11 @@ async fn invalid_accepts_encoding() { let mut req = Request::new(Method::Get, Url::parse("http://_/").unwrap()); req.insert_header(headers::ACCEPT_ENCODING, "not_an_encoding"); - let mut res: tide::http::Response = app.respond(req).await.unwrap(); + let res: tide::http::Response = app.respond(req).await.unwrap(); - assert_eq!(res.status(), 200); + assert_eq!(res.status(), StatusCode::NotAcceptable); assert!(res.header(headers::CONTENT_LENGTH).is_none()); assert!(res.header(headers::CONTENT_ENCODING).is_none()); - assert_eq!(res.body_string().await.unwrap(), TEXT); } #[async_std::test]