diff --git a/Cargo.toml b/Cargo.toml index 9d20b70e..43a5c777 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ members = [ "crates/jiter", "crates/jiter-python", + "crates/batson", "crates/fuzz", ] resolver = "2" @@ -29,5 +30,15 @@ inherits = "release" debug = true [workspace.dependencies] +jiter = { path = "crates/jiter", version = "0.5.0" } +batson = { path = "crates/batson", version = "0.5.0" } +bencher = "0.1.5" +codspeed-bencher-compat = "2.7.1" +num-bigint = "0.4.4" +num-traits = "0.2.16" +paste = "1.0.7" pyo3 = { version = "0.22.0" } pyo3-build-config = { version = "0.22.0" } +smallvec = "2.0.0-alpha.7" +serde = "1.0.210" +serde_json = "1.0.128" diff --git a/crates/batson/Cargo.toml b/crates/batson/Cargo.toml new file mode 100644 index 00000000..1fc5e9df --- /dev/null +++ b/crates/batson/Cargo.toml @@ -0,0 +1,43 @@ +[package] +name = "batson" +description = "Binary Alternative To (J)SON. Designed to be very fast to query." +readme = "../../README.md" +version = {workspace = true} +edition = {workspace = true} +authors = {workspace = true} +license = {workspace = true} +keywords = {workspace = true} +categories = {workspace = true} +homepage = {workspace = true} +repository = {workspace = true} + +[dependencies] +bytemuck = { version = "1.17.1", features = ["aarch64_simd", "derive", "align_offset"] } +jiter = { workspace = true } +num-bigint = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +simdutf8 = { version = "0.1.4", features = ["aarch64_neon"] } +smallvec = { workspace = true } + +[dev-dependencies] +bencher = { workspace = true } +paste = { workspace = true } +codspeed-bencher-compat = { workspace = true } + +[[bench]] +name = "main" +harness = false + +[lints.clippy] +dbg_macro = "deny" +print_stdout = "deny" +print_stderr = "deny" +# in general we lint against the pedantic group, but we will whitelist +# certain lints which we don't want to enforce (for now) +pedantic = { level = "deny", priority = -1 } +missing_errors_doc = "allow" +cast_possible_truncation = "allow" # TODO remove +cast_sign_loss = "allow" # TODO remove +cast_possible_wrap = "allow" # TODO remove +checked_conversions = "allow" # TODO remove diff --git a/crates/batson/README.md b/crates/batson/README.md new file mode 100644 index 00000000..92bdd242 --- /dev/null +++ b/crates/batson/README.md @@ -0,0 +1,16 @@ +# batson + +Binary Alternative To (J)SON. Designed to be very fast to query. + +Inspired by Postgres' [JSONB type](https://github.com/postgres/postgres/commit/d9134d0a355cfa447adc80db4505d5931084278a?diff=unified&w=0) and Snowflake's [VARIANT type](https://www.youtube.com/watch?v=jtjOfggD4YY). + +For a relatively small JSON document (3KB), batson is 14 to 126x faster than Jiter, and 106 to 588x faster than Serde. + +``` +test medium_get_str_found_batson ... bench: 51 ns/iter (+/- 1) +test medium_get_str_found_jiter ... bench: 755 ns/iter (+/- 66) +test medium_get_str_found_serde ... bench: 5,420 ns/iter (+/- 93) +test medium_get_str_missing_batson ... bench: 9 ns/iter (+/- 0) +test medium_get_str_missing_jiter ... bench: 1,135 ns/iter (+/- 46) +test medium_get_str_missing_serde ... bench: 5,292 ns/iter (+/- 324) +``` diff --git a/crates/batson/benches/main.rs b/crates/batson/benches/main.rs new file mode 100644 index 00000000..e280f0b6 --- /dev/null +++ b/crates/batson/benches/main.rs @@ -0,0 +1,213 @@ +use codspeed_bencher_compat::{benchmark_group, benchmark_main, Bencher}; +use std::hint::black_box; + +use std::fs::File; +use std::io::Read; + +use batson::get::{get_str, BatsonPath}; +use batson::{batson_to_json_string, encode_from_json}; +use jiter::JsonValue; + +fn read_file(path: &str) -> String { + let mut file = File::open(path).unwrap(); + let mut contents = String::new(); + file.read_to_string(&mut contents).unwrap(); + contents +} + +/// taken from +mod jiter_find { + use jiter::{Jiter, Peek}; + + #[derive(Debug)] + pub enum JsonPath<'s> { + Key(&'s str), + Index(usize), + None, + } + + impl From for JsonPath<'_> { + fn from(index: u64) -> Self { + JsonPath::Index(usize::try_from(index).unwrap()) + } + } + + impl From for JsonPath<'_> { + fn from(index: i32) -> Self { + match usize::try_from(index) { + Ok(i) => Self::Index(i), + Err(_) => Self::None, + } + } + } + + impl<'s> From<&'s str> for JsonPath<'s> { + fn from(key: &'s str) -> Self { + JsonPath::Key(key) + } + } + + pub fn jiter_json_find<'j>(opt_json: Option<&'j str>, path: &[JsonPath]) -> Option<(Jiter<'j>, Peek)> { + let json_str = opt_json?; + let mut jiter = Jiter::new(json_str.as_bytes()); + let mut peek = jiter.peek().ok()?; + for element in path { + match element { + JsonPath::Key(key) if peek == Peek::Object => { + let mut next_key = jiter.known_object().ok()??; + + while next_key != *key { + jiter.next_skip().ok()?; + next_key = jiter.next_key().ok()??; + } + + peek = jiter.peek().ok()?; + } + JsonPath::Index(index) if peek == Peek::Array => { + let mut array_item = jiter.known_array().ok()??; + + for _ in 0..*index { + jiter.known_skip(array_item).ok()?; + array_item = jiter.array_step().ok()??; + } + + peek = array_item; + } + _ => { + return None; + } + } + } + Some((jiter, peek)) + } + + pub fn get_str(json_data: Option<&str>, path: &[JsonPath]) -> Option { + if let Some((mut jiter, peek)) = jiter_json_find(json_data, path) { + match peek { + Peek::String => Some(jiter.known_str().ok()?.to_owned()), + _ => None, + } + } else { + None + } + } +} + +mod serde_find { + use batson::get::BatsonPath; + use serde_json::Value; + + pub fn get_str(json_data: &[u8], path: &[BatsonPath]) -> Option { + let json_value: Value = serde_json::from_slice(json_data).ok()?; + let mut current = &json_value; + for key in path { + current = match (key, current) { + (BatsonPath::Key(k), Value::Object(map)) => map.get(*k)?, + (BatsonPath::Index(i), Value::Array(vec)) => vec.get(*i)?, + _ => return None, + } + } + match current { + Value::String(s) => Some(s.clone()), + _ => None, + } + } +} + +fn json_to_batson(json: &[u8]) -> Vec { + let json_value = JsonValue::parse(json, false).unwrap(); + encode_from_json(&json_value).unwrap() +} + +fn medium_get_str_found_batson(bench: &mut Bencher) { + let json = read_file("../jiter/benches/medium_response.json"); + let json_data = json.as_bytes(); + let batson_data = json_to_batson(json_data); + let path: Vec = vec!["person".into(), "linkedin".into(), "handle".into()]; + bench.iter(|| { + let v = get_str(black_box(&batson_data), &path); + black_box(v) + }); +} + +fn medium_get_str_found_jiter(bench: &mut Bencher) { + let json = read_file("../jiter/benches/medium_response.json"); + let path: Vec = vec!["person".into(), "linkedin".into(), "handle".into()]; + bench.iter(|| { + let v = jiter_find::get_str(black_box(Some(&json)), &path); + black_box(v) + }); +} + +fn medium_get_str_found_serde(bench: &mut Bencher) { + let json = read_file("../jiter/benches/medium_response.json"); + let json_data = json.as_bytes(); + let path: Vec = vec!["person".into(), "linkedin".into(), "handle".into()]; + bench.iter(|| { + let v = serde_find::get_str(black_box(json_data), &path).unwrap(); + black_box(v) + }); +} + +fn medium_get_str_missing_batson(bench: &mut Bencher) { + let json = read_file("../jiter/benches/medium_response.json"); + let json_data = json.as_bytes(); + let batson_data = json_to_batson(json_data); + let path: Vec = vec!["squid".into(), "linkedin".into(), "handle".into()]; + bench.iter(|| { + let v = get_str(black_box(&batson_data), &path); + black_box(v) + }); +} + +fn medium_get_str_missing_jiter(bench: &mut Bencher) { + let json = read_file("../jiter/benches/medium_response.json"); + let path: Vec = vec!["squid".into(), "linkedin".into(), "handle".into()]; + bench.iter(|| { + let v = jiter_find::get_str(black_box(Some(&json)), &path); + black_box(v) + }); +} + +fn medium_get_str_missing_serde(bench: &mut Bencher) { + let json = read_file("../jiter/benches/medium_response.json"); + let json_data = json.as_bytes(); + let path: Vec = vec!["squid".into(), "linkedin".into(), "handle".into()]; + bench.iter(|| { + let v = serde_find::get_str(black_box(json_data), &path); + black_box(v) + }); +} + +fn medium_convert_batson_to_json(bench: &mut Bencher) { + let json = read_file("../jiter/benches/medium_response.json"); + let json_data = json.as_bytes(); + let batson_data = json_to_batson(json_data); + bench.iter(|| { + let v = batson_to_json_string(black_box(&batson_data)).unwrap(); + black_box(v) + }); +} + +fn medium_convert_json_to_batson(bench: &mut Bencher) { + let json = read_file("../jiter/benches/medium_response.json"); + let json = json.as_bytes(); + bench.iter(|| { + let json_value = JsonValue::parse(json, false).unwrap(); + let b = encode_from_json(&json_value).unwrap(); + black_box(b) + }); +} + +benchmark_group!( + benches, + medium_get_str_found_batson, + medium_get_str_found_jiter, + medium_get_str_found_serde, + medium_get_str_missing_batson, + medium_get_str_missing_jiter, + medium_get_str_missing_serde, + medium_convert_batson_to_json, + medium_convert_json_to_batson +); +benchmark_main!(benches); diff --git a/crates/batson/examples/read_file.rs b/crates/batson/examples/read_file.rs new file mode 100644 index 00000000..120bd24e --- /dev/null +++ b/crates/batson/examples/read_file.rs @@ -0,0 +1,57 @@ +use batson::get::BatsonPath; +use batson::{batson_to_json_string, encode_from_json}; +use jiter::JsonValue; +use std::fs::File; +use std::io::Read; + +fn main() { + let filename = std::env::args().nth(1).expect( + r#" +No arguments provided! + +Usage: +cargo run --example read_file file.json [path] +"#, + ); + + let mut file = File::open(&filename).expect("failed to open file"); + let mut json = Vec::new(); + file.read_to_end(&mut json).expect("failed to read file"); + + let json_value = JsonValue::parse(&json, false).expect("invalid JSON"); + let batson = encode_from_json(&json_value).expect("failed to construct batson data"); + println!("json length: {}", json.len()); + println!("batson length: {}", batson.len()); + + let output_json = batson_to_json_string(&batson).expect("failed to convert batson to JSON"); + println!("output json length: {}", output_json.len()); + + if let Some(path) = std::env::args().nth(2) { + let path: Vec = path.split('.').map(to_batson_path).collect(); + let start = std::time::Instant::now(); + let value = batson::get::get_str(&batson, &path).expect("failed to get value"); + let elapsed = start.elapsed(); + println!("Found value: {value:?} (time taken: {elapsed:?})"); + } + + println!("reloading to check round-trip"); + let json_value = JsonValue::parse(output_json.as_bytes(), false).expect("invalid JSON"); + let batson = encode_from_json(&json_value).expect("failed to construct batson data"); + let output_json2 = batson_to_json_string(&batson).expect("failed to convert batson to JSON"); + println!("JSON unchanged after re-encoding: {:?}", output_json == output_json2); + + if output_json.len() < 2000 { + println!("\n\noutput json:\n{}", output_json); + } else { + println!("\n\noutput json is too long to display"); + } +} + +fn to_batson_path(s: &str) -> BatsonPath { + if s.chars().all(char::is_numeric) { + let index: usize = s.parse().unwrap(); + index.into() + } else { + s.into() + } +} diff --git a/crates/batson/src/array.rs b/crates/batson/src/array.rs new file mode 100644 index 00000000..cd05ee2f --- /dev/null +++ b/crates/batson/src/array.rs @@ -0,0 +1,553 @@ +use bytemuck::NoUninit; +use jiter::{JsonArray, JsonValue}; +use smallvec::SmallVec; +use std::mem::size_of; +use std::sync::Arc; + +use crate::decoder::Decoder; +use crate::encoder::Encoder; +use crate::errors::{DecodeResult, EncodeResult, ToJsonResult}; +use crate::header::{Category, Header, Length, NumberHint, Primitive}; +use crate::json_writer::JsonWriter; +use crate::object::minimum_value_size_estimate; +use crate::EncodeError; + +#[cfg(target_endian = "big")] +compile_error!("big-endian architectures are not yet supported as we use `bytemuck` for zero-copy header decoding."); + +/// Batson heterogeneous array representation +#[derive(Debug)] +pub(crate) struct HetArray<'b> { + offsets: HetArrayOffsets<'b>, +} + +impl<'b> HetArray<'b> { + pub fn decode_header(d: &mut Decoder<'b>, length: Length) -> DecodeResult { + let offsets = match length { + Length::Empty => HetArrayOffsets::U8(&[]), + Length::U32 => HetArrayOffsets::U32(take_slice_as(d, length)?), + Length::U16 => HetArrayOffsets::U16(take_slice_as(d, length)?), + _ => HetArrayOffsets::U8(take_slice_as(d, length)?), + }; + Ok(Self { offsets }) + } + + pub fn len(&self) -> usize { + match self.offsets { + HetArrayOffsets::U8(v) => v.len(), + HetArrayOffsets::U16(v) => v.len(), + HetArrayOffsets::U32(v) => v.len(), + } + } + + pub fn get(&self, d: &mut Decoder<'b>, index: usize) -> bool { + let opt_offset = match &self.offsets { + HetArrayOffsets::U8(v) => v.get(index).map(|&o| o as usize), + HetArrayOffsets::U16(v) => v.get(index).map(|&o| o as usize), + HetArrayOffsets::U32(v) => v.get(index).map(|&o| o as usize), + }; + if let Some(offset) = opt_offset { + d.index += offset; + true + } else { + false + } + } + + pub fn to_value(&self, d: &mut Decoder<'b>) -> DecodeResult> { + (0..self.len()) + .map(|_| d.take_value()) + .collect::>>() + .map(Arc::new) + } + + pub fn write_json(&self, d: &mut Decoder<'b>, writer: &mut JsonWriter) -> ToJsonResult<()> { + let mut steps = 0..self.len(); + writer.start_array(); + if steps.next().is_some() { + d.write_json(writer)?; + for _ in steps { + writer.comma(); + d.write_json(writer)?; + } + } + writer.end_array(); + Ok(()) + } + + pub fn move_to_end(&self, d: &mut Decoder<'b>) -> DecodeResult<()> { + d.index += match &self.offsets { + HetArrayOffsets::U8(v) => v.last().copied().unwrap() as usize, + HetArrayOffsets::U16(v) => v.last().copied().unwrap() as usize, + HetArrayOffsets::U32(v) => v.last().copied().unwrap() as usize, + }; + let header = d.take_header()?; + d.move_to_end(header) + } +} + +fn take_slice_as<'b, T: bytemuck::Pod>(d: &mut Decoder<'b>, length: Length) -> DecodeResult<&'b [T]> { + let length = length.decode(d)?; + d.take_slice_as(length) +} + +#[derive(Debug)] +enum HetArrayOffsets<'b> { + U8(&'b [u8]), + U16(&'b [u16]), + U32(&'b [u32]), +} + +pub(crate) fn header_array_get(d: &mut Decoder, length: Length, index: usize) -> DecodeResult> { + u8_array_get(d, length, index)? + .map(|b| Header::decode(b, d)) + .transpose() +} + +pub(crate) fn header_array_to_json<'b>(d: &mut Decoder<'b>, length: Length) -> DecodeResult> { + let length = length.decode(d)?; + d.take_slice(length)? + .iter() + .map(|b| Header::decode(*b, d).map(|h| h.header_as_value(d))) + .collect::>() + .map(Arc::new) +} + +pub(crate) fn header_array_write_to_json(d: &mut Decoder, length: Length, writer: &mut JsonWriter) -> ToJsonResult<()> { + let length = length.decode(d)?; + let s = d.take_slice(length)?; + let mut iter = s.iter(); + + writer.start_array(); + if let Some(b) = iter.next() { + let h = Header::decode(*b, d)?; + h.write_json_header_only(writer)?; + for b in iter { + writer.comma(); + let h = Header::decode(*b, d)?; + h.write_json_header_only(writer)?; + } + } + writer.end_array(); + Ok(()) +} + +pub(crate) fn u8_array_get(d: &mut Decoder, length: Length, index: usize) -> DecodeResult> { + let length = length.decode(d)?; + let v = d.take_slice(length)?; + Ok(v.get(index).copied()) +} + +pub(crate) fn u8_array_to_json<'b>(d: &mut Decoder<'b>, length: Length) -> DecodeResult> { + let v = u8_array_slice(d, length)? + .iter() + .map(|b| JsonValue::Int(i64::from(*b))) + .collect(); + Ok(Arc::new(v)) +} + +pub(crate) fn u8_array_slice<'b>(d: &mut Decoder<'b>, length: Length) -> DecodeResult<&'b [u8]> { + let length = length.decode(d)?; + d.take_slice(length) +} + +pub(crate) fn i64_array_get(d: &mut Decoder, length: Length, index: usize) -> DecodeResult> { + let length = length.decode(d)?; + d.align::(); + let s: &[i64] = d.take_slice_as(length)?; + Ok(s.get(index).copied()) +} + +pub(crate) fn i64_array_to_json<'b>(d: &mut Decoder<'b>, length: Length) -> DecodeResult> { + let s = i64_array_slice(d, length)?; + let v = s.iter().copied().map(JsonValue::Int).collect(); + Ok(Arc::new(v)) +} + +pub(crate) fn i64_array_slice<'b>(d: &mut Decoder<'b>, length: Length) -> DecodeResult<&'b [i64]> { + let length = length.decode(d)?; + d.take_slice_as(length) +} + +pub(crate) fn encode_array(encoder: &mut Encoder, array: &JsonArray) -> EncodeResult<()> { + if array.is_empty() { + // shortcut but also no alignment! + encoder.encode_length(Category::HetArray, 0) + } else if let Some(packed_array) = PackedArray::new(array) { + match packed_array { + PackedArray::Header(array) => { + encoder.encode_length(Category::HeaderArray, array.len())?; + // no alignment necessary, it's a vec of u8 + encoder.extend(&array); + } + PackedArray::I64(array) => { + encoder.encode_length(Category::I64Array, array.len())?; + encoder.align::(); + encoder.extend(bytemuck::cast_slice(&array)); + } + PackedArray::U8(array) => { + encoder.encode_length(Category::U8Array, array.len())?; + // no alignment necessary, it's a vec of u8 + encoder.extend(&array); + } + } + Ok(()) + } else { + let min_size = minimum_array_size_estimate(array); + let encoder_position = encoder.position(); + + if min_size <= u8::MAX as usize { + encoder.encode_length(Category::HetArray, array.len())?; + if encode_array_sized::(encoder, array)? { + return Ok(()); + } + encoder.reset_position(encoder_position); + } + + if min_size <= u16::MAX as usize { + encoder.encode_len_u16(Category::HetArray, u16::try_from(array.len()).unwrap()); + if encode_array_sized::(encoder, array)? { + return Ok(()); + } + encoder.reset_position(encoder_position); + } + + encoder.encode_len_u32(Category::HetArray, array.len())?; + if encode_array_sized::(encoder, array)? { + Ok(()) + } else { + Err(EncodeError::ArrayTooLarge) + } + } +} + +fn encode_array_sized + NoUninit>(encoder: &mut Encoder, array: &JsonArray) -> EncodeResult { + let mut offsets: Vec = Vec::with_capacity(array.len()); + encoder.align::(); + let positions_start = encoder.ring_fence(array.len() * size_of::()); + + let offset_start = encoder.position(); + for value in array.iter() { + let Ok(offset) = T::try_from(encoder.position() - offset_start) else { + return Ok(false); + }; + offsets.push(offset); + encoder.encode_value(value)?; + } + encoder.set_range(positions_start, bytemuck::cast_slice(&offsets)); + Ok(true) +} + +/// Estimate the minimize amount of space needed to encode the object. +/// +/// This is NOT recursive, instead it makes very optimistic guesses about how long arrays and objects might be. +fn minimum_array_size_estimate(array: &JsonArray) -> usize { + array.iter().map(minimum_value_size_estimate).sum() +} + +#[derive(Debug)] +enum PackedArray { + Header(Vec), + U8(Vec), + I64(Vec), +} + +impl PackedArray { + fn new(array: &JsonArray) -> Option { + let mut header_only: Option> = Some(Vec::with_capacity(array.len())); + let mut u8_only: Option> = Some(Vec::with_capacity(array.len())); + let mut i64_only: Option> = Some(Vec::with_capacity(array.len())); + + macro_rules! push_len { + ($cat: expr, $is_empty: expr) => {{ + u8_only = None; + i64_only = None; + if $is_empty { + header_only.as_mut()?.push($cat.encode_with(Length::Empty as u8)); + } else { + header_only = None; + } + }}; + } + + for element in array.iter() { + match element { + JsonValue::Null => { + u8_only = None; + i64_only = None; + header_only + .as_mut()? + .push(Category::Primitive.encode_with(Primitive::Null as u8)); + } + JsonValue::Bool(b) => { + u8_only = None; + i64_only = None; + let right: Primitive = (*b).into(); + header_only.as_mut()?.push(Category::Primitive.encode_with(right as u8)); + } + JsonValue::Int(i) => { + if let Some(i64_only) = &mut i64_only { + i64_only.push(*i); + } + // if u8_only is still alive, push to it if we can + if let Some(u8_only_) = &mut u8_only { + if let Ok(u8) = u8::try_from(*i) { + u8_only_.push(u8); + } else { + u8_only = None; + } + } + // if header_only is still alive, push to it if we can + if let Some(h) = &mut header_only { + if let Some(n) = NumberHint::header_only_i64(*i) { + h.push(Category::Int.encode_with(n as u8)); + } else { + header_only = None; + } + } + } + JsonValue::BigInt(_) => return None, + JsonValue::Float(f) => { + u8_only = None; + i64_only = None; + if let Some(n) = NumberHint::header_only_f64(*f) { + header_only.as_mut()?.push(Category::Float.encode_with(n as u8)); + } else { + header_only = None; + } + } + JsonValue::Str(s) => push_len!(Category::Str, s.is_empty()), + // TODO could use a header only array if it's empty + JsonValue::Array(a) => push_len!(Category::HetArray, a.is_empty()), + JsonValue::Object(o) => push_len!(Category::Object, o.is_empty()), + } + if header_only.is_none() && i64_only.is_none() { + // stop early if neither work + return None; + } + } + // u8 array is preferable to header array as it's the pure binary representation + if let Some(u8_array) = u8_only { + Some(Self::U8(u8_array)) + } else if let Some(header_only) = header_only { + Some(Self::Header(header_only)) + } else { + i64_only.map(Self::I64) + } + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use smallvec::smallvec; + + use crate::compare_json_values; + use crate::decoder::Decoder; + use crate::encoder::Encoder; + use crate::header::Header; + + use super::*; + + /// hack while waiting for + macro_rules! assert_arrays_eq { + ($a: expr, $b: expr) => {{ + assert_eq!($a.len(), $b.len()); + for (a, b) in $a.iter().zip($b.iter()) { + assert!(compare_json_values(a, b)); + } + }}; + } + + #[test] + fn array_round_trip() { + let array = Arc::new(smallvec![JsonValue::Null, JsonValue::Int(123), JsonValue::Bool(false),]); + let min_size = minimum_array_size_estimate(&array); + assert_eq!(min_size, 4); + + let mut encoder = Encoder::new(); + encode_array(&mut encoder, &array).unwrap(); + let bytes: Vec = encoder.into(); + + let mut decoder = Decoder::new(&bytes); + let header = decoder.take_header().unwrap(); + assert_eq!(header, Header::HetArray(3.into())); + + let het_array = HetArray::decode_header(&mut decoder, 3.into()).unwrap(); + assert_eq!(het_array.len(), 3); + + let offsets = match het_array.offsets { + HetArrayOffsets::U8(v) => v, + _ => panic!("expected u8 offsets"), + }; + + assert_eq!(offsets, &[0, 1, 3]); + let decode_array = het_array.to_value(&mut decoder).unwrap(); + assert_arrays_eq!(decode_array, array); + } + + #[test] + fn array_round_trip_empty() { + let array = Arc::new(smallvec![]); + let mut encoder = Encoder::new(); + encode_array(&mut encoder, &array).unwrap(); + let bytes: Vec = encoder.into(); + assert_eq!(bytes.len(), 1); + + let mut decoder = Decoder::new(&bytes); + let header = decoder.take_header().unwrap(); + assert_eq!(header, Header::HetArray(0.into())); + + let het_array = HetArray::decode_header(&mut decoder, 0.into()).unwrap(); + assert_eq!(het_array.len(), 0); + let decode_array = het_array.to_value(&mut decoder).unwrap(); + assert_arrays_eq!(decode_array, array); + } + + #[test] + fn header_array_round_trip() { + let array = Arc::new(smallvec![ + JsonValue::Null, + JsonValue::Bool(false), + JsonValue::Bool(true), + JsonValue::Int(7), + JsonValue::Float(4.0), + ]); + let mut encoder = Encoder::new(); + encode_array(&mut encoder, &array).unwrap(); + let bytes: Vec = encoder.into(); + assert_eq!(bytes.len(), 6); + + let mut decoder = Decoder::new(&bytes); + let header = decoder.take_header().unwrap(); + assert_eq!(header, Header::HeaderArray(5.into())); + + let header_array = header_array_to_json(&mut decoder, 5.into()).unwrap(); + assert_arrays_eq!(header_array, array); + } + + #[test] + fn u8_array_round_trip() { + let array = Arc::new(smallvec![JsonValue::Int(7), JsonValue::Int(4), JsonValue::Int(123),]); + let mut encoder = Encoder::new(); + encode_array(&mut encoder, &array).unwrap(); + let bytes: Vec = encoder.into(); + assert_eq!(bytes.len(), 4); + + let mut decoder = Decoder::new(&bytes); + let header = decoder.take_header().unwrap(); + assert_eq!(header, Header::U8Array(3.into())); + + let mut decoder = Decoder::new(&bytes); + let v = decoder.take_value().unwrap(); + assert!(compare_json_values(&v, &JsonValue::Array(array))); + } + + #[test] + fn i64_array_round_trip() { + let array = Arc::new(smallvec![ + JsonValue::Int(7), + JsonValue::Int(i64::MAX), + JsonValue::Int(i64::MIN), + JsonValue::Int(1234), + JsonValue::Int(1_234_567_890), + ]); + let mut encoder = Encoder::new(); + encode_array(&mut encoder, &array).unwrap(); + let bytes: Vec = encoder.into(); + assert_eq!(bytes.len(), 6 * 8); + + let mut decoder = Decoder::new(&bytes); + let header = decoder.take_header().unwrap(); + assert_eq!(header, Header::I64Array(5.into())); + + let i64_array = i64_array_to_json(&mut decoder, 5.into()).unwrap(); + assert_arrays_eq!(i64_array, array); + } + + #[test] + fn test_u16_array() { + let mut array = vec![JsonValue::Bool(true); 100]; + array.extend(vec![JsonValue::Int(i64::MAX); 100]); + let array = Arc::new(array.into()); + + let mut encoder = Encoder::new(); + encode_array(&mut encoder, &array).unwrap(); + let bytes: Vec = encoder.into(); + + let mut decoder = Decoder::new(&bytes); + let header = decoder.take_header().unwrap(); + assert_eq!(header, Header::HetArray(Length::U16)); + + let het_array = HetArray::decode_header(&mut decoder, Length::U16).unwrap(); + assert_eq!(het_array.len(), 200); + + let offsets = match het_array.offsets { + HetArrayOffsets::U16(v) => v, + _ => panic!("expected U16 offsets"), + }; + assert_eq!(offsets.len(), 200); + assert_eq!(offsets[0], 0); + assert_eq!(offsets[1], 1); + + let mut d = decoder.clone(); + assert!(het_array.get(&mut d, 0)); + assert!(compare_json_values(&d.take_value().unwrap(), &JsonValue::Bool(true))); + + let mut d = decoder.clone(); + assert!(het_array.get(&mut d, 99)); + assert!(compare_json_values(&d.take_value().unwrap(), &JsonValue::Bool(true))); + + let mut d = decoder.clone(); + assert!(het_array.get(&mut d, 100)); + assert!(compare_json_values(&d.take_value().unwrap(), &JsonValue::Int(i64::MAX))); + + let mut d = decoder.clone(); + assert!(het_array.get(&mut d, 199)); + assert!(compare_json_values(&d.take_value().unwrap(), &JsonValue::Int(i64::MAX))); + + let mut d = decoder.clone(); + assert!(!het_array.get(&mut d, 200)); + + let decode_array = het_array.to_value(&mut decoder).unwrap(); + assert_arrays_eq!(decode_array, array); + } + + #[test] + fn test_u32_array() { + let long_string = "a".repeat(u16::MAX as usize); + let array = Arc::new(smallvec![ + JsonValue::Str(long_string.clone().into()), + JsonValue::Int(42), + ]); + + let mut encoder = Encoder::new(); + encode_array(&mut encoder, &array).unwrap(); + let bytes: Vec = encoder.into(); + + let mut decoder = Decoder::new(&bytes); + let header = decoder.take_header().unwrap(); + assert_eq!(header, Header::HetArray(Length::U32)); + + let het_array = HetArray::decode_header(&mut decoder, Length::U32).unwrap(); + assert_eq!(het_array.len(), 2); + + let offsets = match het_array.offsets { + HetArrayOffsets::U32(v) => v, + _ => panic!("expected U32 offsets"), + }; + assert_eq!(offsets, [0, 65538]); + + let mut d = decoder.clone(); + assert!(het_array.get(&mut d, 0)); + assert!(compare_json_values( + &d.take_value().unwrap(), + &JsonValue::Str(long_string.into()) + )); + + let mut d = decoder.clone(); + assert!(het_array.get(&mut d, 1)); + assert!(compare_json_values(&d.take_value().unwrap(), &JsonValue::Int(42))); + } +} diff --git a/crates/batson/src/decoder.rs b/crates/batson/src/decoder.rs new file mode 100644 index 00000000..0a0c1207 --- /dev/null +++ b/crates/batson/src/decoder.rs @@ -0,0 +1,247 @@ +use jiter::JsonValue; +use num_bigint::{BigInt, Sign}; +use std::fmt; +use std::mem::{align_of, size_of}; + +use crate::array::{ + header_array_to_json, header_array_write_to_json, i64_array_slice, i64_array_to_json, u8_array_slice, + u8_array_to_json, HetArray, +}; +use crate::errors::{DecodeError, DecodeErrorType, DecodeResult, ToJsonResult}; +use crate::header::{Header, Length}; +use crate::json_writer::JsonWriter; +use crate::object::Object; + +#[cfg(target_endian = "big")] +compile_error!("big-endian architectures are not yet supported as we use `bytemuck` for zero-copy header decoding."); +// see `decode_slice_as` for more information + +#[derive(Clone)] +pub(crate) struct Decoder<'b> { + bytes: &'b [u8], + pub index: usize, +} + +impl fmt::Debug for Decoder<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let upcoming = self.bytes.get(self.index..).unwrap_or_default(); + f.debug_struct("Decoder") + .field("total_length", &self.bytes.len()) + .field("upcoming_length", &upcoming.len()) + .field("index", &self.index) + .field("upcoming", &upcoming) + .finish() + } +} + +impl<'b> Decoder<'b> { + pub fn new(bytes: &'b [u8]) -> Self { + Self { bytes, index: 0 } + } + + pub fn get_range(&self, start: usize, end: usize) -> DecodeResult<&'b [u8]> { + self.bytes + .get(start..end) + .ok_or_else(|| self.error(DecodeErrorType::EOF)) + } + + /// Get the length of the data that follows a header + pub fn move_to_end(&mut self, header: Header) -> DecodeResult<()> { + match header { + Header::Null | Header::Bool(_) => (), + Header::Int(n) | Header::Float(n) => { + self.index += n.data_length(); + } + Header::Object(l) => { + let obj = Object::decode_header(self, l)?; + obj.move_to_end(self)?; + } + Header::I64Array(l) => { + let length = l.decode(self)?; + self.index += length * size_of::(); + } + Header::HetArray(l) => { + let het = HetArray::decode_header(self, l)?; + het.move_to_end(self)?; + } + Header::IntBig(_, l) | Header::Str(l) | Header::HeaderArray(l) | Header::U8Array(l) => { + self.index += l.decode(self)?; + } + }; + Ok(()) + } + + pub fn take_header(&mut self) -> DecodeResult
{ + let byte = self.next().ok_or_else(|| self.eof())?; + Header::decode(byte, self) + } + + pub fn align(&mut self) { + let align = align_of::(); + // I've checked and this is equivalent to: `self.index = self.index + align - (self.index % align)` + // is it actually faster? + self.index = (self.index + align - 1) & !(align - 1); + } + + pub fn take_value(&mut self) -> DecodeResult> { + match self.take_header()? { + Header::Null => Ok(JsonValue::Null), + Header::Bool(b) => Ok(JsonValue::Bool(b)), + Header::Int(n) => n.decode_i64(self).map(JsonValue::Int), + Header::IntBig(s, l) => self.take_big_int(s, l).map(JsonValue::BigInt), + Header::Float(n) => n.decode_f64(self).map(JsonValue::Float), + Header::Str(l) => self.take_str_len(l).map(|s| JsonValue::Str(s.into())), + Header::Object(length) => { + let obj = Object::decode_header(self, length)?; + obj.to_value(self).map(JsonValue::Object) + } + Header::HetArray(length) => { + let het = HetArray::decode_header(self, length)?; + het.to_value(self).map(JsonValue::Array) + } + Header::U8Array(length) => u8_array_to_json(self, length).map(JsonValue::Array), + Header::HeaderArray(length) => header_array_to_json(self, length).map(JsonValue::Array), + Header::I64Array(length) => i64_array_to_json(self, length).map(JsonValue::Array), + } + } + + pub fn write_json(&mut self, writer: &mut JsonWriter) -> ToJsonResult<()> { + match self.take_header()? { + Header::Null => writer.write_null(), + Header::Bool(b) => writer.write_value(b)?, + Header::Int(n) => { + let i = n.decode_i64(self)?; + writer.write_value(i)?; + } + Header::IntBig(s, l) => { + let int = self.take_big_int(s, l)?; + writer.write_value(int)?; + } + Header::Float(n) => { + let f = n.decode_f64(self)?; + writer.write_value(f)?; + } + Header::Str(l) => { + let s = self.take_str_len(l)?; + writer.write_value(s)?; + } + Header::Object(length) => { + let obj = Object::decode_header(self, length)?; + obj.write_json(self, writer)?; + } + Header::HetArray(length) => { + let het = HetArray::decode_header(self, length)?; + het.write_json(self, writer)?; + } + Header::U8Array(length) => { + let a = u8_array_slice(self, length)?; + writer.write_seq(a.iter())?; + } + Header::HeaderArray(length) => header_array_write_to_json(self, length, writer)?, + Header::I64Array(length) => { + let a = i64_array_slice(self, length)?; + writer.write_seq(a.iter())?; + } + }; + Ok(()) + } + + pub fn take_slice(&mut self, size: usize) -> DecodeResult<&'b [u8]> { + let end = self.index + size; + let s = self.bytes.get(self.index..end).ok_or_else(|| self.eof())?; + self.index = end; + Ok(s) + } + + pub fn take_slice_as(&mut self, length: usize) -> DecodeResult<&'b [T]> { + self.align::(); + let size = length * size_of::(); + let end = self.index + size; + let s = self.bytes.get(self.index..end).ok_or_else(|| self.eof())?; + + let t: &[T] = bytemuck::try_cast_slice(s).map_err(|e| self.error(DecodeErrorType::PodCastError(e)))?; + + self.index = end; + Ok(t) + } + + fn take_str_len(&mut self, length: Length) -> DecodeResult<&'b str> { + let len = length.decode(self)?; + self.take_str(len) + } + + pub fn take_str(&mut self, length: usize) -> DecodeResult<&'b str> { + if length == 0 { + Ok("") + } else { + let end = self.index + length; + let slice = self.bytes.get(self.index..end).ok_or_else(|| self.eof())?; + let s = simdutf8::basic::from_utf8(slice).map_err(|e| DecodeError::from_utf8_error(self.index, e))?; + self.index = end; + Ok(s) + } + } + + pub fn take_u8(&mut self) -> DecodeResult { + self.next().ok_or_else(|| self.eof()) + } + + pub fn take_u16(&mut self) -> DecodeResult { + let slice = self.take_slice(2)?; + Ok(u16::from_le_bytes(slice.try_into().unwrap())) + } + + pub fn take_u32(&mut self) -> DecodeResult { + let slice = self.take_slice(4)?; + Ok(u32::from_le_bytes(slice.try_into().unwrap())) + } + + pub fn take_i8(&mut self) -> DecodeResult { + match self.next() { + Some(byte) => Ok(byte as i8), + None => Err(self.eof()), + } + } + + pub fn take_i32(&mut self) -> DecodeResult { + let slice = self.take_slice(4)?; + Ok(i32::from_le_bytes(slice.try_into().unwrap())) + } + + pub fn take_i64(&mut self) -> DecodeResult { + let slice = self.take_slice(8)?; + Ok(i64::from_le_bytes(slice.try_into().unwrap())) + } + + pub fn take_big_int(&mut self, sign: Sign, length: Length) -> DecodeResult { + let size = length.decode(self)?; + let slice = self.take_slice(size)?; + Ok(BigInt::from_bytes_le(sign, slice)) + } + + pub fn take_f64(&mut self) -> DecodeResult { + let slice = self.take_slice(8)?; + Ok(f64::from_le_bytes(slice.try_into().unwrap())) + } + + pub fn eof(&self) -> DecodeError { + self.error(DecodeErrorType::EOF) + } + + pub fn error(&self, error_type: DecodeErrorType) -> DecodeError { + DecodeError::new(self.index, error_type) + } +} + +impl<'b> Iterator for Decoder<'b> { + type Item = u8; + + fn next(&mut self) -> Option { + if let Some(byte) = self.bytes.get(self.index) { + self.index += 1; + Some(*byte) + } else { + None + } + } +} diff --git a/crates/batson/src/encoder.rs b/crates/batson/src/encoder.rs new file mode 100644 index 00000000..c322503f --- /dev/null +++ b/crates/batson/src/encoder.rs @@ -0,0 +1,212 @@ +use jiter::{JsonArray, JsonObject, JsonValue}; +use num_bigint::{BigInt, Sign}; +use std::mem::align_of; + +use crate::array::encode_array; +use crate::errors::{EncodeError, EncodeResult}; +use crate::header::{Category, Length, NumberHint, Primitive}; +use crate::object::encode_object; + +#[derive(Debug)] +pub(crate) struct Encoder { + data: Vec, +} + +impl Encoder { + pub fn new() -> Self { + Self { data: Vec::new() } + } + + pub fn with_capacity(capacity: usize) -> Self { + Self { + data: Vec::with_capacity(capacity), + } + } + + pub fn align(&mut self) { + let align = align_of::(); + // same calculation as in `Decoder::align` + let new_len = (self.data.len() + align - 1) & !(align - 1); + self.data.resize(new_len, 0); + } + + pub fn ring_fence(&mut self, size: usize) -> usize { + let start = self.data.len(); + self.data.resize(start + size, 0); + start + } + + pub fn encode_value(&mut self, value: &JsonValue<'_>) -> EncodeResult<()> { + match value { + JsonValue::Null => self.encode_null(), + JsonValue::Bool(b) => self.encode_bool(*b), + JsonValue::Int(int) => self.encode_i64(*int), + JsonValue::BigInt(big_int) => self.encode_big_int(big_int)?, + JsonValue::Float(f) => self.encode_f64(*f), + JsonValue::Str(s) => self.encode_str(s.as_ref())?, + JsonValue::Array(array) => self.encode_array(array)?, + JsonValue::Object(obj) => self.encode_object(obj)?, + }; + Ok(()) + } + + pub fn position(&self) -> usize { + self.data.len() + } + + pub fn reset_position(&mut self, position: usize) { + self.data.truncate(position); + } + + pub fn encode_i64(&mut self, value: i64) { + if (0..=10).contains(&value) { + self.push(Category::Int.encode_with(value as u8)); + } else if let Ok(size_8) = i8::try_from(value) { + self.push(Category::Int.encode_with(NumberHint::Size8 as u8)); + self.extend(&size_8.to_le_bytes()); + } else if let Ok(size_32) = i32::try_from(value) { + self.push(Category::Int.encode_with(NumberHint::Size32 as u8)); + self.extend(&size_32.to_le_bytes()); + } else { + self.push(Category::Int.encode_with(NumberHint::Size64 as u8)); + self.extend(&value.to_le_bytes()); + } + } + + pub fn extend(&mut self, s: &[u8]) { + self.data.extend_from_slice(s); + } + + pub fn set_range(&mut self, start: usize, s: &[u8]) { + self.data[start..start + s.len()].as_mut().copy_from_slice(s); + } + + pub fn encode_length(&mut self, cat: Category, len: usize) -> EncodeResult<()> { + match len { + 0 => self.push(cat.encode_with(Length::Empty as u8)), + 1 => self.push(cat.encode_with(Length::One as u8)), + 2 => self.push(cat.encode_with(Length::Two as u8)), + 3 => self.push(cat.encode_with(Length::Three as u8)), + 4 => self.push(cat.encode_with(Length::Four as u8)), + 5 => self.push(cat.encode_with(Length::Five as u8)), + 6 => self.push(cat.encode_with(Length::Six as u8)), + 7 => self.push(cat.encode_with(Length::Seven as u8)), + 8 => self.push(cat.encode_with(Length::Eight as u8)), + 9 => self.push(cat.encode_with(Length::Nine as u8)), + 10 => self.push(cat.encode_with(Length::Ten as u8)), + _ => { + if let Ok(s) = u8::try_from(len) { + self.push(cat.encode_with(Length::U8 as u8)); + self.push(s); + } else if let Ok(int) = u16::try_from(len) { + self.encode_len_u16(cat, int); + } else { + self.encode_len_u32(cat, len)?; + } + } + } + Ok(()) + } + + pub fn encode_len_u16(&mut self, cat: Category, int: u16) { + self.push(cat.encode_with(Length::U16 as u8)); + self.extend(&int.to_le_bytes()); + } + + pub fn encode_len_u32(&mut self, cat: Category, len: usize) -> EncodeResult<()> { + let int = u32::try_from(len).map_err(|_| match cat { + Category::Str => EncodeError::StrTooLong, + Category::HetArray => EncodeError::ObjectTooLarge, + _ => EncodeError::ArrayTooLarge, + })?; + self.push(cat.encode_with(Length::U32 as u8)); + self.extend(&int.to_le_bytes()); + Ok(()) + } + + fn encode_null(&mut self) { + let h = Category::Primitive.encode_with(Primitive::Null as u8); + self.push(h); + } + + fn encode_bool(&mut self, bool: bool) { + let right: Primitive = bool.into(); + let h = Category::Primitive.encode_with(right as u8); + self.push(h); + } + + fn encode_f64(&mut self, value: f64) { + match value { + 0.0 => self.push(Category::Float.encode_with(NumberHint::Zero as u8)), + 1.0 => self.push(Category::Float.encode_with(NumberHint::One as u8)), + 2.0 => self.push(Category::Float.encode_with(NumberHint::Two as u8)), + 3.0 => self.push(Category::Float.encode_with(NumberHint::Three as u8)), + 4.0 => self.push(Category::Float.encode_with(NumberHint::Four as u8)), + 5.0 => self.push(Category::Float.encode_with(NumberHint::Five as u8)), + 6.0 => self.push(Category::Float.encode_with(NumberHint::Six as u8)), + 7.0 => self.push(Category::Float.encode_with(NumberHint::Seven as u8)), + 8.0 => self.push(Category::Float.encode_with(NumberHint::Eight as u8)), + 9.0 => self.push(Category::Float.encode_with(NumberHint::Nine as u8)), + 10.0 => self.push(Category::Float.encode_with(NumberHint::Ten as u8)), + _ => { + // should we do something with f32 here? + self.push(Category::Float.encode_with(NumberHint::Size64 as u8)); + self.extend(&value.to_le_bytes()); + } + } + } + + fn encode_big_int(&mut self, int: &BigInt) -> EncodeResult<()> { + let (sign, bytes) = int.to_bytes_le(); + match sign { + Sign::Minus => self.encode_length(Category::BigIntNeg, bytes.len())?, + _ => self.encode_length(Category::BigIntPos, bytes.len())?, + } + self.extend(&bytes); + Ok(()) + } + + fn encode_str(&mut self, s: &str) -> EncodeResult<()> { + self.encode_length(Category::Str, s.len())?; + self.extend(s.as_bytes()); + Ok(()) + } + + fn encode_object(&mut self, object: &JsonObject) -> EncodeResult<()> { + encode_object(self, object) + } + + fn encode_array(&mut self, array: &JsonArray) -> EncodeResult<()> { + encode_array(self, array) + } + + fn push(&mut self, h: u8) { + self.data.push(h); + } +} + +impl From for Vec { + fn from(encoder: Encoder) -> Self { + encoder.data + } +} + +#[cfg(test)] +mod test { + use super::*; + use crate::decoder::Decoder; + use crate::header::Header; + + #[test] + fn encode_int() { + let mut enc = Encoder::new(); + enc.encode_i64(0); + let h = Decoder::new(&enc.data).take_header().unwrap(); + assert_eq!(h, Header::Int(NumberHint::Zero)); + + let mut enc = Encoder::new(); + enc.encode_i64(7); + let h = Decoder::new(&enc.data).take_header().unwrap(); + assert_eq!(h, Header::Int(NumberHint::Seven)); + } +} diff --git a/crates/batson/src/errors.rs b/crates/batson/src/errors.rs new file mode 100644 index 00000000..257ab805 --- /dev/null +++ b/crates/batson/src/errors.rs @@ -0,0 +1,94 @@ +use std::fmt; + +use bytemuck::PodCastError; +use serde::ser::Error; +use simdutf8::basic::Utf8Error; + +pub type EncodeResult = Result; + +#[derive(Debug, Copy, Clone)] +pub enum EncodeError { + StrTooLong, + ObjectTooLarge, + ArrayTooLarge, +} + +pub type DecodeResult = Result; + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct DecodeError { + pub index: usize, + pub error_type: DecodeErrorType, +} + +impl fmt::Display for DecodeError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "Error at index {}: {}", self.index, self.error_type) + } +} + +impl From for serde_json::Error { + fn from(e: DecodeError) -> Self { + serde_json::Error::custom(e.to_string()) + } +} + +impl DecodeError { + pub fn new(index: usize, error_type: DecodeErrorType) -> Self { + Self { index, error_type } + } + + pub fn from_utf8_error(index: usize, error: Utf8Error) -> Self { + Self::new(index, DecodeErrorType::Utf8Error(error)) + } +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub enum DecodeErrorType { + EOF, + ObjectBodyIndexInvalid, + HeaderInvalid { value: u8, ty: &'static str }, + Utf8Error(Utf8Error), + PodCastError(PodCastError), +} + +impl fmt::Display for DecodeErrorType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + DecodeErrorType::EOF => write!(f, "Unexpected end of file"), + DecodeErrorType::ObjectBodyIndexInvalid => write!(f, "Object body index is invalid"), + DecodeErrorType::HeaderInvalid { value, ty } => { + write!(f, "Header value {value} is invalid for type {ty}") + } + DecodeErrorType::Utf8Error(e) => write!(f, "UTF-8 error: {e}"), + DecodeErrorType::PodCastError(e) => write!(f, "Pod cast error: {e}"), + } + } +} + +pub type ToJsonResult = Result; + +#[derive(Debug)] +pub enum ToJsonError { + Str(&'static str), + DecodeError(DecodeError), + JsonError(serde_json::Error), +} + +impl From<&'static str> for ToJsonError { + fn from(e: &'static str) -> Self { + Self::Str(e) + } +} + +impl From for ToJsonError { + fn from(e: DecodeError) -> Self { + Self::DecodeError(e) + } +} + +impl From for ToJsonError { + fn from(e: serde_json::Error) -> Self { + Self::JsonError(e) + } +} diff --git a/crates/batson/src/get.rs b/crates/batson/src/get.rs new file mode 100644 index 00000000..800efd2e --- /dev/null +++ b/crates/batson/src/get.rs @@ -0,0 +1,304 @@ +#![allow(clippy::module_name_repetitions)] + +use crate::array::{header_array_get, i64_array_get, u8_array_get, HetArray}; +use crate::decoder::Decoder; +use crate::encoder::Encoder; +use crate::errors::{DecodeError, DecodeResult}; +use crate::header::Header; +use crate::object::Object; +use std::borrow::Cow; + +#[derive(Debug)] +pub enum BatsonPath<'s> { + Key(&'s str), + Index(usize), +} + +impl From for BatsonPath<'_> { + fn from(index: usize) -> Self { + Self::Index(index) + } +} + +impl<'s> From<&'s str> for BatsonPath<'s> { + fn from(key: &'s str) -> Self { + Self::Key(key) + } +} + +pub fn get_bool(bytes: &[u8], path: &[BatsonPath]) -> DecodeResult> { + GetValue::get(bytes, path).map(|v| v.and_then(Into::into)) +} + +pub fn get_str<'b>(bytes: &'b [u8], path: &[BatsonPath]) -> DecodeResult> { + get_try_into(bytes, path) +} + +pub fn get_int(bytes: &[u8], path: &[BatsonPath]) -> DecodeResult> { + get_try_into(bytes, path) +} + +pub fn get_batson<'b>(bytes: &'b [u8], path: &[BatsonPath]) -> DecodeResult>> { + if let Some(v) = GetValue::get(bytes, path)? { + v.into_batson().map(Some) + } else { + Ok(None) + } +} + +pub fn contains(bytes: &[u8], path: &[BatsonPath]) -> DecodeResult { + GetValue::get(bytes, path).map(|v| v.is_some()) +} + +pub fn get_length(bytes: &[u8], path: &[BatsonPath]) -> DecodeResult> { + if let Some(v) = GetValue::get(bytes, path)? { + v.into_length() + } else { + Ok(None) + } +} + +fn get_try_into<'b, T>(bytes: &'b [u8], path: &[BatsonPath]) -> DecodeResult> +where + Option: TryFrom, Error = DecodeError>, +{ + if let Some(v) = GetValue::get(bytes, path)? { + v.try_into() + } else { + Ok(None) + } +} + +#[derive(Debug)] +enum GetValue<'b> { + Header(Decoder<'b>, Header), + U8(u8), + I64(i64), +} + +impl<'b> GetValue<'b> { + fn get(bytes: &'b [u8], path: &[BatsonPath]) -> DecodeResult> { + let mut decoder = Decoder::new(bytes); + let mut opt_header: Option
= Some(decoder.take_header()?); + let mut value: Option = None; + for element in path { + let Some(header) = opt_header.take() else { + return Ok(None); + }; + match element { + BatsonPath::Key(key) => { + if let Header::Object(length) = header { + let object = Object::decode_header(&mut decoder, length)?; + if object.get(&mut decoder, key)? { + opt_header = Some(decoder.take_header()?); + } + } + } + BatsonPath::Index(index) => match header { + Header::HeaderArray(length) => { + opt_header = header_array_get(&mut decoder, length, *index)?; + } + Header::U8Array(length) => { + if let Some(u8_value) = u8_array_get(&mut decoder, length, *index)? { + value = Some(GetValue::U8(u8_value)); + } + } + Header::I64Array(length) => { + if let Some(i64_value) = i64_array_get(&mut decoder, length, *index)? { + value = Some(GetValue::I64(i64_value)); + } + } + Header::HetArray(length) => { + let a = HetArray::decode_header(&mut decoder, length)?; + if a.get(&mut decoder, *index) { + opt_header = Some(decoder.take_header()?); + } + } + _ => {} + }, + } + } + if let Some(header) = opt_header { + Ok(Some(Self::Header(decoder, header))) + } else if let Some(value) = value { + Ok(Some(value)) + } else { + Ok(None) + } + } + + fn header(self) -> Option
{ + match self { + Self::Header(_, header) => Some(header), + _ => None, + } + } + + fn into_length(self) -> DecodeResult> { + let Self::Header(mut decoder, header) = self else { + return Ok(None); + }; + match header { + Header::Str(length) + | Header::Object(length) + | Header::HeaderArray(length) + | Header::U8Array(length) + | Header::I64Array(length) + | Header::HetArray(length) => length.decode(&mut decoder).map(Some), + _ => Ok(None), + } + } + + fn into_batson(self) -> DecodeResult> { + match self { + Self::Header(mut decoder, header) => { + let start = decoder.index - 1; + decoder.move_to_end(header)?; + let end = decoder.index; + decoder.get_range(start, end).map(Cow::Borrowed) + } + Self::U8(int) => { + let mut encoder = Encoder::with_capacity(2); + encoder.encode_i64(int.into()); + Ok(Cow::Owned(encoder.into())) + } + Self::I64(int) => { + let mut encoder = Encoder::with_capacity(9); + encoder.encode_i64(int); + Ok(Cow::Owned(encoder.into())) + } + } + } +} + +impl From> for Option { + fn from(v: GetValue) -> Self { + v.header().and_then(Header::into_bool) + } +} + +impl<'b> TryFrom> for Option<&'b str> { + type Error = DecodeError; + + fn try_from(v: GetValue<'b>) -> DecodeResult { + match v { + GetValue::Header(mut decoder, Header::Str(length)) => { + let length = length.decode(&mut decoder)?; + decoder.take_str(length).map(Some) + } + _ => Ok(None), + } + } +} + +impl TryFrom> for Option { + type Error = DecodeError; + + fn try_from(v: GetValue) -> DecodeResult { + match v { + GetValue::Header(mut decoder, Header::Int(n)) => n.decode_i64(&mut decoder).map(Some), + GetValue::I64(i64) => Ok(Some(i64)), + GetValue::U8(u8) => Ok(Some(i64::from(u8))), + GetValue::Header(..) => Ok(None), + } + } +} + +#[cfg(test)] +mod test { + use crate::encode_from_json; + use crate::header::{Header, NumberHint}; + use jiter::{JsonValue, LazyIndexMap}; + use smallvec::smallvec; + use std::sync::Arc; + + use super::*; + + #[test] + fn get_object() { + let v: JsonValue<'static> = JsonValue::Object(Arc::new(LazyIndexMap::from(vec![ + ("null".into(), JsonValue::Null), + ("true".into(), JsonValue::Bool(true)), + ]))); + let bytes = encode_from_json(&v).unwrap(); + + let v = GetValue::get(&bytes, &["null".into()]).unwrap().unwrap(); + assert_eq!(v.header(), Some(Header::Null)); + let v = GetValue::get(&bytes, &["true".into()]).unwrap().unwrap(); + assert_eq!(v.header(), Some(Header::Bool(true))); + + assert!(GetValue::get(&bytes, &["foo".into()]).unwrap().is_none()); + assert!(GetValue::get(&bytes, &[1.into()]).unwrap().is_none()); + assert!(GetValue::get(&bytes, &["null".into(), 1.into()]).unwrap().is_none()); + } + + #[test] + fn get_header_array() { + let v: JsonValue<'static> = JsonValue::Array(Arc::new(smallvec![JsonValue::Null, JsonValue::Bool(true),])); + let bytes = encode_from_json(&v).unwrap(); + + let v = GetValue::get(&bytes, &[0.into()]).unwrap().unwrap(); + assert_eq!(v.header(), Some(Header::Null)); + + let v = GetValue::get(&bytes, &[1.into()]).unwrap().unwrap(); + assert_eq!(v.header(), Some(Header::Bool(true))); + + assert!(GetValue::get(&bytes, &["foo".into()]).unwrap().is_none()); + assert!(GetValue::get(&bytes, &[2.into()]).unwrap().is_none()); + } + + #[test] + fn get_het_array() { + let v: JsonValue<'static> = + JsonValue::Array(Arc::new( + smallvec![JsonValue::Int(42), JsonValue::Str("foobar".into()),], + )); + let bytes = encode_from_json(&v).unwrap(); + + let v = GetValue::get(&bytes, &[0.into()]).unwrap().unwrap(); + assert_eq!(v.header(), Some(Header::Int(NumberHint::Size8))); + } + + fn value_u8(v: &GetValue) -> Option { + match v { + GetValue::U8(u8) => Some(*u8), + _ => None, + } + } + + fn value_i64(v: &GetValue) -> Option { + match v { + GetValue::I64(i64) => Some(*i64), + _ => None, + } + } + + #[test] + fn get_u8_array() { + let v: JsonValue<'static> = JsonValue::Array(Arc::new(smallvec![JsonValue::Int(42), JsonValue::Int(255),])); + let bytes = encode_from_json(&v).unwrap(); + + let v = GetValue::get(&bytes, &[0.into()]).unwrap().unwrap(); + assert_eq!(value_u8(&v), Some(42)); + + let v = GetValue::get(&bytes, &[1.into()]).unwrap().unwrap(); + assert_eq!(value_u8(&v), Some(255)); + + assert!(GetValue::get(&bytes, &[2.into()]).unwrap().is_none()); + } + + #[test] + fn get_i64_array() { + let v: JsonValue<'static> = + JsonValue::Array(Arc::new(smallvec![JsonValue::Int(42), JsonValue::Int(i64::MAX),])); + let bytes = encode_from_json(&v).unwrap(); + + let v = GetValue::get(&bytes, &[0.into()]).unwrap().unwrap(); + assert_eq!(value_i64(&v), Some(42)); + + let v = GetValue::get(&bytes, &[1.into()]).unwrap().unwrap(); + assert_eq!(value_i64(&v), Some(i64::MAX)); + + assert!(GetValue::get(&bytes, &[2.into()]).unwrap().is_none()); + } +} diff --git a/crates/batson/src/header.rs b/crates/batson/src/header.rs new file mode 100644 index 00000000..d9d1e8fb --- /dev/null +++ b/crates/batson/src/header.rs @@ -0,0 +1,366 @@ +use std::sync::Arc; + +use jiter::{JsonValue, LazyIndexMap}; +use num_bigint::Sign; +use smallvec::smallvec; + +use crate::decoder::Decoder; +use crate::errors::{DecodeErrorType, DecodeResult}; +use crate::json_writer::JsonWriter; +use crate::ToJsonResult; + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub(crate) enum Header { + Null, + Bool(bool), + Int(NumberHint), + IntBig(Sign, Length), + Float(NumberHint), + Str(Length), + Object(Length), + // array types in order of their complexity + #[allow(clippy::enum_variant_names)] + HeaderArray(Length), + U8Array(Length), + I64Array(Length), + HetArray(Length), +} + +impl Header { + /// Decode the next byte from a decoder into a header value + pub fn decode(byte: u8, d: &Decoder) -> DecodeResult { + let (left, right) = split_byte(byte); + let cat = Category::from_u8(left, d)?; + match cat { + Category::Primitive => Primitive::from_u8(right, d).map(Primitive::header_value), + Category::Int => NumberHint::from_u8(right, d).map(Self::Int), + Category::BigIntPos => Length::from_u8(right, d).map(|l| Self::IntBig(Sign::Plus, l)), + Category::BigIntNeg => Length::from_u8(right, d).map(|l| Self::IntBig(Sign::Minus, l)), + Category::Float => NumberHint::from_u8(right, d).map(Self::Float), + Category::Str => Length::from_u8(right, d).map(Self::Str), + Category::Object => Length::from_u8(right, d).map(Self::Object), + Category::HeaderArray => Length::from_u8(right, d).map(Self::HeaderArray), + Category::U8Array => Length::from_u8(right, d).map(Self::U8Array), + Category::I64Array => Length::from_u8(right, d).map(Self::I64Array), + Category::HetArray => Length::from_u8(right, d).map(Self::HetArray), + } + } + + /// TODO `'static` should be okay as return lifetime, I don't know why it's not + pub fn header_as_value<'b>(self, _: &Decoder<'b>) -> JsonValue<'b> { + match self { + Header::Null => JsonValue::Null, + Header::Bool(b) => JsonValue::Bool(b), + Header::Int(n) => JsonValue::Int(n.decode_i64_header()), + Header::IntBig(..) => unreachable!("Big ints are not supported as header only values"), + Header::Float(n) => JsonValue::Float(n.decode_f64_header()), + Header::Str(_) => JsonValue::Str("".into()), + Header::Object(_) => JsonValue::Object(Arc::new(LazyIndexMap::default())), + _ => JsonValue::Array(Arc::new(smallvec![])), + } + } + + pub fn write_json_header_only(self, writer: &mut JsonWriter) -> ToJsonResult<()> { + match self { + Header::Null => writer.write_null(), + Header::Bool(b) => writer.write_value(b)?, + Header::Int(n) => writer.write_value(n.decode_i64_header())?, + Header::IntBig(..) => return Err("Big ints are not supported as header only values".into()), + Header::Float(n) => writer.write_value(n.decode_f64_header())?, + // TODO check the + Header::Str(len) => { + len.check_empty()?; + writer.write_value("")?; + } + Header::Object(len) => { + len.check_empty()?; + writer.write_empty_object(); + } + Self::HeaderArray(len) | Self::U8Array(len) | Self::I64Array(len) | Self::HetArray(len) => { + len.check_empty()?; + writer.write_empty_array(); + } + } + Ok(()) + } + + pub fn into_bool(self) -> Option { + match self { + Header::Bool(b) => Some(b), + _ => None, + } + } +} + +macro_rules! impl_from_u8 { + ($header_enum:ty, $max_value:literal) => { + impl $header_enum { + fn from_u8(value: u8, p: &Decoder) -> DecodeResult { + if value <= $max_value { + Ok(unsafe { std::mem::transmute::(value) }) + } else { + Err(p.error(DecodeErrorType::HeaderInvalid { + value, + ty: stringify!($header_enum), + })) + } + } + } + }; +} + +/// Left half of the first header byte determines the category of the value +/// Up to 16 categories are possible +#[derive(Debug, Copy, Clone)] +pub(crate) enum Category { + Primitive = 0, + Int = 1, + BigIntPos = 2, + BigIntNeg = 3, + Float = 4, + Str = 5, + Object = 6, + HeaderArray = 7, + U8Array = 8, + I64Array = 9, + HetArray = 10, +} +impl_from_u8!(Category, 10); + +impl Category { + pub fn encode_with(self, right: u8) -> u8 { + let left = self as u8; + (left << 4) | right + } +} + +#[derive(Debug, Copy, Clone)] +pub(crate) enum Primitive { + Null = 0, + True = 1, + False = 2, +} +impl_from_u8!(Primitive, 2); + +impl From for Primitive { + fn from(value: bool) -> Self { + if value { + Self::True + } else { + Self::False + } + } +} + +impl Primitive { + fn header_value(self) -> Header { + match self { + Primitive::Null => Header::Null, + Primitive::True => Header::Bool(true), + Primitive::False => Header::Bool(false), + } + } +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub(crate) enum NumberHint { + Zero = 0, + One = 1, + Two = 2, + Three = 3, + Four = 4, + Five = 5, + Six = 6, + Seven = 7, + Eight = 8, + Nine = 9, + Ten = 10, + // larger numbers + Size8 = 11, + Size32 = 12, + Size64 = 13, +} +impl_from_u8!(NumberHint, 13); + +impl NumberHint { + pub fn decode_i64(self, d: &mut Decoder) -> DecodeResult { + match self { + NumberHint::Size8 => d.take_i8().map(i64::from), + NumberHint::Size32 => d.take_i32().map(i64::from), + NumberHint::Size64 => d.take_i64(), + // TODO check this has same performance as inline match + _ => Ok(self.decode_i64_header()), + } + } + + #[inline] + pub fn decode_i64_header(self) -> i64 { + match self { + NumberHint::Zero => 0, + NumberHint::One => 1, + NumberHint::Two => 2, + NumberHint::Three => 3, + NumberHint::Four => 4, + NumberHint::Five => 5, + NumberHint::Six => 6, + NumberHint::Seven => 7, + NumberHint::Eight => 8, + NumberHint::Nine => 9, + NumberHint::Ten => 10, + _ => unreachable!("Expected concrete value, got {self:?}"), + } + } + + pub fn decode_f64(self, d: &mut Decoder) -> DecodeResult { + match self { + // f8 doesn't exist, and currently we don't use f32 anywhere + NumberHint::Size8 | NumberHint::Size32 => Err(d.error(DecodeErrorType::HeaderInvalid { + value: self as u8, + ty: "f64", + })), + NumberHint::Size64 => d.take_f64(), + // TODO check this has same performance as inline match + _ => Ok(self.decode_f64_header()), + } + } + + #[inline] + fn decode_f64_header(self) -> f64 { + match self { + NumberHint::Zero => 0.0, + NumberHint::One => 1.0, + NumberHint::Two => 2.0, + NumberHint::Three => 3.0, + NumberHint::Four => 4.0, + NumberHint::Five => 5.0, + NumberHint::Six => 6.0, + NumberHint::Seven => 7.0, + NumberHint::Eight => 8.0, + NumberHint::Nine => 9.0, + NumberHint::Ten => 10.0, + _ => unreachable!("Expected concrete value, got {self:?}"), + } + } + + pub fn header_only_i64(int: i64) -> Option { + match int { + 0 => Some(NumberHint::Zero), + 1 => Some(NumberHint::One), + 2 => Some(NumberHint::Two), + 3 => Some(NumberHint::Three), + 4 => Some(NumberHint::Four), + 5 => Some(NumberHint::Five), + 6 => Some(NumberHint::Six), + 7 => Some(NumberHint::Seven), + 8 => Some(NumberHint::Eight), + 9 => Some(NumberHint::Nine), + 10 => Some(NumberHint::Ten), + _ => None, + } + } + + pub fn header_only_f64(float: f64) -> Option { + match float { + 0.0 => Some(NumberHint::Zero), + 1.0 => Some(NumberHint::One), + 2.0 => Some(NumberHint::Two), + 3.0 => Some(NumberHint::Three), + 4.0 => Some(NumberHint::Four), + 5.0 => Some(NumberHint::Five), + 6.0 => Some(NumberHint::Six), + 7.0 => Some(NumberHint::Seven), + 8.0 => Some(NumberHint::Eight), + 9.0 => Some(NumberHint::Nine), + 10.0 => Some(NumberHint::Ten), + _ => None, + } + } + + /// Get the length of the data that follows the header + pub fn data_length(self) -> usize { + match self { + Self::Size8 => 1, + Self::Size32 => 4, + Self::Size64 => 8, + _ => 0, + } + } +} + +/// String, object, and array lengths +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub(crate) enum Length { + Empty = 0, + One = 1, + Two = 2, + Three = 3, + Four = 4, + Five = 5, + Six = 6, + Seven = 7, + Eight = 8, + Nine = 9, + Ten = 10, + // larger numbers + U8 = 11, + U16 = 12, + U32 = 13, +} +impl_from_u8!(Length, 13); + +impl From for Length { + fn from(len: u64) -> Self { + match len { + 0 => Self::Empty, + 1 => Self::One, + 2 => Self::Two, + 3 => Self::Three, + 4 => Self::Four, + 5 => Self::Five, + 6 => Self::Six, + 7 => Self::Seven, + 8 => Self::Eight, + 9 => Self::Nine, + 10 => Self::Ten, + len if len <= u64::from(u8::MAX) => Self::U8, + len if len <= u64::from(u16::MAX) => Self::U16, + _ => Self::U32, + } + } +} + +impl Length { + pub fn decode(self, d: &mut Decoder) -> DecodeResult { + match self { + Self::Empty => Ok(0), + Self::One => Ok(1), + Self::Two => Ok(2), + Self::Three => Ok(3), + Self::Four => Ok(4), + Self::Five => Ok(5), + Self::Six => Ok(6), + Self::Seven => Ok(7), + Self::Eight => Ok(8), + Self::Nine => Ok(9), + Self::Ten => Ok(10), + Self::U8 => d.take_u8().map(|s| s as usize), + Self::U16 => d.take_u16().map(|s| s as usize), + Self::U32 => d.take_u32().map(|s| s as usize), + } + } + + pub fn check_empty(self) -> ToJsonResult<()> { + if matches!(self, Self::Empty) { + Ok(()) + } else { + Err("Expected empty length, got non-empty".into()) + } + } +} + +/// Split a byte into two 4-bit halves - u8 numbers with a range of 0-15 +fn split_byte(byte: u8) -> (u8, u8) { + let left = byte >> 4; // Shift the byte right by 4 bits + let right = byte & 0b0000_1111; // Mask the byte with 00001111 + (left, right) +} diff --git a/crates/batson/src/json_writer.rs b/crates/batson/src/json_writer.rs new file mode 100644 index 00000000..ebaa508e --- /dev/null +++ b/crates/batson/src/json_writer.rs @@ -0,0 +1,126 @@ +use num_bigint::BigInt; +use serde::ser::Serializer as _; +use serde_json::ser::Serializer; + +use crate::errors::ToJsonResult; + +pub(crate) struct JsonWriter { + vec: Vec, +} + +impl JsonWriter { + pub fn new() -> Self { + Self { + vec: Vec::with_capacity(128), + } + } + + pub fn write_null(&mut self) { + self.vec.extend_from_slice(b"null"); + } + + #[allow(clippy::needless_pass_by_value)] + pub fn write_value(&mut self, v: impl WriteJson) -> ToJsonResult<()> { + v.write_json(self) + } + + pub fn write_seq<'a>(&mut self, mut v: impl Iterator) -> ToJsonResult<()> { + self.start_array(); + + if let Some(first) = v.next() { + first.write_json(self)?; + for value in v { + self.comma(); + value.write_json(self)?; + } + } + self.end_array(); + Ok(()) + } + + pub fn write_empty_array(&mut self) { + self.vec.extend_from_slice(b"[]"); + } + + pub fn start_array(&mut self) { + self.vec.push(b'['); + } + + pub fn end_array(&mut self) { + self.vec.push(b']'); + } + + pub fn write_key(&mut self, key: &str) -> ToJsonResult<()> { + self.write_value(key)?; + self.vec.push(b':'); + Ok(()) + } + + pub fn write_empty_object(&mut self) { + self.vec.extend_from_slice(b"{}"); + } + + pub fn start_object(&mut self) { + self.vec.push(b'{'); + } + + pub fn end_object(&mut self) { + self.vec.push(b'}'); + } + + pub fn comma(&mut self) { + self.vec.push(b','); + } +} + +impl From for Vec { + fn from(writer: JsonWriter) -> Self { + writer.vec + } +} + +pub(crate) trait WriteJson { + fn write_json(&self, writer: &mut JsonWriter) -> ToJsonResult<()>; +} + +impl WriteJson for &str { + fn write_json(&self, writer: &mut JsonWriter) -> ToJsonResult<()> { + let mut ser = Serializer::new(&mut writer.vec); + ser.serialize_str(self).map_err(Into::into) + } +} + +impl WriteJson for bool { + fn write_json(&self, writer: &mut JsonWriter) -> ToJsonResult<()> { + writer.vec.extend_from_slice(if *self { b"true" } else { b"false" }); + Ok(()) + } +} + +impl WriteJson for u8 { + fn write_json(&self, writer: &mut JsonWriter) -> ToJsonResult<()> { + let mut ser = Serializer::new(&mut writer.vec); + ser.serialize_u8(*self).map_err(Into::into) + } +} + +impl WriteJson for i64 { + fn write_json(&self, writer: &mut JsonWriter) -> ToJsonResult<()> { + let mut ser = Serializer::new(&mut writer.vec); + ser.serialize_i64(*self).map_err(Into::into) + } +} + +impl WriteJson for f64 { + fn write_json(&self, writer: &mut JsonWriter) -> ToJsonResult<()> { + let mut ser = Serializer::new(&mut writer.vec); + ser.serialize_f64(*self).map_err(Into::into) + } +} + +impl WriteJson for BigInt { + fn write_json(&self, writer: &mut JsonWriter) -> ToJsonResult<()> { + writer.vec.extend_from_slice(self.to_str_radix(10).as_bytes()); + Ok(()) + } +} diff --git a/crates/batson/src/lib.rs b/crates/batson/src/lib.rs new file mode 100644 index 00000000..de05848f --- /dev/null +++ b/crates/batson/src/lib.rs @@ -0,0 +1,76 @@ +mod array; +mod decoder; +mod encoder; +mod errors; +pub mod get; +mod header; +mod json_writer; +mod object; + +use jiter::JsonValue; + +use crate::json_writer::JsonWriter; +use decoder::Decoder; +use encoder::Encoder; +pub use errors::{DecodeErrorType, DecodeResult, EncodeError, EncodeResult, ToJsonError, ToJsonResult}; + +/// Encode binary data from a JSON value. +/// +/// # Errors +/// +/// Returns an error if the data is not valid. +pub fn encode_from_json(value: &JsonValue<'_>) -> EncodeResult> { + let mut encoder = Encoder::new(); + encoder.encode_value(value)?; + encoder.align::(); + Ok(encoder.into()) +} + +/// Decode binary data to a JSON value. +/// +/// # Errors +/// +/// Returns an error if the data is not valid. +pub fn decode_to_json_value(bytes: &[u8]) -> DecodeResult { + Decoder::new(bytes).take_value() +} + +pub fn batson_to_json_vec(batson_bytes: &[u8]) -> ToJsonResult> { + let mut writer = JsonWriter::new(); + Decoder::new(batson_bytes).write_json(&mut writer)?; + Ok(writer.into()) +} + +pub fn batson_to_json_string(batson_bytes: &[u8]) -> ToJsonResult { + let v = batson_to_json_vec(batson_bytes)?; + // safe since we're guaranteed to have written valid UTF-8 + unsafe { Ok(String::from_utf8_unchecked(v)) } +} + +/// Hack while waiting for +#[must_use] +pub fn compare_json_values(a: &JsonValue<'_>, b: &JsonValue<'_>) -> bool { + match (a, b) { + (JsonValue::Null, JsonValue::Null) => true, + (JsonValue::Bool(a), JsonValue::Bool(b)) => a == b, + (JsonValue::Int(a), JsonValue::Int(b)) => a == b, + (JsonValue::BigInt(a), JsonValue::BigInt(b)) => a == b, + (JsonValue::Float(a), JsonValue::Float(b)) => (a - b).abs() <= f64::EPSILON, + (JsonValue::Str(a), JsonValue::Str(b)) => a == b, + (JsonValue::Array(a), JsonValue::Array(b)) => { + if a.len() != b.len() { + return false; + } + a.iter().zip(b.iter()).all(|(a, b)| compare_json_values(a, b)) + } + (JsonValue::Object(a), JsonValue::Object(b)) => { + if a.len() != b.len() { + return false; + } + a.iter() + .zip(b.iter()) + .all(|((ak, av), (bk, bv))| ak == bk && compare_json_values(av, bv)) + } + _ => false, + } +} diff --git a/crates/batson/src/object.rs b/crates/batson/src/object.rs new file mode 100644 index 00000000..a08156cd --- /dev/null +++ b/crates/batson/src/object.rs @@ -0,0 +1,618 @@ +use std::borrow::Cow; +use std::cmp::Ordering; +use std::fmt; +use std::mem::size_of; +use std::num::TryFromIntError; +use std::sync::Arc; + +use bytemuck::{Pod, Zeroable}; +use jiter::{JsonObject, JsonValue, LazyIndexMap}; + +use crate::decoder::Decoder; +use crate::encoder::Encoder; +use crate::errors::{DecodeErrorType, DecodeResult, EncodeResult}; +use crate::header::{Category, Length}; +use crate::json_writer::JsonWriter; +use crate::{EncodeError, ToJsonResult}; + +#[derive(Debug)] +pub(crate) struct Object<'b>(ObjectChoice<'b>); + +impl<'b> Object<'b> { + pub fn decode_header(d: &mut Decoder<'b>, length: Length) -> DecodeResult { + match length { + Length::Empty => Ok(Self(ObjectChoice::U8(ObjectSized { super_header: &[] }))), + Length::U32 => Ok(Self(ObjectChoice::U32(ObjectSized::new(d, length)?))), + Length::U16 => Ok(Self(ObjectChoice::U16(ObjectSized::new(d, length)?))), + _ => Ok(Self(ObjectChoice::U8(ObjectSized::new(d, length)?))), + } + } + + pub fn get(&self, d: &mut Decoder<'b>, key: &str) -> DecodeResult { + match &self.0 { + ObjectChoice::U8(o) => o.get(d, key), + ObjectChoice::U16(o) => o.get(d, key), + ObjectChoice::U32(o) => o.get(d, key), + } + } + + pub fn to_value(&self, d: &mut Decoder<'b>) -> DecodeResult> { + match &self.0 { + ObjectChoice::U8(o) => o.to_value(d), + ObjectChoice::U16(o) => o.to_value(d), + ObjectChoice::U32(o) => o.to_value(d), + } + } + + pub fn write_json(&self, d: &mut Decoder<'b>, writer: &mut JsonWriter) -> ToJsonResult<()> { + match &self.0 { + ObjectChoice::U8(o) => o.write_json(d, writer), + ObjectChoice::U16(o) => o.write_json(d, writer), + ObjectChoice::U32(o) => o.write_json(d, writer), + } + } + + /// Get the length of the data that follows the header + pub fn move_to_end(self, d: &mut Decoder<'b>) -> DecodeResult<()> { + match self.0 { + ObjectChoice::U8(o) => o.move_to_end(d), + ObjectChoice::U16(o) => o.move_to_end(d), + ObjectChoice::U32(o) => o.move_to_end(d), + } + } +} + +#[derive(Debug)] +enum ObjectChoice<'b> { + U8(ObjectSized<'b, SuperHeaderItem8>), + U16(ObjectSized<'b, SuperHeaderItem16>), + U32(ObjectSized<'b, SuperHeaderItem32>), +} + +#[derive(Debug)] +struct ObjectSized<'b, S: SuperHeaderItem> { + super_header: &'b [S], +} + +impl<'b, S: SuperHeaderItem> ObjectSized<'b, S> { + fn new(d: &mut Decoder<'b>, length: Length) -> DecodeResult { + let length = length.decode(d)?; + let super_header: &[S] = d.take_slice_as(length)?; + Ok(Self { super_header }) + } + + fn len(&self) -> usize { + self.super_header.len() + } + + fn get(&self, d: &mut Decoder<'b>, key: &str) -> DecodeResult { + // "item" for comparison only, so offset doesn't matter. if this errors it's because the key is too long + // to encode + let Ok(key_item) = S::new(key, 0) else { + return Ok(false); + }; + + let Some(header_iter) = binary_search(self.super_header, |h| h.sort_order(&key_item)) else { + return Ok(false); + }; + let start_index = d.index; + + for h in header_iter { + d.index = start_index + h.offset(); + let possible_key = d.take_slice(h.key_length())?; + if possible_key == key.as_bytes() { + return Ok(true); + } + } + // reset the index + d.index = start_index; + Ok(false) + } + + fn to_value(&self, d: &mut Decoder<'b>) -> DecodeResult> { + self.super_header + .iter() + .map(|_| { + let key = self.take_next_key(d)?; + let value = d.take_value()?; + Ok((key.into(), value)) + }) + .collect::>>() + .map(Arc::new) + } + + fn write_json(&self, d: &mut Decoder<'b>, writer: &mut JsonWriter) -> ToJsonResult<()> { + let mut steps = 0..self.len(); + writer.start_object(); + if steps.next().is_some() { + let key = self.take_next_key(d)?; + writer.write_key(key)?; + d.write_json(writer)?; + for _ in steps { + writer.comma(); + let key = self.take_next_key(d)?; + writer.write_key(key)?; + d.write_json(writer)?; + } + } + writer.end_object(); + Ok(()) + } + + fn take_next_key(&self, d: &mut Decoder<'b>) -> DecodeResult<&'b str> { + let header_index = S::take_header_index(d)?; + match self.super_header.get(header_index) { + Some(h) => d.take_str(h.key_length()), + None => Err(d.error(DecodeErrorType::ObjectBodyIndexInvalid)), + } + } + + /// the offset of the end of the last value + pub fn move_to_end(self, d: &mut Decoder<'b>) -> DecodeResult<()> { + let h = self.super_header.last().unwrap(); + d.index += h.offset() + h.key_length(); + let header = d.take_header()?; + d.move_to_end(header) + } +} + +trait SuperHeaderItem: fmt::Debug + Copy + Clone + Pod + Zeroable + Eq + PartialEq { + fn new(key: &str, offset: usize) -> Result; + + fn sort_order(&self, other: &Self) -> Ordering; + + fn offset(&self) -> usize; + + fn key_length(&self) -> usize; + + fn header_index_le_bytes(index: usize) -> impl AsRef<[u8]>; + + fn take_header_index(d: &mut Decoder) -> DecodeResult; +} + +/// `SuperHeader` Represents an item in the header +/// +/// # Warning +/// +/// **Member order matters here** since it decides the layout of the struct when serialized. +macro_rules! super_header_item { + ($name:ident, $int_type:ty, $int_size:literal) => { + #[derive(Debug, Copy, Clone, Pod, Zeroable, Eq, PartialEq)] + #[repr(C)] + struct $name { + key_length: $int_type, + key_hash: $int_type, + offset: $int_type, + } + + impl SuperHeaderItem for $name { + fn new(key: &str, offset: usize) -> Result { + Ok(Self { + key_length: <$int_type>::try_from(key.len())?, + // note we really do want key_hash to wrap around on the cast here! + key_hash: djb2_hash(key) as $int_type, + offset: <$int_type>::try_from(offset)?, + }) + } + + fn sort_order(&self, other: &Self) -> Ordering { + match self.key_length.cmp(&other.key_length) { + Ordering::Equal => self.key_hash.cmp(&other.key_hash), + x => x, + } + } + + fn offset(&self) -> usize { + self.offset as usize + } + + fn key_length(&self) -> usize { + self.key_length as usize + } + + fn header_index_le_bytes(index: usize) -> impl AsRef<[u8]> { + let index_size = index as $int_type; + index_size.to_le_bytes() + } + + fn take_header_index(d: &mut Decoder) -> DecodeResult { + // same logic as `take_` + let slice = d.take_slice($int_size)?; + let v = <$int_type>::from_le_bytes(slice.try_into().unwrap()); + Ok(v as usize) + } + } + }; +} + +super_header_item!(SuperHeaderItem8, u8, 1); +super_header_item!(SuperHeaderItem16, u16, 2); +super_header_item!(SuperHeaderItem32, u32, 4); + +/// Search a sorted slice and return a sub-slice of values that match a given predicate. +fn binary_search<'b, S>( + haystack: &'b [S], + compare: impl Fn(&S) -> Ordering + 'b, +) -> Option> { + let mut low = 0; + let mut high = haystack.len(); + + // Perform binary search to find one occurrence of the value + loop { + let mid = low + (high - low) / 2; + match compare(&haystack[mid]) { + Ordering::Less => low = mid + 1, + Ordering::Greater => high = mid, + Ordering::Equal => { + // Finding the start of the sub-slice with the target value + let start = haystack[..mid] + .iter() + .rposition(|x| compare(x).is_ne()) + .map_or(0, |pos| pos + 1); + return Some(haystack[start..].iter().take_while(move |x| compare(x).is_eq())); + } + } + if low >= high { + return None; + } + } +} + +pub(crate) fn encode_object(encoder: &mut Encoder, object: &JsonObject) -> EncodeResult<()> { + if object.is_empty() { + // shortcut but also no alignment! + return encoder.encode_length(Category::Object, 0); + } + + let items: Vec = object.iter_unique().collect(); + + let min_size = minimum_object_size_estimate(&items); + let encoder_position = encoder.position(); + if min_size <= u8::MAX as usize { + encoder.encode_length(Category::Object, items.len())?; + if encode_object_sized::(encoder, &items)? { + return Ok(()); + } + encoder.reset_position(encoder_position); + } + + if min_size <= u16::MAX as usize { + encoder.encode_len_u16(Category::Object, u16::try_from(items.len()).unwrap()); + if encode_object_sized::(encoder, &items)? { + return Ok(()); + } + encoder.reset_position(encoder_position); + } + + encoder.encode_len_u32(Category::Object, items.len())?; + if encode_object_sized::(encoder, &items)? { + Ok(()) + } else { + Err(EncodeError::ObjectTooLarge) + } +} + +type ObjectItems<'a, 'b> = (&'a Cow<'b, str>, &'a JsonValue<'b>); + +fn encode_object_sized(encoder: &mut Encoder, items: &[ObjectItems]) -> EncodeResult { + let mut super_header = Vec::with_capacity(items.len()); + encoder.align::(); + let super_header_start = encoder.ring_fence(items.len() * size_of::()); + + let offset_start = encoder.position(); + for (key, value) in items { + let key_str = key.as_ref(); + // add space for the header index, to be set correctly later + encoder.extend(S::header_index_le_bytes(0).as_ref()); + // push to the super header, with the position at this stage + let Ok(h) = S::new(key_str, encoder.position() - offset_start) else { + return Ok(false); + }; + super_header.push(h); + // now we've recorded the offset in the header, write the key and value to the encoder + encoder.extend(key_str.as_bytes()); + encoder.encode_value(value)?; + } + super_header.sort_by(S::sort_order); + + // iterate over the super header and set the header index for each item in the body + for (header_index, h) in super_header.iter().enumerate() { + let header_index_bytes = S::header_index_le_bytes(header_index); + let header_index_ref = header_index_bytes.as_ref(); + encoder.set_range(offset_start + h.offset() - header_index_ref.len(), header_index_ref); + } + encoder.set_range(super_header_start, bytemuck::cast_slice(&super_header)); + Ok(true) +} + +/// Estimate the minimize amount of space needed to encode the object. +/// +/// This is NOT recursive, instead it makes very optimistic guesses about how long arrays and objects might be. +fn minimum_object_size_estimate(items: &[ObjectItems]) -> usize { + let mut size = 0; + for (key, value) in items { + size += 1 + key.len(); // one byte header index and key + size += minimum_value_size_estimate(value); + } + size +} + +pub(crate) fn minimum_value_size_estimate(value: &JsonValue) -> usize { + match value { + // we could try harder for floats, but this is a good enough for now + JsonValue::Null | JsonValue::Bool(_) | JsonValue::Float(_) => 1, + JsonValue::Int(i) if (0..=10).contains(i) => 1, + // we could try harder here, but this is a good enough for now + JsonValue::Int(_) => 2, + JsonValue::BigInt(int) => 1 + int.to_bytes_le().1.len(), + JsonValue::Str(s) => 1 + s.len(), + JsonValue::Array(a) => 1 + a.len(), + JsonValue::Object(o) => 1 + o.len(), + } +} + +/// Very simple and fast hashing algorithm that nonetheless gives good distribution. +/// +/// See and +/// and for more information. +fn djb2_hash(s: &str) -> u32 { + let mut hash_value: u32 = 5381; + for i in s.bytes() { + // hash_value * 33 + char + hash_value = hash_value + .wrapping_shl(5) + .wrapping_add(hash_value) + .wrapping_add(u32::from(i)); + } + hash_value +} + +#[cfg(test)] +mod test { + use jiter::{JsonValue, LazyIndexMap}; + use smallvec::smallvec; + + use crate::header::Header; + use crate::{compare_json_values, encode_from_json}; + + use super::*; + + #[test] + fn super_header_sizes() { + assert_eq!(size_of::(), 3); + assert_eq!(size_of::(), 6); + assert_eq!(size_of::(), 12); + } + + #[test] + fn decode_get() { + let v = JsonValue::Object(Arc::new(LazyIndexMap::from(vec![ + ("aa".into(), JsonValue::Str("hello, world!".into())), + ("bat".into(), JsonValue::Int(42)), + ("c".into(), JsonValue::Bool(true)), + ]))); + let b = encode_from_json(&v).unwrap(); + let mut d = Decoder::new(&b); + let header = d.take_header().unwrap(); + assert_eq!(header, Header::Object(3.into())); + + let obj = Object::decode_header(&mut d, 3.into()).unwrap(); + + let obj_u8 = match obj.0 { + ObjectChoice::U8(ref o) => o, + _ => panic!("expected U8"), + }; + assert_eq!(obj_u8.len(), 3); + + let mut d2 = d.clone(); + assert!(obj.get(&mut d2, "aa").unwrap()); + assert_eq!(d2.take_value().unwrap(), JsonValue::Str("hello, world!".into())); + + let mut d3 = d.clone(); + assert!(obj.get(&mut d3, "bat").unwrap()); + assert_eq!(d3.take_value().unwrap(), JsonValue::Int(42)); + + let mut d4 = d.clone(); + assert!(obj.get(&mut d4, "c").unwrap()); + assert_eq!(d4.take_value().unwrap(), JsonValue::Bool(true)); + + assert!(!obj.get(&mut d, "x").unwrap()); + } + + #[test] + fn offsets() { + let v = JsonValue::Object(Arc::new(LazyIndexMap::from(vec![ + ("a".into(), JsonValue::Bool(true)), + ( + "bb".into(), + JsonValue::Object(Arc::new(LazyIndexMap::from(vec![("ccc".into(), JsonValue::Int(42))]))), + ), + ]))); + let b = encode_from_json(&v).unwrap(); + let mut d = Decoder::new(&b); + let header = d.take_header().unwrap(); + assert_eq!(header, Header::Object(2.into())); + + let obj = Object::decode_header(&mut d, 2.into()).unwrap(); + + let obj = match obj.0 { + ObjectChoice::U8(o) => o, + _ => panic!("expected U8"), + }; + + assert_eq!( + obj.super_header, + vec![ + SuperHeaderItem8 { + key_length: 1, + key_hash: 6, + offset: 1 + }, + SuperHeaderItem8 { + key_length: 2, + key_hash: 73, + offset: 4 + } + ] + ); + + assert!(obj.get(&mut d, "bb").unwrap()); + let header = d.take_header().unwrap(); + assert_eq!(header, Header::Object(1.into())); + + let obj = Object::decode_header(&mut d, 1.into()).unwrap(); + + let obj = match obj.0 { + ObjectChoice::U8(o) => o, + _ => panic!("expected U8"), + }; + + assert_eq!( + obj.super_header, + vec![SuperHeaderItem8 { + key_length: 3, + key_hash: 46, + // note the offset here is relative to the start of the object + offset: 1, + },] + ); + } + + #[test] + fn decode_empty() { + let v = JsonValue::Object(Arc::new(LazyIndexMap::default())); + let b = encode_from_json(&v).unwrap(); + assert_eq!(b.len(), 4); + let mut d = Decoder::new(&b); + let header = d.take_header().unwrap(); + assert_eq!(header, Header::Object(0.into())); + + let obj = Object::decode_header(&mut d, 0.into()).unwrap(); + let obj = match obj.0 { + ObjectChoice::U8(o) => o, + _ => panic!("expected U8"), + }; + assert_eq!(obj.len(), 0); + } + + #[test] + fn binary_search_direct() { + let slice = &["", "b", "ba", "fo", "spam"]; + let mut count = 0; + for i in binary_search(slice, |x| x.len().cmp(&1)).unwrap() { + assert_eq!(*i, "b"); + count += 1; + } + assert_eq!(count, 1); + } + + fn binary_search_vec(haystack: &[S], compare: impl Fn(&S) -> Ordering) -> Option> { + binary_search(haystack, compare).map(|i| i.cloned().collect()) + } + + #[test] + fn binary_search_ints() { + let slice = &[1, 2, 2, 2, 3, 4, 5, 6, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8]; + assert_eq!(binary_search_vec(slice, |x| x.cmp(&1)), Some(vec![1])); + assert_eq!(binary_search_vec(slice, |x| x.cmp(&2)), Some(vec![2, 2, 2])); + assert_eq!(binary_search_vec(slice, |x| x.cmp(&3)), Some(vec![3])); + assert_eq!(binary_search_vec(slice, |x| x.cmp(&4)), Some(vec![4])); + assert_eq!( + binary_search_vec(slice, |x| x.cmp(&7)), + Some(vec![7, 7, 7, 7, 7, 7, 7, 7]) + ); + assert_eq!(binary_search_vec(slice, |x| x.cmp(&8)), Some(vec![8, 8])); + assert_eq!(binary_search_vec(slice, |x| x.cmp(&12)), None); + } + + #[test] + fn binary_search_strings() { + let slice = &["", "b", "ba", "fo", "spam"]; + assert_eq!(binary_search_vec(slice, |x| x.len().cmp(&0)), Some(vec![""])); + assert_eq!(binary_search_vec(slice, |x| x.len().cmp(&1)), Some(vec!["b"])); + assert_eq!(binary_search_vec(slice, |x| x.len().cmp(&2)), Some(vec!["ba", "fo"])); + assert_eq!(binary_search_vec(slice, |x| x.len().cmp(&4)), Some(vec!["spam"])); + assert_eq!(binary_search_vec(slice, |x| x.len().cmp(&5)), None); + } + + #[test] + fn binary_search_take_while() { + // in valid input to test take_while isn't iterating further + let slice = &[1, 2, 2, 1, 3]; + assert_eq!(binary_search_vec(slice, |x| x.cmp(&1)), Some(vec![1])); + } + + #[test] + fn exceed_size() { + let array = JsonValue::Array(Arc::new(smallvec![JsonValue::Int(1_234); 100])); + let v = Arc::new(LazyIndexMap::from(vec![ + ( + "a".into(), + // 240 * i64 is longer than a u8 can encode + array.clone(), + ), + // need another key to encounter the error + ("b".into(), JsonValue::Null), + ])); + + // less than 255, so encode_from_json will try to encode with SuperHeaderItem8 + let items: Vec<_> = v.iter_unique().collect(); + assert_eq!(minimum_object_size_estimate(&items), 106); + let b = encode_from_json(&JsonValue::Object(v)).unwrap(); + + let mut d = Decoder::new(&b); + let header = d.take_header().unwrap(); + assert_eq!(header, Header::Object(Length::U16)); + + let obj = Object::decode_header(&mut d, Length::U16).unwrap(); + let obj = match obj.0 { + ObjectChoice::U16(o) => o, + _ => panic!("expected U16"), + }; + + assert_eq!(obj.len(), 2); + + let mut d2 = d.clone(); + assert!(obj.get(&mut d2, "a").unwrap()); + assert!(compare_json_values(&d2.take_value().unwrap(), &array)); + + let mut d3 = d.clone(); + assert!(obj.get(&mut d3, "b").unwrap()); + assert_eq!(d3.take_value().unwrap(), JsonValue::Null); + } + + #[test] + fn test_u32() { + let long_string = "a".repeat(u16::MAX as usize); + let v = Arc::new(LazyIndexMap::from(vec![ + ("".into(), JsonValue::Str(long_string.clone().into())), + // need another key to encounter the error + ("£".into(), JsonValue::Null), + ])); + + let b = encode_from_json(&JsonValue::Object(v)).unwrap(); + + let mut decoder = Decoder::new(&b); + let header = decoder.take_header().unwrap(); + assert_eq!(header, Header::Object(Length::U32)); + + let obj = Object::decode_header(&mut decoder, Length::U32).unwrap(); + let obj = match obj.0 { + ObjectChoice::U32(o) => o, + _ => panic!("expected U32"), + }; + + assert_eq!(obj.len(), 2); + + let mut d = decoder.clone(); + assert!(obj.get(&mut d, "").unwrap()); + assert!(compare_json_values( + &d.take_value().unwrap(), + &JsonValue::Str(long_string.into()) + )); + + let mut d = decoder.clone(); + assert!(obj.get(&mut d, "£").unwrap()); + assert!(compare_json_values(&d.take_value().unwrap(), &JsonValue::Null)); + } +} diff --git a/crates/batson/tests/batson_example.bin b/crates/batson/tests/batson_example.bin new file mode 100644 index 00000000..544d93f7 Binary files /dev/null and b/crates/batson/tests/batson_example.bin differ diff --git a/crates/batson/tests/main.rs b/crates/batson/tests/main.rs new file mode 100644 index 00000000..de3d3c62 --- /dev/null +++ b/crates/batson/tests/main.rs @@ -0,0 +1,391 @@ +use std::fs::File; +use std::io::Read; +use std::path::PathBuf; +use std::sync::Arc; + +use jiter::{JsonValue, LazyIndexMap}; +use smallvec::smallvec; + +use batson::get::{contains, get_batson, get_bool, get_int, get_length, get_str}; +use batson::{batson_to_json_string, compare_json_values, decode_to_json_value, encode_from_json}; + +#[test] +fn round_trip_all() { + let v: JsonValue<'static> = JsonValue::Object(Arc::new(LazyIndexMap::from(vec![ + // primitives + ("null".into(), JsonValue::Null), + ("false".into(), JsonValue::Bool(false)), + ("true".into(), JsonValue::Bool(true)), + // int + ("int-zero".into(), JsonValue::Int(0)), + ("int-in-header".into(), JsonValue::Int(9)), + ("int-8".into(), JsonValue::Int(123)), + ("int-32".into(), JsonValue::Int(1_000)), + ("int-64".into(), JsonValue::Int(i64::from(i32::MAX) + 5)), + ("int-max".into(), JsonValue::Int(i64::MAX)), + ("int-neg-in-header".into(), JsonValue::Int(-9)), + ("int-neg-8".into(), JsonValue::Int(-123)), + ("int-neg-32".into(), JsonValue::Int(-1_000)), + ("int-gex-64".into(), JsonValue::Int(-(i64::from(i32::MAX) + 5))), + ("int-min".into(), JsonValue::Int(i64::MIN)), + // floats + ("float-zero".into(), JsonValue::Float(0.0)), + ("float-in-header".into(), JsonValue::Float(9.0)), + ("float-pos".into(), JsonValue::Float(123.45)), + ("float-pos2".into(), JsonValue::Float(123_456_789.0)), + ("float-neg".into(), JsonValue::Float(-123.45)), + ("float-neg2".into(), JsonValue::Float(-123_456_789.0)), + // strings + ("str-empty".into(), JsonValue::Str("".into())), + ("str-short".into(), JsonValue::Str("foo".into())), + ("str-larger".into(), JsonValue::Str("foo bat spam".into())), + // het array + ( + "het-array".into(), + JsonValue::Array(Arc::new(smallvec![ + JsonValue::Int(42), + JsonValue::Str("foobar".into()), + JsonValue::Bool(true), + ])), + ), + // header array + ( + "header-array".into(), + JsonValue::Array(Arc::new(smallvec![JsonValue::Int(6), JsonValue::Bool(true),])), + ), + // i64 array + ( + "i64-array".into(), + JsonValue::Array(Arc::new(smallvec![JsonValue::Int(42), JsonValue::Int(i64::MAX),])), + ), + // u8 array + ( + "u8-array".into(), + JsonValue::Array(Arc::new(smallvec![JsonValue::Int(42), JsonValue::Int(255),])), + ), + ]))); + let b = encode_from_json(&v).unwrap(); + + let v2 = decode_to_json_value(&b).unwrap(); + assert!(compare_json_values(&v2, &v)); +} + +fn json_to_batson(json: &[u8]) -> Vec { + let json_value = JsonValue::parse(json, false).unwrap(); + encode_from_json(&json_value).unwrap() +} + +#[test] +fn test_get_bool() { + let bytes = json_to_batson(br#"{"foo": true}"#); + + assert!(get_bool(&bytes, &["foo".into()]).unwrap().unwrap()); + assert!(get_bool(&bytes, &["bar".into()]).unwrap().is_none()); +} + +#[test] +fn test_contains() { + let bytes = json_to_batson(br#"{"foo": true, "bar": [1, 2], "ham": "foo"}"#); + + assert!(contains(&bytes, &["foo".into()]).unwrap()); + assert!(contains(&bytes, &["bar".into()]).unwrap()); + assert!(contains(&bytes, &["ham".into()]).unwrap()); + assert!(contains(&bytes, &["bar".into(), 0.into()]).unwrap()); + assert!(contains(&bytes, &["bar".into(), 1.into()]).unwrap()); + + assert!(!contains(&bytes, &["spam".into()]).unwrap()); + assert!(!contains(&bytes, &["bar".into(), 2.into()]).unwrap()); + assert!(!contains(&bytes, &["ham".into(), 0.into()]).unwrap()); +} + +#[test] +fn test_contains_case() { + let bytes = json_to_batson(br#"{"host.id":0,"rpc.grpc.status_code":1,"rpc.grpc.status_code":4}"#); + + assert!(contains(&bytes, &["host.id".into()]).unwrap()); +} + +#[test] +fn test_get_str_object() { + let bytes = json_to_batson(br#"{"foo": "bar", "spam": true}"#); + + assert_eq!(get_str(&bytes, &["foo".into()]).unwrap().unwrap(), "bar"); + assert!(get_str(&bytes, &["bar".into()]).unwrap().is_none()); + assert!(get_str(&bytes, &["spam".into()]).unwrap().is_none()); +} + +#[test] +fn test_get_str_array() { + let bytes = json_to_batson(br#"["foo", 123, "bar"]"#); + + assert_eq!(get_str(&bytes, &[0.into()]).unwrap().unwrap(), "foo"); + assert_eq!(get_str(&bytes, &[2.into()]).unwrap().unwrap(), "bar"); + + assert!(get_str(&bytes, &["bar".into()]).unwrap().is_none()); + assert!(get_str(&bytes, &[3.into()]).unwrap().is_none()); +} + +#[test] +fn test_get_str_nested() { + let bytes = json_to_batson(br#"{"foo": [null, {"bar": "baz"}]}"#); + + assert_eq!( + get_str(&bytes, &["foo".into(), 1.into(), "bar".into()]) + .unwrap() + .unwrap(), + "baz" + ); + + assert!(get_str(&bytes, &["foo".into()]).unwrap().is_none()); + assert!(get_str(&bytes, &["spam".into(), 1.into()]).unwrap().is_none()); + assert!(get_str(&bytes, &["spam".into(), 1.into(), "bar".into(), 6.into()]) + .unwrap() + .is_none()); +} + +#[test] +fn test_get_int_object() { + let bytes = json_to_batson(br#"{"foo": 42, "spam": true}"#); + + assert_eq!(get_int(&bytes, &["foo".into()]).unwrap().unwrap(), 42); + assert!(get_int(&bytes, &["bar".into()]).unwrap().is_none()); + assert!(get_int(&bytes, &["spam".into()]).unwrap().is_none()); +} + +#[test] +fn test_get_int_het_array() { + let bytes = json_to_batson(br#"[-42, "foo", 922337203685477580]"#); + + assert_eq!(get_int(&bytes, &[0.into()]).unwrap().unwrap(), -42); + assert_eq!(get_int(&bytes, &[2.into()]).unwrap().unwrap(), 922_337_203_685_477_580); + + assert!(get_int(&bytes, &["bar".into()]).unwrap().is_none()); + assert!(get_int(&bytes, &[3.into()]).unwrap().is_none()); +} + +#[test] +fn test_get_int_u8_array() { + let bytes = json_to_batson(br"[42, 123]"); + + assert_eq!(get_int(&bytes, &[0.into()]).unwrap().unwrap(), 42); + assert_eq!(get_int(&bytes, &[1.into()]).unwrap().unwrap(), 123); + + assert!(get_int(&bytes, &["bar".into()]).unwrap().is_none()); + assert!(get_int(&bytes, &[2.into()]).unwrap().is_none()); +} + +#[test] +fn test_get_int_i64_array() { + let bytes = json_to_batson(br"[-123, 922337203685477580]"); + + assert_eq!(get_int(&bytes, &[0.into()]).unwrap().unwrap(), -123); + assert_eq!(get_int(&bytes, &[1.into()]).unwrap().unwrap(), 922_337_203_685_477_580); + + assert!(get_int(&bytes, &["bar".into()]).unwrap().is_none()); + assert!(get_int(&bytes, &[2.into()]).unwrap().is_none()); +} + +#[test] +fn test_get_length() { + let bytes = json_to_batson(br#"{"foo": [null, {"a": 1, "b": 2}, 1]}"#); + + assert_eq!(get_length(&bytes, &[]).unwrap().unwrap(), 1); + assert_eq!(get_length(&bytes, &["foo".into()]).unwrap().unwrap(), 3); + assert_eq!(get_length(&bytes, &["foo".into(), 1.into()]).unwrap().unwrap(), 2); +} + +#[test] +fn test_get_batson() { + let bytes = json_to_batson(br#"{"foo": [null, {"a": 1, "b": 22}, 4294967299]}"#); + + assert_eq!(get_batson(&bytes, &[]).unwrap().unwrap(), bytes); + + let null_bytes = get_batson(&bytes, &["foo".into(), 0.into()]).unwrap().unwrap(); + assert_eq!(null_bytes, [0u8].as_ref()); + assert_eq!(batson_to_json_string(&null_bytes).unwrap(), "null"); + + let foo_bytes = get_batson(&bytes, &["foo".into()]).unwrap().unwrap(); + assert_eq!( + batson_to_json_string(&foo_bytes).unwrap(), + r#"[null,{"a":1,"b":22},4294967299]"# + ); + + let missing = get_batson(&bytes, &["bar".into()]).unwrap(); + assert!(missing.is_none()); + + let missing = get_batson(&bytes, &["foo".into(), "bar".into()]).unwrap(); + assert!(missing.is_none()); + + let obj_bytes = get_batson(&bytes, &["foo".into(), 1.into()]).unwrap().unwrap(); + assert_eq!(batson_to_json_string(&obj_bytes).unwrap(), r#"{"a":1,"b":22}"#); + + let a_bytes = get_batson(&bytes, &["foo".into(), 1.into(), "a".into()]) + .unwrap() + .unwrap(); + assert_eq!(batson_to_json_string(&a_bytes).unwrap(), "1"); + + let b_bytes = get_batson(&bytes, &["foo".into(), 1.into(), "b".into()]) + .unwrap() + .unwrap(); + assert_eq!(batson_to_json_string(&b_bytes).unwrap(), "22"); + + let int_bytes = get_batson(&bytes, &["foo".into(), 2.into()]).unwrap().unwrap(); + assert_eq!(batson_to_json_string(&int_bytes).unwrap(), "4294967299"); +} + +#[test] +fn test_get_batson_u8array() { + let bytes = json_to_batson(br#"[1, 2, 0, 255, 128]"#); + + // not last two bytes because of alignment + assert_eq!(get_batson(&bytes, &[]).unwrap().unwrap(), &bytes[..bytes.len() - 2]); + + let zeroth_bytes = get_batson(&bytes, &[0.into()]).unwrap().unwrap(); + assert_eq!(batson_to_json_string(&zeroth_bytes).unwrap(), "1"); + + let first_bytes = get_batson(&bytes, &[1.into()]).unwrap().unwrap(); + assert_eq!(batson_to_json_string(&first_bytes).unwrap(), "2"); + + let second_bytes = get_batson(&bytes, &[2.into()]).unwrap().unwrap(); + assert_eq!(batson_to_json_string(&second_bytes).unwrap(), "0"); + + let third_bytes = get_batson(&bytes, &[3.into()]).unwrap().unwrap(); + assert_eq!(batson_to_json_string(&third_bytes).unwrap(), "255"); + + let fourth_bytes = get_batson(&bytes, &[4.into()]).unwrap().unwrap(); + assert_eq!(batson_to_json_string(&fourth_bytes).unwrap(), "128"); + + let missing = get_batson(&bytes, &[5.into()]).unwrap(); + assert!(missing.is_none()); + + let missing = get_batson(&bytes, &[4.into(), 0.into()]).unwrap(); + assert!(missing.is_none()); +} + +#[test] +fn test_to_json() { + let bytes = json_to_batson(br" [true, 123] "); + let s = batson_to_json_string(&bytes).unwrap(); + assert_eq!(s, r"[true,123]"); +} + +fn json_round_trip(input_json: &str) { + let bytes = json_to_batson(input_json.as_bytes()); + let output_json = batson_to_json_string(&bytes).unwrap(); + assert_eq!(&output_json, input_json); +} + +macro_rules! json_round_trip_tests { + ($($name:ident => $json:literal;)*) => { + $( + paste::item! { + #[test] + fn [< json_round_trip_ $name >]() { + json_round_trip($json); + } + } + )* + } +} + +json_round_trip_tests!( + array_empty => "[]"; + array_bool => "[true,false]"; + array_bool_int => "[true,123]"; + array_u8 => "[1,2,44,255]"; + array_i64 => "[-1,2,44,255,1234]"; + array_header => r#"[6,true,false,null,0,[],{},""]"#; + array_het => r#"[true,123,"foo",null]"#; + string_empty => r#""""#; + string_hello => r#""hello""#; + string_escape => r#""\"he\nllo\"""#; + string_unicode => r#"{"£":"🤪"}"#; + object_empty => r#"{}"#; + object_bool => r#"{"foo":true}"#; + object_two => r#"{"foo":1,"bar":2}"#; + object_three => r#"{"foo":1,"bar":2,"baz":3}"#; + object_int => r#"{"foo":123}"#; + object_string => r#"{"foo":"bar"}"#; + object_array => r#"{"foo":[1,2]}"#; + object_nested => r#"{"foo":{"bar":true}}"#; + object_nested_array => r#"{"foo":{"bar":[1,2]}}"#; + object_nested_array_nested => r#"{"foo":{"bar":[{"baz":true}]}}"#; + float_zero => "0.0"; + float_neg => "-123.45"; + float_pos => "123.456789"; + float_zero_to_10 => "[0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0]"; // header only + float_zero_to_12 => "[0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0]"; // het array + int_zero_to_10 => "[0,1,2,3,4,5,6,7,8,9,10]"; + int_zero_to_10_neg => "[0,1,2,3,4,5,6,7,8,9,10,-1,-2,-3,-4,-5,-6,-7,-8,-9,-10,-11,-12]"; + array_len_0_to_10 => "[[],[0],[0,1],[0,1,2],[0,1,2,3],[0,1,2,3,4],[0,1,2,3,4,5],[0,1,2,3,4,5,6],[0,1,2,3,4,5,6,7],[0,1,2,3,4,5,6,7,8],[0,1,2,3,4,5,6,7,8,9],[0,1,2,3,4,5,6,7,8,9,10]]"; + // bigger than i64::MAX (9223372036854775807) + big_int_pos => "92233720368547758070"; + // less than i64::MIN (-9223372036854775808) + big_int_neg => "-92233720368547758080"; + big_int_array => "[92233720368547758070,-92233720368547758080,1,2,3,-42]"; + big_int_object => r#"{"foo":92233720368547758070,"bar":-92233720368547758080}"#; +); + +#[test] +fn batson_file() { + // check the binary format doesn't change + let json = r#" + { + "header_only": [6, true, false, null, 0, [], {}, ""], + "u8_array": [0, 1, 2, 42, 255], + "i64_array": [-1, 2, 44, 255, 1234, 922337203685477], + "het_array": [true, 123, "foo", "£100", null], + "big_int": [92233720368547758070,-92233720368547758080], + "true": true, + "false": false, + "null": null, + "" : "empty-key" + } + "#; + let bytes = json_to_batson(json.as_bytes()); + + let s = batson_to_json_string(&bytes).unwrap(); + assert_eq!(s, json.replace(" ", "").replace("\n", "")); + + let file_path = "tests/batson_example.bin"; + // std::fs::write("tests/batson_example.bin", &bytes).unwrap(); + + // read the file and compare + let mut file = File::open(file_path).unwrap(); + let mut contents = Vec::new(); + file.read_to_end(&mut contents).unwrap(); + + assert_eq!(contents, bytes); + // dbg!(contents.len()); + // dbg!(json.replace(" ", "").replace("\n", "").len()); +} + +fn read_file(path: &PathBuf) -> Vec { + let mut file = File::open(path).unwrap(); + let mut contents = Vec::new(); + file.read_to_end(&mut contents).unwrap(); + contents +} + +/// Round trip test all the JSON files in the jiter benches directory +#[test] +fn round_trip_json_files() { + let dir = std::fs::read_dir("../jiter/benches").unwrap(); + for file in dir.map(|r| r.unwrap()) { + let path = file.path(); + if !path.extension().map(|e| e == "json").unwrap_or(false) { + continue; + } + println!("Testing: {path:?}"); + + let json = read_file(&path); + let value_from_json = JsonValue::parse(&json, false).unwrap(); + + let bytes = json_to_batson(&json); + let value_from_batson = decode_to_json_value(&bytes).unwrap(); + assert!( + compare_json_values(&value_from_json, &value_from_batson), + "Failed for {path:?}" + ); + } +} diff --git a/crates/fuzz/Cargo.toml b/crates/fuzz/Cargo.toml index cde9a3bc..da0d33e0 100644 --- a/crates/fuzz/Cargo.toml +++ b/crates/fuzz/Cargo.toml @@ -15,7 +15,8 @@ serde = "1.0.190" indexmap = "2.0.0" num-bigint = "0.4.4" num-traits = "0.2.17" -jiter = {path = "../jiter"} +jiter = {workspace = true} +batson = {workspace = true} [[bin]] name = "compare_to_serde" @@ -28,3 +29,9 @@ name = "compare_skip" path = "fuzz_targets/compare_skip.rs" test = false doc = false + +[[bin]] +name = "batson_round_trip" +path = "fuzz_targets/batson_round_trip.rs" +test = false +doc = false diff --git a/crates/fuzz/fuzz_targets/batson_round_trip.rs b/crates/fuzz/fuzz_targets/batson_round_trip.rs new file mode 100644 index 00000000..4dcce955 --- /dev/null +++ b/crates/fuzz/fuzz_targets/batson_round_trip.rs @@ -0,0 +1,24 @@ +#![no_main] + +use batson::{batson_to_json_string, encode_from_json}; +use jiter::JsonValue; + +use libfuzzer_sys::fuzz_target; + +fn round_trip(json: String) { + let Ok(jiter_value1) = JsonValue::parse(json.as_bytes(), false) else { + return; + }; + let bytes1 = encode_from_json(&jiter_value1).unwrap(); + let json1 = batson_to_json_string(&bytes1).unwrap(); + + let jiter_value2 = JsonValue::parse(json1.as_bytes(), false).unwrap(); + let bytes2 = encode_from_json(&jiter_value2).unwrap(); + let json2 = batson_to_json_string(&bytes2).unwrap(); + + assert_eq!(json1, json2); +} + +fuzz_target!(|json: String| { + round_trip(json); +}); diff --git a/crates/jiter-python/Cargo.toml b/crates/jiter-python/Cargo.toml index 71ed7c98..b38f4634 100644 --- a/crates/jiter-python/Cargo.toml +++ b/crates/jiter-python/Cargo.toml @@ -11,7 +11,7 @@ repository = {workspace = true} [dependencies] pyo3 = { workspace = true, features = ["num-bigint"] } -jiter = { path = "../jiter", features = ["python", "num-bigint"] } +jiter = { workspace = true, features = ["python", "num-bigint"] } [features] # must be enabled when building with `cargo build`, maturin enables this automatically diff --git a/crates/jiter/Cargo.toml b/crates/jiter/Cargo.toml index 93c1dc2e..a3acd730 100644 --- a/crates/jiter/Cargo.toml +++ b/crates/jiter/Cargo.toml @@ -12,12 +12,12 @@ homepage = { workspace = true } repository = { workspace = true } [dependencies] -num-bigint = { version = "0.4.4", optional = true } -num-traits = "0.2.16" +num-bigint = { workspace = true, optional = true } +num-traits = { workspace = true } ahash = "0.8.0" -smallvec = "1.11.0" -pyo3 = { workspace = true, optional = true } -lexical-parse-float = { version = "0.8.5", features = ["format"] } +smallvec = { workspace = true } +pyo3 = { workspace = true, optional = true, features = ["num-bigint"] } +lexical-parse-float = { version = "0.8.5", features = ["format"] } bitvec = "1.0.1" [features] @@ -26,16 +26,12 @@ python = ["dep:pyo3", "dep:pyo3-build-config"] num-bigint = ["dep:num-bigint", "pyo3/num-bigint"] [dev-dependencies] -bencher = "0.1.5" -paste = "1.0.7" -serde_json = { version = "1.0.87", features = [ - "preserve_order", - "arbitrary_precision", - "float_roundtrip", -] } -serde = "1.0.147" +bencher = { workspace = true } +paste = { workspace = true } +codspeed-bencher-compat = { workspace = true } +serde_json = { workspace = true, features = ["preserve_order", "arbitrary_precision", "float_roundtrip"]} +serde = { workspace = true } pyo3 = { workspace = true, features = ["auto-initialize"] } -codspeed-bencher-compat = "2.7.1" [build-dependencies] pyo3-build-config = { workspace = true, optional = true } diff --git a/crates/jiter/src/lazy_index_map.rs b/crates/jiter/src/lazy_index_map.rs index b57bdcb0..dc032239 100644 --- a/crates/jiter/src/lazy_index_map.rs +++ b/crates/jiter/src/lazy_index_map.rs @@ -10,7 +10,7 @@ use smallvec::SmallVec; /// Like [IndexMap](https://docs.rs/indexmap/latest/indexmap/) but only builds the lookup map when it's needed. pub struct LazyIndexMap { - vec: SmallVec<[(K, V); 8]>, + vec: SmallVec<(K, V), 8>, map: OnceLock>, last_find: AtomicUsize, } @@ -149,8 +149,29 @@ impl PartialEq for LazyIndexMap { } } +impl FromIterator<(K, V)> for LazyIndexMap { + fn from_iter>(iter: T) -> Self { + let vec = iter.into_iter().collect(); + Self { + vec, + map: OnceLock::new(), + last_find: AtomicUsize::new(0), + } + } +} + +impl From> for LazyIndexMap { + fn from(vec: Vec<(K, V)>) -> Self { + Self { + vec: vec.into(), + map: OnceLock::new(), + last_find: AtomicUsize::new(0), + } + } +} + struct IterUnique<'a, K, V> { - vec: &'a SmallVec<[(K, V); 8]>, + vec: &'a SmallVec<(K, V), 8>, map: &'a AHashMap, index: usize, } @@ -172,3 +193,32 @@ impl<'a, K: Hash + Eq, V> Iterator for IterUnique<'a, K, V> { None } } + +#[cfg(test)] +mod test { + use super::*; + use crate::JsonValue; + + #[test] + fn test_lazy_index_map_collect() { + let pairs: Vec<(Cow<'static, str>, JsonValue<'static>)> = vec![ + ("foo".into(), JsonValue::Null), + ("bar".into(), JsonValue::Bool(true)), + ("spam".into(), JsonValue::Int(123)), + ]; + let map: LazyIndexMap<_, _> = pairs.into_iter().collect(); + assert_eq!(map.len(), 3); + assert_eq!(map.keys().collect::>(), ["foo", "bar", "spam"]); + assert_eq!(map.get("bar"), Some(&JsonValue::Bool(true))); + } + + #[test] + fn test_lazy_index_map_from_vec() { + let pairs: Vec<(Cow<'static, str>, JsonValue<'static>)> = + vec![("foo".into(), JsonValue::Null), ("bar".into(), JsonValue::Bool(true))]; + let map: LazyIndexMap<_, _> = pairs.into(); + assert_eq!(map.len(), 2); + assert_eq!(map.keys().collect::>(), ["foo", "bar"]); + assert_eq!(map.get("bar"), Some(&JsonValue::Bool(true))); + } +} diff --git a/crates/jiter/src/python.rs b/crates/jiter/src/python.rs index b811d1c3..6cf7d95b 100644 --- a/crates/jiter/src/python.rs +++ b/crates/jiter/src/python.rs @@ -148,7 +148,7 @@ impl<'j, StringCache: StringMaybeCache, KeyCheck: MaybeKeyCheck, ParseNumber: Ma Ok(None) | Err(_) => return Ok(PyList::empty_bound(py).into_any()), }; - let mut vec: SmallVec<[Bound<'_, PyAny>; 8]> = SmallVec::with_capacity(8); + let mut vec: SmallVec, 8> = SmallVec::with_capacity(8); if let Err(e) = self._parse_array(py, peek_first, &mut vec) { if !self._allow_partial_err(&e) { return Err(e); @@ -174,7 +174,7 @@ impl<'j, StringCache: StringMaybeCache, KeyCheck: MaybeKeyCheck, ParseNumber: Ma &mut self, py: Python<'py>, peek_first: Peek, - vec: &mut SmallVec<[Bound<'py, PyAny>; 8]>, + vec: &mut SmallVec, 8>, ) -> JsonResult<()> { let v = self._check_take_value(py, peek_first)?; vec.push(v); diff --git a/crates/jiter/src/value.rs b/crates/jiter/src/value.rs index 2a13bcaf..e63ee4b6 100644 --- a/crates/jiter/src/value.rs +++ b/crates/jiter/src/value.rs @@ -25,7 +25,7 @@ pub enum JsonValue<'s> { Object(JsonObject<'s>), } -pub type JsonArray<'s> = Arc; 8]>>; +pub type JsonArray<'s> = Arc, 8>>; pub type JsonObject<'s> = Arc, JsonValue<'s>>>; #[cfg(feature = "python")] @@ -85,7 +85,9 @@ fn value_static(v: JsonValue<'_>) -> JsonValue<'static> { JsonValue::BigInt(b) => JsonValue::BigInt(b), JsonValue::Float(f) => JsonValue::Float(f), JsonValue::Str(s) => JsonValue::Str(s.into_owned().into()), - JsonValue::Array(v) => JsonValue::Array(Arc::new(v.iter().map(JsonValue::to_static).collect::>())), + JsonValue::Array(v) => { + JsonValue::Array(Arc::new(v.iter().map(JsonValue::to_static).collect::>())) + } JsonValue::Object(o) => JsonValue::Object(Arc::new(o.to_static())), } } @@ -240,7 +242,7 @@ fn take_value_recursive<'j, 's>( ) -> JsonResult> { let recursion_limit: usize = recursion_limit.into(); - let mut recursion_stack: SmallVec<[RecursedValue; 8]> = SmallVec::new(); + let mut recursion_stack: SmallVec = SmallVec::new(); macro_rules! push_recursion { ($next_peek:expr, $value:expr) => {