Skip to content

Commit

Permalink
Hacky gzip compression impl
Browse files Browse the repository at this point in the history
  • Loading branch information
Ameobea committed Sep 14, 2020
1 parent 549c924 commit eab53cf
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 145 deletions.
4 changes: 3 additions & 1 deletion contrib/lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ tokio = { version = "0.2.0", optional = true }
rocket_contrib_codegen = { version = "0.5.0-dev", path = "../codegen", optional = true }
rocket = { version = "0.5.0-dev", path = "../../core/lib/", default-features = false }
log = "0.4"
lazy_static = "1.4"
futures = "0.3"

# Serialization and templating dependencies.
serde = { version = "1.0", optional = true }
Expand Down Expand Up @@ -75,7 +77,7 @@ time = { version = "0.2.9", optional = true }

# Compression dependencies
brotli = { version = "3.3", optional = true }
flate2 = { version = "1.0", optional = true }
flate2 = { version = "1.0", optional = true, features = ["tokio"] }

[package.metadata.docs.rs]
all-features = true
78 changes: 12 additions & 66 deletions contrib/lib/src/compression/fairing.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,16 @@
use rocket::config::{ConfigError, Value};
use rocket::fairing::{Fairing, Info, Kind};
use rocket::http::MediaType;
use rocket::Rocket;
use rocket::{Request, Response};

struct Context {
exclusions: Vec<MediaType>,
}

impl Default for Context {
fn default() -> Context {
Context {
exclusions: vec![
MediaType::parse_flexible("application/gzip").unwrap(),
MediaType::parse_flexible("application/zip").unwrap(),
MediaType::parse_flexible("image/*").unwrap(),
MediaType::parse_flexible("video/*").unwrap(),
MediaType::parse_flexible("application/wasm").unwrap(),
MediaType::parse_flexible("application/octet-stream").unwrap(),
],
}
}
lazy_static! {
static ref EXCLUSIONS: Vec<MediaType> = vec![
MediaType::parse_flexible("application/gzip").unwrap(),
MediaType::parse_flexible("application/zip").unwrap(),
MediaType::parse_flexible("image/*").unwrap(),
MediaType::parse_flexible("video/*").unwrap(),
MediaType::parse_flexible("application/wasm").unwrap(),
MediaType::parse_flexible("application/octet-stream").unwrap(),
];
}

/// Compresses all responses with Brotli or Gzip compression.
Expand Down Expand Up @@ -95,6 +85,7 @@ impl Compression {
}
}

#[async_trait]
impl Fairing for Compression {
fn info(&self) -> Info {
Info {
Expand All @@ -103,52 +94,7 @@ impl Fairing for Compression {
}
}

fn on_attach(&self, rocket: Rocket) -> Result<Rocket, Rocket> {
let mut ctxt = Context::default();

match rocket.config().get_table("compress").and_then(|t| {
t.get("exclude").ok_or_else(|| ConfigError::Missing(String::from("exclude")))
}) {
Ok(excls) => match excls.as_array() {
Some(excls) => {
ctxt.exclusions = excls.iter().flat_map(|ex| {
if let Value::String(s) = ex {
let mt = MediaType::parse_flexible(s);
if mt.is_none() {
warn_!("Ignoring invalid media type '{:?}'", s);
}
mt
} else {
warn_!("Ignoring non-string media type '{:?}'", ex);
None
}
}).collect();
}
None => {
warn_!(
"Exclusions is not an array; using default compression exclusions '{:?}'",
ctxt.exclusions
);
}
},
Err(ConfigError::Missing(_)) => { /* ignore missing */ }
Err(e) => {
e.pretty_print();
warn_!(
"Using default compression exclusions '{:?}'",
ctxt.exclusions
);
}
};

Ok(rocket.manage(ctxt))
}

fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
let context = request
.guard::<rocket::State<'_, Context>>()
.expect("Compression Context registered in on_attach");

super::CompressionUtils::compress_response(request, response, &context.exclusions);
async fn on_response<'r>(&self, request: &'r Request<'_>, response: &mut Response<'r>) {
super::CompressionUtils::compress_response(request, response, &EXCLUSIONS);
}
}
96 changes: 87 additions & 9 deletions contrib/lib/src/compression/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ pub use self::responder::Compress;

