diff --git a/parquet-variant/src/builder.rs b/parquet-variant/src/builder.rs index b81cacc4bc40..6cde4ce91125 100644 --- a/parquet-variant/src/builder.rs +++ b/parquet-variant/src/builder.rs @@ -15,11 +15,10 @@ // specific language governing permissions and limitations // under the License. use crate::decoder::{VariantBasicType, VariantPrimitiveType}; -use crate::Variant; +use crate::{ShortString, Variant}; use std::collections::HashMap; const BASIC_TYPE_BITS: u8 = 2; -const MAX_SHORT_STRING_SIZE: usize = 0x3F; const UNIX_EPOCH_DATE: chrono::NaiveDate = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); fn primitive_header(primitive_type: VariantPrimitiveType) -> u8 { @@ -114,11 +113,11 @@ fn make_room_for_header(buffer: &mut Vec, start_pos: usize, header_size: usi /// }; /// assert_eq!( /// variant_object.field_by_name("first_name").unwrap(), -/// Some(Variant::ShortString("Jiaying")) +/// Some(Variant::from("Jiaying")) /// ); /// assert_eq!( /// variant_object.field_by_name("last_name").unwrap(), -/// Some(Variant::ShortString("Li")) +/// Some(Variant::from("Li")) /// ); /// ``` /// @@ -281,17 +280,18 @@ impl VariantBuilder { self.buffer.extend_from_slice(value); } + fn append_short_string(&mut self, value: ShortString) { + let inner = value.0; + self.buffer.push(short_string_header(inner.len())); + self.buffer.extend_from_slice(inner.as_bytes()); + } + fn append_string(&mut self, value: &str) { - if value.len() <= MAX_SHORT_STRING_SIZE { - self.buffer.push(short_string_header(value.len())); - self.buffer.extend_from_slice(value.as_bytes()); - } else { - self.buffer - .push(primitive_header(VariantPrimitiveType::String)); - self.buffer - .extend_from_slice(&(value.len() as u32).to_le_bytes()); - self.buffer.extend_from_slice(value.as_bytes()); - } + self.buffer + .push(primitive_header(VariantPrimitiveType::String)); + self.buffer + .extend_from_slice(&(value.len() as u32).to_le_bytes()); + self.buffer.extend_from_slice(value.as_bytes()); } /// Add key to dictionary, return its ID @@ -390,7 +390,8 @@ impl VariantBuilder { Variant::Float(v) => self.append_float(v), Variant::Double(v) => self.append_double(v), Variant::Binary(v) => self.append_binary(v), - Variant::String(s) | Variant::ShortString(s) => self.append_string(s), + Variant::String(s) => self.append_string(s), + Variant::ShortString(s) => self.append_short_string(s), Variant::Object(_) | Variant::List(_) => { unreachable!("Object and List variants cannot be created through Into") } @@ -639,7 +640,7 @@ mod tests { builder.append_value("hello"); let (metadata, value) = builder.finish(); let variant = Variant::try_new(&metadata, &value).unwrap(); - assert_eq!(variant, Variant::ShortString("hello")); + assert_eq!(variant, Variant::ShortString(ShortString("hello"))); } { @@ -688,7 +689,7 @@ mod tests { assert_eq!(val1, Variant::Int8(2)); let val2 = list.get(2).unwrap(); - assert_eq!(val2, Variant::ShortString("test")); + assert_eq!(val2, Variant::ShortString(ShortString("test"))); } _ => panic!("Expected an array variant, got: {:?}", variant), } diff --git a/parquet-variant/src/decoder.rs b/parquet-variant/src/decoder.rs index 7fb41c7da202..7096b0a08631 100644 --- a/parquet-variant/src/decoder.rs +++ b/parquet-variant/src/decoder.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. use crate::utils::{array_from_slice, slice_from_slice, string_from_slice}; +use crate::ShortString; use arrow_schema::ArrowError; use chrono::{DateTime, Duration, NaiveDate, NaiveDateTime, Utc}; @@ -273,10 +274,10 @@ pub(crate) fn decode_long_string(data: &[u8]) -> Result<&str, ArrowError> { } /// Decodes a short string from the value section of a variant. -pub(crate) fn decode_short_string(metadata: u8, data: &[u8]) -> Result<&str, ArrowError> { +pub(crate) fn decode_short_string(metadata: u8, data: &[u8]) -> Result { let len = (metadata >> 2) as usize; let string = string_from_slice(data, 0..len)?; - Ok(string) + ShortString::try_new(string) } #[cfg(test)] @@ -420,7 +421,7 @@ mod tests { fn test_short_string() -> Result<(), ArrowError> { let data = [b'H', b'e', b'l', b'l', b'o', b'o']; let result = decode_short_string(1 | 5 << 2, &data)?; - assert_eq!(result, "Hello"); + assert_eq!(result.0, "Hello"); Ok(()) } diff --git a/parquet-variant/src/variant.rs b/parquet-variant/src/variant.rs index 843fe2048c72..2e042b6074cb 100644 --- a/parquet-variant/src/variant.rs +++ b/parquet-variant/src/variant.rs @@ -1,3 +1,5 @@ +use std::ops::Deref; + // Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information @@ -29,6 +31,65 @@ mod list; mod metadata; mod object; +const MAX_SHORT_STRING_BYTES: usize = 0x3F; + +/// A Variant [`ShortString`] +/// +/// This implementation is a zero cost wrapper over `&str` that ensures +/// the length of the underlying string is a valid Variant short string (63 bytes or less) +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct ShortString<'a>(pub(crate) &'a str); + +impl<'a> ShortString<'a> { + /// Attempts to interpret `value` as a variant short string value. + /// + /// # Validation + /// + /// This constructor verifies that `value` is shorter than or equal to `MAX_SHORT_STRING_BYTES` + pub fn try_new(value: &'a str) -> Result { + if value.len() > MAX_SHORT_STRING_BYTES { + return Err(ArrowError::InvalidArgumentError(format!( + "value is larger than {MAX_SHORT_STRING_BYTES} bytes" + ))); + } + + Ok(Self(value)) + } + + /// Returns the underlying Variant short string as a &str + pub fn as_str(&self) -> &'a str { + self.0 + } +} + +impl<'a> From> for &'a str { + fn from(value: ShortString<'a>) -> Self { + value.0 + } +} + +impl<'a> TryFrom<&'a str> for ShortString<'a> { + type Error = ArrowError; + + fn try_from(value: &'a str) -> Result { + Self::try_new(value) + } +} + +impl<'a> AsRef for ShortString<'a> { + fn as_ref(&self) -> &str { + self.0 + } +} + +impl<'a> Deref for ShortString<'a> { + type Target = str; + + fn deref(&self) -> &Self::Target { + self.0 + } +} + /// Represents a [Parquet Variant] /// /// The lifetimes `'m` and `'v` are for metadata and value buffers, respectively. @@ -85,7 +146,7 @@ mod object; /// /// ## Creating `Variant` from Rust Types /// ``` -/// # use parquet_variant::Variant; +/// use parquet_variant::Variant; /// // variants can be directly constructed /// let variant = Variant::Int32(123); /// // or constructed via `From` impls @@ -98,7 +159,7 @@ mod object; /// let value = [0x09, 0x48, 0x49]; /// // parse the header metadata /// assert_eq!( -/// Variant::ShortString("HI"), +/// Variant::from("HI"), /// Variant::try_new(&metadata, &value).unwrap() /// ); /// ``` @@ -152,7 +213,7 @@ pub enum Variant<'m, 'v> { /// Primitive (type_id=1): STRING String(&'v str), /// Short String (type_id=2): STRING - ShortString(&'v str), + ShortString(ShortString<'v>), // need both metadata & value /// Object (type_id=3): N/A Object(VariantObject<'m, 'v>), @@ -165,12 +226,12 @@ impl<'m, 'v> Variant<'m, 'v> { /// /// # Example /// ``` - /// # use parquet_variant::{Variant, VariantMetadata}; + /// use parquet_variant::{Variant, VariantMetadata}; /// let metadata = [0x01, 0x00, 0x00]; /// let value = [0x09, 0x48, 0x49]; /// // parse the header metadata /// assert_eq!( - /// Variant::ShortString("HI"), + /// Variant::from("HI"), /// Variant::try_new(&metadata, &value).unwrap() /// ); /// ``` @@ -189,7 +250,7 @@ impl<'m, 'v> Variant<'m, 'v> { /// // parse the header metadata first /// let metadata = VariantMetadata::try_new(&metadata).unwrap(); /// assert_eq!( - /// Variant::ShortString("HI"), + /// Variant::from("HI"), /// Variant::try_new_with_metadata(metadata, &value).unwrap() /// ); /// ``` @@ -432,7 +493,7 @@ impl<'m, 'v> Variant<'m, 'v> { /// /// // you can extract a string from string variants /// let s = "hello!"; - /// let v1 = Variant::ShortString(s); + /// let v1 = Variant::from(s); /// assert_eq!(v1.as_string(), Some(s)); /// /// // but not from other variants @@ -441,7 +502,7 @@ impl<'m, 'v> Variant<'m, 'v> { /// ``` pub fn as_string(&'v self) -> Option<&'v str> { match self { - Variant::String(s) | Variant::ShortString(s) => Some(s), + Variant::String(s) | Variant::ShortString(ShortString(s)) => Some(s), _ => None, } } @@ -861,10 +922,25 @@ impl<'v> From<&'v [u8]> for Variant<'_, 'v> { impl<'v> From<&'v str> for Variant<'_, 'v> { fn from(value: &'v str) -> Self { - if value.len() < 64 { - Variant::ShortString(value) - } else { + if value.len() > MAX_SHORT_STRING_BYTES { Variant::String(value) + } else { + Variant::ShortString(ShortString(value)) } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_construct_short_string() { + let short_string = ShortString::try_new("norm").expect("should fit in short string"); + assert_eq!(short_string.as_str(), "norm"); + + let long_string = "a".repeat(MAX_SHORT_STRING_BYTES + 1); + let res = ShortString::try_new(&long_string); + assert!(res.is_err()); + } +} diff --git a/parquet-variant/tests/variant_interop.rs b/parquet-variant/tests/variant_interop.rs index 82766a8fbea8..bfa2ab267c27 100644 --- a/parquet-variant/tests/variant_interop.rs +++ b/parquet-variant/tests/variant_interop.rs @@ -24,7 +24,7 @@ use std::fs; use std::path::{Path, PathBuf}; use chrono::NaiveDate; -use parquet_variant::{Variant, VariantBuilder}; +use parquet_variant::{ShortString, Variant, VariantBuilder}; fn cases_dir() -> PathBuf { Path::new(env!("CARGO_MANIFEST_DIR")) @@ -76,7 +76,7 @@ fn get_primitive_cases() -> Vec<(&'static str, Variant<'static, 'static>)> { ("primitive_string", Variant::String("This string is longer than 64 bytes and therefore does not fit in a short_string and it also includes several non ascii characters such as 🐢, 💖, ♥\u{fe0f}, 🎣 and 🤦!!")), ("primitive_timestamp", Variant::TimestampMicros(NaiveDate::from_ymd_opt(2025, 4, 16).unwrap().and_hms_milli_opt(16, 34, 56, 780).unwrap().and_utc())), ("primitive_timestampntz", Variant::TimestampNtzMicros(NaiveDate::from_ymd_opt(2025, 4, 16).unwrap().and_hms_milli_opt(12, 34, 56, 780).unwrap())), - ("short_string", Variant::ShortString("Less than 64 bytes (❤\u{fe0f} with utf8)")), + ("short_string", Variant::ShortString(ShortString::try_new("Less than 64 bytes (❤\u{fe0f} with utf8)").unwrap())), ] } #[test] @@ -130,11 +130,20 @@ fn variant_object_primitive() { ), ("int_field", Variant::Int8(1)), ("null_field", Variant::Null), - ("string_field", Variant::ShortString("Apache Parquet")), + ( + "string_field", + Variant::ShortString( + ShortString::try_new("Apache Parquet") + .expect("value should fit inside a short string"), + ), + ), ( // apparently spark wrote this as a string (not a timestamp) "timestamp_field", - Variant::ShortString("2025-04-16T12:34:56.78"), + Variant::ShortString( + ShortString::try_new("2025-04-16T12:34:56.78") + .expect("value should fit inside a short string"), + ), ), ]; let actual_fields: Vec<_> = variant_object.iter().collect();