diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs index 2c1dae5187fa..bc37174b94f2 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast.rs @@ -303,6 +303,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { IntervalUnit::MonthDayNano => false, } } + (Duration(_), Interval(IntervalUnit::MonthDayNano)) => true, + (Interval(IntervalUnit::MonthDayNano), Duration(_)) => true, (_, _) => false, } } @@ -458,6 +460,122 @@ where } } +/// Cast the array from interval to duration +fn cast_interval_to_duration>( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result { + let array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::ComputeError( + "Internal Error: Cannot cast interval to IntervalArray of expected type" + .to_string(), + ) + })?; + + let scale = match D::DATA_TYPE { + DataType::Duration(TimeUnit::Second) => 1_000_000_000, + DataType::Duration(TimeUnit::Millisecond) => 1_000_000, + DataType::Duration(TimeUnit::Microsecond) => 1_000, + DataType::Duration(TimeUnit::Nanosecond) => 1, + _ => unreachable!(), + }; + + if cast_options.safe { + let iter = array.iter().map(|v| { + v.and_then(|v| { + let v = v / scale; + if v > i64::MAX as i128 { + None + } else { + Some(v as i64) + } + }) + }); + Ok(Arc::new(unsafe { + PrimitiveArray::::from_trusted_len_iter(iter) + })) + } else { + let vec = array + .iter() + .map(|v| { + v.map(|v| { + let v = v / scale; + if v > i64::MAX as i128 { + Err(ArrowError::ComputeError(format!( + "Cannot cast to {:?}. Overflowing on {:?}", + D::DATA_TYPE, + v + ))) + } else { + Ok(v as i64) + } + }) + .transpose() + }) + .collect::, _>>()?; + Ok(Arc::new(unsafe { + PrimitiveArray::::from_trusted_len_iter(vec.iter()) + })) + } +} + +/// Cast the array from duration and interval +fn cast_duration_to_interval>( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result { + let array = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + ArrowError::ComputeError( + "Internal Error: Cannot cast duration to DurationArray of expected type" + .to_string(), + ) + })?; + + let scale = match array.data_type() { + DataType::Duration(TimeUnit::Second) => 1_000_000_000, + DataType::Duration(TimeUnit::Millisecond) => 1_000_000, + DataType::Duration(TimeUnit::Microsecond) => 1_000, + DataType::Duration(TimeUnit::Nanosecond) => 1, + _ => unreachable!(), + }; + + if cast_options.safe { + let iter = array + .iter() + .map(|v| v.and_then(|v| v.checked_mul(scale).map(|v| v as i128))); + Ok(Arc::new(unsafe { + PrimitiveArray::::from_trusted_len_iter(iter) + })) + } else { + let vec = array + .iter() + .map(|v| { + v.map(|v| { + if let Ok(v) = v.mul_checked(scale) { + Ok(v as i128) + } else { + Err(ArrowError::ComputeError(format!( + "Cannot cast to {:?}. Overflowing on {:?}", + IntervalMonthDayNanoType::DATA_TYPE, + v + ))) + } + }) + .transpose() + }) + .collect::, _>>()?; + Ok(Arc::new(unsafe { + PrimitiveArray::::from_trusted_len_iter(vec.iter()) + })) + } +} + /// Cast the primitive array using [`PrimitiveArray::reinterpret_cast`] fn cast_reinterpret_arrays< I: ArrowPrimitiveType, @@ -2014,7 +2132,30 @@ pub fn cast_with_options( (Duration(TimeUnit::Nanosecond), Int64) => { cast_reinterpret_arrays::(array) } - + (Duration(TimeUnit::Second), Interval(IntervalUnit::MonthDayNano)) => { + cast_duration_to_interval::(array, cast_options) + } + (Duration(TimeUnit::Millisecond), Interval(IntervalUnit::MonthDayNano)) => { + cast_duration_to_interval::(array, cast_options) + } + (Duration(TimeUnit::Microsecond), Interval(IntervalUnit::MonthDayNano)) => { + cast_duration_to_interval::(array, cast_options) + } + (Duration(TimeUnit::Nanosecond), Interval(IntervalUnit::MonthDayNano)) => { + cast_duration_to_interval::(array, cast_options) + } + (DataType::Interval(IntervalUnit::MonthDayNano), DataType::Duration(TimeUnit::Second)) => { + cast_interval_to_duration::(array, cast_options) + } + (DataType::Interval(IntervalUnit::MonthDayNano), DataType::Duration(TimeUnit::Millisecond)) => { + cast_interval_to_duration::(array, cast_options) + } + (DataType::Interval(IntervalUnit::MonthDayNano), DataType::Duration(TimeUnit::Microsecond)) => { + cast_interval_to_duration::(array, cast_options) + } + (DataType::Interval(IntervalUnit::MonthDayNano), DataType::Duration(TimeUnit::Nanosecond)) => { + cast_interval_to_duration::(array, cast_options) + } (Interval(IntervalUnit::YearMonth), Int64) => { cast_numeric_arrays::(array, cast_options) } @@ -8269,4 +8410,266 @@ mod tests { ); assert_eq!("Invalid argument error: 1234567000 is too large to store in a Decimal256 of precision 7. Max is 9999999", err.unwrap_err().to_string()); } + + /// helper function to test casting from duration to interval + fn cast_from_duration_to_interval( + array: Vec, + cast_options: &CastOptions, + ) -> Result, ArrowError> + where + arrow_array::PrimitiveArray: From>, + { + let array = PrimitiveArray::::from(array); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options( + &array, + &DataType::Interval(IntervalUnit::MonthDayNano), + cast_options, + )?; + casted_array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::ComputeError( + "Failed to downcast to IntervalMonthDayNanoArray".to_string(), + ) + }) + .cloned() + } + + #[test] + fn test_cast_from_duration_to_interval() { + // from duration second to interval month day nano + let array = vec![1234567]; + let casted_array = cast_from_duration_to_interval::( + array, + &DEFAULT_CAST_OPTIONS, + ) + .unwrap(); + assert_eq!( + casted_array.data_type(), + &DataType::Interval(IntervalUnit::MonthDayNano) + ); + assert_eq!(casted_array.value(0), 1234567000000000); + + let array = vec![i64::MAX]; + let casted_array = cast_from_duration_to_interval::( + array.clone(), + &DEFAULT_CAST_OPTIONS, + ) + .unwrap(); + assert!(!casted_array.is_valid(0)); + + let casted_array = cast_from_duration_to_interval::( + array, + &CastOptions { safe: false }, + ); + assert!(casted_array.is_err()); + + // from duration millisecond to interval month day nano + let array = vec![1234567]; + let casted_array = cast_from_duration_to_interval::( + array, + &DEFAULT_CAST_OPTIONS, + ) + .unwrap(); + assert_eq!( + casted_array.data_type(), + &DataType::Interval(IntervalUnit::MonthDayNano) + ); + assert_eq!(casted_array.value(0), 1234567000000); + + let array = vec![i64::MAX]; + let casted_array = cast_from_duration_to_interval::( + array.clone(), + &DEFAULT_CAST_OPTIONS, + ) + .unwrap(); + assert!(!casted_array.is_valid(0)); + + let casted_array = cast_from_duration_to_interval::( + array, + &CastOptions { safe: false }, + ); + assert!(casted_array.is_err()); + + // from duration microsecond to interval month day nano + let array = vec![1234567]; + let casted_array = cast_from_duration_to_interval::( + array, + &DEFAULT_CAST_OPTIONS, + ) + .unwrap(); + assert_eq!( + casted_array.data_type(), + &DataType::Interval(IntervalUnit::MonthDayNano) + ); + assert_eq!(casted_array.value(0), 1234567000); + + let array = vec![i64::MAX]; + let casted_array = cast_from_duration_to_interval::( + array.clone(), + &DEFAULT_CAST_OPTIONS, + ) + .unwrap(); + assert!(!casted_array.is_valid(0)); + + let casted_array = cast_from_duration_to_interval::( + array, + &CastOptions { safe: false }, + ); + assert!(casted_array.is_err()); + + // from duration nanosecond to interval month day nano + let array = vec![1234567]; + let casted_array = cast_from_duration_to_interval::( + array, + &DEFAULT_CAST_OPTIONS, + ) + .unwrap(); + assert_eq!( + casted_array.data_type(), + &DataType::Interval(IntervalUnit::MonthDayNano) + ); + assert_eq!(casted_array.value(0), 1234567); + + let array = vec![i64::MAX]; + let casted_array = cast_from_duration_to_interval::( + array, + &CastOptions { safe: false }, + ) + .unwrap(); + assert_eq!(casted_array.value(0), 9223372036854775807); + } + + // helper function to test casting from interval to duration + fn cast_from_interval_to_duration( + array: Vec, + cast_options: &CastOptions, + ) -> Result, ArrowError> { + let array = IntervalMonthDayNanoArray::from(array); + let array = Arc::new(array) as ArrayRef; + let casted_array = cast_with_options(&array, &T::DATA_TYPE, cast_options)?; + casted_array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + ArrowError::ComputeError(format!( + "Failed to downcast to {}", + T::DATA_TYPE + )) + }) + .cloned() + } + + #[test] + fn test_cast_from_interval_to_duration() { + // from interval month day nano to duration second + let array = vec![1234567]; + let casted_array = cast_from_interval_to_duration::( + array, + &DEFAULT_CAST_OPTIONS, + ) + .unwrap(); + assert_eq!( + casted_array.data_type(), + &DataType::Duration(TimeUnit::Second) + ); + assert_eq!(casted_array.value(0), 0); + + let array = vec![i128::MAX]; + let casted_array = cast_from_interval_to_duration::( + array.clone(), + &DEFAULT_CAST_OPTIONS, + ) + .unwrap(); + assert!(!casted_array.is_valid(0)); + + let casted_array = cast_from_interval_to_duration::( + array, + &CastOptions { safe: false }, + ); + assert!(casted_array.is_err()); + + // from interval month day nano to duration millisecond + let array = vec![1234567]; + let casted_array = cast_from_interval_to_duration::( + array, + &DEFAULT_CAST_OPTIONS, + ) + .unwrap(); + assert_eq!(casted_array.value(0), 1); + + let array = vec![i128::MAX]; + let casted_array = cast_from_interval_to_duration::( + array.clone(), + &DEFAULT_CAST_OPTIONS, + ) + .unwrap(); + assert!(!casted_array.is_valid(0)); + + let casted_array = cast_from_interval_to_duration::( + array, + &CastOptions { safe: false }, + ); + assert!(casted_array.is_err()); + + // from interval month day nano to duration microsecond + let array = vec![1234567]; + let casted_array = cast_from_interval_to_duration::( + array, + &DEFAULT_CAST_OPTIONS, + ) + .unwrap(); + assert_eq!( + casted_array.data_type(), + &DataType::Duration(TimeUnit::Microsecond) + ); + assert_eq!(casted_array.value(0), 1234); + + let array = vec![i128::MAX]; + let casted_array = cast_from_interval_to_duration::( + array.clone(), + &DEFAULT_CAST_OPTIONS, + ) + .unwrap(); + assert!(!casted_array.is_valid(0)); + + let casted_array = cast_from_interval_to_duration::( + array, + &CastOptions { safe: false }, + ); + assert!(casted_array.is_err()); + + // from interval month day nano to duration nanosecond + let array = vec![1234567]; + let casted_array = cast_from_interval_to_duration::( + array, + &DEFAULT_CAST_OPTIONS, + ) + .unwrap(); + assert_eq!( + casted_array.data_type(), + &DataType::Duration(TimeUnit::Nanosecond) + ); + assert_eq!(casted_array.value(0), 1234567); + + let array = vec![i128::MAX]; + let casted_array = cast_from_interval_to_duration::( + array.clone(), + &DEFAULT_CAST_OPTIONS, + ) + .unwrap(); + assert_eq!( + casted_array.data_type(), + &DataType::Duration(TimeUnit::Nanosecond) + ); + assert!(!casted_array.is_valid(0)); + + let casted_array = cast_from_interval_to_duration::( + array, + &CastOptions { safe: false }, + ); + assert!(casted_array.is_err()); + } }