use std::io::Read;

use futures::future::FutureExt;
use futures::StreamExt;
use rocket::http::hyper::header::CONTENT_ENCODING;
use rocket::http::MediaType;
use rocket::http::hyper::header::{ContentEncoding, Encoding};
use rocket::{Request, Response};

#[cfg(feature = "brotli_compression")]
Expand All @@ -40,6 +42,57 @@ use brotli::enc::backward_references::BrotliEncoderMode;
#[cfg(feature = "gzip_compression")]
use flate2::read::GzEncoder;

pub enum Encoding {
/// The `chunked` encoding.
Chunked,
/// The `br` encoding.
Brotli,
/// The `gzip` encoding.
Gzip,
/// The `deflate` encoding.
Deflate,
/// The `compress` encoding.
Compress,
/// The `identity` encoding.
Identity,
/// The `trailers` encoding.
Trailers,
/// Some other encoding that is less common, can be any String.
EncodingExt(String),
}

impl std::fmt::Display for Encoding {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(match *self {
Encoding::Chunked => "chunked",
Encoding::Brotli => "br",
Encoding::Gzip => "gzip",
Encoding::Deflate => "deflate",
Encoding::Compress => "compress",
Encoding::Identity => "identity",
Encoding::Trailers => "trailers",
Encoding::EncodingExt(ref s) => s.as_ref(),
})
}
}

impl std::str::FromStr for Encoding {
type Err = std::convert::Infallible;

fn from_str(s: &str) -> Result<Encoding, std::convert::Infallible> {
match s {
"chunked" => Ok(Encoding::Chunked),
"br" => Ok(Encoding::Brotli),
"deflate" => Ok(Encoding::Deflate),
"gzip" => Ok(Encoding::Gzip),
"compress" => Ok(Encoding::Compress),
"identity" => Ok(Encoding::Identity),
"trailers" => Ok(Encoding::Trailers),
_ => Ok(Encoding::EncodingExt(s.to_owned())),
}
}
}

struct CompressionUtils;

