Skip to content

Commit

Permalink
fix(sse): skip sse incompatible chars of serde_json::RawValue
Browse files Browse the repository at this point in the history
  • Loading branch information
Threated committed Oct 16, 2024
1 parent 0ddc63f commit 7996e22
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 13 deletions.
24 changes: 12 additions & 12 deletions axum/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,36 +42,36 @@ __private_docs = ["tower/full", "dep:tower-http"]

[dependencies]
axum-core = { path = "../axum-core", version = "0.5.0-alpha.1" }

# optional dependencies
axum-macros = { path = "../axum-macros", version = "0.5.0-alpha.1", optional = true }
base64 = { version = "0.22.1", optional = true }
bytes = "1.0"
futures-util = { version = "0.3", default-features = false, features = ["alloc"] }
http = "1.0.0"
http-body = "1.0.0"
http-body-util = "0.1.0"
hyper = { version = "1.1.0", optional = true }
hyper-util = { version = "0.1.3", features = ["tokio", "server", "service"], optional = true }
itoa = "1.0.5"
matchit = "=0.8.0"
memchr = "2.4.1"
mime = "0.3.16"
multer = { version = "3.0.0", optional = true }
percent-encoding = "2.1"
pin-project-lite = "0.2.7"
rustversion = "1.0.9"
serde = "1.0"
sync_wrapper = "1.0.0"
tower = { version = "0.5.1", default-features = false, features = ["util"] }
tower-layer = "0.3.2"
tower-service = "0.3"

# optional dependencies
axum-macros = { path = "../axum-macros", version = "0.5.0-alpha.1", optional = true }
base64 = { version = "0.22.1", optional = true }
hyper = { version = "1.1.0", optional = true }
hyper-util = { version = "0.1.3", features = ["tokio", "server", "service"], optional = true }
multer = { version = "3.0.0", optional = true }
serde_json = { version = "1.0", features = ["raw_value"], optional = true }
serde_path_to_error = { version = "0.1.8", optional = true }
serde_urlencoded = { version = "0.7", optional = true }
sha1 = { version = "0.10", optional = true }
sync_wrapper = "1.0.0"
tokio = { package = "tokio", version = "1.25.0", features = ["time"], optional = true }
tokio-tungstenite = { version = "0.24.0", optional = true }
tower = { version = "0.5.1", default-features = false, features = ["util"] }
tower-layer = "0.3.2"
tower-service = "0.3"
tracing = { version = "0.1", default-features = false, optional = true }

[dependencies.tower-http]
Expand Down Expand Up @@ -117,7 +117,7 @@ quickcheck = "1.0"
quickcheck_macros = "1.0"
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "multipart"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
serde_json = { version = "1.0", features = ["raw_value"] }
time = { version = "0.3", features = ["serde-human-readable"] }
tokio = { package = "tokio", version = "1.25.0", features = ["macros", "rt", "rt-multi-thread", "net", "test-util"] }
tokio-stream = "0.1"
Expand Down
32 changes: 31 additions & 1 deletion axum/src/response/sse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,24 @@ impl Event {
}

self.buffer.extend_from_slice(b"data: ");
serde_json::to_writer((&mut self.buffer).writer(), &data).map_err(axum_core::Error::new)?;
struct IgnoreNewLines<'a>(bytes::buf::Writer<&'a mut BytesMut>);
impl std::io::Write for IgnoreNewLines<'_> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let mut last_split = 0;
for delimiter in memchr::memchr2_iter(b'\n', b'\r', buf) {
self.0.write_all(&buf[last_split..delimiter])?;
last_split = delimiter + 1;
}
self.0.write_all(&buf[last_split..])?;
Ok(buf.len())
}

fn flush(&mut self) -> std::io::Result<()> {
self.0.flush()
}
}
serde_json::to_writer(IgnoreNewLines((&mut self.buffer).writer()), &data)
.map_err(axum_core::Error::new)?;
self.buffer.put_u8(b'\n');

self.flags.insert(EventFlags::HAS_DATA);
Expand Down Expand Up @@ -515,6 +532,7 @@ mod tests {
use super::*;
use crate::{routing::get, test_helpers::*, Router};
use futures_util::stream;
use serde_json::value::RawValue;
use std::{collections::HashMap, convert::Infallible};
use tokio_stream::StreamExt as _;

Expand All @@ -527,6 +545,18 @@ mod tests {
assert_eq!(&*leading_space.finalize(), b"data: foobar\n\n");
}

#[test]
fn valid_json_raw_value_chars_stripped() {
let json_string = "{\r\"foo\": \n\r\r \"bar\\n\"\n}";
let json_raw_value_event = Event::default()
.json_data(serde_json::from_str::<&RawValue>(&json_string).unwrap())
.unwrap();
assert_eq!(
&*json_raw_value_event.finalize(),
format!("data: {}\n\n", json_string.replace(['\n', '\r'], "")).as_bytes()
);
}

#[crate::test]
async fn basic() {
let app = Router::new().route(
Expand Down

0 comments on commit 7996e22

Please sign in to comment.