Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: cast from/to interval and duration #4020

Merged
merged 9 commits into from
Apr 16, 2023
Merged
247 changes: 246 additions & 1 deletion arrow-cast/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,8 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
IntervalUnit::MonthDayNano => false,
}
}
(Duration(_), Interval(IntervalUnit::MonthDayNano)) => true,
(Interval(IntervalUnit::MonthDayNano), Duration(_)) => true,
(_, _) => false,
}
}
Expand Down Expand Up @@ -458,6 +460,93 @@ where
}
}

/// Cast the array from interval to duration
fn cast_interval_to_duration<D: ArrowTemporalType<Native = i64>>(
array: &dyn Array,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError> {
let array = array
.as_any()
.downcast_ref::<IntervalMonthDayNanoArray>()
.unwrap();

let mut builder = PrimitiveBuilder::<D>::new();

let scale = match D::DATA_TYPE {
DataType::Duration(TimeUnit::Second) => 1_000_000_000,
DataType::Duration(TimeUnit::Millisecond) => 1_000_000,
DataType::Duration(TimeUnit::Microsecond) => 1_000,
DataType::Duration(TimeUnit::Nanosecond) => 1,
_ => unreachable!(),
};

for i in 0..array.len() {
if array.is_null(i) {
builder.append_null();
} else {
let v = array.value(i) / scale;
if v > i64::MAX as i128 {
if cast_options.safe {
builder.append_null();
} else {
return Err(ArrowError::ComputeError(format!(
"Cannot cast to {:?}. Overflowing on {:?}",
D::DATA_TYPE,
v
)));
}
} else {
builder.append_value(v as i64);
}
}
}

Ok(Arc::new(builder.finish()) as ArrayRef)
}

/// Cast the array from duration and interval
fn cast_duration_to_interval<D: ArrowTemporalType<Native = i64>>(
array: &dyn Array,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError> {
let array = array.as_any().downcast_ref::<PrimitiveArray<D>>().unwrap();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we propagate the error?


let mut builder = IntervalMonthDayNanoBuilder::new();

let scale = match array.data_type() {
DataType::Duration(TimeUnit::Second) => 1_000_000_000,
DataType::Duration(TimeUnit::Millisecond) => 1_000_000,
DataType::Duration(TimeUnit::Microsecond) => 1_000,
DataType::Duration(TimeUnit::Nanosecond) => 1,
_ => unreachable!(),
};

for i in 0..array.len() {
if array.is_null(i) {
builder.append_null();
} else {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For array with no nulls, I think we can have a faster path?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PrimitiveArray::from_trusted_len_iter may fit here.

let v = array.value(i) as i128;
let v = v.mul_checked(scale);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For scale 1 case, maybe we can skip this call?

match v {
Ok(v) => builder.append_value(v),
Err(e) => {
if cast_options.safe {
builder.append_null()
} else {
return Err(ArrowError::ComputeError(format!(
"Cannot cast to {:?}. Overflowing on {:?}",
D::DATA_TYPE,
e
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e is ArrowError, so you will get nested ArrowError.

)));
}
}
};
}
}

Ok(Arc::new(builder.finish()))
}

/// Cast the primitive array using [`PrimitiveArray::reinterpret_cast`]
fn cast_reinterpret_arrays<
I: ArrowPrimitiveType,
Expand Down Expand Up @@ -2014,7 +2103,30 @@ pub fn cast_with_options(
(Duration(TimeUnit::Nanosecond), Int64) => {
cast_reinterpret_arrays::<DurationNanosecondType, Int64Type>(array)
}

(Duration(TimeUnit::Second), Interval(IntervalUnit::MonthDayNano)) => {
cast_duration_to_interval::<DurationSecondType>(array, cast_options)
}
(Duration(TimeUnit::Millisecond), Interval(IntervalUnit::MonthDayNano)) => {
cast_duration_to_interval::<DurationMillisecondType>(array, cast_options)
}
(Duration(TimeUnit::Microsecond), Interval(IntervalUnit::MonthDayNano)) => {
cast_duration_to_interval::<DurationMicrosecondType>(array, cast_options)
}
(Duration(TimeUnit::Nanosecond), Interval(IntervalUnit::MonthDayNano)) => {
cast_duration_to_interval::<DurationNanosecondType>(array, cast_options)
}
(DataType::Interval(IntervalUnit::MonthDayNano), DataType::Duration(TimeUnit::Second)) => {
cast_interval_to_duration::<DurationSecondType>(array, cast_options)
}
(DataType::Interval(IntervalUnit::MonthDayNano), DataType::Duration(TimeUnit::Millisecond)) => {
cast_interval_to_duration::<DurationMillisecondType>(array, cast_options)
}
(DataType::Interval(IntervalUnit::MonthDayNano), DataType::Duration(TimeUnit::Microsecond)) => {
cast_interval_to_duration::<DurationMicrosecondType>(array, cast_options)
}
(DataType::Interval(IntervalUnit::MonthDayNano), DataType::Duration(TimeUnit::Nanosecond)) => {
cast_interval_to_duration::<DurationNanosecondType>(array, cast_options)
}
(Interval(IntervalUnit::YearMonth), Int64) => {
cast_numeric_arrays::<IntervalYearMonthType, Int64Type>(array, cast_options)
}
Expand Down Expand Up @@ -8246,4 +8358,137 @@ mod tests {
);
assert_eq!("Invalid argument error: 1234567000 is too large to store in a Decimal256 of precision 7. Max is 9999999", err.unwrap_err().to_string());
}

#[test]
fn test_cast_from_duration_to_interval() {
// from duration second to interval month day nano
let array = vec![1234567];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also please test a value if i64::MAX here and ensure an error is returned?

let array = DurationSecondArray::from(array);
let array = Arc::new(array) as ArrayRef;
let casted_array =
cast(&array, &DataType::Interval(IntervalUnit::MonthDayNano)).unwrap();
let casted_array = casted_array
.as_any()
.downcast_ref::<IntervalMonthDayNanoArray>()
.unwrap();
assert_eq!(
casted_array.data_type(),
&DataType::Interval(IntervalUnit::MonthDayNano)
);
assert_eq!(casted_array.value(0), 1234567000000000);

// from duration millisecond to interval month day nano
let array = vec![1234567];
let array = DurationMillisecondArray::from(array);
let array = Arc::new(array) as ArrayRef;
let casted_array =
cast(&array, &DataType::Interval(IntervalUnit::MonthDayNano)).unwrap();
let casted_array = casted_array
.as_any()
.downcast_ref::<IntervalMonthDayNanoArray>()
.unwrap();
assert_eq!(
casted_array.data_type(),
&DataType::Interval(IntervalUnit::MonthDayNano)
);
assert_eq!(casted_array.value(0), 1234567000000);

// from duration microsecond to interval month day nano
let array = vec![1234567];
let array = DurationMicrosecondArray::from(array);
let array = Arc::new(array) as ArrayRef;
let casted_array =
cast(&array, &DataType::Interval(IntervalUnit::MonthDayNano)).unwrap();
let casted_array = casted_array
.as_any()
.downcast_ref::<IntervalMonthDayNanoArray>()
.unwrap();
assert_eq!(
casted_array.data_type(),
&DataType::Interval(IntervalUnit::MonthDayNano)
);
assert_eq!(casted_array.value(0), 1234567000);

// from duration nanosecond to interval month day nano
let array = vec![1234567];
let array = DurationNanosecondArray::from(array);
let array = Arc::new(array) as ArrayRef;
let casted_array =
cast(&array, &DataType::Interval(IntervalUnit::MonthDayNano)).unwrap();
let casted_array = casted_array
.as_any()
.downcast_ref::<IntervalMonthDayNanoArray>()
.unwrap();
assert_eq!(
casted_array.data_type(),
&DataType::Interval(IntervalUnit::MonthDayNano)
);
assert_eq!(casted_array.value(0), 1234567);
}

#[test]
fn test_cast_from_interval_to_duration() {
// from interval month day nano to duration second
let array = vec![1234567];
let array = IntervalMonthDayNanoArray::from(array);
let array = Arc::new(array) as ArrayRef;
let casted_array = cast(&array, &DataType::Duration(TimeUnit::Second)).unwrap();
let casted_array = casted_array
.as_any()
.downcast_ref::<DurationSecondArray>()
.unwrap();

assert_eq!(
casted_array.data_type(),
&DataType::Duration(TimeUnit::Second)
);
assert_eq!(casted_array.value(0), 0);

// from interval month day nano to duration millisecond
let array = vec![1234567];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could probably refactor this repetition into a function and avoid so much boiler plate in the tests

let array = IntervalMonthDayNanoArray::from(array);
let array = Arc::new(array) as ArrayRef;
let casted_array =
cast(&array, &DataType::Duration(TimeUnit::Millisecond)).unwrap();
let casted_array = casted_array
.as_any()
.downcast_ref::<DurationMillisecondArray>()
.unwrap();
assert_eq!(
casted_array.data_type(),
&DataType::Duration(TimeUnit::Millisecond)
);
assert_eq!(casted_array.value(0), 1);

// from interval month day nano to duration microsecond
let array = vec![1234567];
let array = IntervalMonthDayNanoArray::from(array);
let array = Arc::new(array) as ArrayRef;
let casted_array =
cast(&array, &DataType::Duration(TimeUnit::Microsecond)).unwrap();
let casted_array = casted_array
.as_any()
.downcast_ref::<DurationMicrosecondArray>()
.unwrap();
assert_eq!(
casted_array.data_type(),
&DataType::Duration(TimeUnit::Microsecond)
);
assert_eq!(casted_array.value(0), 1234);
// from interval month day nano to duration nanosecond
let array = vec![1234567];
let array = IntervalMonthDayNanoArray::from(array);
let array = Arc::new(array) as ArrayRef;
let casted_array =
cast(&array, &DataType::Duration(TimeUnit::Nanosecond)).unwrap();
let casted_array = casted_array
.as_any()
.downcast_ref::<DurationNanosecondArray>()
.unwrap();
assert_eq!(
casted_array.data_type(),
&DataType::Duration(TimeUnit::Nanosecond)
);
assert_eq!(casted_array.value(0), 1234567);
}
}