impl CompressionUtils {
Expand All @@ -56,12 +109,15 @@ impl CompressionUtils {
response.headers().get("Content-Encoding").next().is_some()
}

fn set_body_and_encoding<'r, B: Read + 'r>(
fn set_body_and_encoding<'r, B: rocket::tokio::io::AsyncRead + Send + 'r>(
response: &mut Response<'r>,
body: B,
encoding: Encoding,
) {
response.set_header(ContentEncoding(vec![encoding]));
response.set_header(::rocket::http::Header::new(
CONTENT_ENCODING.as_str(),
format!("{}", encoding),
));
response.set_streamed_body(body);
}

Expand All @@ -81,7 +137,11 @@ impl CompressionUtils {
}
}

fn compress_response(request: &Request<'_>, response: &mut Response<'_>, exclusions: &[MediaType]) {
fn compress_response(
request: &Request<'_>,
response: &mut Response<'_>,
exclusions: &[MediaType],
) {
if CompressionUtils::already_encoded(response) {
return;
}
Expand All @@ -94,7 +154,7 @@ impl CompressionUtils {

// Compression is done when the request accepts brotli or gzip encoding
// and the corresponding feature is enabled
if cfg!(feature = "brotli_compression") && CompressionUtils::accepts_encoding(request, "br")
/*if cfg!(feature = "brotli_compression") && CompressionUtils::accepts_encoding(request, "br")
{
#[cfg(feature = "brotli_compression")]
{
Expand All @@ -118,15 +178,33 @@ impl CompressionUtils {
);
}
}
} else if cfg!(feature = "gzip_compression")
&& CompressionUtils::accepts_encoding(request, "gzip")
} else */
if cfg!(feature = "gzip_compression") && CompressionUtils::accepts_encoding(request, "gzip")
{
#[cfg(feature = "gzip_compression")]
{
if let Some(plain) = response.take_body() {
let compressor = GzEncoder::new(plain.into_inner(), flate2::Compression::default());
let body = async {
let body = plain.into_bytes().await.unwrap_or_else(Vec::new);
let mut compressor =
GzEncoder::new(body.as_slice(), flate2::Compression::default());
let mut buf = Vec::new();
match compressor.read_to_end(&mut buf) {
Ok(_) => (),
Err(err) => {
error!("Error compressing response with gzip: {:?}", err);
return futures::stream::iter(vec![Err(err)]);
}
}

futures::stream::iter(vec![Ok(std::io::Cursor::new(buf))])
}
.into_stream()
.flatten();

let body = tokio::io::stream_reader(body);

CompressionUtils::set_body_and_encoding(response, compressor, Encoding::Gzip);
CompressionUtils::set_body_and_encoding(response, body, Encoding::Gzip);
}
}
}
Expand Down
1 change: 1 addition & 0 deletions contrib/lib/src/compression/responder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ impl<'r, 'o: 'r, R: Responder<'r, 'o>> Responder<'r, 'o> for Compress<R> {
.merge(self.0.respond_to(request)?)
.finalize();

println!("YOU SUCK");
CompressionUtils::compress_response(request, &mut response, &[]);
Ok(response)
}
Expand Down
41 changes: 28 additions & 13 deletions contrib/lib/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#![doc(html_root_url = "https://api.rocket.rs/v0.5")]
#![doc(html_favicon_url = "https://rocket.rs/images/favicon.ico")]
#![doc(html_logo_url = "https://rocket.rs/images/logo-boxed.png")]

#![warn(rust_2018_idioms)]
#![allow(unused_extern_crates)]

Expand Down Expand Up @@ -40,17 +39,33 @@
//! This crate is expected to grow with time, bringing in outside crates to be
//! officially supported by Rocket.

#[allow(unused_imports)] #[macro_use] extern crate log;
#[allow(unused_imports)] #[macro_use] extern crate rocket;
#[allow(unused_imports)]
#[macro_use]
extern crate log;
#[allow(unused_imports)]
#[macro_use]
extern crate rocket;
#[macro_use]
extern crate lazy_static;

#[cfg(feature="json")] #[macro_use] pub mod json;
#[cfg(feature="serve")] pub mod serve;
#[cfg(feature="msgpack")] pub mod msgpack;
#[cfg(feature="templates")] pub mod templates;
#[cfg(feature="uuid")] pub mod uuid;
#[cfg(feature="databases")] pub mod databases;
#[cfg(feature = "helmet")] pub mod helmet;
// TODO.async: Migrate compression, reenable this, tests, and add to docs.
//#[cfg(any(feature="brotli_compression", feature="gzip_compression"))] pub mod compression;
#[cfg(feature = "json")]
#[macro_use]
pub mod json;
#[cfg(any(feature = "brotli_compression", feature = "gzip_compression"))]
pub mod compression;
#[cfg(feature = "databases")]
pub mod databases;
#[cfg(feature = "helmet")]
pub mod helmet;
#[cfg(feature = "msgpack")]
pub mod msgpack;
#[cfg(feature = "serve")]
pub mod serve;
#[cfg(feature = "templates")]
pub mod templates;
#[cfg(feature = "uuid")]
pub mod uuid;

#[cfg(feature="databases")] #[doc(hidden)] pub use rocket_contrib_codegen::*;
#[cfg(feature = "databases")]
#[doc(hidden)]
pub use rocket_contrib_codegen::*;
Loading

0 comments on commit eab53cf

Please sign in to comment.