Skip to content

Commit

Permalink
use http_types::content::AcceptEncoding
Browse files Browse the repository at this point in the history
  • Loading branch information
Fishrock123 committed Aug 24, 2020
1 parent 7487280 commit dff0747
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 65 deletions.
7 changes: 0 additions & 7 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,3 @@
mod middleware;

pub use middleware::CompressMiddleware;

#[derive(PartialEq)]
pub enum Encoding {
BROTLI,
GZIP,
DEFLATE,
}
69 changes: 19 additions & 50 deletions src/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@ use async_compression::futures::bufread::DeflateEncoder;
use async_compression::futures::bufread::GzipEncoder;
use futures_util::io::BufReader;
use regex::Regex;
use tide::http::content::{AcceptEncoding, ContentEncoding, Encoding};
use tide::http::{headers, Body, Method};
use tide::{Middleware, Next, Request, Response};

use crate::Encoding;

const THRESHOLD: usize = 1024;

/// A middleware for compressing response body data.
Expand Down Expand Up @@ -51,16 +50,17 @@ impl<State: Clone + Send + Sync + 'static> Middleware<State> for CompressMiddlew
// Incoming Request data
// Need to grab these things before the request is consumed by `next.run()`.
let is_head = req.method() == Method::Head;
let accepts_encoding = accepts_encoding(&req);
let accepts = AcceptEncoding::from_headers(&req)?;

// Propagate to route
let mut res: Response = next.run(req).await;

// Head requests should have no body to compress.
// Can't tell if we can compress if there is no Accepts-Encoding header.
if is_head || accepts_encoding.is_none() {
if is_head || accepts.is_none() {
return Ok(res);
}
let mut accepts = accepts.unwrap();

// Should we transform?
if let Some(cache_control) = res.header(headers::CACHE_CONTROL) {
Expand All @@ -74,8 +74,8 @@ impl<State: Clone + Send + Sync + 'static> Middleware<State> for CompressMiddlew

// Check if an encoding may already exist.
// Can't tell if we should compress if an encoding set.
if let Some(previous_encoding) = res.header(headers::CONTENT_ENCODING) {
if previous_encoding.iter().any(|v| v.as_str() != "identity") {
if let Some(previous_encoding) = ContentEncoding::from_headers(&res)? {
if previous_encoding != Encoding::Identity {
return Ok(res);
}
}
Expand All @@ -88,11 +88,18 @@ impl<State: Clone + Send + Sync + 'static> Middleware<State> for CompressMiddlew
}

let body = res.take_body();
let encoding = accepts_encoding.unwrap();
let encoding = accepts.negotiate(&[
#[cfg(feature = "brotli")]
Encoding::Brotli,
#[cfg(feature = "gzip")]
Encoding::Gzip,
#[cfg(feature = "deflate")]
Encoding::Deflate,
])?;

// Get a new Body backed by an appropriate encoder, if one is available.
res.set_body(get_encoder(body, &encoding));
res.insert_header(headers::CONTENT_ENCODING, get_encoding_name(&encoding));
encoding.apply(&mut res);

// End size no longer matches body size, so any existing Content-Length is useless.
res.remove_header(headers::CONTENT_LENGTH);
Expand All @@ -101,66 +108,28 @@ impl<State: Clone + Send + Sync + 'static> Middleware<State> for CompressMiddlew
}
}

/// Gets an `Encoding` that matches up to the Accept-Encoding value.
fn accepts_encoding<State: Send + Sync + 'static>(req: &Request<State>) -> Option<Encoding> {
let header = req.header(headers::ACCEPT_ENCODING)?;

#[cfg(feature = "brotli")]
{
if header.iter().any(|v| v.as_str().contains("br")) {
return Some(Encoding::BROTLI);
}
}

#[cfg(feature = "gzip")]
{
if header.iter().any(|v| v.as_str().contains("gzip")) {
return Some(Encoding::GZIP);
}
}

#[cfg(feature = "deflate")]
{
if header.iter().any(|v| v.as_str().contains("deflate")) {
return Some(Encoding::DEFLATE);
}
}

None
}

