diff --git a/Cargo.toml b/Cargo.toml index 377fb1a..db2cc60 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -56,11 +56,11 @@ base64 = { version = "^0.13", optional = true, default-features = false } http = { version = "^0.2", optional = true, default-features = false } hyper = { version = "^0.14.4", optional = true, features = ["client", "stream"], default-features = false } yup-oauth2 = { version = "5.1", optional = true, features = ["hyper-rustls"], default-features = false } -prost = { version = "0.7", optional = true, features = ["std"], default-features = false } +prost = { version = "0.8", optional = true, features = ["std"], default-features = false } [dev-dependencies] hyper-tls = "0.5.0" -prost = { version = "0.7", features = ["std", "prost-derive"] } +prost = { version = "0.8", features = ["std", "prost-derive"] } tokio = { version = "1", features = ["macros", "rt"] } serde = { version = "1", features = ["derive"] } diff --git a/src/publish/publishers/googlepubsub.rs b/src/publish/publishers/googlepubsub.rs index 930a71d..f047a17 100644 --- a/src/publish/publishers/googlepubsub.rs +++ b/src/publish/publishers/googlepubsub.rs @@ -411,13 +411,17 @@ impl<'a, I> GoogleMessageSegmenter<'a, I> { fn take_batch(&mut self) -> Option { debug_assert!(self.messages_in_body <= API_MSG_COUNT_LIMIT); - debug_assert!(self.body_data.len() <= API_DATA_LENGTH_LIMIT); + debug_assert!(self.body_data.len() + API_BODY_SUFFIX.len() <= API_DATA_LENGTH_LIMIT); if self.messages_in_body == 0 { return None; } let mut body_data = std::mem::replace(&mut self.body_data, Vec::from(API_BODY_PREFIX)); let messages_in_body = std::mem::replace(&mut self.messages_in_body, 0); body_data.extend(API_BODY_SUFFIX); + // Quite an expensive check but worth it given that we do our own json nonsense to ensure + // we get quotas right... + debug_assert!(body_data.len() <= API_DATA_LENGTH_LIMIT); + debug_assert!(serde_json::from_slice::(&body_data[..]).is_ok()); Some(SegmentationResult { request_body: hyper::Body::from(body_data), messages_in_body, @@ -519,7 +523,6 @@ impl<'a, 'v, I: Iterator> Iterator for GoogleMessag let message_fits_in_current = self.messages_in_body < API_MSG_COUNT_LIMIT; if !data_fits_in_current || !message_fits_in_current { // We need a new batch. - self.body_data.extend(API_BODY_SUFFIX); let batch = self.take_batch(); self.append_message_data(&msg_json); debug_assert!(batch.is_some()); @@ -697,6 +700,53 @@ mod tests { test_segmenter(msgs).await; } + #[cfg(feature = "json-schema")] + #[tokio::test] + async fn regression_for_double_suffix() { + let validator = validators::JsonSchemaValidator::new(SCHEMA).unwrap(); + let small_message = JsonUserCreatedMessage::new_valid( + String::from_utf8(vec![b'a'; 512]).unwrap() + ); + let oversized_message = JsonUserCreatedMessage::new_valid( + String::from_utf8(vec![b'a'; (10 * 1024 * 1024 - 512) * 3 / 4]).unwrap(), + ); + let msgs = vec![ + small_message.encode(&validator).unwrap(), + oversized_message.encode(&validator).unwrap(), + ]; + let mut segmenter = GoogleMessageSegmenter::new("", msgs.iter()); + let body1 = hyper::body::to_bytes(segmenter.next().unwrap().unwrap().request_body).await.unwrap(); + serde_json::from_slice::(&body1[..]).unwrap(); + let body2 = hyper::body::to_bytes(segmenter.next().unwrap().unwrap().request_body).await.unwrap(); + serde_json::from_slice::(&body2[..]).unwrap(); + } + + + #[cfg(feature = "json-schema")] + #[tokio::test] + async fn ensure_request_limits() { + let validator = validators::JsonSchemaValidator::new(SCHEMA).unwrap(); + let small_message = JsonUserCreatedMessage::new_valid( + String::from_utf8(vec![b'a'; 512]).unwrap() + ); + for i in 1225..1227 { + let oversized_message = JsonUserCreatedMessage::new_valid( + String::from_utf8(vec![b'a'; (10 * 1024 * 1024 - i) * 3 / 4]).unwrap(), + ); + let msgs = vec![ + small_message.encode(&validator).unwrap(), + oversized_message.encode(&validator).unwrap(), + ]; + let mut segmenter = GoogleMessageSegmenter::new("", msgs.iter()); + for segment in segmenter { + let body = hyper::body::to_bytes(segment.unwrap().request_body).await.unwrap(); + assert!(body.len() < super::API_DATA_LENGTH_LIMIT); + serde_json::from_slice::(&body[..]).unwrap(); + } + } + } + + #[cfg(feature = "json-schema")] #[test] fn oversized_single_msg_segmenter() {