Skip to content

Commit b8fdd90

Browse files
authored
[Variant] Define and use VariantDecimalType trait (#8562)
# Which issue does this PR close? We generally require a GitHub issue to be filed for all bug fixes and enhancements and this helps us generate change logs for our releases. You can link an issue to this PR using the GitHub syntax. - Closes #NNN. # Rationale for this change `VariantDecimalXX` structs are structurally near-identical but lack any trait to that can expose that regularity. # What changes are included in this PR? Define and use a new `VariantDecimalType` trait that exposes common functionality of all three variant decimal types. # Are these changes tested? Yes, existing unit tests cover the changes. # Are there any user-facing changes? New pub trait.
1 parent 2f5ae5c commit b8fdd90

File tree

6 files changed

+293
-332
lines changed

6 files changed

+293
-332
lines changed

parquet-variant-compute/src/arrow_to_variant.rs

Lines changed: 52 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,22 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use crate::type_conversion::{CastOptions, decimal_to_variant_decimal};
18+
use crate::type_conversion::CastOptions;
1919
use arrow::array::{
2020
Array, AsArray, FixedSizeListArray, GenericBinaryArray, GenericListArray, GenericListViewArray,
2121
GenericStringArray, OffsetSizeTrait, PrimitiveArray,
2222
};
2323
use arrow::compute::kernels::cast;
2424
use arrow::datatypes::{
25-
ArrowNativeType, ArrowPrimitiveType, ArrowTemporalType, ArrowTimestampType, Date32Type,
26-
Date64Type, Float16Type, Float32Type, Float64Type, Int8Type, Int16Type, Int32Type, Int64Type,
27-
RunEndIndexType, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
28-
Time64NanosecondType, TimestampMicrosecondType, TimestampMillisecondType,
29-
TimestampNanosecondType, TimestampSecondType, UInt8Type, UInt16Type, UInt32Type, UInt64Type,
25+
self as datatypes, ArrowNativeType, ArrowPrimitiveType, ArrowTemporalType, ArrowTimestampType,
26+
DecimalType, RunEndIndexType,
3027
};
3128
use arrow::temporal_conversions::{as_date, as_datetime, as_time};
3229
use arrow_schema::{ArrowError, DataType, TimeUnit};
3330
use chrono::{DateTime, TimeZone, Utc};
3431
use parquet_variant::{
3532
ObjectFieldBuilder, Variant, VariantBuilderExt, VariantDecimal4, VariantDecimal8,
36-
VariantDecimal16,
33+
VariantDecimal16, VariantDecimalType,
3734
};
3835
use std::collections::HashMap;
3936
use std::ops::Range;
@@ -46,31 +43,31 @@ use std::ops::Range;
4643
pub(crate) enum ArrowToVariantRowBuilder<'a> {
4744
Null(NullArrowToVariantBuilder),
4845
Boolean(BooleanArrowToVariantBuilder<'a>),
49-
PrimitiveInt8(PrimitiveArrowToVariantBuilder<'a, Int8Type>),
50-
PrimitiveInt16(PrimitiveArrowToVariantBuilder<'a, Int16Type>),
51-
PrimitiveInt32(PrimitiveArrowToVariantBuilder<'a, Int32Type>),
52-
PrimitiveInt64(PrimitiveArrowToVariantBuilder<'a, Int64Type>),
53-
PrimitiveUInt8(PrimitiveArrowToVariantBuilder<'a, UInt8Type>),
54-
PrimitiveUInt16(PrimitiveArrowToVariantBuilder<'a, UInt16Type>),
55-
PrimitiveUInt32(PrimitiveArrowToVariantBuilder<'a, UInt32Type>),
56-
PrimitiveUInt64(PrimitiveArrowToVariantBuilder<'a, UInt64Type>),
57-
PrimitiveFloat16(PrimitiveArrowToVariantBuilder<'a, Float16Type>),
58-
PrimitiveFloat32(PrimitiveArrowToVariantBuilder<'a, Float32Type>),
59-
PrimitiveFloat64(PrimitiveArrowToVariantBuilder<'a, Float64Type>),
60-
Decimal32(Decimal32ArrowToVariantBuilder<'a>),
61-
Decimal64(Decimal64ArrowToVariantBuilder<'a>),
62-
Decimal128(Decimal128ArrowToVariantBuilder<'a>),
46+
PrimitiveInt8(PrimitiveArrowToVariantBuilder<'a, datatypes::Int8Type>),
47+
PrimitiveInt16(PrimitiveArrowToVariantBuilder<'a, datatypes::Int16Type>),
48+
PrimitiveInt32(PrimitiveArrowToVariantBuilder<'a, datatypes::Int32Type>),
49+
PrimitiveInt64(PrimitiveArrowToVariantBuilder<'a, datatypes::Int64Type>),
50+
PrimitiveUInt8(PrimitiveArrowToVariantBuilder<'a, datatypes::UInt8Type>),
51+
PrimitiveUInt16(PrimitiveArrowToVariantBuilder<'a, datatypes::UInt16Type>),
52+
PrimitiveUInt32(PrimitiveArrowToVariantBuilder<'a, datatypes::UInt32Type>),
53+
PrimitiveUInt64(PrimitiveArrowToVariantBuilder<'a, datatypes::UInt64Type>),
54+
PrimitiveFloat16(PrimitiveArrowToVariantBuilder<'a, datatypes::Float16Type>),
55+
PrimitiveFloat32(PrimitiveArrowToVariantBuilder<'a, datatypes::Float32Type>),
56+
PrimitiveFloat64(PrimitiveArrowToVariantBuilder<'a, datatypes::Float64Type>),
57+
Decimal32(DecimalArrowToVariantBuilder<'a, datatypes::Decimal32Type, VariantDecimal4>),
58+
Decimal64(DecimalArrowToVariantBuilder<'a, datatypes::Decimal64Type, VariantDecimal8>),
59+
Decimal128(DecimalArrowToVariantBuilder<'a, datatypes::Decimal128Type, VariantDecimal16>),
6360
Decimal256(Decimal256ArrowToVariantBuilder<'a>),
64-
TimestampSecond(TimestampArrowToVariantBuilder<'a, TimestampSecondType>),
65-
TimestampMillisecond(TimestampArrowToVariantBuilder<'a, TimestampMillisecondType>),
66-
TimestampMicrosecond(TimestampArrowToVariantBuilder<'a, TimestampMicrosecondType>),
67-
TimestampNanosecond(TimestampArrowToVariantBuilder<'a, TimestampNanosecondType>),
68-
Date32(DateArrowToVariantBuilder<'a, Date32Type>),
69-
Date64(DateArrowToVariantBuilder<'a, Date64Type>),
70-
Time32Second(TimeArrowToVariantBuilder<'a, Time32SecondType>),
71-
Time32Millisecond(TimeArrowToVariantBuilder<'a, Time32MillisecondType>),
72-
Time64Microsecond(TimeArrowToVariantBuilder<'a, Time64MicrosecondType>),
73-
Time64Nanosecond(TimeArrowToVariantBuilder<'a, Time64NanosecondType>),
61+
TimestampSecond(TimestampArrowToVariantBuilder<'a, datatypes::TimestampSecondType>),
62+
TimestampMillisecond(TimestampArrowToVariantBuilder<'a, datatypes::TimestampMillisecondType>),
63+
TimestampMicrosecond(TimestampArrowToVariantBuilder<'a, datatypes::TimestampMicrosecondType>),
64+
TimestampNanosecond(TimestampArrowToVariantBuilder<'a, datatypes::TimestampNanosecondType>),
65+
Date32(DateArrowToVariantBuilder<'a, datatypes::Date32Type>),
66+
Date64(DateArrowToVariantBuilder<'a, datatypes::Date64Type>),
67+
Time32Second(TimeArrowToVariantBuilder<'a, datatypes::Time32SecondType>),
68+
Time32Millisecond(TimeArrowToVariantBuilder<'a, datatypes::Time32MillisecondType>),
69+
Time64Microsecond(TimeArrowToVariantBuilder<'a, datatypes::Time64MicrosecondType>),
70+
Time64Nanosecond(TimeArrowToVariantBuilder<'a, datatypes::Time64NanosecondType>),
7471
Binary(BinaryArrowToVariantBuilder<'a, i32>),
7572
LargeBinary(BinaryArrowToVariantBuilder<'a, i64>),
7673
BinaryView(BinaryViewArrowToVariantBuilder<'a>),
@@ -87,9 +84,9 @@ pub(crate) enum ArrowToVariantRowBuilder<'a> {
8784
Map(MapArrowToVariantBuilder<'a>),
8885
Union(UnionArrowToVariantBuilder<'a>),
8986
Dictionary(DictionaryArrowToVariantBuilder<'a>),
90-
RunEndEncodedInt16(RunEndEncodedArrowToVariantBuilder<'a, Int16Type>),
91-
RunEndEncodedInt32(RunEndEncodedArrowToVariantBuilder<'a, Int32Type>),
92-
RunEndEncodedInt64(RunEndEncodedArrowToVariantBuilder<'a, Int64Type>),
87+
RunEndEncodedInt16(RunEndEncodedArrowToVariantBuilder<'a, datatypes::Int16Type>),
88+
RunEndEncodedInt32(RunEndEncodedArrowToVariantBuilder<'a, datatypes::Int32Type>),
89+
RunEndEncodedInt64(RunEndEncodedArrowToVariantBuilder<'a, datatypes::Int64Type>),
9390
}
9491

9592
impl<'a> ArrowToVariantRowBuilder<'a> {
@@ -174,13 +171,13 @@ pub(crate) fn make_arrow_to_variant_row_builder<'a>(
174171
DataType::Float32 => PrimitiveFloat32(PrimitiveArrowToVariantBuilder::new(array)),
175172
DataType::Float64 => PrimitiveFloat64(PrimitiveArrowToVariantBuilder::new(array)),
176173
DataType::Decimal32(_, scale) => {
177-
Decimal32(Decimal32ArrowToVariantBuilder::new(array, options, *scale))
174+
Decimal32(DecimalArrowToVariantBuilder::new(array, options, *scale))
178175
}
179176
DataType::Decimal64(_, scale) => {
180-
Decimal64(Decimal64ArrowToVariantBuilder::new(array, options, *scale))
177+
Decimal64(DecimalArrowToVariantBuilder::new(array, options, *scale))
181178
}
182179
DataType::Decimal128(_, scale) => {
183-
Decimal128(Decimal128ArrowToVariantBuilder::new(array, options, *scale))
180+
Decimal128(DecimalArrowToVariantBuilder::new(array, options, *scale))
184181
}
185182
DataType::Decimal256(_, scale) => {
186183
Decimal256(Decimal256ArrowToVariantBuilder::new(array, options, *scale))
@@ -320,26 +317,28 @@ pub(crate) fn make_arrow_to_variant_row_builder<'a>(
320317
// worth the trouble, tho, because it makes for some pretty bulky and unwieldy macro expansions.
321318
macro_rules! define_row_builder {
322319
(
323-
struct $name:ident<$lifetime:lifetime $(, $generic:ident: $bound:path )?>
320+
struct $name:ident<$lifetime:lifetime $(, $generic:ident $( : $bound:path )? )*>
324321
$( where $where_path:path: $where_bound:path $(,)? )?
325-
$({ $($field:ident: $field_type:ty),+ $(,)? })?,
322+
$({ $( $field:ident: $field_type:ty ),+ $(,)? })?,
326323
|$array_param:ident| -> $array_type:ty { $init_expr:expr }
327-
$(, |$value:ident| $(-> Option<$option_ty:ty>)? $value_transform:expr)?
324+
$(, |$value:ident| $(-> Option<$option_ty:ty>)? $value_transform:expr )?
328325
) => {
329-
pub(crate) struct $name<$lifetime $(, $generic: $bound )?>
326+
pub(crate) struct $name<$lifetime $(, $generic: $( $bound )? )*>
330327
$( where $where_path: $where_bound )?
331328
{
332329
array: &$lifetime $array_type,
333330
$( $( $field: $field_type, )+ )?
331+
_phantom: std::marker::PhantomData<($( $generic, )*)>, // capture all type params
334332
}
335333

336-
impl<$lifetime $(, $generic: $bound+ )?> $name<$lifetime $(, $generic)?>
334+
impl<$lifetime $(, $generic: $( $bound )? )*> $name<$lifetime $(, $generic)*>
337335
$( where $where_path: $where_bound )?
338336
{
339-
pub(crate) fn new($array_param: &$lifetime dyn Array $(, $( $field: $field_type ),+ )?) -> Self {
337+
pub(crate) fn new($array_param: &$lifetime dyn Array $( $(, $field: $field_type )+ )?) -> Self {
340338
Self {
341339
array: $init_expr,
342340
$( $( $field, )+ )?
341+
_phantom: std::marker::PhantomData,
343342
}
344343
}
345344

@@ -401,43 +400,27 @@ define_row_builder!(
401400
);
402401

403402
define_row_builder!(
404-
struct Decimal32ArrowToVariantBuilder<'a> {
405-
options: &'a CastOptions,
406-
scale: i8,
407-
},
408-
|array| -> arrow::array::Decimal32Array { array.as_primitive() },
409-
|value| -> Option<_> { decimal_to_variant_decimal!(value, scale, i32, VariantDecimal4) }
410-
);
411-
412-
define_row_builder!(
413-
struct Decimal64ArrowToVariantBuilder<'a> {
414-
options: &'a CastOptions,
415-
scale: i8,
416-
},
417-
|array| -> arrow::array::Decimal64Array { array.as_primitive() },
418-
|value| -> Option<_> { decimal_to_variant_decimal!(value, scale, i64, VariantDecimal8) }
419-
);
420-
421-
define_row_builder!(
422-
struct Decimal128ArrowToVariantBuilder<'a> {
403+
struct DecimalArrowToVariantBuilder<'a, A: DecimalType, V>
404+
where
405+
V: VariantDecimalType<Native = A::Native>,
406+
{
423407
options: &'a CastOptions,
424408
scale: i8,
425409
},
426-
|array| -> arrow::array::Decimal128Array { array.as_primitive() },
427-
|value| -> Option<_> { decimal_to_variant_decimal!(value, scale, i128, VariantDecimal16) }
410+
|array| -> PrimitiveArray<A> { array.as_primitive() },
411+
|value| -> Option<_> { V::try_new_with_signed_scale(value, *scale).ok() }
428412
);
429413

414+
// Decimal256 needs a two-stage conversion via i128
430415
define_row_builder!(
431416
struct Decimal256ArrowToVariantBuilder<'a> {
432417
options: &'a CastOptions,
433418
scale: i8,
434419
},
435420
|array| -> arrow::array::Decimal256Array { array.as_primitive() },
436421
|value| -> Option<_> {
437-
// Decimal256 needs special handling - convert to i128 if possible
438-
value.to_i128().and_then(|i128_val| {
439-
decimal_to_variant_decimal!(i128_val, scale, i128, VariantDecimal16)
440-
})
422+
let value = value.to_i128();
423+
value.and_then(|v| VariantDecimal16::try_new_with_signed_scale(v, *scale).ok())
441424
}
442425
);
443426

@@ -911,6 +894,7 @@ mod tests {
911894
use super::*;
912895
use crate::{VariantArray, VariantArrayBuilder};
913896
use arrow::array::{ArrayRef, BooleanArray, Int32Array, StringArray};
897+
use arrow::datatypes::Int32Type;
914898
use std::sync::Arc;
915899

916900
/// Builds a VariantArray from an Arrow array using the row builder.

parquet-variant-compute/src/type_conversion.rs

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -150,22 +150,3 @@ macro_rules! primitive_conversion_single_value {
150150
}};
151151
}
152152
pub(crate) use primitive_conversion_single_value;
153-
154-
/// Convert a decimal value to a `VariantDecimal`
155-
macro_rules! decimal_to_variant_decimal {
156-
($v:ident, $scale:expr, $value_type:ty, $variant_type:ty) => {{
157-
let (v, scale) = if *$scale < 0 {
158-
// For negative scale, we need to multiply the value by 10^|scale|
159-
// For example: 123 with scale -2 becomes 12300 with scale 0
160-
let multiplier = <$value_type>::pow(10, (-*$scale) as u32);
161-
(<$value_type>::checked_mul($v, multiplier), 0u8)
162-
} else {
163-
(Some($v), *$scale as u8)
164-
};
165-
166-
// Return an Option to allow callers to decide whether to error (strict)
167-
// or append null (non-strict) on conversion failure
168-
v.and_then(|v| <$variant_type>::try_new(v, scale).ok())
169-
}};
170-
}
171-
pub(crate) use decimal_to_variant_decimal;

0 commit comments

Comments
 (0)