/// Returns a `Body` made from an encoder chosen from the `Encoding`.
fn get_encoder(body: Body, encoding: &Encoding) -> Body {
fn get_encoder(body: Body, encoding: &ContentEncoding) -> Body {
#[cfg(feature = "brotli")]
{
if *encoding == Encoding::BROTLI {
if *encoding == Encoding::Brotli {
return Body::from_reader(BufReader::new(BrotliEncoder::new(body)), None);
}
}

#[cfg(feature = "gzip")]
{
if *encoding == Encoding::GZIP {
if *encoding == Encoding::Gzip {
return Body::from_reader(BufReader::new(GzipEncoder::new(body)), None);
}
}

#[cfg(feature = "deflate")]
{
if *encoding == Encoding::DEFLATE {
if *encoding == Encoding::Deflate {
return Body::from_reader(BufReader::new(DeflateEncoder::new(body)), None);
}
}

body
}

/// Maps an `Encoding` to a Content-Encoding string.
fn get_encoding_name(encoding: &Encoding) -> String {
(match *encoding {
Encoding::BROTLI => "br",
Encoding::GZIP => "gzip",
Encoding::DEFLATE => "deflate",
})
.to_string()
}
10 changes: 5 additions & 5 deletions tests/existing-encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ async fn existing_encoding() {
app.at("/").get(|_| async {
let mut res = Response::new(StatusCode::Ok);
res.set_body(TEXT.to_owned());
res.insert_header(headers::CONTENT_ENCODING, "some-format");
res.insert_header(headers::CONTENT_ENCODING, "deflate");
Ok(res)
});

Expand All @@ -29,7 +29,7 @@ async fn existing_encoding() {

assert_eq!(res.status(), 200);
assert!(res.header(headers::CONTENT_LENGTH).is_none());
assert_eq!(res[headers::CONTENT_ENCODING], "some-format");
assert_eq!(res[headers::CONTENT_ENCODING], "deflate");
assert_eq!(res.body_string().await.unwrap(), TEXT);
}

Expand All @@ -40,8 +40,8 @@ async fn multi_existing_encoding() {
app.at("/").get(|_| async {
let mut res = Response::new(StatusCode::Ok);
res.set_body(TEXT.to_owned());
res.append_header(headers::CONTENT_ENCODING, "gzip");
res.append_header(headers::CONTENT_ENCODING, "identity");
res.append_header(headers::CONTENT_ENCODING, "gzip");
Ok(res)
});

Expand All @@ -51,7 +51,7 @@ async fn multi_existing_encoding() {

assert_eq!(res.status(), 200);
assert!(res.header(headers::CONTENT_LENGTH).is_none());
assert_eq!(res[headers::CONTENT_ENCODING][0].as_str(), "gzip");
assert_eq!(res[headers::CONTENT_ENCODING][1].as_str(), "identity");
assert_eq!(res[headers::CONTENT_ENCODING][0].as_str(), "identity");
assert_eq!(res[headers::CONTENT_ENCODING][1].as_str(), "gzip");
assert_eq!(res.body_string().await.unwrap(), TEXT);
}
5 changes: 2 additions & 3 deletions tests/unencoded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,11 @@ async fn invalid_accepts_encoding() {

let mut req = Request::new(Method::Get, Url::parse("http://_/").unwrap());
req.insert_header(headers::ACCEPT_ENCODING, "not_an_encoding");
let mut res: tide::http::Response = app.respond(req).await.unwrap();
let res: tide::http::Response = app.respond(req).await.unwrap();

assert_eq!(res.status(), 200);
assert_eq!(res.status(), StatusCode::NotAcceptable);
assert!(res.header(headers::CONTENT_LENGTH).is_none());
assert!(res.header(headers::CONTENT_ENCODING).is_none());
assert_eq!(res.body_string().await.unwrap(), TEXT);
}

#[async_std::test]
Expand Down

0 comments on commit dff0747

Please sign in to comment.