From 6d58d1cd14247c29ee07125b36def2558a202f62 Mon Sep 17 00:00:00 2001 From: Weijun Huang Date: Wed, 5 Apr 2023 11:55:39 +0200 Subject: [PATCH] feat: cast from/to interval and duration --- arrow-cast/src/cast.rs | 245 ++++++++++++++++++++++++++++++++++++++++- 1 file changed, 244 insertions(+), 1 deletion(-) diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs index 0ea6332a7ea5..a0e198b96cfb 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast.rs @@ -458,6 +458,93 @@ where } } +/// Cast the array from interval to duration +fn cast_interval_to_duration>( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result { + let array = array + .as_any() + .downcast_ref::() + .unwrap(); + + let mut builder = PrimitiveBuilder::::new(); + + let scale = match D::DATA_TYPE { + DataType::Duration(TimeUnit::Second) => 1_000_000_000, + DataType::Duration(TimeUnit::Millisecond) => 1_000_000, + DataType::Duration(TimeUnit::Microsecond) => 1_000, + DataType::Duration(TimeUnit::Nanosecond) => 1, + _ => unreachable!(), + }; + + for i in 0..array.len() { + if array.is_null(i) { + builder.append_null(); + } else { + let v = array.value(i) / scale; + if v > i64::MAX as i128 { + if cast_options.safe { + builder.append_null(); + } else { + return Err(ArrowError::ComputeError(format!( + "Cannot cast to {:?}. Overflowing on {:?}", + D::DATA_TYPE, + v + ))); + } + } else { + builder.append_value(v as i64); + } + } + } + + Ok(Arc::new(builder.finish()) as ArrayRef) +} + +/// Cast the array from duration and interval +fn cast_duration_to_interval>( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result { + let array = array.as_any().downcast_ref::>().unwrap(); + + let mut builder = IntervalMonthDayNanoBuilder::new(); + + let scale = match array.data_type() { + DataType::Duration(TimeUnit::Second) => 1_000_000_000, + DataType::Duration(TimeUnit::Millisecond) => 1_000_000, + DataType::Duration(TimeUnit::Microsecond) => 1_000, + DataType::Duration(TimeUnit::Nanosecond) => 1, + _ => unreachable!(), + }; + + for i in 0..array.len() { + if array.is_null(i) { + builder.append_null(); + } else { + let v = array.value(i) as i128; + let v = v.mul_checked(scale); + match v { + Ok(v) => builder.append_value(v), + Err(e) => { + if cast_options.safe { + builder.append_null() + } else { + return Err(ArrowError::ComputeError(format!( + "Cannot cast to {:?}. Overflowing on {:?}", + D::DATA_TYPE, + e + ))); + } + } + }; + } + } + + Ok(Arc::new(builder.finish())) +} + /// Cast the primitive array using [`PrimitiveArray::reinterpret_cast`] fn cast_reinterpret_arrays< I: ArrowPrimitiveType, @@ -2014,7 +2101,30 @@ pub fn cast_with_options( (Duration(TimeUnit::Nanosecond), Int64) => { cast_reinterpret_arrays::(array) } - + (Duration(TimeUnit::Second), Interval(IntervalUnit::MonthDayNano)) => { + cast_duration_to_interval::(array, cast_options) + } + (Duration(TimeUnit::Millisecond), Interval(IntervalUnit::MonthDayNano)) => { + cast_duration_to_interval::(array, cast_options) + } + (Duration(TimeUnit::Microsecond), Interval(IntervalUnit::MonthDayNano)) => { + cast_duration_to_interval::(array, cast_options) + } + (Duration(TimeUnit::Nanosecond), Interval(IntervalUnit::MonthDayNano)) => { + cast_duration_to_interval::(array, cast_options) + } + (DataType::Interval(IntervalUnit::MonthDayNano), DataType::Duration(TimeUnit::Second)) => { + cast_interval_to_duration::(array, cast_options) + } + (DataType::Interval(IntervalUnit::MonthDayNano), DataType::Duration(TimeUnit::Millisecond)) => { + cast_interval_to_duration::(array, cast_options) + } + (DataType::Interval(IntervalUnit::MonthDayNano), DataType::Duration(TimeUnit::Microsecond)) => { + cast_interval_to_duration::(array, cast_options) + } + (DataType::Interval(IntervalUnit::MonthDayNano), DataType::Duration(TimeUnit::Nanosecond)) => { + cast_interval_to_duration::(array, cast_options) + } (Interval(IntervalUnit::YearMonth), Int64) => { cast_numeric_arrays::(array, cast_options) } @@ -8246,4 +8356,137 @@ mod tests { ); assert_eq!("Invalid argument error: 1234567000 is too large to store in a Decimal256 of precision 7. Max is 9999999", err.unwrap_err().to_string()); } + + #[test] + fn test_cast_from_duration_to_interval() { + // from duration second to interval month day nano + let array = vec![1234567]; + let array = DurationSecondArray::from(array); + let array = Arc::new(array) as ArrayRef; + let casted_array = + cast(&array, &DataType::Interval(IntervalUnit::MonthDayNano)).unwrap(); + let casted_array = casted_array + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + casted_array.data_type(), + &DataType::Interval(IntervalUnit::MonthDayNano) + ); + assert_eq!(casted_array.value(0), 1234567000000000); + + // from duration millisecond to interval month day nano + let array = vec![1234567]; + let array = DurationMillisecondArray::from(array); + let array = Arc::new(array) as ArrayRef; + let casted_array = + cast(&array, &DataType::Interval(IntervalUnit::MonthDayNano)).unwrap(); + let casted_array = casted_array + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + casted_array.data_type(), + &DataType::Interval(IntervalUnit::MonthDayNano) + ); + assert_eq!(casted_array.value(0), 1234567000000); + + // from duration microsecond to interval month day nano + let array = vec![1234567]; + let array = DurationMicrosecondArray::from(array); + let array = Arc::new(array) as ArrayRef; + let casted_array = + cast(&array, &DataType::Interval(IntervalUnit::MonthDayNano)).unwrap(); + let casted_array = casted_array + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + casted_array.data_type(), + &DataType::Interval(IntervalUnit::MonthDayNano) + ); + assert_eq!(casted_array.value(0), 1234567000); + + // from duration nanosecond to interval month day nano + let array = vec![1234567]; + let array = DurationNanosecondArray::from(array); + let array = Arc::new(array) as ArrayRef; + let casted_array = + cast(&array, &DataType::Interval(IntervalUnit::MonthDayNano)).unwrap(); + let casted_array = casted_array + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + casted_array.data_type(), + &DataType::Interval(IntervalUnit::MonthDayNano) + ); + assert_eq!(casted_array.value(0), 1234567); + } + + #[test] + fn test_cast_from_interval_to_duration() { + // from interval month day nano to duration second + let array = vec![1234567]; + let array = IntervalMonthDayNanoArray::from(array); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast(&array, &DataType::Duration(TimeUnit::Second)).unwrap(); + let casted_array = casted_array + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!( + casted_array.data_type(), + &DataType::Duration(TimeUnit::Second) + ); + assert_eq!(casted_array.value(0), 0); + + // from interval month day nano to duration millisecond + let array = vec![1234567]; + let array = IntervalMonthDayNanoArray::from(array); + let array = Arc::new(array) as ArrayRef; + let casted_array = + cast(&array, &DataType::Duration(TimeUnit::Millisecond)).unwrap(); + let casted_array = casted_array + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + casted_array.data_type(), + &DataType::Duration(TimeUnit::Millisecond) + ); + assert_eq!(casted_array.value(0), 1); + + // from interval month day nano to duration microsecond + let array = vec![1234567]; + let array = IntervalMonthDayNanoArray::from(array); + let array = Arc::new(array) as ArrayRef; + let casted_array = + cast(&array, &DataType::Duration(TimeUnit::Microsecond)).unwrap(); + let casted_array = casted_array + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + casted_array.data_type(), + &DataType::Duration(TimeUnit::Microsecond) + ); + assert_eq!(casted_array.value(0), 1234); + // from interval month day nano to duration nanosecond + let array = vec![1234567]; + let array = IntervalMonthDayNanoArray::from(array); + let array = Arc::new(array) as ArrayRef; + let casted_array = + cast(&array, &DataType::Duration(TimeUnit::Nanosecond)).unwrap(); + let casted_array = casted_array + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + casted_array.data_type(), + &DataType::Duration(TimeUnit::Nanosecond) + ); + assert_eq!(casted_array.value(0), 1234567); + } }