diff --git a/parquet-variant-compute/src/arrow_to_variant.rs b/parquet-variant-compute/src/arrow_to_variant.rs index fe0c52109052..5e01aba3c1a1 100644 --- a/parquet-variant-compute/src/arrow_to_variant.rs +++ b/parquet-variant-compute/src/arrow_to_variant.rs @@ -15,25 +15,22 @@ // specific language governing permissions and limitations // under the License. -use crate::type_conversion::{CastOptions, decimal_to_variant_decimal}; +use crate::type_conversion::CastOptions; use arrow::array::{ Array, AsArray, FixedSizeListArray, GenericBinaryArray, GenericListArray, GenericListViewArray, GenericStringArray, OffsetSizeTrait, PrimitiveArray, }; use arrow::compute::kernels::cast; use arrow::datatypes::{ - ArrowNativeType, ArrowPrimitiveType, ArrowTemporalType, ArrowTimestampType, Date32Type, - Date64Type, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type, - RunEndIndexType, Time32MillisecondType, Time32SecondType, Time64MicrosecondType, - Time64NanosecondType, TimestampMicrosecondType, TimestampMillisecondType, - TimestampNanosecondType, TimestampSecondType, UInt8Type, UInt16Type, UInt32Type, UInt64Type, + self as datatypes, ArrowNativeType, ArrowPrimitiveType, ArrowTemporalType, ArrowTimestampType, + DecimalType, RunEndIndexType, }; use arrow::temporal_conversions::{as_date, as_datetime, as_time}; use arrow_schema::{ArrowError, DataType, TimeUnit}; use chrono::{DateTime, TimeZone, Utc}; use parquet_variant::{ ObjectFieldBuilder, Variant, VariantBuilderExt, VariantDecimal4, VariantDecimal8, - VariantDecimal16, + VariantDecimal16, VariantDecimalType, }; use std::collections::HashMap; use std::ops::Range; @@ -46,31 +43,31 @@ use std::ops::Range; pub(crate) enum ArrowToVariantRowBuilder<'a> { Null(NullArrowToVariantBuilder), Boolean(BooleanArrowToVariantBuilder<'a>), - PrimitiveInt8(PrimitiveArrowToVariantBuilder<'a, Int8Type>), - PrimitiveInt16(PrimitiveArrowToVariantBuilder<'a, Int16Type>), - PrimitiveInt32(PrimitiveArrowToVariantBuilder<'a, Int32Type>), - PrimitiveInt64(PrimitiveArrowToVariantBuilder<'a, Int64Type>), - PrimitiveUInt8(PrimitiveArrowToVariantBuilder<'a, UInt8Type>), - PrimitiveUInt16(PrimitiveArrowToVariantBuilder<'a, UInt16Type>), - PrimitiveUInt32(PrimitiveArrowToVariantBuilder<'a, UInt32Type>), - PrimitiveUInt64(PrimitiveArrowToVariantBuilder<'a, UInt64Type>), - PrimitiveFloat16(PrimitiveArrowToVariantBuilder<'a, Float16Type>), - PrimitiveFloat32(PrimitiveArrowToVariantBuilder<'a, Float32Type>), - PrimitiveFloat64(PrimitiveArrowToVariantBuilder<'a, Float64Type>), - Decimal32(Decimal32ArrowToVariantBuilder<'a>), - Decimal64(Decimal64ArrowToVariantBuilder<'a>), - Decimal128(Decimal128ArrowToVariantBuilder<'a>), + PrimitiveInt8(PrimitiveArrowToVariantBuilder<'a, datatypes::Int8Type>), + PrimitiveInt16(PrimitiveArrowToVariantBuilder<'a, datatypes::Int16Type>), + PrimitiveInt32(PrimitiveArrowToVariantBuilder<'a, datatypes::Int32Type>), + PrimitiveInt64(PrimitiveArrowToVariantBuilder<'a, datatypes::Int64Type>), + PrimitiveUInt8(PrimitiveArrowToVariantBuilder<'a, datatypes::UInt8Type>), + PrimitiveUInt16(PrimitiveArrowToVariantBuilder<'a, datatypes::UInt16Type>), + PrimitiveUInt32(PrimitiveArrowToVariantBuilder<'a, datatypes::UInt32Type>), + PrimitiveUInt64(PrimitiveArrowToVariantBuilder<'a, datatypes::UInt64Type>), + PrimitiveFloat16(PrimitiveArrowToVariantBuilder<'a, datatypes::Float16Type>), + PrimitiveFloat32(PrimitiveArrowToVariantBuilder<'a, datatypes::Float32Type>), + PrimitiveFloat64(PrimitiveArrowToVariantBuilder<'a, datatypes::Float64Type>), + Decimal32(DecimalArrowToVariantBuilder<'a, datatypes::Decimal32Type, VariantDecimal4>), + Decimal64(DecimalArrowToVariantBuilder<'a, datatypes::Decimal64Type, VariantDecimal8>), + Decimal128(DecimalArrowToVariantBuilder<'a, datatypes::Decimal128Type, VariantDecimal16>), Decimal256(Decimal256ArrowToVariantBuilder<'a>), - TimestampSecond(TimestampArrowToVariantBuilder<'a, TimestampSecondType>), - TimestampMillisecond(TimestampArrowToVariantBuilder<'a, TimestampMillisecondType>), - TimestampMicrosecond(TimestampArrowToVariantBuilder<'a, TimestampMicrosecondType>), - TimestampNanosecond(TimestampArrowToVariantBuilder<'a, TimestampNanosecondType>), - Date32(DateArrowToVariantBuilder<'a, Date32Type>), - Date64(DateArrowToVariantBuilder<'a, Date64Type>), - Time32Second(TimeArrowToVariantBuilder<'a, Time32SecondType>), - Time32Millisecond(TimeArrowToVariantBuilder<'a, Time32MillisecondType>), - Time64Microsecond(TimeArrowToVariantBuilder<'a, Time64MicrosecondType>), - Time64Nanosecond(TimeArrowToVariantBuilder<'a, Time64NanosecondType>), + TimestampSecond(TimestampArrowToVariantBuilder<'a, datatypes::TimestampSecondType>), + TimestampMillisecond(TimestampArrowToVariantBuilder<'a, datatypes::TimestampMillisecondType>), + TimestampMicrosecond(TimestampArrowToVariantBuilder<'a, datatypes::TimestampMicrosecondType>), + TimestampNanosecond(TimestampArrowToVariantBuilder<'a, datatypes::TimestampNanosecondType>), + Date32(DateArrowToVariantBuilder<'a, datatypes::Date32Type>), + Date64(DateArrowToVariantBuilder<'a, datatypes::Date64Type>), + Time32Second(TimeArrowToVariantBuilder<'a, datatypes::Time32SecondType>), + Time32Millisecond(TimeArrowToVariantBuilder<'a, datatypes::Time32MillisecondType>), + Time64Microsecond(TimeArrowToVariantBuilder<'a, datatypes::Time64MicrosecondType>), + Time64Nanosecond(TimeArrowToVariantBuilder<'a, datatypes::Time64NanosecondType>), Binary(BinaryArrowToVariantBuilder<'a, i32>), LargeBinary(BinaryArrowToVariantBuilder<'a, i64>), BinaryView(BinaryViewArrowToVariantBuilder<'a>), @@ -87,9 +84,9 @@ pub(crate) enum ArrowToVariantRowBuilder<'a> { Map(MapArrowToVariantBuilder<'a>), Union(UnionArrowToVariantBuilder<'a>), Dictionary(DictionaryArrowToVariantBuilder<'a>), - RunEndEncodedInt16(RunEndEncodedArrowToVariantBuilder<'a, Int16Type>), - RunEndEncodedInt32(RunEndEncodedArrowToVariantBuilder<'a, Int32Type>), - RunEndEncodedInt64(RunEndEncodedArrowToVariantBuilder<'a, Int64Type>), + RunEndEncodedInt16(RunEndEncodedArrowToVariantBuilder<'a, datatypes::Int16Type>), + RunEndEncodedInt32(RunEndEncodedArrowToVariantBuilder<'a, datatypes::Int32Type>), + RunEndEncodedInt64(RunEndEncodedArrowToVariantBuilder<'a, datatypes::Int64Type>), } impl<'a> ArrowToVariantRowBuilder<'a> { @@ -174,13 +171,13 @@ pub(crate) fn make_arrow_to_variant_row_builder<'a>( DataType::Float32 => PrimitiveFloat32(PrimitiveArrowToVariantBuilder::new(array)), DataType::Float64 => PrimitiveFloat64(PrimitiveArrowToVariantBuilder::new(array)), DataType::Decimal32(_, scale) => { - Decimal32(Decimal32ArrowToVariantBuilder::new(array, options, *scale)) + Decimal32(DecimalArrowToVariantBuilder::new(array, options, *scale)) } DataType::Decimal64(_, scale) => { - Decimal64(Decimal64ArrowToVariantBuilder::new(array, options, *scale)) + Decimal64(DecimalArrowToVariantBuilder::new(array, options, *scale)) } DataType::Decimal128(_, scale) => { - Decimal128(Decimal128ArrowToVariantBuilder::new(array, options, *scale)) + Decimal128(DecimalArrowToVariantBuilder::new(array, options, *scale)) } DataType::Decimal256(_, scale) => { Decimal256(Decimal256ArrowToVariantBuilder::new(array, options, *scale)) @@ -320,26 +317,28 @@ pub(crate) fn make_arrow_to_variant_row_builder<'a>( // worth the trouble, tho, because it makes for some pretty bulky and unwieldy macro expansions. macro_rules! define_row_builder { ( - struct $name:ident<$lifetime:lifetime $(, $generic:ident: $bound:path )?> + struct $name:ident<$lifetime:lifetime $(, $generic:ident $( : $bound:path )? )*> $( where $where_path:path: $where_bound:path $(,)? )? - $({ $($field:ident: $field_type:ty),+ $(,)? })?, + $({ $( $field:ident: $field_type:ty ),+ $(,)? })?, |$array_param:ident| -> $array_type:ty { $init_expr:expr } - $(, |$value:ident| $(-> Option<$option_ty:ty>)? $value_transform:expr)? + $(, |$value:ident| $(-> Option<$option_ty:ty>)? $value_transform:expr )? ) => { - pub(crate) struct $name<$lifetime $(, $generic: $bound )?> + pub(crate) struct $name<$lifetime $(, $generic: $( $bound )? )*> $( where $where_path: $where_bound )? { array: &$lifetime $array_type, $( $( $field: $field_type, )+ )? + _phantom: std::marker::PhantomData<($( $generic, )*)>, // capture all type params } - impl<$lifetime $(, $generic: $bound+ )?> $name<$lifetime $(, $generic)?> + impl<$lifetime $(, $generic: $( $bound )? )*> $name<$lifetime $(, $generic)*> $( where $where_path: $where_bound )? { - pub(crate) fn new($array_param: &$lifetime dyn Array $(, $( $field: $field_type ),+ )?) -> Self { + pub(crate) fn new($array_param: &$lifetime dyn Array $( $(, $field: $field_type )+ )?) -> Self { Self { array: $init_expr, $( $( $field, )+ )? + _phantom: std::marker::PhantomData, } } @@ -401,32 +400,18 @@ define_row_builder!( ); define_row_builder!( - struct Decimal32ArrowToVariantBuilder<'a> { - options: &'a CastOptions, - scale: i8, - }, - |array| -> arrow::array::Decimal32Array { array.as_primitive() }, - |value| -> Option<_> { decimal_to_variant_decimal!(value, scale, i32, VariantDecimal4) } -); - -define_row_builder!( - struct Decimal64ArrowToVariantBuilder<'a> { - options: &'a CastOptions, - scale: i8, - }, - |array| -> arrow::array::Decimal64Array { array.as_primitive() }, - |value| -> Option<_> { decimal_to_variant_decimal!(value, scale, i64, VariantDecimal8) } -); - -define_row_builder!( - struct Decimal128ArrowToVariantBuilder<'a> { + struct DecimalArrowToVariantBuilder<'a, A: DecimalType, V> + where + V: VariantDecimalType, + { options: &'a CastOptions, scale: i8, }, - |array| -> arrow::array::Decimal128Array { array.as_primitive() }, - |value| -> Option<_> { decimal_to_variant_decimal!(value, scale, i128, VariantDecimal16) } + |array| -> PrimitiveArray { array.as_primitive() }, + |value| -> Option<_> { V::try_new_with_signed_scale(value, *scale).ok() } ); +// Decimal256 needs a two-stage conversion via i128 define_row_builder!( struct Decimal256ArrowToVariantBuilder<'a> { options: &'a CastOptions, @@ -434,10 +419,8 @@ define_row_builder!( }, |array| -> arrow::array::Decimal256Array { array.as_primitive() }, |value| -> Option<_> { - // Decimal256 needs special handling - convert to i128 if possible - value.to_i128().and_then(|i128_val| { - decimal_to_variant_decimal!(i128_val, scale, i128, VariantDecimal16) - }) + let value = value.to_i128(); + value.and_then(|v| VariantDecimal16::try_new_with_signed_scale(v, *scale).ok()) } ); @@ -911,6 +894,7 @@ mod tests { use super::*; use crate::{VariantArray, VariantArrayBuilder}; use arrow::array::{ArrayRef, BooleanArray, Int32Array, StringArray}; + use arrow::datatypes::Int32Type; use std::sync::Arc; /// Builds a VariantArray from an Arrow array using the row builder. diff --git a/parquet-variant-compute/src/type_conversion.rs b/parquet-variant-compute/src/type_conversion.rs index 5afebb1bfa6b..7851ccc735db 100644 --- a/parquet-variant-compute/src/type_conversion.rs +++ b/parquet-variant-compute/src/type_conversion.rs @@ -150,22 +150,3 @@ macro_rules! primitive_conversion_single_value { }}; } pub(crate) use primitive_conversion_single_value; - -/// Convert a decimal value to a `VariantDecimal` -macro_rules! decimal_to_variant_decimal { - ($v:ident, $scale:expr, $value_type:ty, $variant_type:ty) => {{ - let (v, scale) = if *$scale < 0 { - // For negative scale, we need to multiply the value by 10^|scale| - // For example: 123 with scale -2 becomes 12300 with scale 0 - let multiplier = <$value_type>::pow(10, (-*$scale) as u32); - (<$value_type>::checked_mul($v, multiplier), 0u8) - } else { - (Some($v), *$scale as u8) - }; - - // Return an Option to allow callers to decide whether to error (strict) - // or append null (non-strict) on conversion failure - v.and_then(|v| <$variant_type>::try_new(v, scale).ok()) - }}; -} -pub(crate) use decimal_to_variant_decimal; diff --git a/parquet-variant-compute/src/unshred_variant.rs b/parquet-variant-compute/src/unshred_variant.rs index 64eaa46ed06b..c20bb697903c 100644 --- a/parquet-variant-compute/src/unshred_variant.rs +++ b/parquet-variant-compute/src/unshred_variant.rs @@ -35,8 +35,9 @@ use chrono::{DateTime, Utc}; use indexmap::IndexMap; use parquet_variant::{ ObjectFieldBuilder, Variant, VariantBuilderExt, VariantDecimal4, VariantDecimal8, - VariantDecimal16, VariantMetadata, + VariantDecimal16, VariantDecimalType, VariantMetadata, }; +use std::marker::PhantomData; use uuid::Uuid; /// Removes all (nested) typed_value columns from a VariantArray by converting them back to binary @@ -95,9 +96,9 @@ enum UnshredVariantRowBuilder<'a> { PrimitiveInt64(UnshredPrimitiveRowBuilder<'a, PrimitiveArray>), PrimitiveFloat32(UnshredPrimitiveRowBuilder<'a, PrimitiveArray>), PrimitiveFloat64(UnshredPrimitiveRowBuilder<'a, PrimitiveArray>), - Decimal32(DecimalUnshredRowBuilder<'a, Decimal32Spec>), - Decimal64(DecimalUnshredRowBuilder<'a, Decimal64Spec>), - Decimal128(DecimalUnshredRowBuilder<'a, Decimal128Spec>), + Decimal32(DecimalUnshredRowBuilder<'a, Decimal32Type, VariantDecimal4>), + Decimal64(DecimalUnshredRowBuilder<'a, Decimal64Type, VariantDecimal8>), + Decimal128(DecimalUnshredRowBuilder<'a, Decimal128Type, VariantDecimal16>), PrimitiveDate32(UnshredPrimitiveRowBuilder<'a, PrimitiveArray>), PrimitiveTime64(UnshredPrimitiveRowBuilder<'a, PrimitiveArray>), TimestampMicrosecond(TimestampUnshredRowBuilder<'a, TimestampMicrosecondType>), @@ -185,25 +186,23 @@ impl<'a> UnshredVariantRowBuilder<'a> { DataType::Int64 => primitive_builder!(PrimitiveInt64, as_primitive), DataType::Float32 => primitive_builder!(PrimitiveFloat32, as_primitive), DataType::Float64 => primitive_builder!(PrimitiveFloat64, as_primitive), - DataType::Decimal32(_, scale) => Self::Decimal32(DecimalUnshredRowBuilder::new( - value, - typed_value.as_primitive(), - *scale, - )), - DataType::Decimal64(_, scale) => Self::Decimal64(DecimalUnshredRowBuilder::new( - value, - typed_value.as_primitive(), - *scale, - )), - DataType::Decimal128(_, scale) => Self::Decimal128(DecimalUnshredRowBuilder::new( - value, - typed_value.as_primitive(), - *scale, - )), - DataType::Decimal256(_, _) => { - return Err(ArrowError::InvalidArgumentError( - "Decimal256 is not a valid variant shredding type".to_string(), - )); + DataType::Decimal32(p, s) if VariantDecimal4::is_valid_precision_and_scale(p, s) => { + Self::Decimal32(DecimalUnshredRowBuilder::new(value, typed_value, *s as _)) + } + DataType::Decimal64(p, s) if VariantDecimal8::is_valid_precision_and_scale(p, s) => { + Self::Decimal64(DecimalUnshredRowBuilder::new(value, typed_value, *s as _)) + } + DataType::Decimal128(p, s) if VariantDecimal16::is_valid_precision_and_scale(p, s) => { + Self::Decimal128(DecimalUnshredRowBuilder::new(value, typed_value, *s as _)) + } + DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) => { + return Err(ArrowError::InvalidArgumentError(format!( + "{} is not a valid variant shredding type", + typed_value.data_type() + ))); } DataType::Date32 => primitive_builder!(PrimitiveDate32, as_primitive), DataType::Time64(TimeUnit::Microsecond) => { @@ -214,20 +213,12 @@ impl<'a> UnshredVariantRowBuilder<'a> { "Time64({time_unit}) is not a valid variant shredding type", ))); } - DataType::Timestamp(TimeUnit::Microsecond, timezone) => { - Self::TimestampMicrosecond(TimestampUnshredRowBuilder::new( - value, - typed_value.as_primitive(), - timezone.is_some(), - )) - } - DataType::Timestamp(TimeUnit::Nanosecond, timezone) => { - Self::TimestampNanosecond(TimestampUnshredRowBuilder::new( - value, - typed_value.as_primitive(), - timezone.is_some(), - )) - } + DataType::Timestamp(TimeUnit::Microsecond, timezone) => Self::TimestampMicrosecond( + TimestampUnshredRowBuilder::new(value, typed_value, timezone.is_some()), + ), + DataType::Timestamp(TimeUnit::Nanosecond, timezone) => Self::TimestampNanosecond( + TimestampUnshredRowBuilder::new(value, typed_value, timezone.is_some()), + ), DataType::Timestamp(time_unit, _) => { return Err(ArrowError::InvalidArgumentError(format!( "Timestamp({time_unit}) is not a valid variant shredding type", @@ -474,12 +465,12 @@ struct TimestampUnshredRowBuilder<'a, T: TimestampType> { impl<'a, T: TimestampType> TimestampUnshredRowBuilder<'a, T> { fn new( value: Option<&'a BinaryViewArray>, - typed_value: &'a PrimitiveArray, + typed_value: &'a dyn Array, has_timezone: bool, ) -> Self { Self { value, - typed_value, + typed_value: typed_value.as_primitive(), has_timezone, } } @@ -504,78 +495,27 @@ impl<'a, T: TimestampType> TimestampUnshredRowBuilder<'a, T> { } } -/// Trait to unify decimal unshredding across Decimal32/64/128 types -trait DecimalSpec { - type Arrow: ArrowPrimitiveType + DecimalType; - - fn into_variant( - raw: ::Native, - scale: i8, - ) -> Result>; -} - -/// Spec for Decimal32 -> VariantDecimal4 -struct Decimal32Spec; - -impl DecimalSpec for Decimal32Spec { - type Arrow = Decimal32Type; - - fn into_variant(raw: i32, scale: i8) -> Result> { - let scale = - u8::try_from(scale).map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?; - let value = VariantDecimal4::try_new(raw, scale) - .map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?; - Ok(value.into()) - } -} - -/// Spec for Decimal64 -> VariantDecimal8 -struct Decimal64Spec; - -impl DecimalSpec for Decimal64Spec { - type Arrow = Decimal64Type; - - fn into_variant(raw: i64, scale: i8) -> Result> { - let scale = - u8::try_from(scale).map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?; - let value = VariantDecimal8::try_new(raw, scale) - .map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?; - Ok(value.into()) - } -} - -/// Spec for Decimal128 -> VariantDecimal16 -struct Decimal128Spec; - -impl DecimalSpec for Decimal128Spec { - type Arrow = Decimal128Type; - - fn into_variant(raw: i128, scale: i8) -> Result> { - let scale = - u8::try_from(scale).map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?; - let value = VariantDecimal16::try_new(raw, scale) - .map_err(|e| ArrowError::InvalidArgumentError(e.to_string()))?; - Ok(value.into()) - } -} - -/// Generic builder for decimal unshredding that caches scale -struct DecimalUnshredRowBuilder<'a, S: DecimalSpec> { +/// Generic builder for decimal unshredding +struct DecimalUnshredRowBuilder<'a, A: DecimalType, V> +where + V: VariantDecimalType, +{ value: Option<&'a BinaryViewArray>, - typed_value: &'a PrimitiveArray, + typed_value: &'a PrimitiveArray, scale: i8, + _phantom: PhantomData, } -impl<'a, S: DecimalSpec> DecimalUnshredRowBuilder<'a, S> { - fn new( - value: Option<&'a BinaryViewArray>, - typed_value: &'a PrimitiveArray, - scale: i8, - ) -> Self { +impl<'a, A: DecimalType, V> DecimalUnshredRowBuilder<'a, A, V> +where + V: VariantDecimalType, +{ + fn new(value: Option<&'a BinaryViewArray>, typed_value: &'a dyn Array, scale: i8) -> Self { Self { value, - typed_value, + typed_value: typed_value.as_primitive(), scale, + _phantom: PhantomData, } } @@ -588,7 +528,7 @@ impl<'a, S: DecimalSpec> DecimalUnshredRowBuilder<'a, S> { handle_unshredded_case!(self, builder, metadata, index, false); let raw = self.typed_value.value(index); - let variant = S::into_variant(raw, self.scale)?; + let variant = V::try_new_with_signed_scale(raw, self.scale)?; builder.append_value(variant); Ok(()) } diff --git a/parquet-variant-compute/src/variant_array.rs b/parquet-variant-compute/src/variant_array.rs index 5686d102d3fd..522c5a7546b5 100644 --- a/parquet-variant-compute/src/variant_array.rs +++ b/parquet-variant-compute/src/variant_array.rs @@ -26,13 +26,11 @@ use arrow::datatypes::{ TimestampMicrosecondType, TimestampNanosecondType, }; use arrow_schema::extension::ExtensionType; -use arrow_schema::{ - ArrowError, DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION, DECIMAL128_MAX_PRECISION, - DataType, Field, FieldRef, Fields, TimeUnit, -}; +use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields, TimeUnit}; use chrono::DateTime; -use parquet_variant::Uuid; -use parquet_variant::Variant; +use parquet_variant::{ + Uuid, Variant, VariantDecimal4, VariantDecimal8, VariantDecimal16, VariantDecimalType as _, +}; use std::borrow::Cow; use std::sync::Arc; @@ -937,18 +935,6 @@ fn cast_to_binary_view_arrays(array: &dyn Array) -> Result cast(array, new_type.as_ref()) } -/// Validates whether a given arrow decimal is a valid variant decimal -/// -/// NOTE: By a strict reading of the "decimal table" in the [shredding spec], each decimal type -/// should have a width-dependent lower bound on precision as well as an upper bound (i.e. Decimal16 -/// with precision 5 is invalid because Decimal4 "covers" it). But the variant shredding integration -/// tests specifically expect such cases to succeed, so we only enforce the upper bound here. -/// -/// [shredding spec]: https://github.com/apache/parquet-format/blob/master/VariantEncoding.md#encoding-types -fn is_valid_variant_decimal(p: &u8, s: &i8, max_precision: u8) -> bool { - (1..=max_precision).contains(p) && (0..=*p as i8).contains(s) -} - /// Recursively visits a data type, ensuring that it only contains data types that can legally /// appear in a (possibly shredded) variant array. It also replaces Binary fields with BinaryView, /// since that's what comes back from the parquet reader and what the variant code expects to find. @@ -984,16 +970,16 @@ fn canonicalize_and_verify_data_type( // NOTE: arrow-parquet reads widens 32- and 64-bit decimals to 128-bit, but the variant spec // requires using the narrowest decimal type for a given precision. Fix those up first. Decimal64(p, s) | Decimal128(p, s) - if is_valid_variant_decimal(p, s, DECIMAL32_MAX_PRECISION) => + if VariantDecimal4::is_valid_precision_and_scale(p, s) => { Cow::Owned(Decimal32(*p, *s)) } - Decimal128(p, s) if is_valid_variant_decimal(p, s, DECIMAL64_MAX_PRECISION) => { + Decimal128(p, s) if VariantDecimal8::is_valid_precision_and_scale(p, s) => { Cow::Owned(Decimal64(*p, *s)) } - Decimal32(p, s) if is_valid_variant_decimal(p, s, DECIMAL32_MAX_PRECISION) => borrow!(), - Decimal64(p, s) if is_valid_variant_decimal(p, s, DECIMAL64_MAX_PRECISION) => borrow!(), - Decimal128(p, s) if is_valid_variant_decimal(p, s, DECIMAL128_MAX_PRECISION) => borrow!(), + Decimal32(p, s) if VariantDecimal4::is_valid_precision_and_scale(p, s) => borrow!(), + Decimal64(p, s) if VariantDecimal8::is_valid_precision_and_scale(p, s) => borrow!(), + Decimal128(p, s) if VariantDecimal16::is_valid_precision_and_scale(p, s) => borrow!(), Decimal32(..) | Decimal64(..) | Decimal128(..) | Decimal256(..) => fail!(), // Only micro and nano timestamps are allowed diff --git a/parquet-variant/src/variant.rs b/parquet-variant/src/variant.rs index aa3eb51ed32b..819c20d554ce 100644 --- a/parquet-variant/src/variant.rs +++ b/parquet-variant/src/variant.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -pub use self::decimal::{VariantDecimal4, VariantDecimal8, VariantDecimal16}; +pub use self::decimal::{VariantDecimal4, VariantDecimal8, VariantDecimal16, VariantDecimalType}; pub use self::list::VariantList; pub use self::metadata::{EMPTY_VARIANT_METADATA, EMPTY_VARIANT_METADATA_BYTES, VariantMetadata}; pub use self::object::VariantObject; diff --git a/parquet-variant/src/variant/decimal.rs b/parquet-variant/src/variant/decimal.rs index b0b7d36ed161..c7849a381af9 100644 --- a/parquet-variant/src/variant/decimal.rs +++ b/parquet-variant/src/variant/decimal.rs @@ -17,52 +17,188 @@ use arrow_schema::ArrowError; use std::fmt; -// All decimal types use the same try_new implementation -macro_rules! decimal_try_new { - ($integer:ident, $scale:ident) => {{ - // Validate that scale doesn't exceed precision - if $scale > Self::MAX_PRECISION { - return Err(ArrowError::InvalidArgumentError(format!( - "Scale {} is larger than max precision {}", - $scale, - Self::MAX_PRECISION, - ))); - } +/// Trait for variant decimal types, enabling generic code across Decimal4/8/16 +/// +/// This trait provides a common interface for the three variant decimal types, +/// allowing generic functions and data structures to work with any decimal width. +/// It is modeled after Arrow's `DecimalType` trait but adapted for variant semantics. +/// +/// # Example +/// +/// ``` +/// # use parquet_variant::{VariantDecimal4, VariantDecimal8, VariantDecimalType}; +/// # +/// fn extract_scale(decimal: D) -> u8 { +/// decimal.scale() +/// } +/// +/// let dec4 = VariantDecimal4::try_new(12345, 2).unwrap(); +/// let dec8 = VariantDecimal8::try_new(67890, 3).unwrap(); +/// +/// assert_eq!(extract_scale(dec4), 2); +/// assert_eq!(extract_scale(dec8), 3); +/// ``` +pub trait VariantDecimalType: Into> { + /// The underlying signed integer type (i32, i64, or i128) + type Native; - // Validate that the integer value fits within the precision - if $integer.unsigned_abs() > Self::MAX_UNSCALED_VALUE { - return Err(ArrowError::InvalidArgumentError(format!( - "{} is wider than max precision {}", - $integer, - Self::MAX_PRECISION - ))); - } + /// Maximum number of significant digits this decimal type can represent (9, 18, or 38) + const MAX_PRECISION: u8; + /// The largest positive unscaled value that fits in [`Self::MAX_PRECISION`] digits. + const MAX_UNSCALED_VALUE: Self::Native; - Ok(Self { $integer, $scale }) - }}; + /// True if the given precision and scale are valid for this variant decimal type. + /// + /// NOTE: By a strict reading of the "decimal table" in the [variant spec], one might conclude that + /// each decimal type has both lower and upper bounds on precision (i.e. Decimal16 with precision 5 + /// is invalid because Decimal4 "covers" it). But the variant shredding integration tests + /// specifically expect such cases to succeed, so we only enforce the upper bound here. + /// + /// [shredding spec]: https://github.com/apache/parquet-format/blob/master/VariantEncoding.md#encoding-types + /// + /// # Example + /// ``` + /// # use parquet_variant::{VariantDecimal4, VariantDecimalType}; + /// # + /// assert!(VariantDecimal4::is_valid_precision_and_scale(&5, &2)); + /// assert!(!VariantDecimal4::is_valid_precision_and_scale(&10, &2)); // too wide + /// assert!(!VariantDecimal4::is_valid_precision_and_scale(&5, &-1)); // negative scale + /// assert!(!VariantDecimal4::is_valid_precision_and_scale(&5, &7)); // scale too big + /// ``` + fn is_valid_precision_and_scale(precision: &u8, scale: &i8) -> bool { + (1..=Self::MAX_PRECISION).contains(precision) && (0..=*precision as i8).contains(scale) + } + + /// Creates a new decimal value from the given unscaled integer and scale, failing if the + /// integer's width, or the requested scale, exceeds `MAX_PRECISION`. + /// + /// NOTE: For compatibility with arrow decimal types, negative scale is allowed as long + /// as the rescaled value fits in the available precision. + /// + /// # Example + /// + /// ``` + /// # use parquet_variant::{VariantDecimal4, VariantDecimalType}; + /// # + /// // Valid: 123.45 (5 digits, scale 2) + /// let d = VariantDecimal4::try_new(12345, 2).unwrap(); + /// assert_eq!(d.integer(), 12345); + /// assert_eq!(d.scale(), 2); + /// + /// VariantDecimal4::try_new(123, 10).expect_err("scale exceeds MAX_PRECISION"); + /// VariantDecimal4::try_new(1234567890, 10).expect_err("value's width exceeds MAX_PRECISION"); + /// ``` + fn try_new(integer: Self::Native, scale: u8) -> Result; + + /// Attempts to convert an unscaled arrow decimal value to the indicated variant decimal type. + /// + /// Unlike [`Self::try_new`], this function accepts a signed scale, and attempts to rescale + /// negative-scale values to their equivalent (larger) scale-0 values. For example, a decimal + /// value of 123 with scale -2 becomes 12300 with scale 0. + /// + /// Fails if rescaling fails, or for any of the reasons [`Self::try_new`] could fail. + fn try_new_with_signed_scale(integer: Self::Native, scale: i8) -> Result; + + /// Returns the unscaled integer value + fn integer(&self) -> Self::Native; + + /// Returns the scale (number of digits after the decimal point) + fn scale(&self) -> u8; } -// All decimal values format the same way, using integer arithmetic to avoid floating point precision loss -macro_rules! format_decimal { - ($f:expr, $integer:expr, $scale:expr, $int_type:ty) => {{ - let integer = if $scale == 0 { - $integer - } else { - let divisor = <$int_type>::pow(10, $scale as u32); - let remainder = $integer % divisor; - if remainder != 0 { - // Track the sign explicitly, in case the quotient is zero - let sign = if $integer < 0 { "-" } else { "" }; - // Format an unsigned remainder with leading zeros and strip (unnecessary) trailing zeros. - let remainder = format!("{:0width$}", remainder.abs(), width = $scale as usize); - let remainder = remainder.trim_end_matches('0'); - let quotient = $integer / divisor; - return write!($f, "{}{}.{}", sign, quotient.abs(), remainder); +/// Implements the complete variant decimal type: methods, Display, and VariantDecimalType trait +macro_rules! impl_variant_decimal { + ($struct_name:ident, $native:ty) => { + impl $struct_name { + /// Attempts to create a new instance of this decimal type, failing if the value is too + /// wide or the scale is too large. + pub fn try_new(integer: $native, scale: u8) -> Result { + let max_precision = Self::MAX_PRECISION; + if scale > max_precision { + return Err(ArrowError::InvalidArgumentError(format!( + "Scale {scale} is larger than max precision {max_precision}", + ))); + } + if !(-Self::MAX_UNSCALED_VALUE..=Self::MAX_UNSCALED_VALUE).contains(&integer) { + return Err(ArrowError::InvalidArgumentError(format!( + "{integer} is wider than max precision {max_precision}", + ))); + } + + Ok(Self { integer, scale }) + } + + /// Returns the unscaled integer value of the decimal. + /// + /// For example, if the decimal is `123.45`, this will return `12345`. + pub fn integer(&self) -> $native { + self.integer + } + + /// Returns the scale of the decimal (how many digits after the decimal point). + /// + /// For example, if the decimal is `123.45`, this will return `2`. + pub fn scale(&self) -> u8 { + self.scale + } + } + + impl VariantDecimalType for $struct_name { + type Native = $native; + const MAX_PRECISION: u8 = Self::MAX_PRECISION; + const MAX_UNSCALED_VALUE: $native = <$native>::pow(10, Self::MAX_PRECISION as u32) - 1; + + fn try_new(integer: $native, scale: u8) -> Result { + Self::try_new(integer, scale) + } + + fn try_new_with_signed_scale(integer: $native, scale: i8) -> Result { + let (integer, scale) = if scale < 0 { + let multiplier = <$native>::checked_pow(10, -scale as u32); + let Some(rescaled) = multiplier.and_then(|m| integer.checked_mul(m)) else { + return Err(ArrowError::InvalidArgumentError(format!( + "Overflow when rescaling {integer} with scale {scale}" + ))); + }; + (rescaled, 0u8) + } else { + (integer, scale as u8) + }; + Self::try_new(integer, scale) } - $integer / divisor - }; - write!($f, "{}", integer) - }}; + + fn integer(&self) -> $native { + self.integer() + } + + fn scale(&self) -> u8 { + self.scale() + } + } + + impl fmt::Display for $struct_name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let integer = if self.scale == 0 { + self.integer + } else { + let divisor = <$native>::pow(10, self.scale as u32); + let remainder = self.integer % divisor; + if remainder != 0 { + // Track the sign explicitly, in case the quotient is zero + let sign = if self.integer < 0 { "-" } else { "" }; + // Format an unsigned remainder with leading zeros and strip trailing zeros + let remainder = + format!("{:0width$}", remainder.abs(), width = self.scale as usize); + let remainder = remainder.trim_end_matches('0'); + let quotient = (self.integer / divisor).abs(); + return write!(f, "{sign}{quotient}.{remainder}"); + } + self.integer / divisor + }; + write!(f, "{integer}") + } + } + }; } /// Represents a 4-byte decimal value in the Variant format. @@ -86,33 +222,11 @@ pub struct VariantDecimal4 { } impl VariantDecimal4 { - pub(crate) const MAX_PRECISION: u8 = 9; - pub(crate) const MAX_UNSCALED_VALUE: u32 = u32::pow(10, Self::MAX_PRECISION as u32) - 1; - - pub fn try_new(integer: i32, scale: u8) -> Result { - decimal_try_new!(integer, scale) - } - - /// Returns the underlying value of the decimal. - /// - /// For example, if the decimal is `123.4567`, this will return `1234567`. - pub fn integer(&self) -> i32 { - self.integer - } - - /// Returns the scale of the decimal (how many digits after the decimal point). - /// - /// For example, if the decimal is `123.4567`, this will return `4`. - pub fn scale(&self) -> u8 { - self.scale - } + /// Maximum number of significant digits (9 for 4-byte decimals) + pub const MAX_PRECISION: u8 = arrow_schema::DECIMAL32_MAX_PRECISION; } -impl fmt::Display for VariantDecimal4 { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - format_decimal!(f, self.integer, self.scale, i32) - } -} +impl_variant_decimal!(VariantDecimal4, i32); /// Represents an 8-byte decimal value in the Variant format. /// @@ -136,33 +250,11 @@ pub struct VariantDecimal8 { } impl VariantDecimal8 { - pub(crate) const MAX_PRECISION: u8 = 18; - pub(crate) const MAX_UNSCALED_VALUE: u64 = u64::pow(10, Self::MAX_PRECISION as u32) - 1; - - pub fn try_new(integer: i64, scale: u8) -> Result { - decimal_try_new!(integer, scale) - } - - /// Returns the underlying value of the decimal. - /// - /// For example, if the decimal is `123456.78`, this will return `12345678`. - pub fn integer(&self) -> i64 { - self.integer - } - - /// Returns the scale of the decimal (how many digits after the decimal point). - /// - /// For example, if the decimal is `123456.78`, this will return `2`. - pub fn scale(&self) -> u8 { - self.scale - } + /// Maximum number of significant digits (18 for 8-byte decimals) + pub const MAX_PRECISION: u8 = arrow_schema::DECIMAL64_MAX_PRECISION; } -impl fmt::Display for VariantDecimal8 { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - format_decimal!(f, self.integer, self.scale, i64) - } -} +impl_variant_decimal!(VariantDecimal8, i64); /// Represents an 16-byte decimal value in the Variant format. /// @@ -186,33 +278,11 @@ pub struct VariantDecimal16 { } impl VariantDecimal16 { - const MAX_PRECISION: u8 = 38; - const MAX_UNSCALED_VALUE: u128 = u128::pow(10, Self::MAX_PRECISION as u32) - 1; - - pub fn try_new(integer: i128, scale: u8) -> Result { - decimal_try_new!(integer, scale) - } - - /// Returns the underlying value of the decimal. - /// - /// For example, if the decimal is `12345678901234567.890`, this will return `12345678901234567890`. - pub fn integer(&self) -> i128 { - self.integer - } - - /// Returns the scale of the decimal (how many digits after the decimal point). - /// - /// For example, if the decimal is `12345678901234567.890`, this will return `3`. - pub fn scale(&self) -> u8 { - self.scale - } + /// Maximum number of significant digits (38 for 16-byte decimals) + pub const MAX_PRECISION: u8 = arrow_schema::DECIMAL128_MAX_PRECISION; } -impl fmt::Display for VariantDecimal16 { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - format_decimal!(f, self.integer, self.scale, i128) - } -} +impl_variant_decimal!(VariantDecimal16, i128); // Infallible conversion from a narrower decimal type to a wider one macro_rules! impl_from_decimal_for_decimal {