diff --git a/Cargo.lock b/Cargo.lock index 1ac5056be9..e4bdd79ee1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -179,6 +179,7 @@ dependencies = [ "lru", "maplit", "mediatype", + "memchr", "miette 5.3.0", "mime", "mockall", @@ -365,7 +366,8 @@ dependencies = [ [[package]] name = "async-compression" version = "0.3.14" -source = "git+https://github.com/geal/async-compression?tag=encoder-flush#9800cd0d36be7f3414fbb98b25f9f61900ec8c7c" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "345fd392ab01f746c717b1357165b76f0b67a60192007b234058c9045fdcf695" dependencies = [ "brotli", "flate2", @@ -2936,9 +2938,9 @@ checksum = "90da6e15720cff55898a02a2ed6e9a21b152f0283a5ad89465f8d8f80c9750ca" [[package]] name = "memchr" -version = "2.4.1" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "308cc39be01b73d0d18f82a0e7b2a3df85245f84af96fdddc5d202d27e47b86a" +checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" [[package]] name = "memoffset" diff --git a/Cargo.toml b/Cargo.toml index 06b12f91a4..609df9e60c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,11 +32,3 @@ members = [ # debug = false strip = "debuginfo" incremental = false - -# Temporary patch to async-compression -# It is used by tower-http's CompressionLayer. The compression code was not handling -# the Poll::Pending result from the underlying stream, so it was accumulating the -# entire compressed response in memory before sending it, which creates issues with -# deferred responses getting received too late -[patch.crates-io] -async-compression = { git = 'https://github.com/geal/async-compression', tag = 'encoder-flush' } diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index a5decd91f4..b06d9f203e 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -91,4 +91,25 @@ We changed `QueryPlannerResponse` to: By [@bnjjj](https://github.com/bnjjj) in https://github.com/apollographql/router/pull/1504 +### Disable compression of multipart HTTP responses ([Issue #1572](https://github.com/apollographql/router/issues/1572)) + +For features such a `@defer`, the Router may send a stream of multiple GraphQL responses +in a single HTTP response. +The body of the HTTP response is a single byte stream. +When HTTP compression is used, that byte stream is compressed as a whole. +Due to limitations in current versions of the `async-compression` crate, +[issue #1572](https://github.com/apollographql/router/issues/1572) was a bug where +some GraphQL responses might not be sent to the client until more of them became available. +This buffering yields better compression, but defeats the point of `@defer`. + +Our previous work-around involved a patched `async-compression`, +which was not trivial to apply when using the Router as a dependency +since [Cargo patching](https://doc.rust-lang.org/cargo/reference/overriding-dependencies.html) +is done in a project’s root `Cargo.toml`. + +The Router now reverts to using unpatched `async-compression`, +and instead disables compression of multipart responses. +We aim to re-enable compression soon, with a proper solution that is being designed in +. + ## 📚 Documentation diff --git a/apollo-router/Cargo.toml b/apollo-router/Cargo.toml index ec65c0706c..8c4ec3903e 100644 --- a/apollo-router/Cargo.toml +++ b/apollo-router/Cargo.toml @@ -182,6 +182,7 @@ uname = "0.1.1" insta = { version = "1.19.1", features = [ "json", "redactions" ] } jsonpath_lib = "0.3.0" maplit = "1.0.2" +memchr = { version = "2.5.0", default-features = false } mockall = "0.11.2" once_cell = "1.14.0" reqwest = { version = "0.11.11", default-features = false, features = [ diff --git a/apollo-router/src/axum_http_server_factory.rs b/apollo-router/src/axum_http_server_factory.rs index 34731dbe9c..9b60542a68 100644 --- a/apollo-router/src/axum_http_server_factory.rs +++ b/apollo-router/src/axum_http_server_factory.rs @@ -53,7 +53,10 @@ use tokio::sync::Notify; use tower::util::BoxService; use tower::BoxError; use tower::ServiceExt; +use tower_http::compression::predicate::NotForContentType; use tower_http::compression::CompressionLayer; +use tower_http::compression::DefaultPredicate; +use tower_http::compression::Predicate; use tower_http::trace::MakeSpan; use tower_http::trace::TraceLayer; use tower_service::Service; @@ -189,7 +192,11 @@ where .route(&configuration.server.health_check_path, get(health_check)) .layer(Extension(service_factory)) .layer(cors) - .layer(CompressionLayer::new()); // To compress response body + // Compress the response body, except for multipart responses such as with `@defer`. + // This is a work-around for https://github.com/apollographql/router/issues/1572 + .layer(CompressionLayer::new().compress_when( + DefaultPredicate::new().and(NotForContentType::const_new("multipart/")), + )); let listener = configuration.server.listen.clone(); Ok(ListenAddrAndRouter(listener, route)) @@ -865,6 +872,8 @@ impl MakeSpan for PropagatingMakeSpan { mod tests { use std::net::SocketAddr; use std::str::FromStr; + use std::sync::atomic::AtomicU32; + use std::sync::atomic::Ordering; use async_compression::tokio::write::GzipEncoder; use http::header::ACCEPT_ENCODING; @@ -884,6 +893,7 @@ mod tests { use reqwest::StatusCode; use serde_json::json; use test_log::test; + use tokio::io::BufReader; use tower::service_fn; use super::*; @@ -892,6 +902,9 @@ mod tests { use crate::services::new_service::NewService; use crate::services::transport; use crate::services::MULTIPART_DEFER_CONTENT_TYPE; + use crate::test_harness::http_client; + use crate::test_harness::http_client::MaybeMultipart; + use crate::TestHarness; macro_rules! assert_header { ($response:expr, $header:expr, $expected:expr $(, $msg:expr)?) => { @@ -1893,7 +1906,6 @@ mod tests { #[cfg(unix)] async fn send_to_unix_socket(addr: &ListenAddr, method: Method, body: &str) -> Vec { use tokio::io::AsyncBufReadExt; - use tokio::io::BufReader; use tokio::io::Interest; use tokio::net::UnixStream; @@ -2495,4 +2507,172 @@ Content-Type: application/json\r assert!(value == "one" || value == "two"); } } + + /// A counter of how many GraphQL responses have been sent by an Apollo Router + /// + /// When `@defer` is used, it should increment multiple times for a single HTTP request. + #[derive(Clone, Default)] + struct GraphQLResponseCounter(Arc); + + impl GraphQLResponseCounter { + fn increment(&self) { + self.0.fetch_add(1, Ordering::SeqCst); + } + + fn get(&self) -> u32 { + self.0.load(Ordering::SeqCst) + } + } + + async fn http_service() -> impl Service< + http::Request, + Response = http::Response>, + Error = BoxError, + > { + let counter = GraphQLResponseCounter::default(); + let service = TestHarness::builder() + .configuration_json(json!({ + "plugins": { + "experimental.include_subgraph_errors": { + "all": true + } + } + })) + .unwrap() + .supergraph_hook(move |service| { + let counter = counter.clone(); + service + .map_response(move |mut response| { + response.response.extensions_mut().insert(counter.clone()); + response.map_stream(move |graphql_response| { + counter.increment(); + graphql_response + }) + }) + .boxed() + }) + .build_http_service() + .await + .unwrap() + .map_err(Into::into); + let service = http_client::response_decompression(service); + let service = http_client::defer_spec_20220824_multipart(service); + http_client::json(service) + } + + /// Creates an Apollo Router as an HTTP-level Tower service and makes one request. + async fn make_request( + request_body: serde_json::Value, + ) -> http::Response> { + let request = http::Request::builder() + .method(http::Method::POST) + .header("host", "127.0.0.1") + .body(request_body) + .unwrap(); + http_service().await.oneshot(request).await.unwrap() + } + + fn assert_compressed(response: &http::Response, expected: bool) { + assert_eq!( + response + .extensions() + .get::() + .unwrap() + .0, + expected + ) + } + + #[tokio::test] + async fn test_compressed_response() { + let response = make_request(json!({ + "query": " + query TopProducts($first: Int) { + topProducts(first: $first) { + upc + name + reviews { + id + product { name } + author { id name } + } + } + } + ", + "variables": {"first": 2_u32}, + })) + .await; + assert_compressed(&response, true); + let status = response.status().as_u16(); + let graphql_response = response.into_body().expect_not_multipart(); + assert_eq!(graphql_response["errors"], json!(null)); + assert_eq!(status, 200); + } + + #[tokio::test] + async fn test_defer_is_not_buffered() { + let mut response = make_request(json!({ + "query": " + query TopProducts($first: Int) { + topProducts(first: $first) { + upc + name + reviews { + id + product { name } + ... @defer { author { id name } } + } + } + } + ", + "variables": {"first": 2_u32}, + })) + .await; + assert_compressed(&response, false); + let status = response.status().as_u16(); + assert_eq!(status, 200); + let counter: GraphQLResponseCounter = response.extensions_mut().remove().unwrap(); + let parts = response.into_body().expect_multipart(); + + let (parts, counts): (Vec<_>, Vec<_>) = + parts.map(|part| (part, counter.get())).unzip().await; + let parts = serde_json::Value::Array(parts); + assert_eq!( + parts, + json!([ + { + "data": { + "topProducts": [ + {"upc": "1", "name": "Table", "reviews": null}, + {"upc": "2", "name": "Couch", "reviews": null} + ] + }, + "errors": [ + { + "message": "invalid content: Missing key `_entities`!", + "path": ["topProducts", "@"], + "extensions": { + "type": "ExecutionInvalidContent", + "reason": "Missing key `_entities`!" + } + }], + "hasNext": true, + }, + {"hasNext": false} + ]), + "{}", + serde_json::to_string(&parts).unwrap() + ); + + // Non-regression test for https://github.com/apollographql/router/issues/1572 + // + // With unpatched async-compression 0.3.14 as used by tower-http 0.3.4, + // `counts` is `[2, 2]` since both parts have to be generated on the server side + // before the first one reaches the client. + // + // Conversly, observing the value `1` after receiving the first part + // means the didn’t wait for all parts to be in the compression buffer + // before sending any. + assert_eq!(counts, [1, 2]); + } } diff --git a/apollo-router/src/test_harness.rs b/apollo-router/src/test_harness.rs index 62541a4d7b..a645935c19 100644 --- a/apollo-router/src/test_harness.rs +++ b/apollo-router/src/test_harness.rs @@ -13,8 +13,12 @@ use crate::router_factory::YamlSupergraphServiceFactory; use crate::services::execution; use crate::services::subgraph; use crate::services::supergraph; +use crate::services::RouterCreator; use crate::Schema; +#[cfg(test)] +pub(crate) mod http_client; + /// Builder for the part of an Apollo Router that handles GraphQL requests, as a [`tower::Service`]. /// /// This allows tests, benchmarks, etc @@ -165,7 +169,7 @@ impl<'a> TestHarness<'a> { } /// Builds the GraphQL service - pub async fn build(self) -> Result { + async fn build_common(self) -> Result<(Arc, RouterCreator), BoxError> { let builder = if self.schema.is_none() { self.subgraph_hook(|subgraph_name, default| match subgraph_name { "products" => canned::products_subgraph().boxed(), @@ -195,16 +199,42 @@ impl<'a> TestHarness<'a> { let schema = builder.schema.unwrap_or(canned_schema); let schema = Arc::new(Schema::parse(schema, &config)?); let router_creator = YamlSupergraphServiceFactory - .create(config, schema, None, Some(builder.extra_plugins)) + .create(config.clone(), schema, None, Some(builder.extra_plugins)) .await?; + Ok((config, router_creator)) + } + + pub async fn build(self) -> Result { + let (_config, router_creator) = self.build_common().await?; Ok(tower::service_fn(move |request| { let service = router_creator.make(); async move { service.oneshot(request).await } }) .boxed_clone()) } + + #[cfg(test)] + pub(crate) async fn build_http_service(self) -> Result { + use crate::axum_http_server_factory::make_axum_router; + use crate::axum_http_server_factory::ListenAddrAndRouter; + use crate::router_factory::SupergraphServiceFactory; + + let (config, router_creator) = self.build_common().await?; + let web_endpoints = router_creator.web_endpoints(); + let routers = make_axum_router(router_creator, &config, web_endpoints)?; + let ListenAddrAndRouter(_listener, router) = routers.main; + Ok(router.boxed()) + } } +/// An HTTP-level service, as would be given to Hyper’s server +#[cfg(test)] +pub(crate) type HttpService = tower::util::BoxService< + http::Request, + http::Response, + std::convert::Infallible, +>; + struct SupergraphServicePlugin(F); struct ExecutionServicePlugin(F); struct SubgraphServicePlugin(F); diff --git a/apollo-router/src/test_harness/http_client.rs b/apollo-router/src/test_harness/http_client.rs new file mode 100644 index 0000000000..079a4ab78c --- /dev/null +++ b/apollo-router/src/test_harness/http_client.rs @@ -0,0 +1,278 @@ +use std::io; +use std::pin::Pin; +use std::task::Poll; + +use async_compression::tokio::bufread::BrotliDecoder; +use axum::body::BoxBody; +use futures::stream::poll_fn; +use futures::Future; +use futures::Stream; +use futures::StreamExt; +use http_body::Body; +use mediatype::MediaType; +use mediatype::ReadParams; +use tokio::io::AsyncRead; +use tokio::io::AsyncReadExt; +use tokio_util::io::StreamReader; +use tower::BoxError; +use tower::Service; +use tower::ServiceBuilder; + +/// Added by `response_decompression` to `http::Response::extensions` +pub(crate) struct ResponseBodyWasCompressed(pub(crate) bool); + +pub(crate) enum MaybeMultipart { + Multipart(Pin + Send>>), + NotMultipart(Part), +} + +impl MaybeMultipart { + pub(crate) fn expect_multipart(self) -> Pin + Send>> { + match self { + MaybeMultipart::Multipart(stream) => stream, + MaybeMultipart::NotMultipart(_) => panic!("expected a multipart response"), + } + } + + pub(crate) fn expect_not_multipart(self) -> Part { + match self { + MaybeMultipart::Multipart(_) => panic!("expected a non-multipart response"), + MaybeMultipart::NotMultipart(part) => part, + } + } +} + +pub(crate) fn response_decompression( + inner: InnerService, +) -> impl Service< + http::Request, + Response = http::Response>>, + Error = BoxError, +> +where + InnerService: + Service, Response = http::Response, Error = BoxError>, +{ + ServiceBuilder::new() + .map_request(|mut request: http::Request| { + request + .headers_mut() + .insert("accept-encoding", "br".try_into().unwrap()); + request + }) + .map_response(|response: http::Response| { + let mut response = response.map(|body| { + // Convert from axum’s BoxBody to AsyncBufRead + let mut body = Box::pin(body); + let stream = poll_fn(move |ctx| body.as_mut().poll_data(ctx)) + .map(|result| result.map_err(|e| io::Error::new(io::ErrorKind::Other, e))); + StreamReader::new(stream) + }); + let content_encoding = response.headers().get("content-encoding"); + if let Some(encoding) = content_encoding { + assert_eq!( + encoding.as_bytes(), + b"br", + "unexpected content-encoding: {:?}", + String::from_utf8_lossy(encoding.as_bytes()) + ); + } + let compressed = content_encoding.is_some(); + response + .extensions_mut() + .insert(ResponseBodyWasCompressed(compressed)); + if compressed { + response.map(|body| Box::pin(BrotliDecoder::new(body)) as _) + } else { + response.map(|body| Box::pin(body) as _) + } + }) + .service(inner) +} + +pub(crate) fn defer_spec_20220824_multipart( + inner: InnerService, +) -> impl Service< + http::Request, + Response = http::Response>>, + Error = BoxError, +> +where + InnerService: Service< + http::Request, + Response = http::Response>>, + Error = BoxError, + >, +{ + ServiceBuilder::new() + .map_request(|mut request: http::Request| { + request.headers_mut().insert( + "accept", + "multipart/mixed; deferSpec=20220824".try_into().unwrap(), + ); + request + }) + .map_future(|future| async { + let response: http::Response>> = future.await?; + let (parts, mut body) = response.into_parts(); + let content_type = parts.headers.get("content-type").unwrap(); + let media_type = MediaType::parse(content_type.to_str().unwrap()).unwrap(); + let body = if media_type.ty == "multipart" { + let defer_spec = mediatype::Name::new("deferSpec").unwrap(); + assert_eq!(media_type.subty, "mixed"); + assert_eq!(media_type.get_param(defer_spec).unwrap(), "20220824"); + let boundary = media_type.get_param(mediatype::names::BOUNDARY).unwrap(); + let boundary = format!("\r\n--{}", boundary.unquoted_str()); + MaybeMultipart::Multipart(parse_multipart(boundary, body).await) + } else { + let mut vec = Vec::new(); + body.read_to_end(&mut vec).await.unwrap(); + MaybeMultipart::NotMultipart(vec) + }; + Ok(http::Response::from_parts(parts, body)) + }) + .service(inner) +} + +async fn parse_multipart( + boundary: String, + mut body: Pin>, +) -> Pin> + Send>> { + let mut buffer = Vec::new(); + while buffer.len() < boundary.len() { + read_some_more(&mut body, &mut buffer).await; + } + assert_prefix(&buffer, &boundary); + buffer.drain(..boundary.len()); + + let mut future = Some(Box::pin(read_part(body, boundary, buffer))); + futures::stream::poll_fn(move |ctx| { + if let Some(f) = &mut future { + match f.as_mut().poll(ctx) { + Poll::Pending => Poll::Pending, + Poll::Ready(None) => { + future = None; + Poll::Ready(None) + } + // Juggle ownership of `boundary` and `next_buffer` + // across multiple instances of async-fn-returned futures. + Poll::Ready(Some((body, boundary, part, next_buffer))) => { + future = Some(Box::pin(read_part(body, boundary, next_buffer))); + Poll::Ready(Some(part)) + } + } + } else { + Poll::Ready(None) + } + }) + .boxed() +} + +/// Reads one part of `multipart/mixed` +/// +/// To be called when the position of `body` is just after a multipart boundary +/// +/// Returns `Some((body, boundary, part, next_buffer))`, +/// or `None` when there is no further part. +async fn read_part( + mut body: Pin>, + boundary: String, + mut buffer: Vec, +) -> Option<(Pin>, String, Vec, Vec)> { + const BOUNDARY_SUFFIX_LEN: usize = 2; + while buffer.len() < BOUNDARY_SUFFIX_LEN { + read_some_more(&mut body, &mut buffer).await; + } + let boundary_suffix = &buffer[..BOUNDARY_SUFFIX_LEN]; + match boundary_suffix { + b"--" => return None, // This boundary marked the end of multipart + b"\r\n" => {} // Another part follows + _ => panic!("unexpected boundary suffix"), + }; + buffer.drain(..BOUNDARY_SUFFIX_LEN); + + loop { + // Restarting the substring seach from the start of `part` at every iteration + // makes this overall loop take O(n²) time. + // This is good enough for tests with known-small responses, + // and makes it easier to account for multipart boundaries + // that might be split across multiple reads. + if let Some(before_boundary) = memchr::memmem::find(&buffer, boundary.as_bytes()) { + let part = buffer[..before_boundary].to_vec(); + let after_boundary = before_boundary + boundary.len(); + buffer.drain(..after_boundary); + return Some((body, boundary, part, buffer)); + } + read_some_more(&mut body, &mut buffer).await; + } +} + +// Similar to AsyncBufRead::fill_buf, but reads the stream even if the buffer is not empty. +// This allows searching for patterns more than one byte long. +async fn read_some_more(body: &mut Pin>, buffer: &mut Vec) { + const BUFFER_SIZE_INCREMENT: usize = 1024; + let previous_len = buffer.len(); + buffer.resize(previous_len + BUFFER_SIZE_INCREMENT, 0); + let read_len = body.read(&mut buffer[previous_len..]).await.unwrap(); + if read_len == 0 { + panic!("end of response body without a multipart end boundary") + } + buffer.truncate(previous_len + read_len); +} + +fn assert_prefix<'a>(bytes: &'a [u8], expected_prefix: &str) -> &'a [u8] { + let (prefix, rest) = bytes.split_at(expected_prefix.len().min(bytes.len())); + assert_eq!( + prefix, + expected_prefix.as_bytes(), + "{:?} != {:?}", + String::from_utf8_lossy(prefix), + expected_prefix + ); + rest +} + +pub(crate) fn json( + inner: InnerService, +) -> impl Service< + http::Request, + Response = http::Response>, + Error = BoxError, +> +where + InnerService: Service< + http::Request, + Response = http::Response>>, + Error = BoxError, + >, +{ + ServiceBuilder::new() + .map_request(|mut request: http::Request| { + request + .headers_mut() + .insert("content-type", "application/json".try_into().unwrap()); + request.map(|body| serde_json::to_vec(&body).unwrap().into()) + }) + .map_response(|response: http::Response>>| { + let (parts, body) = response.into_parts(); + let body = match body { + MaybeMultipart::NotMultipart(bytes) => { + assert_eq!( + parts.headers.get("content-type").unwrap(), + "application/json" + ); + MaybeMultipart::NotMultipart(serde_json::from_slice(&bytes).unwrap()) + } + MaybeMultipart::Multipart(stream) => MaybeMultipart::Multipart( + stream + .map(|part| { + let expected_headers = "content-type: application/json\r\n\r\n"; + serde_json::from_slice(assert_prefix(&part, expected_headers)).unwrap() + }) + .boxed(), + ), + }; + http::Response::from_parts(parts, body) + }) + .service(inner) +}