From 32bc99cbccbcfe76c803aeff2be5698cdb9d85ac Mon Sep 17 00:00:00 2001 From: Ryan Johnson Date: Thu, 2 Oct 2025 14:44:46 -0700 Subject: [PATCH 1/8] [Variant] Define and use VariantDecimalType trait --- .../src/arrow_to_variant.rs | 139 ++++---- .../src/type_conversion.rs | 19 - .../src/unshred_variant.rs | 125 +++---- parquet-variant-compute/src/variant_array.rs | 36 +- parquet-variant/src/variant.rs | 6 +- parquet-variant/src/variant/decimal.rs | 329 ++++++++++++------ 6 files changed, 333 insertions(+), 321 deletions(-) diff --git a/parquet-variant-compute/src/arrow_to_variant.rs b/parquet-variant-compute/src/arrow_to_variant.rs index fe0c52109052..72618533c57e 100644 --- a/parquet-variant-compute/src/arrow_to_variant.rs +++ b/parquet-variant-compute/src/arrow_to_variant.rs @@ -15,27 +15,25 @@ // specific language governing permissions and limitations // under the License. -use crate::type_conversion::{CastOptions, decimal_to_variant_decimal}; +use crate::type_conversion::CastOptions; use arrow::array::{ Array, AsArray, FixedSizeListArray, GenericBinaryArray, GenericListArray, GenericListViewArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray, }; use arrow::compute::kernels::cast; use arrow::datatypes::{ - ArrowNativeType, ArrowPrimitiveType, ArrowTemporalType, ArrowTimestampType, Date32Type, - Date64Type, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, - RunEndIndexType, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, - Time64NanosecondType, TimestampMicrosecondType, TimestampMillisecondType, - TimestampNanosecondType, TimestampSecondType, UInt8Type, UInt16Type, UInt32Type, UInt64Type, + self as datatypes, ArrowNativeType, ArrowPrimitiveType, ArrowTemporalType, ArrowTimestampType, + DecimalType, RunEndIndexType, }; use arrow::temporal_conversions::{as_date, as_datetime, as_time}; use arrow_schema::{ArrowError, DataType, TimeUnit}; use chrono::{DateTime, TimeZone, Utc}; use parquet_variant::{ ObjectFieldBuilder, Variant, VariantBuilderExt, VariantDecimal4, VariantDecimal8, - VariantDecimal16, + VariantDecimal16, VariantDecimalType, }; use std::collections::HashMap; +use std::marker::PhantomData; use std::ops::Range; // ============================================================================ @@ -46,31 +44,31 @@ use std::ops::Range; pub(crate) enum ArrowToVariantRowBuilder<'a> { Null(NullArrowToVariantBuilder), Boolean(BooleanArrowToVariantBuilder<'a>), - PrimitiveInt8(PrimitiveArrowToVariantBuilder<'a, Int8Type>), - PrimitiveInt16(PrimitiveArrowToVariantBuilder<'a, Int16Type>), - PrimitiveInt32(PrimitiveArrowToVariantBuilder<'a, Int32Type>), - PrimitiveInt64(PrimitiveArrowToVariantBuilder<'a, Int64Type>), - PrimitiveUInt8(PrimitiveArrowToVariantBuilder<'a, UInt8Type>), - PrimitiveUInt16(PrimitiveArrowToVariantBuilder<'a, UInt16Type>), - PrimitiveUInt32(PrimitiveArrowToVariantBuilder<'a, UInt32Type>), - PrimitiveUInt64(PrimitiveArrowToVariantBuilder<'a, UInt64Type>), - PrimitiveFloat16(PrimitiveArrowToVariantBuilder<'a, Float16Type>), - PrimitiveFloat32(PrimitiveArrowToVariantBuilder<'a, Float32Type>), - PrimitiveFloat64(PrimitiveArrowToVariantBuilder<'a, Float64Type>), - Decimal32(Decimal32ArrowToVariantBuilder<'a>), - Decimal64(Decimal64ArrowToVariantBuilder<'a>), - Decimal128(Decimal128ArrowToVariantBuilder<'a>), + PrimitiveInt8(PrimitiveArrowToVariantBuilder<'a, datatypes::Int8Type>), + PrimitiveInt16(PrimitiveArrowToVariantBuilder<'a, datatypes::Int16Type>), + PrimitiveInt32(PrimitiveArrowToVariantBuilder<'a, datatypes::Int32Type>), + PrimitiveInt64(PrimitiveArrowToVariantBuilder<'a, datatypes::Int64Type>), + PrimitiveUInt8(PrimitiveArrowToVariantBuilder<'a, datatypes::UInt8Type>), + PrimitiveUInt16(PrimitiveArrowToVariantBuilder<'a, datatypes::UInt16Type>), + PrimitiveUInt32(PrimitiveArrowToVariantBuilder<'a, datatypes::UInt32Type>), + PrimitiveUInt64(PrimitiveArrowToVariantBuilder<'a, datatypes::UInt64Type>), + PrimitiveFloat16(PrimitiveArrowToVariantBuilder<'a, datatypes::Float16Type>), + PrimitiveFloat32(PrimitiveArrowToVariantBuilder<'a, datatypes::Float32Type>), + PrimitiveFloat64(PrimitiveArrowToVariantBuilder<'a, datatypes::Float64Type>), + Decimal32(DecimalArrowToVariantBuilder<'a, datatypes::Decimal32Type, VariantDecimal4>), + Decimal64(DecimalArrowToVariantBuilder<'a, datatypes::Decimal64Type, VariantDecimal8>), + Decimal128(DecimalArrowToVariantBuilder<'a, datatypes::Decimal128Type, VariantDecimal16>), Decimal256(Decimal256ArrowToVariantBuilder<'a>), - TimestampSecond(TimestampArrowToVariantBuilder<'a, TimestampSecondType>), - TimestampMillisecond(TimestampArrowToVariantBuilder<'a, TimestampMillisecondType>), - TimestampMicrosecond(TimestampArrowToVariantBuilder<'a, TimestampMicrosecondType>), - TimestampNanosecond(TimestampArrowToVariantBuilder<'a, TimestampNanosecondType>), - Date32(DateArrowToVariantBuilder<'a, Date32Type>), - Date64(DateArrowToVariantBuilder<'a, Date64Type>), - Time32Second(TimeArrowToVariantBuilder<'a, Time32SecondType>), - Time32Millisecond(TimeArrowToVariantBuilder<'a, Time32MillisecondType>), - Time64Microsecond(TimeArrowToVariantBuilder<'a, Time64MicrosecondType>), - Time64Nanosecond(TimeArrowToVariantBuilder<'a, Time64NanosecondType>), + TimestampSecond(TimestampArrowToVariantBuilder<'a, datatypes::TimestampSecondType>), + TimestampMillisecond(TimestampArrowToVariantBuilder<'a, datatypes::TimestampMillisecondType>), + TimestampMicrosecond(TimestampArrowToVariantBuilder<'a, datatypes::TimestampMicrosecondType>), + TimestampNanosecond(TimestampArrowToVariantBuilder<'a, datatypes::TimestampNanosecondType>), + Date32(DateArrowToVariantBuilder<'a, datatypes::Date32Type>), + Date64(DateArrowToVariantBuilder<'a, datatypes::Date64Type>), + Time32Second(TimeArrowToVariantBuilder<'a, datatypes::Time32SecondType>), + Time32Millisecond(TimeArrowToVariantBuilder<'a, datatypes::Time32MillisecondType>), + Time64Microsecond(TimeArrowToVariantBuilder<'a, datatypes::Time64MicrosecondType>), + Time64Nanosecond(TimeArrowToVariantBuilder<'a, datatypes::Time64NanosecondType>), Binary(BinaryArrowToVariantBuilder<'a, i32>), LargeBinary(BinaryArrowToVariantBuilder<'a, i64>), BinaryView(BinaryViewArrowToVariantBuilder<'a>), @@ -87,9 +85,9 @@ pub(crate) enum ArrowToVariantRowBuilder<'a> { Map(MapArrowToVariantBuilder<'a>), Union(UnionArrowToVariantBuilder<'a>), Dictionary(DictionaryArrowToVariantBuilder<'a>), - RunEndEncodedInt16(RunEndEncodedArrowToVariantBuilder<'a, Int16Type>), - RunEndEncodedInt32(RunEndEncodedArrowToVariantBuilder<'a, Int32Type>), - RunEndEncodedInt64(RunEndEncodedArrowToVariantBuilder<'a, Int64Type>), + RunEndEncodedInt16(RunEndEncodedArrowToVariantBuilder<'a, datatypes::Int16Type>), + RunEndEncodedInt32(RunEndEncodedArrowToVariantBuilder<'a, datatypes::Int32Type>), + RunEndEncodedInt64(RunEndEncodedArrowToVariantBuilder<'a, datatypes::Int64Type>), } impl<'a> ArrowToVariantRowBuilder<'a> { @@ -173,17 +171,26 @@ pub(crate) fn make_arrow_to_variant_row_builder<'a>( DataType::Float16 => PrimitiveFloat16(PrimitiveArrowToVariantBuilder::new(array)), DataType::Float32 => PrimitiveFloat32(PrimitiveArrowToVariantBuilder::new(array)), DataType::Float64 => PrimitiveFloat64(PrimitiveArrowToVariantBuilder::new(array)), - DataType::Decimal32(_, scale) => { - Decimal32(Decimal32ArrowToVariantBuilder::new(array, options, *scale)) - } - DataType::Decimal64(_, scale) => { - Decimal64(Decimal64ArrowToVariantBuilder::new(array, options, *scale)) - } - DataType::Decimal128(_, scale) => { - Decimal128(Decimal128ArrowToVariantBuilder::new(array, options, *scale)) - } - DataType::Decimal256(_, scale) => { - Decimal256(Decimal256ArrowToVariantBuilder::new(array, options, *scale)) + DataType::Decimal32(_, s) => Decimal32(DecimalArrowToVariantBuilder::new( + array, + options, + *s, + PhantomData, + )), + DataType::Decimal64(_, s) => Decimal64(DecimalArrowToVariantBuilder::new( + array, + options, + *s, + PhantomData, + )), + DataType::Decimal128(_, s) => Decimal128(DecimalArrowToVariantBuilder::new( + array, + options, + *s, + PhantomData, + )), + DataType::Decimal256(_, s) => { + Decimal256(Decimal256ArrowToVariantBuilder::new(array, options, *s)) } DataType::Timestamp(time_unit, time_zone) => { match time_unit { @@ -320,20 +327,20 @@ pub(crate) fn make_arrow_to_variant_row_builder<'a>( // worth the trouble, tho, because it makes for some pretty bulky and unwieldy macro expansions. macro_rules! define_row_builder { ( - struct $name:ident<$lifetime:lifetime $(, $generic:ident: $bound:path )?> + struct $name:ident<$lifetime:lifetime $(, $generic:ident: $bound:path )*> $( where $where_path:path: $where_bound:path $(,)? )? $({ $($field:ident: $field_type:ty),+ $(,)? })?, |$array_param:ident| -> $array_type:ty { $init_expr:expr } $(, |$value:ident| $(-> Option<$option_ty:ty>)? $value_transform:expr)? ) => { - pub(crate) struct $name<$lifetime $(, $generic: $bound )?> + pub(crate) struct $name<$lifetime $(, $generic: $bound )*> $( where $where_path: $where_bound )? { array: &$lifetime $array_type, $( $( $field: $field_type, )+ )? } - impl<$lifetime $(, $generic: $bound+ )?> $name<$lifetime $(, $generic)?> + impl<$lifetime $(, $generic: $bound )*> $name<$lifetime $(, $generic)*> $( where $where_path: $where_bound )? { pub(crate) fn new($array_param: &$lifetime dyn Array $(, $( $field: $field_type ),+ )?) -> Self { @@ -401,32 +408,19 @@ define_row_builder!( ); define_row_builder!( - struct Decimal32ArrowToVariantBuilder<'a> { - options: &'a CastOptions, - scale: i8, - }, - |array| -> arrow::array::Decimal32Array { array.as_primitive() }, - |value| -> Option<_> { decimal_to_variant_decimal!(value, scale, i32, VariantDecimal4) } -); - -define_row_builder!( - struct Decimal64ArrowToVariantBuilder<'a> { - options: &'a CastOptions, - scale: i8, - }, - |array| -> arrow::array::Decimal64Array { array.as_primitive() }, - |value| -> Option<_> { decimal_to_variant_decimal!(value, scale, i64, VariantDecimal8) } -); - -define_row_builder!( - struct Decimal128ArrowToVariantBuilder<'a> { + struct DecimalArrowToVariantBuilder<'a, A: DecimalType, V: VariantDecimalType> + where + V::Native: From, + { options: &'a CastOptions, scale: i8, + _phantom: PhantomData, }, - |array| -> arrow::array::Decimal128Array { array.as_primitive() }, - |value| -> Option<_> { decimal_to_variant_decimal!(value, scale, i128, VariantDecimal16) } + |array| -> PrimitiveArray { array.as_primitive() }, + |value| -> Option<_> { V::try_new_with_signed_scale(value.into(), *scale).ok() } ); +// Decimal256 needs a two-stage conversion via i128 define_row_builder!( struct Decimal256ArrowToVariantBuilder<'a> { options: &'a CastOptions, @@ -434,10 +428,8 @@ define_row_builder!( }, |array| -> arrow::array::Decimal256Array { array.as_primitive() }, |value| -> Option<_> { - // Decimal256 needs special handling - convert to i128 if possible - value.to_i128().and_then(|i128_val| { - decimal_to_variant_decimal!(i128_val, scale, i128, VariantDecimal16) - }) + let value = value.to_i128(); + value.and_then(|v| VariantDecimal16::try_new_with_signed_scale(v, *scale).ok()) } ); @@ -911,6 +903,7 @@ mod tests { use super::*; use crate::{VariantArray, VariantArrayBuilder}; use arrow::array::{ArrayRef, BooleanArray, Int32Array, StringArray}; + use arrow::datatypes::Int32Type; use std::sync::Arc; /// Builds a VariantArray from an Arrow array using the row builder. diff --git a/parquet-variant-compute/src/type_conversion.rs b/parquet-variant-compute/src/type_conversion.rs index 5dda1855297a..6d5a2c901f6d 100644 --- a/parquet-variant-compute/src/type_conversion.rs +++ b/parquet-variant-compute/src/type_conversion.rs @@ -102,22 +102,3 @@ macro_rules! primitive_conversion_single_value { }}; } pub(crate) use primitive_conversion_single_value; - -/// Convert a decimal value to a `VariantDecimal` -macro_rules! decimal_to_variant_decimal { - ($v:ident, $scale:expr, $value_type:ty, $variant_type:ty) => {{ - let (v, scale) = if *$scale < 0 { - // For negative scale, we need to multiply the value by 10^|scale| - // For example: 123 with scale -2 becomes 12300 with scale 0 - let multiplier = <$value_type>::pow(10, (-*$scale) as u32); - (<$value_type>::checked_mul($v, multiplier), 0u8) - } else { - (Some($v), *$scale as u8) - }; - - // Return an Option to allow callers to decide whether to error (strict) - // or append null (non-strict) on conversion failure - v.and_then(|v| <$variant_type>::try_new(v, scale).ok()) - }}; -} -pub(crate) use decimal_to_variant_decimal; diff --git a/parquet-variant-compute/src/unshred_variant.rs b/parquet-variant-compute/src/unshred_variant.rs index 64eaa46ed06b..44687af78480 100644 --- a/parquet-variant-compute/src/unshred_variant.rs +++ b/parquet-variant-compute/src/unshred_variant.rs @@ -35,8 +35,10 @@ use chrono::{DateTime, Utc}; use indexmap::IndexMap; use parquet_variant::{ ObjectFieldBuilder, Variant, VariantBuilderExt, VariantDecimal4, VariantDecimal8, - VariantDecimal16, VariantMetadata, + VariantDecimal16, VariantDecimalType, VariantMetadata, is_valid_variant_decimal4, + is_valid_variant_decimal8, is_valid_variant_decimal16, }; +use std::marker::PhantomData; use uuid::Uuid; /// Removes all (nested) typed_value columns from a VariantArray by converting them back to binary @@ -95,9 +97,9 @@ enum UnshredVariantRowBuilder<'a> { PrimitiveInt64(UnshredPrimitiveRowBuilder<'a, PrimitiveArray>), PrimitiveFloat32(UnshredPrimitiveRowBuilder<'a, PrimitiveArray>), PrimitiveFloat64(UnshredPrimitiveRowBuilder<'a, PrimitiveArray>), - Decimal32(DecimalUnshredRowBuilder<'a, Decimal32Spec>), - Decimal64(DecimalUnshredRowBuilder<'a, Decimal64Spec>), - Decimal128(DecimalUnshredRowBuilder<'a, Decimal128Spec>), + Decimal32(DecimalUnshredRowBuilder<'a, Decimal32Type, VariantDecimal4>), + Decimal64(DecimalUnshredRowBuilder<'a, Decimal64Type, VariantDecimal8>), + Decimal128(DecimalUnshredRowBuilder<'a, Decimal128Type, VariantDecimal16>), PrimitiveDate32(UnshredPrimitiveRowBuilder<'a, PrimitiveArray>), PrimitiveTime64(UnshredPrimitiveRowBuilder<'a, PrimitiveArray>), TimestampMicrosecond(TimestampUnshredRowBuilder<'a, TimestampMicrosecondType>), @@ -185,25 +187,23 @@ impl<'a> UnshredVariantRowBuilder<'a> { DataType::Int64 => primitive_builder!(PrimitiveInt64, as_primitive), DataType::Float32 => primitive_builder!(PrimitiveFloat32, as_primitive), DataType::Float64 => primitive_builder!(PrimitiveFloat64, as_primitive), - DataType::Decimal32(_, scale) => Self::Decimal32(DecimalUnshredRowBuilder::new( - value, - typed_value.as_primitive(), - *scale, - )), - DataType::Decimal64(_, scale) => Self::Decimal64(DecimalUnshredRowBuilder::new( - value, - typed_value.as_primitive(), - *scale, - )), - DataType::Decimal128(_, scale) => Self::Decimal128(DecimalUnshredRowBuilder::new( - value, - typed_value.as_primitive(), - *scale, - )), - DataType::Decimal256(_, _) => { - return Err(ArrowError::InvalidArgumentError( - "Decimal256 is not a valid variant shredding type".to_string(), - )); + DataType::Decimal32(p, s) if is_valid_variant_decimal4(p, s) => Self::Decimal32( + DecimalUnshredRowBuilder::new(value, typed_value.as_primitive(), *s as _), + ), + DataType::Decimal64(p, s) if is_valid_variant_decimal8(p, s) => Self::Decimal64( + DecimalUnshredRowBuilder::new(value, typed_value.as_primitive(), *s as _), + ), + DataType::Decimal128(p, s) if is_valid_variant_decimal16(p, s) => Self::Decimal128( + DecimalUnshredRowBuilder::new(value, typed_value.as_primitive(), *s as _), + ), + DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) => { + return Err(ArrowError::InvalidArgumentError(format!( + "{} is not a valid variant shredding type", + typed_value.data_type() + ))); } DataType::Date32 => primitive_builder!(PrimitiveDate32, as_primitive), DataType::Time64(TimeUnit::Microsecond) => { @@ -504,78 +504,33 @@ impl<'a, T: TimestampType> TimestampUnshredRowBuilder<'a, T> { } } -/// Trait to unify decimal unshredding across Decimal32/64/128 types -trait DecimalSpec { - type Arrow: ArrowPrimitiveType + DecimalType; - - fn into_variant( - raw: ::Native, - scale: i8, - ) -> Result>; -} - -/// Spec for Decimal32 -> VariantDecimal4 -struct Decimal32Spec; - -impl DecimalSpec for Decimal32Spec { - type Arrow = Decimal32Type; - - fn into_variant(raw: i32, scale: i8) -> Result> { - let scale = - u8::try_from(scale).map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?; - let value = VariantDecimal4::try_new(raw, scale) - .map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?; - Ok(value.into()) - } -} - -/// Spec for Decimal64 -> VariantDecimal8 -struct Decimal64Spec; - -impl DecimalSpec for Decimal64Spec { - type Arrow = Decimal64Type; - - fn into_variant(raw: i64, scale: i8) -> Result> { - let scale = - u8::try_from(scale).map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?; - let value = VariantDecimal8::try_new(raw, scale) - .map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?; - Ok(value.into()) - } -} - -/// Spec for Decimal128 -> VariantDecimal16 -struct Decimal128Spec; - -impl DecimalSpec for Decimal128Spec { - type Arrow = Decimal128Type; - - fn into_variant(raw: i128, scale: i8) -> Result> { - let scale = - u8::try_from(scale).map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?; - let value = VariantDecimal16::try_new(raw, scale) - .map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?; - Ok(value.into()) - } -} - -/// Generic builder for decimal unshredding that caches scale -struct DecimalUnshredRowBuilder<'a, S: DecimalSpec> { +/// Generic builder for decimal unshredding +struct DecimalUnshredRowBuilder<'a, A, V> +where + A: DecimalType, + V: VariantDecimalType, +{ value: Option<&'a BinaryViewArray>, - typed_value: &'a PrimitiveArray, + typed_value: &'a PrimitiveArray, scale: i8, + _phantom: PhantomData, } -impl<'a, S: DecimalSpec> DecimalUnshredRowBuilder<'a, S> { +impl<'a, A, V> DecimalUnshredRowBuilder<'a, A, V> +where + A: DecimalType, + V: VariantDecimalType, +{ fn new( value: Option<&'a BinaryViewArray>, - typed_value: &'a PrimitiveArray, + typed_value: &'a PrimitiveArray, scale: i8, ) -> Self { Self { value, typed_value, scale, + _phantom: PhantomData, } } @@ -588,8 +543,8 @@ impl<'a, S: DecimalSpec> DecimalUnshredRowBuilder<'a, S> { handle_unshredded_case!(self, builder, metadata, index, false); let raw = self.typed_value.value(index); - let variant = S::into_variant(raw, self.scale)?; - builder.append_value(variant); + let value = V::try_new_with_signed_scale(raw, self.scale)?; + builder.append_value(value); Ok(()) } } diff --git a/parquet-variant-compute/src/variant_array.rs b/parquet-variant-compute/src/variant_array.rs index 5686d102d3fd..8aa173886568 100644 --- a/parquet-variant-compute/src/variant_array.rs +++ b/parquet-variant-compute/src/variant_array.rs @@ -26,13 +26,11 @@ use arrow::datatypes::{ TimestampMicrosecondType, TimestampNanosecondType, }; use arrow_schema::extension::ExtensionType; -use arrow_schema::{ - ArrowError, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, DECIMAL128_MAX_PRECISION, - DataType, Field, FieldRef, Fields, TimeUnit, -}; +use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields, TimeUnit}; use chrono::DateTime; -use parquet_variant::Uuid; -use parquet_variant::Variant; +use parquet_variant::{ + Uuid, Variant, is_valid_variant_decimal4, is_valid_variant_decimal8, is_valid_variant_decimal16, +}; use std::borrow::Cow; use std::sync::Arc; @@ -937,18 +935,6 @@ fn cast_to_binary_view_arrays(array: &dyn Array) -> Result cast(array, new_type.as_ref()) } -/// Validates whether a given arrow decimal is a valid variant decimal -/// -/// NOTE: By a strict reading of the "decimal table" in the [shredding spec], each decimal type -/// should have a width-dependent lower bound on precision as well as an upper bound (i.e. Decimal16 -/// with precision 5 is invalid because Decimal4 "covers" it). But the variant shredding integration -/// tests specifically expect such cases to succeed, so we only enforce the upper bound here. -/// -/// [shredding spec]: https://github.com/apache/parquet-format/blob/master/VariantEncoding.md#encoding-types -fn is_valid_variant_decimal(p: &u8, s: &i8, max_precision: u8) -> bool { - (1..=max_precision).contains(p) && (0..=*p as i8).contains(s) -} - /// Recursively visits a data type, ensuring that it only contains data types that can legally /// appear in a (possibly shredded) variant array. It also replaces Binary fields with BinaryView, /// since that's what comes back from the parquet reader and what the variant code expects to find. @@ -983,17 +969,13 @@ fn canonicalize_and_verify_data_type( // // NOTE: arrow-parquet reads widens 32- and 64-bit decimals to 128-bit, but the variant spec // requires using the narrowest decimal type for a given precision. Fix those up first. - Decimal64(p, s) | Decimal128(p, s) - if is_valid_variant_decimal(p, s, DECIMAL32_MAX_PRECISION) => - { + Decimal64(p, s) | Decimal128(p, s) if is_valid_variant_decimal4(p, s) => { Cow::Owned(Decimal32(*p, *s)) } - Decimal128(p, s) if is_valid_variant_decimal(p, s, DECIMAL64_MAX_PRECISION) => { - Cow::Owned(Decimal64(*p, *s)) - } - Decimal32(p, s) if is_valid_variant_decimal(p, s, DECIMAL32_MAX_PRECISION) => borrow!(), - Decimal64(p, s) if is_valid_variant_decimal(p, s, DECIMAL64_MAX_PRECISION) => borrow!(), - Decimal128(p, s) if is_valid_variant_decimal(p, s, DECIMAL128_MAX_PRECISION) => borrow!(), + Decimal128(p, s) if is_valid_variant_decimal8(p, s) => Cow::Owned(Decimal64(*p, *s)), + Decimal32(p, s) if is_valid_variant_decimal4(p, s) => borrow!(), + Decimal64(p, s) if is_valid_variant_decimal8(p, s) => borrow!(), + Decimal128(p, s) if is_valid_variant_decimal16(p, s) => borrow!(), Decimal32(..) | Decimal64(..) | Decimal128(..) | Decimal256(..) => fail!(), // Only micro and nano timestamps are allowed diff --git a/parquet-variant/src/variant.rs b/parquet-variant/src/variant.rs index 849947675b13..8a05dd890b61 100644 --- a/parquet-variant/src/variant.rs +++ b/parquet-variant/src/variant.rs @@ -15,7 +15,11 @@ // specific language governing permissions and limitations // under the License. -pub use self::decimal::{VariantDecimal16, VariantDecimal4, VariantDecimal8}; +pub use self::decimal::{ + is_valid_variant_decimal, is_valid_variant_decimal16, is_valid_variant_decimal4, + is_valid_variant_decimal8, VariantDecimal16, VariantDecimal4, VariantDecimal8, + VariantDecimalType, +}; pub use self::list::VariantList; pub use self::metadata::{VariantMetadata, EMPTY_VARIANT_METADATA, EMPTY_VARIANT_METADATA_BYTES}; pub use self::object::VariantObject; diff --git a/parquet-variant/src/variant/decimal.rs b/parquet-variant/src/variant/decimal.rs index 4793d88569bf..ce90097ade05 100644 --- a/parquet-variant/src/variant/decimal.rs +++ b/parquet-variant/src/variant/decimal.rs @@ -17,52 +17,209 @@ use arrow_schema::ArrowError; use std::fmt; -// All decimal types use the same try_new implementation -macro_rules! decimal_try_new { - ($integer:ident, $scale:ident) => {{ - // Validate that scale doesn't exceed precision - if $scale > Self::MAX_PRECISION { - return Err(ArrowError::InvalidArgumentError(format!( - "Scale {} is larger than max precision {}", - $scale, - Self::MAX_PRECISION, - ))); - } +/// True if the given precision and scale are valid for a variant decimal type with the given +/// maximum precision. +/// +/// NOTE: By a strict reading of the "decimal table" in the [variant spec], one might conclude that +/// each decimal type has both lower and upper bounds on precision (i.e. Decimal16 with precision 5 +/// is invalid because Decimal4 "covers" it). But the variant shredding integration tests +/// specifically expect such cases to succeed, so we only enforce the upper bound here. +/// +/// [shredding spec]: https://github.com/apache/parquet-format/blob/master/VariantEncoding.md#encoding-types +/// +/// # Example +/// ``` +/// # use parquet_variant::{is_valid_variant_decimal, VariantDecimal4}; +/// # +/// assert!(is_valid_variant_decimal(&5, &2, VariantDecimal4::MAX_PRECISION)); +/// assert!(!is_valid_variant_decimal(&10, &2, VariantDecimal4::MAX_PRECISION)); // too wide +/// assert!(!is_valid_variant_decimal(&5, &-1, VariantDecimal4::MAX_PRECISION)); // negative scale +/// assert!(!is_valid_variant_decimal(&5, &7, VariantDecimal4::MAX_PRECISION)); // scale too big +/// ``` +pub fn is_valid_variant_decimal(precision: &u8, scale: &i8, max_precision: u8) -> bool { + (1..=max_precision).contains(precision) && (0..=*precision as i8).contains(scale) +} - // Validate that the integer value fits within the precision - if $integer.unsigned_abs() > Self::MAX_UNSCALED_VALUE { - return Err(ArrowError::InvalidArgumentError(format!( - "{} is wider than max precision {}", - $integer, - Self::MAX_PRECISION - ))); - } +/// True if the given precision and scale are valid for a variant Decimal4 (max precision 9). +/// +/// See [`is_valid_variant_decimal`] for details. +pub fn is_valid_variant_decimal4(precision: &u8, scale: &i8) -> bool { + is_valid_variant_decimal(precision, scale, VariantDecimal4::MAX_PRECISION) +} + +/// True if the given precision and scale are valid for a variant Decimal8 (max precision 18). +/// +/// See [`is_valid_variant_decimal`] for details. +pub fn is_valid_variant_decimal8(precision: &u8, scale: &i8) -> bool { + is_valid_variant_decimal(precision, scale, VariantDecimal8::MAX_PRECISION) +} - Ok(Self { $integer, $scale }) - }}; +/// True if the given precision and scale are valid for a variant Decimal16 (max precision 38). +/// +/// See [`is_valid_variant_decimal`] for details. +pub fn is_valid_variant_decimal16(precision: &u8, scale: &i8) -> bool { + is_valid_variant_decimal(precision, scale, VariantDecimal16::MAX_PRECISION) } -// All decimal values format the same way, using integer arithmetic to avoid floating point precision loss -macro_rules! format_decimal { - ($f:expr, $integer:expr, $scale:expr, $int_type:ty) => {{ - let integer = if $scale == 0 { - $integer - } else { - let divisor = <$int_type>::pow(10, $scale as u32); - let remainder = $integer % divisor; - if remainder != 0 { - // Track the sign explicitly, in case the quotient is zero - let sign = if $integer < 0 { "-" } else { "" }; - // Format an unsigned remainder with leading zeros and strip (unnecessary) trailing zeros. - let remainder = format!("{:0width$}", remainder.abs(), width = $scale as usize); - let remainder = remainder.trim_end_matches('0'); - let quotient = $integer / divisor; - return write!($f, "{}{}.{}", sign, quotient.abs(), remainder); +/// Trait for variant decimal types, enabling generic code across Decimal4/8/16 +/// +/// This trait provides a common interface for the three variant decimal types, +/// allowing generic functions and data structures to work with any decimal width. +/// It is modeled after Arrow's `DecimalType` trait but adapted for variant semantics. +/// +/// # Example +/// +/// ``` +/// # use parquet_variant::{VariantDecimal4, VariantDecimal8, VariantDecimalType}; +/// # +/// fn extract_scale(decimal: D) -> u8 { +/// decimal.scale() +/// } +/// +/// let dec4 = VariantDecimal4::try_new(12345, 2).unwrap(); +/// let dec8 = VariantDecimal8::try_new(67890, 3).unwrap(); +/// +/// assert_eq!(extract_scale(dec4), 2); +/// assert_eq!(extract_scale(dec8), 3); +/// ``` +pub trait VariantDecimalType: Into> { + /// The underlying signed integer type (i32, i64, or i128) + type Native; + + /// Maximum number of significant digits this decimal type can represent (9, 18, or 38) + const MAX_PRECISION: u8; + + /// Creates a new decimal value from the given unscaled integer and scale, failing if the + /// integer's width, or the requested scale, exceeds `MAX_PRECISION`. + /// + /// NOTE: For compatibility with arrow decimal types, negative scale is allowed as long + /// as the rescaled value fits in the available precision. + /// + /// # Example + /// + /// ``` + /// # use parquet_variant::{VariantDecimal4, VariantDecimalType}; + /// # + /// // Valid: 123.45 (5 digits, scale 2) + /// let d = VariantDecimal4::try_new(12345, 2).unwrap(); + /// assert_eq!(d.integer(), 12345); + /// assert_eq!(d.scale(), 2); + /// + /// VariantDecimal4::try_new(123, 10).expect_err("scale exceeds MAX_PRECISION"); + /// VariantDecimal4::try_new(1234567890, 10).expect_err("value's width exceeds MAX_PRECISION"); + /// ``` + fn try_new(integer: Self::Native, scale: u8) -> Result; + + /// Attempts to convert an unscaled arrow decimal value to the indicated variant decimal type. + /// + /// Unlike [`Self::try_new`], this function accepts a signed scale, and attempts to rescale + /// negative-scale values to their equivalent (larger) scale-0 values. For example, a decimal + /// value of 123 with scale -2 becomes 12300 with scale 0. + /// + /// Fails if rescaling fails, or for any of the reasons [`Self::try_new`] could fail. + fn try_new_with_signed_scale(integer: Self::Native, scale: i8) -> Result; + + /// Returns the unscaled integer value + fn integer(&self) -> Self::Native; + + /// Returns the scale (number of digits after the decimal point) + fn scale(&self) -> u8; +} + +/// Implements the complete variant decimal type: methods, Display, and VariantDecimalType trait +macro_rules! impl_variant_decimal { + ($struct_name:ident, $native:ty) => { + impl $struct_name { + /// Attempts to create a new instance of this decimal type, failing if the value or + /// scale is too large. + pub fn try_new(integer: $native, scale: u8) -> Result { + let max_precision = Self::MAX_PRECISION; + if scale > max_precision { + return Err(ArrowError::InvalidArgumentError(format!( + "Scale {scale} is larger than max precision {max_precision}", + ))); + } + + // Validate that the integer value fits within the decimal's maximum precision + if integer.unsigned_abs() > Self::MAX_UNSCALED_VALUE { + return Err(ArrowError::InvalidArgumentError(format!( + "{integer} is wider than max precision {max_precision}", + ))); + } + + Ok(Self { integer, scale }) } - $integer / divisor - }; - write!($f, "{}", integer) - }}; + + /// Returns the underlying value of the decimal. + /// + /// For example, if the decimal is `123.45`, this will return `12345`. + pub fn integer(&self) -> $native { + self.integer + } + + /// Returns the scale of the decimal (how many digits after the decimal point). + /// + /// For example, if the decimal is `123.45`, this will return `2`. + pub fn scale(&self) -> u8 { + self.scale + } + } + + impl VariantDecimalType for $struct_name { + type Native = $native; + const MAX_PRECISION: u8 = Self::MAX_PRECISION; + + fn try_new(integer: $native, scale: u8) -> Result { + Self::try_new(integer, scale) + } + + fn try_new_with_signed_scale(integer: $native, scale: i8) -> Result { + let (integer, scale) = if scale < 0 { + let multiplier = <$native>::checked_pow(10, (-scale) as u32); + let Some(rescaled) = multiplier.and_then(|m| integer.checked_mul(m)) else { + return Err(ArrowError::InvalidArgumentError(format!( + "Overflow when rescaling {integer} with scale {scale}" + ))); + }; + (rescaled, 0u8) + } else { + (integer, scale as u8) + }; + Self::try_new(integer, scale) + } + + fn integer(&self) -> $native { + self.integer() + } + + fn scale(&self) -> u8 { + self.scale() + } + } + + impl fmt::Display for $struct_name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let integer = if self.scale == 0 { + self.integer + } else { + let divisor = <$native>::pow(10, self.scale as u32); + let remainder = self.integer % divisor; + if remainder != 0 { + // Track the sign explicitly, in case the quotient is zero + let sign = if self.integer < 0 { "-" } else { "" }; + // Format an unsigned remainder with leading zeros and strip trailing zeros + let remainder = + format!("{:0width$}", remainder.abs(), width = self.scale as usize); + let remainder = remainder.trim_end_matches('0'); + let quotient = (self.integer / divisor).abs(); + return write!(f, "{sign}{quotient}.{remainder}"); + } + self.integer / divisor + }; + write!(f, "{integer}") + } + } + }; } /// Represents a 4-byte decimal value in the Variant format. @@ -86,33 +243,13 @@ pub struct VariantDecimal4 { } impl VariantDecimal4 { - pub(crate) const MAX_PRECISION: u8 = 9; - pub(crate) const MAX_UNSCALED_VALUE: u32 = u32::pow(10, Self::MAX_PRECISION as u32) - 1; - - pub fn try_new(integer: i32, scale: u8) -> Result { - decimal_try_new!(integer, scale) - } - - /// Returns the underlying value of the decimal. - /// - /// For example, if the decimal is `123.4567`, this will return `1234567`. - pub fn integer(&self) -> i32 { - self.integer - } - - /// Returns the scale of the decimal (how many digits after the decimal point). - /// - /// For example, if the decimal is `123.4567`, this will return `4`. - pub fn scale(&self) -> u8 { - self.scale - } + /// Maximum number of significant digits (9 for 4-byte decimals) + pub const MAX_PRECISION: u8 = arrow_schema::DECIMAL32_MAX_PRECISION; + /// The largest unscaled value that fits in [`Self::MAX_PRECISION`] digits. + pub const MAX_UNSCALED_VALUE: u32 = u32::pow(10, Self::MAX_PRECISION as u32) - 1; } -impl fmt::Display for VariantDecimal4 { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - format_decimal!(f, self.integer, self.scale, i32) - } -} +impl_variant_decimal!(VariantDecimal4, i32); /// Represents an 8-byte decimal value in the Variant format. /// @@ -136,33 +273,13 @@ pub struct VariantDecimal8 { } impl VariantDecimal8 { - pub(crate) const MAX_PRECISION: u8 = 18; - pub(crate) const MAX_UNSCALED_VALUE: u64 = u64::pow(10, Self::MAX_PRECISION as u32) - 1; - - pub fn try_new(integer: i64, scale: u8) -> Result { - decimal_try_new!(integer, scale) - } - - /// Returns the underlying value of the decimal. - /// - /// For example, if the decimal is `123456.78`, this will return `12345678`. - pub fn integer(&self) -> i64 { - self.integer - } - - /// Returns the scale of the decimal (how many digits after the decimal point). - /// - /// For example, if the decimal is `123456.78`, this will return `2`. - pub fn scale(&self) -> u8 { - self.scale - } + /// Maximum number of significant digits (18 for 8-byte decimals) + pub const MAX_PRECISION: u8 = arrow_schema::DECIMAL64_MAX_PRECISION; + /// The largest unscaled value that fits in [`Self::MAX_PRECISION`] digits. + pub const MAX_UNSCALED_VALUE: u64 = u64::pow(10, Self::MAX_PRECISION as u32) - 1; } -impl fmt::Display for VariantDecimal8 { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - format_decimal!(f, self.integer, self.scale, i64) - } -} +impl_variant_decimal!(VariantDecimal8, i64); /// Represents an 16-byte decimal value in the Variant format. /// @@ -186,33 +303,13 @@ pub struct VariantDecimal16 { } impl VariantDecimal16 { - const MAX_PRECISION: u8 = 38; - const MAX_UNSCALED_VALUE: u128 = u128::pow(10, Self::MAX_PRECISION as u32) - 1; - - pub fn try_new(integer: i128, scale: u8) -> Result { - decimal_try_new!(integer, scale) - } - - /// Returns the underlying value of the decimal. - /// - /// For example, if the decimal is `12345678901234567.890`, this will return `12345678901234567890`. - pub fn integer(&self) -> i128 { - self.integer - } - - /// Returns the scale of the decimal (how many digits after the decimal point). - /// - /// For example, if the decimal is `12345678901234567.890`, this will return `3`. - pub fn scale(&self) -> u8 { - self.scale - } + /// Maximum number of significant digits (38 for 16-byte decimals) + pub const MAX_PRECISION: u8 = arrow_schema::DECIMAL128_MAX_PRECISION; + /// The largest unscaled value that fits in [`Self::MAX_PRECISION`] digits. + pub const MAX_UNSCALED_VALUE: u128 = u128::pow(10, Self::MAX_PRECISION as u32) - 1; } -impl fmt::Display for VariantDecimal16 { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - format_decimal!(f, self.integer, self.scale, i128) - } -} +impl_variant_decimal!(VariantDecimal16, i128); // Infallible conversion from a narrower decimal type to a wider one macro_rules! impl_from_decimal_for_decimal { From 9f83ab8e6e025b7f8bcbaa5f193a1b2f12e1aff5 Mon Sep 17 00:00:00 2001 From: Ryan Johnson Date: Mon, 6 Oct 2025 18:09:00 -0700 Subject: [PATCH 2/8] additional cleanup --- .../src/unshred_variant.rs | 53 ++++++--------- parquet-variant-compute/src/variant_array.rs | 16 +++-- parquet-variant/src/variant.rs | 6 +- parquet-variant/src/variant/decimal.rs | 66 +++++++------------ 4 files changed, 53 insertions(+), 88 deletions(-) diff --git a/parquet-variant-compute/src/unshred_variant.rs b/parquet-variant-compute/src/unshred_variant.rs index 44687af78480..6e310f493fc7 100644 --- a/parquet-variant-compute/src/unshred_variant.rs +++ b/parquet-variant-compute/src/unshred_variant.rs @@ -35,8 +35,7 @@ use chrono::{DateTime, Utc}; use indexmap::IndexMap; use parquet_variant::{ ObjectFieldBuilder, Variant, VariantBuilderExt, VariantDecimal4, VariantDecimal8, - VariantDecimal16, VariantDecimalType, VariantMetadata, is_valid_variant_decimal4, - is_valid_variant_decimal8, is_valid_variant_decimal16, + VariantDecimal16, VariantDecimalType, VariantMetadata, }; use std::marker::PhantomData; use uuid::Uuid; @@ -187,15 +186,15 @@ impl<'a> UnshredVariantRowBuilder<'a> { DataType::Int64 => primitive_builder!(PrimitiveInt64, as_primitive), DataType::Float32 => primitive_builder!(PrimitiveFloat32, as_primitive), DataType::Float64 => primitive_builder!(PrimitiveFloat64, as_primitive), - DataType::Decimal32(p, s) if is_valid_variant_decimal4(p, s) => Self::Decimal32( - DecimalUnshredRowBuilder::new(value, typed_value.as_primitive(), *s as _), - ), - DataType::Decimal64(p, s) if is_valid_variant_decimal8(p, s) => Self::Decimal64( - DecimalUnshredRowBuilder::new(value, typed_value.as_primitive(), *s as _), - ), - DataType::Decimal128(p, s) if is_valid_variant_decimal16(p, s) => Self::Decimal128( - DecimalUnshredRowBuilder::new(value, typed_value.as_primitive(), *s as _), - ), + DataType::Decimal32(p, s) if VariantDecimal4::is_valid_precision_and_scale(p, s) => { + Self::Decimal32(DecimalUnshredRowBuilder::new(value, typed_value, *s as _)) + } + DataType::Decimal64(p, s) if VariantDecimal8::is_valid_precision_and_scale(p, s) => { + Self::Decimal64(DecimalUnshredRowBuilder::new(value, typed_value, *s as _)) + } + DataType::Decimal128(p, s) if VariantDecimal16::is_valid_precision_and_scale(p, s) => { + Self::Decimal128(DecimalUnshredRowBuilder::new(value, typed_value, *s as _)) + } DataType::Decimal32(_, _) | DataType::Decimal64(_, _) | DataType::Decimal128(_, _) @@ -214,20 +213,12 @@ impl<'a> UnshredVariantRowBuilder<'a> { "Time64({time_unit}) is not a valid variant shredding type", ))); } - DataType::Timestamp(TimeUnit::Microsecond, timezone) => { - Self::TimestampMicrosecond(TimestampUnshredRowBuilder::new( - value, - typed_value.as_primitive(), - timezone.is_some(), - )) - } - DataType::Timestamp(TimeUnit::Nanosecond, timezone) => { - Self::TimestampNanosecond(TimestampUnshredRowBuilder::new( - value, - typed_value.as_primitive(), - timezone.is_some(), - )) - } + DataType::Timestamp(TimeUnit::Microsecond, timezone) => Self::TimestampMicrosecond( + TimestampUnshredRowBuilder::new(value, typed_value, timezone.is_some()), + ), + DataType::Timestamp(TimeUnit::Nanosecond, timezone) => Self::TimestampNanosecond( + TimestampUnshredRowBuilder::new(value, typed_value, timezone.is_some()), + ), DataType::Timestamp(time_unit, _) => { return Err(ArrowError::InvalidArgumentError(format!( "Timestamp({time_unit}) is not a valid variant shredding type", @@ -474,12 +465,12 @@ struct TimestampUnshredRowBuilder<'a, T: TimestampType> { impl<'a, T: TimestampType> TimestampUnshredRowBuilder<'a, T> { fn new( value: Option<&'a BinaryViewArray>, - typed_value: &'a PrimitiveArray, + typed_value: &'a dyn Array, has_timezone: bool, ) -> Self { Self { value, - typed_value, + typed_value: typed_value.as_primitive(), has_timezone, } } @@ -521,14 +512,10 @@ where A: DecimalType, V: VariantDecimalType, { - fn new( - value: Option<&'a BinaryViewArray>, - typed_value: &'a PrimitiveArray, - scale: i8, - ) -> Self { + fn new(value: Option<&'a BinaryViewArray>, typed_value: &'a dyn Array, scale: i8) -> Self { Self { value, - typed_value, + typed_value: typed_value.as_primitive(), scale, _phantom: PhantomData, } diff --git a/parquet-variant-compute/src/variant_array.rs b/parquet-variant-compute/src/variant_array.rs index 8aa173886568..522c5a7546b5 100644 --- a/parquet-variant-compute/src/variant_array.rs +++ b/parquet-variant-compute/src/variant_array.rs @@ -29,7 +29,7 @@ use arrow_schema::extension::ExtensionType; use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields, TimeUnit}; use chrono::DateTime; use parquet_variant::{ - Uuid, Variant, is_valid_variant_decimal4, is_valid_variant_decimal8, is_valid_variant_decimal16, + Uuid, Variant, VariantDecimal4, VariantDecimal8, VariantDecimal16, VariantDecimalType as _, }; use std::borrow::Cow; @@ -969,13 +969,17 @@ fn canonicalize_and_verify_data_type( // // NOTE: arrow-parquet reads widens 32- and 64-bit decimals to 128-bit, but the variant spec // requires using the narrowest decimal type for a given precision. Fix those up first. - Decimal64(p, s) | Decimal128(p, s) if is_valid_variant_decimal4(p, s) => { + Decimal64(p, s) | Decimal128(p, s) + if VariantDecimal4::is_valid_precision_and_scale(p, s) => + { Cow::Owned(Decimal32(*p, *s)) } - Decimal128(p, s) if is_valid_variant_decimal8(p, s) => Cow::Owned(Decimal64(*p, *s)), - Decimal32(p, s) if is_valid_variant_decimal4(p, s) => borrow!(), - Decimal64(p, s) if is_valid_variant_decimal8(p, s) => borrow!(), - Decimal128(p, s) if is_valid_variant_decimal16(p, s) => borrow!(), + Decimal128(p, s) if VariantDecimal8::is_valid_precision_and_scale(p, s) => { + Cow::Owned(Decimal64(*p, *s)) + } + Decimal32(p, s) if VariantDecimal4::is_valid_precision_and_scale(p, s) => borrow!(), + Decimal64(p, s) if VariantDecimal8::is_valid_precision_and_scale(p, s) => borrow!(), + Decimal128(p, s) if VariantDecimal16::is_valid_precision_and_scale(p, s) => borrow!(), Decimal32(..) | Decimal64(..) | Decimal128(..) | Decimal256(..) => fail!(), // Only micro and nano timestamps are allowed diff --git a/parquet-variant/src/variant.rs b/parquet-variant/src/variant.rs index 8a05dd890b61..c967ae4fec68 100644 --- a/parquet-variant/src/variant.rs +++ b/parquet-variant/src/variant.rs @@ -15,11 +15,7 @@ // specific language governing permissions and limitations // under the License. -pub use self::decimal::{ - is_valid_variant_decimal, is_valid_variant_decimal16, is_valid_variant_decimal4, - is_valid_variant_decimal8, VariantDecimal16, VariantDecimal4, VariantDecimal8, - VariantDecimalType, -}; +pub use self::decimal::{VariantDecimal16, VariantDecimal4, VariantDecimal8, VariantDecimalType}; pub use self::list::VariantList; pub use self::metadata::{VariantMetadata, EMPTY_VARIANT_METADATA, EMPTY_VARIANT_METADATA_BYTES}; pub use self::object::VariantObject; diff --git a/parquet-variant/src/variant/decimal.rs b/parquet-variant/src/variant/decimal.rs index ce90097ade05..c6fd9e07c905 100644 --- a/parquet-variant/src/variant/decimal.rs +++ b/parquet-variant/src/variant/decimal.rs @@ -17,50 +17,6 @@ use arrow_schema::ArrowError; use std::fmt; -/// True if the given precision and scale are valid for a variant decimal type with the given -/// maximum precision. -/// -/// NOTE: By a strict reading of the "decimal table" in the [variant spec], one might conclude that -/// each decimal type has both lower and upper bounds on precision (i.e. Decimal16 with precision 5 -/// is invalid because Decimal4 "covers" it). But the variant shredding integration tests -/// specifically expect such cases to succeed, so we only enforce the upper bound here. -/// -/// [shredding spec]: https://github.com/apache/parquet-format/blob/master/VariantEncoding.md#encoding-types -/// -/// # Example -/// ``` -/// # use parquet_variant::{is_valid_variant_decimal, VariantDecimal4}; -/// # -/// assert!(is_valid_variant_decimal(&5, &2, VariantDecimal4::MAX_PRECISION)); -/// assert!(!is_valid_variant_decimal(&10, &2, VariantDecimal4::MAX_PRECISION)); // too wide -/// assert!(!is_valid_variant_decimal(&5, &-1, VariantDecimal4::MAX_PRECISION)); // negative scale -/// assert!(!is_valid_variant_decimal(&5, &7, VariantDecimal4::MAX_PRECISION)); // scale too big -/// ``` -pub fn is_valid_variant_decimal(precision: &u8, scale: &i8, max_precision: u8) -> bool { - (1..=max_precision).contains(precision) && (0..=*precision as i8).contains(scale) -} - -/// True if the given precision and scale are valid for a variant Decimal4 (max precision 9). -/// -/// See [`is_valid_variant_decimal`] for details. -pub fn is_valid_variant_decimal4(precision: &u8, scale: &i8) -> bool { - is_valid_variant_decimal(precision, scale, VariantDecimal4::MAX_PRECISION) -} - -/// True if the given precision and scale are valid for a variant Decimal8 (max precision 18). -/// -/// See [`is_valid_variant_decimal`] for details. -pub fn is_valid_variant_decimal8(precision: &u8, scale: &i8) -> bool { - is_valid_variant_decimal(precision, scale, VariantDecimal8::MAX_PRECISION) -} - -/// True if the given precision and scale are valid for a variant Decimal16 (max precision 38). -/// -/// See [`is_valid_variant_decimal`] for details. -pub fn is_valid_variant_decimal16(precision: &u8, scale: &i8) -> bool { - is_valid_variant_decimal(precision, scale, VariantDecimal16::MAX_PRECISION) -} - /// Trait for variant decimal types, enabling generic code across Decimal4/8/16 /// /// This trait provides a common interface for the three variant decimal types, @@ -89,6 +45,28 @@ pub trait VariantDecimalType: Into> { /// Maximum number of significant digits this decimal type can represent (9, 18, or 38) const MAX_PRECISION: u8; + /// True if the given precision and scale are valid for this variant decimal type. + /// + /// NOTE: By a strict reading of the "decimal table" in the [variant spec], one might conclude that + /// each decimal type has both lower and upper bounds on precision (i.e. Decimal16 with precision 5 + /// is invalid because Decimal4 "covers" it). But the variant shredding integration tests + /// specifically expect such cases to succeed, so we only enforce the upper bound here. + /// + /// [shredding spec]: https://github.com/apache/parquet-format/blob/master/VariantEncoding.md#encoding-types + /// + /// # Example + /// ``` + /// # use parquet_variant::{VariantDecimal4, VariantDecimalType}; + /// # + /// assert!(VariantDecimal4::is_valid_precision_and_scale(&5, &2)); + /// assert!(!VariantDecimal4::is_valid_precision_and_scale(&10, &2)); // too wide + /// assert!(!VariantDecimal4::is_valid_precision_and_scale(&5, &-1)); // negative scale + /// assert!(!VariantDecimal4::is_valid_precision_and_scale(&5, &7)); // scale too big + /// ``` + fn is_valid_precision_and_scale(precision: &u8, scale: &i8) -> bool { + (1..=Self::MAX_PRECISION).contains(precision) && (0..=*precision as i8).contains(scale) + } + /// Creates a new decimal value from the given unscaled integer and scale, failing if the /// integer's width, or the requested scale, exceeds `MAX_PRECISION`. /// From bd7f3a728078ebc447def564608ee6a730c06c55 Mon Sep 17 00:00:00 2001 From: Ryan Johnson Date: Tue, 7 Oct 2025 12:13:45 -0700 Subject: [PATCH 3/8] better approach to MAX_UNSCALED_VALUE --- parquet-variant/src/variant/decimal.rs | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/parquet-variant/src/variant/decimal.rs b/parquet-variant/src/variant/decimal.rs index c6fd9e07c905..e7977b427c6b 100644 --- a/parquet-variant/src/variant/decimal.rs +++ b/parquet-variant/src/variant/decimal.rs @@ -44,6 +44,8 @@ pub trait VariantDecimalType: Into> { /// Maximum number of significant digits this decimal type can represent (9, 18, or 38) const MAX_PRECISION: u8; + /// The largest positive unscaled value that fits in [`Self::MAX_PRECISION`] digits. + const MAX_UNSCALED_VALUE: Self::Native; /// True if the given precision and scale are valid for this variant decimal type. /// @@ -108,8 +110,8 @@ pub trait VariantDecimalType: Into> { macro_rules! impl_variant_decimal { ($struct_name:ident, $native:ty) => { impl $struct_name { - /// Attempts to create a new instance of this decimal type, failing if the value or - /// scale is too large. + /// Attempts to create a new instance of this decimal type, failing if the value is too + /// wide or the scale is too large. pub fn try_new(integer: $native, scale: u8) -> Result { let max_precision = Self::MAX_PRECISION; if scale > max_precision { @@ -117,9 +119,7 @@ macro_rules! impl_variant_decimal { "Scale {scale} is larger than max precision {max_precision}", ))); } - - // Validate that the integer value fits within the decimal's maximum precision - if integer.unsigned_abs() > Self::MAX_UNSCALED_VALUE { + if !(-Self::MAX_UNSCALED_VALUE..=Self::MAX_UNSCALED_VALUE).contains(&integer) { return Err(ArrowError::InvalidArgumentError(format!( "{integer} is wider than max precision {max_precision}", ))); @@ -128,7 +128,7 @@ macro_rules! impl_variant_decimal { Ok(Self { integer, scale }) } - /// Returns the underlying value of the decimal. + /// Returns the unscaled integer value of the decimal. /// /// For example, if the decimal is `123.45`, this will return `12345`. pub fn integer(&self) -> $native { @@ -146,6 +146,7 @@ macro_rules! impl_variant_decimal { impl VariantDecimalType for $struct_name { type Native = $native; const MAX_PRECISION: u8 = Self::MAX_PRECISION; + const MAX_UNSCALED_VALUE: $native = <$native>::pow(10, Self::MAX_PRECISION as u32) - 1; fn try_new(integer: $native, scale: u8) -> Result { Self::try_new(integer, scale) @@ -223,8 +224,6 @@ pub struct VariantDecimal4 { impl VariantDecimal4 { /// Maximum number of significant digits (9 for 4-byte decimals) pub const MAX_PRECISION: u8 = arrow_schema::DECIMAL32_MAX_PRECISION; - /// The largest unscaled value that fits in [`Self::MAX_PRECISION`] digits. - pub const MAX_UNSCALED_VALUE: u32 = u32::pow(10, Self::MAX_PRECISION as u32) - 1; } impl_variant_decimal!(VariantDecimal4, i32); @@ -253,8 +252,6 @@ pub struct VariantDecimal8 { impl VariantDecimal8 { /// Maximum number of significant digits (18 for 8-byte decimals) pub const MAX_PRECISION: u8 = arrow_schema::DECIMAL64_MAX_PRECISION; - /// The largest unscaled value that fits in [`Self::MAX_PRECISION`] digits. - pub const MAX_UNSCALED_VALUE: u64 = u64::pow(10, Self::MAX_PRECISION as u32) - 1; } impl_variant_decimal!(VariantDecimal8, i64); @@ -283,8 +280,6 @@ pub struct VariantDecimal16 { impl VariantDecimal16 { /// Maximum number of significant digits (38 for 16-byte decimals) pub const MAX_PRECISION: u8 = arrow_schema::DECIMAL128_MAX_PRECISION; - /// The largest unscaled value that fits in [`Self::MAX_PRECISION`] digits. - pub const MAX_UNSCALED_VALUE: u128 = u128::pow(10, Self::MAX_PRECISION as u32) - 1; } impl_variant_decimal!(VariantDecimal16, i128); From e6c88b95012266986458c4d05e7d73dcbb1e80a1 Mon Sep 17 00:00:00 2001 From: Ryan Johnson Date: Fri, 10 Oct 2025 07:16:32 -0700 Subject: [PATCH 4/8] simpler macro --- parquet-variant-compute/src/arrow_to_variant.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/parquet-variant-compute/src/arrow_to_variant.rs b/parquet-variant-compute/src/arrow_to_variant.rs index 72618533c57e..fb2240458d54 100644 --- a/parquet-variant-compute/src/arrow_to_variant.rs +++ b/parquet-variant-compute/src/arrow_to_variant.rs @@ -327,20 +327,20 @@ pub(crate) fn make_arrow_to_variant_row_builder<'a>( // worth the trouble, tho, because it makes for some pretty bulky and unwieldy macro expansions. macro_rules! define_row_builder { ( - struct $name:ident<$lifetime:lifetime $(, $generic:ident: $bound:path )*> + struct $name:ident<$lifetime:lifetime $(, $generic:ident $( : $bound:path )? )*> $( where $where_path:path: $where_bound:path $(,)? )? $({ $($field:ident: $field_type:ty),+ $(,)? })?, |$array_param:ident| -> $array_type:ty { $init_expr:expr } $(, |$value:ident| $(-> Option<$option_ty:ty>)? $value_transform:expr)? ) => { - pub(crate) struct $name<$lifetime $(, $generic: $bound )*> + pub(crate) struct $name<$lifetime $(, $generic: $( $bound )? )*> $( where $where_path: $where_bound )? { array: &$lifetime $array_type, $( $( $field: $field_type, )+ )? } - impl<$lifetime $(, $generic: $bound )*> $name<$lifetime $(, $generic)*> + impl<$lifetime $(, $generic: $( $bound )? )*> $name<$lifetime $(, $generic)*> $( where $where_path: $where_bound )? { pub(crate) fn new($array_param: &$lifetime dyn Array $(, $( $field: $field_type ),+ )?) -> Self { @@ -408,16 +408,16 @@ define_row_builder!( ); define_row_builder!( - struct DecimalArrowToVariantBuilder<'a, A: DecimalType, V: VariantDecimalType> + struct DecimalArrowToVariantBuilder<'a, A: DecimalType, V> where - V::Native: From, + V: VariantDecimalType { options: &'a CastOptions, scale: i8, _phantom: PhantomData, }, |array| -> PrimitiveArray { array.as_primitive() }, - |value| -> Option<_> { V::try_new_with_signed_scale(value.into(), *scale).ok() } + |value| -> Option<_> { V::try_new_with_signed_scale(value, *scale).ok() } ); // Decimal256 needs a two-stage conversion via i128 From 3f66de4aa8902059f9ea20e417fb6768531a1e97 Mon Sep 17 00:00:00 2001 From: Ryan Johnson Date: Fri, 10 Oct 2025 07:48:09 -0700 Subject: [PATCH 5/8] simpler phantoms --- .../src/arrow_to_variant.rs | 39 +++++++------------ 1 file changed, 15 insertions(+), 24 deletions(-) diff --git a/parquet-variant-compute/src/arrow_to_variant.rs b/parquet-variant-compute/src/arrow_to_variant.rs index fb2240458d54..45622b678cef 100644 --- a/parquet-variant-compute/src/arrow_to_variant.rs +++ b/parquet-variant-compute/src/arrow_to_variant.rs @@ -33,7 +33,6 @@ use parquet_variant::{ VariantDecimal16, VariantDecimalType, }; use std::collections::HashMap; -use std::marker::PhantomData; use std::ops::Range; // ============================================================================ @@ -171,24 +170,15 @@ pub(crate) fn make_arrow_to_variant_row_builder<'a>( DataType::Float16 => PrimitiveFloat16(PrimitiveArrowToVariantBuilder::new(array)), DataType::Float32 => PrimitiveFloat32(PrimitiveArrowToVariantBuilder::new(array)), DataType::Float64 => PrimitiveFloat64(PrimitiveArrowToVariantBuilder::new(array)), - DataType::Decimal32(_, s) => Decimal32(DecimalArrowToVariantBuilder::new( - array, - options, - *s, - PhantomData, - )), - DataType::Decimal64(_, s) => Decimal64(DecimalArrowToVariantBuilder::new( - array, - options, - *s, - PhantomData, - )), - DataType::Decimal128(_, s) => Decimal128(DecimalArrowToVariantBuilder::new( - array, - options, - *s, - PhantomData, - )), + DataType::Decimal32(_, s) => { + Decimal32(DecimalArrowToVariantBuilder::new(array, options, *s)) + } + DataType::Decimal64(_, s) => { + Decimal64(DecimalArrowToVariantBuilder::new(array, options, *s)) + } + DataType::Decimal128(_, s) => { + Decimal128(DecimalArrowToVariantBuilder::new(array, options, *s)) + } DataType::Decimal256(_, s) => { Decimal256(Decimal256ArrowToVariantBuilder::new(array, options, *s)) } @@ -329,24 +319,26 @@ macro_rules! define_row_builder { ( struct $name:ident<$lifetime:lifetime $(, $generic:ident $( : $bound:path )? )*> $( where $where_path:path: $where_bound:path $(,)? )? - $({ $($field:ident: $field_type:ty),+ $(,)? })?, + $({ $( $field:ident: $field_type:ty ),+ $(,)? })?, |$array_param:ident| -> $array_type:ty { $init_expr:expr } - $(, |$value:ident| $(-> Option<$option_ty:ty>)? $value_transform:expr)? + $(, |$value:ident| $(-> Option<$option_ty:ty>)? $value_transform:expr )? ) => { pub(crate) struct $name<$lifetime $(, $generic: $( $bound )? )*> $( where $where_path: $where_bound )? { array: &$lifetime $array_type, $( $( $field: $field_type, )+ )? + _phantom: std::marker::PhantomData<($( $generic, )*)>, // capture all type params } impl<$lifetime $(, $generic: $( $bound )? )*> $name<$lifetime $(, $generic)*> $( where $where_path: $where_bound )? { - pub(crate) fn new($array_param: &$lifetime dyn Array $(, $( $field: $field_type ),+ )?) -> Self { + pub(crate) fn new($array_param: &$lifetime dyn Array $( $(, $field: $field_type )+ )?) -> Self { Self { array: $init_expr, $( $( $field, )+ )? + _phantom: std::marker::PhantomData, } } @@ -410,11 +402,10 @@ define_row_builder!( define_row_builder!( struct DecimalArrowToVariantBuilder<'a, A: DecimalType, V> where - V: VariantDecimalType + V: VariantDecimalType, { options: &'a CastOptions, scale: i8, - _phantom: PhantomData, }, |array| -> PrimitiveArray { array.as_primitive() }, |value| -> Option<_> { V::try_new_with_signed_scale(value, *scale).ok() } From be9e2c9d2f48b4ca05cf192d4c288fe787be667a Mon Sep 17 00:00:00 2001 From: Ryan Johnson Date: Fri, 10 Oct 2025 08:00:30 -0700 Subject: [PATCH 6/8] another tweak --- parquet-variant-compute/src/unshred_variant.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/parquet-variant-compute/src/unshred_variant.rs b/parquet-variant-compute/src/unshred_variant.rs index 6e310f493fc7..17247ae02dad 100644 --- a/parquet-variant-compute/src/unshred_variant.rs +++ b/parquet-variant-compute/src/unshred_variant.rs @@ -496,9 +496,8 @@ impl<'a, T: TimestampType> TimestampUnshredRowBuilder<'a, T> { } /// Generic builder for decimal unshredding -struct DecimalUnshredRowBuilder<'a, A, V> +struct DecimalUnshredRowBuilder<'a, A: DecimalType, V> where - A: DecimalType, V: VariantDecimalType, { value: Option<&'a BinaryViewArray>, @@ -507,9 +506,8 @@ where _phantom: PhantomData, } -impl<'a, A, V> DecimalUnshredRowBuilder<'a, A, V> +impl<'a, A: DecimalType, V> DecimalUnshredRowBuilder<'a, A, V> where - A: DecimalType, V: VariantDecimalType, { fn new(value: Option<&'a BinaryViewArray>, typed_value: &'a dyn Array, scale: i8) -> Self { From 0b9c5df16a9699722aa8397a5ab8402dc20f28d5 Mon Sep 17 00:00:00 2001 From: Ryan Johnson Date: Fri, 10 Oct 2025 09:15:40 -0700 Subject: [PATCH 7/8] minimize diff --- parquet-variant-compute/src/arrow_to_variant.rs | 16 ++++++++-------- parquet-variant-compute/src/unshred_variant.rs | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/parquet-variant-compute/src/arrow_to_variant.rs b/parquet-variant-compute/src/arrow_to_variant.rs index 45622b678cef..5e01aba3c1a1 100644 --- a/parquet-variant-compute/src/arrow_to_variant.rs +++ b/parquet-variant-compute/src/arrow_to_variant.rs @@ -170,17 +170,17 @@ pub(crate) fn make_arrow_to_variant_row_builder<'a>( DataType::Float16 => PrimitiveFloat16(PrimitiveArrowToVariantBuilder::new(array)), DataType::Float32 => PrimitiveFloat32(PrimitiveArrowToVariantBuilder::new(array)), DataType::Float64 => PrimitiveFloat64(PrimitiveArrowToVariantBuilder::new(array)), - DataType::Decimal32(_, s) => { - Decimal32(DecimalArrowToVariantBuilder::new(array, options, *s)) + DataType::Decimal32(_, scale) => { + Decimal32(DecimalArrowToVariantBuilder::new(array, options, *scale)) } - DataType::Decimal64(_, s) => { - Decimal64(DecimalArrowToVariantBuilder::new(array, options, *s)) + DataType::Decimal64(_, scale) => { + Decimal64(DecimalArrowToVariantBuilder::new(array, options, *scale)) } - DataType::Decimal128(_, s) => { - Decimal128(DecimalArrowToVariantBuilder::new(array, options, *s)) + DataType::Decimal128(_, scale) => { + Decimal128(DecimalArrowToVariantBuilder::new(array, options, *scale)) } - DataType::Decimal256(_, s) => { - Decimal256(Decimal256ArrowToVariantBuilder::new(array, options, *s)) + DataType::Decimal256(_, scale) => { + Decimal256(Decimal256ArrowToVariantBuilder::new(array, options, *scale)) } DataType::Timestamp(time_unit, time_zone) => { match time_unit { diff --git a/parquet-variant-compute/src/unshred_variant.rs b/parquet-variant-compute/src/unshred_variant.rs index 17247ae02dad..c20bb697903c 100644 --- a/parquet-variant-compute/src/unshred_variant.rs +++ b/parquet-variant-compute/src/unshred_variant.rs @@ -528,8 +528,8 @@ where handle_unshredded_case!(self, builder, metadata, index, false); let raw = self.typed_value.value(index); - let value = V::try_new_with_signed_scale(raw, self.scale)?; - builder.append_value(value); + let variant = V::try_new_with_signed_scale(raw, self.scale)?; + builder.append_value(variant); Ok(()) } } From a731191d417aaabacaf75628cdd508fd011a551d Mon Sep 17 00:00:00 2001 From: Ryan Johnson Date: Fri, 10 Oct 2025 09:19:31 -0700 Subject: [PATCH 8/8] tiny tweak --- parquet-variant/src/variant/decimal.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parquet-variant/src/variant/decimal.rs b/parquet-variant/src/variant/decimal.rs index e7977b427c6b..0196cdc72aea 100644 --- a/parquet-variant/src/variant/decimal.rs +++ b/parquet-variant/src/variant/decimal.rs @@ -154,7 +154,7 @@ macro_rules! impl_variant_decimal { fn try_new_with_signed_scale(integer: $native, scale: i8) -> Result { let (integer, scale) = if scale < 0 { - let multiplier = <$native>::checked_pow(10, (-scale) as u32); + let multiplier = <$native>::checked_pow(10, -scale as u32); let Some(rescaled) = multiplier.and_then(|m| integer.checked_mul(m)) else { return Err(ArrowError::InvalidArgumentError(format!( "Overflow when rescaling {integer} with scale {scale}"