Skip to content

Commit

Permalink
feat: cast from/to interval and duration
Browse files Browse the repository at this point in the history
  • Loading branch information
Weijun-H committed Apr 5, 2023
1 parent 7bac07a commit 6d58d1c
Showing 1 changed file with 244 additions and 1 deletion.
245 changes: 244 additions & 1 deletion arrow-cast/src/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,93 @@ where
}
}

/// Cast the array from interval to duration
fn cast_interval_to_duration<D: ArrowTemporalType<Native = i64>>(
array: &dyn Array,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError> {
let array = array
.as_any()
.downcast_ref::<IntervalMonthDayNanoArray>()
.unwrap();

let mut builder = PrimitiveBuilder::<D>::new();

let scale = match D::DATA_TYPE {
DataType::Duration(TimeUnit::Second) => 1_000_000_000,
DataType::Duration(TimeUnit::Millisecond) => 1_000_000,
DataType::Duration(TimeUnit::Microsecond) => 1_000,
DataType::Duration(TimeUnit::Nanosecond) => 1,
_ => unreachable!(),
};

for i in 0..array.len() {
if array.is_null(i) {
builder.append_null();
} else {
let v = array.value(i) / scale;
if v > i64::MAX as i128 {
if cast_options.safe {
builder.append_null();
} else {
return Err(ArrowError::ComputeError(format!(
"Cannot cast to {:?}. Overflowing on {:?}",
D::DATA_TYPE,
v
)));
}
} else {
builder.append_value(v as i64);
}
}
}

Ok(Arc::new(builder.finish()) as ArrayRef)
}

/// Cast the array from duration and interval
fn cast_duration_to_interval<D: ArrowTemporalType<Native = i64>>(
array: &dyn Array,
cast_options: &CastOptions,
) -> Result<ArrayRef, ArrowError> {
let array = array.as_any().downcast_ref::<PrimitiveArray<D>>().unwrap();

let mut builder = IntervalMonthDayNanoBuilder::new();

let scale = match array.data_type() {
DataType::Duration(TimeUnit::Second) => 1_000_000_000,
DataType::Duration(TimeUnit::Millisecond) => 1_000_000,
DataType::Duration(TimeUnit::Microsecond) => 1_000,
DataType::Duration(TimeUnit::Nanosecond) => 1,
_ => unreachable!(),
};

for i in 0..array.len() {
if array.is_null(i) {
builder.append_null();
} else {
let v = array.value(i) as i128;
let v = v.mul_checked(scale);
match v {
Ok(v) => builder.append_value(v),
Err(e) => {
if cast_options.safe {
builder.append_null()
} else {
return Err(ArrowError::ComputeError(format!(
"Cannot cast to {:?}. Overflowing on {:?}",
D::DATA_TYPE,
e
)));
}
}
};
}
}

Ok(Arc::new(builder.finish()))
}

/// Cast the primitive array using [`PrimitiveArray::reinterpret_cast`]
fn cast_reinterpret_arrays<
I: ArrowPrimitiveType,
Expand Down Expand Up @@ -2014,7 +2101,30 @@ pub fn cast_with_options(
(Duration(TimeUnit::Nanosecond), Int64) => {
cast_reinterpret_arrays::<DurationNanosecondType, Int64Type>(array)
}

(Duration(TimeUnit::Second), Interval(IntervalUnit::MonthDayNano)) => {
cast_duration_to_interval::<DurationSecondType>(array, cast_options)
}
(Duration(TimeUnit::Millisecond), Interval(IntervalUnit::MonthDayNano)) => {
cast_duration_to_interval::<DurationMillisecondType>(array, cast_options)
}
(Duration(TimeUnit::Microsecond), Interval(IntervalUnit::MonthDayNano)) => {
cast_duration_to_interval::<DurationMicrosecondType>(array, cast_options)
}
(Duration(TimeUnit::Nanosecond), Interval(IntervalUnit::MonthDayNano)) => {
cast_duration_to_interval::<DurationNanosecondType>(array, cast_options)
}
(DataType::Interval(IntervalUnit::MonthDayNano), DataType::Duration(TimeUnit::Second)) => {
cast_interval_to_duration::<DurationSecondType>(array, cast_options)
}
(DataType::Interval(IntervalUnit::MonthDayNano), DataType::Duration(TimeUnit::Millisecond)) => {
cast_interval_to_duration::<DurationMillisecondType>(array, cast_options)
}
(DataType::Interval(IntervalUnit::MonthDayNano), DataType::Duration(TimeUnit::Microsecond)) => {
cast_interval_to_duration::<DurationMicrosecondType>(array, cast_options)
}
(DataType::Interval(IntervalUnit::MonthDayNano), DataType::Duration(TimeUnit::Nanosecond)) => {
cast_interval_to_duration::<DurationNanosecondType>(array, cast_options)
}
(Interval(IntervalUnit::YearMonth), Int64) => {
cast_numeric_arrays::<IntervalYearMonthType, Int64Type>(array, cast_options)
}
Expand Down Expand Up @@ -8246,4 +8356,137 @@ mod tests {
);
assert_eq!("Invalid argument error: 1234567000 is too large to store in a Decimal256 of precision 7. Max is 9999999", err.unwrap_err().to_string());
}

#[test]
fn test_cast_from_duration_to_interval() {
// from duration second to interval month day nano
let array = vec![1234567];
let array = DurationSecondArray::from(array);
let array = Arc::new(array) as ArrayRef;
let casted_array =
cast(&array, &DataType::Interval(IntervalUnit::MonthDayNano)).unwrap();
let casted_array = casted_array
.as_any()
.downcast_ref::<IntervalMonthDayNanoArray>()
.unwrap();
assert_eq!(
casted_array.data_type(),
&DataType::Interval(IntervalUnit::MonthDayNano)
);
assert_eq!(casted_array.value(0), 1234567000000000);

// from duration millisecond to interval month day nano
let array = vec![1234567];
let array = DurationMillisecondArray::from(array);
let array = Arc::new(array) as ArrayRef;
let casted_array =
cast(&array, &DataType::Interval(IntervalUnit::MonthDayNano)).unwrap();
let casted_array = casted_array
.as_any()
.downcast_ref::<IntervalMonthDayNanoArray>()
.unwrap();
assert_eq!(
casted_array.data_type(),
&DataType::Interval(IntervalUnit::MonthDayNano)
);
assert_eq!(casted_array.value(0), 1234567000000);

// from duration microsecond to interval month day nano
let array = vec![1234567];
let array = DurationMicrosecondArray::from(array);
let array = Arc::new(array) as ArrayRef;
let casted_array =
cast(&array, &DataType::Interval(IntervalUnit::MonthDayNano)).unwrap();
let casted_array = casted_array
.as_any()
.downcast_ref::<IntervalMonthDayNanoArray>()
.unwrap();
assert_eq!(
casted_array.data_type(),
&DataType::Interval(IntervalUnit::MonthDayNano)
);
assert_eq!(casted_array.value(0), 1234567000);

// from duration nanosecond to interval month day nano
let array = vec![1234567];
let array = DurationNanosecondArray::from(array);
let array = Arc::new(array) as ArrayRef;
let casted_array =
cast(&array, &DataType::Interval(IntervalUnit::MonthDayNano)).unwrap();
let casted_array = casted_array
.as_any()
.downcast_ref::<IntervalMonthDayNanoArray>()
.unwrap();
assert_eq!(
casted_array.data_type(),
&DataType::Interval(IntervalUnit::MonthDayNano)
);
assert_eq!(casted_array.value(0), 1234567);
}

#[test]
fn test_cast_from_interval_to_duration() {
// from interval month day nano to duration second
let array = vec![1234567];
let array = IntervalMonthDayNanoArray::from(array);
let array = Arc::new(array) as ArrayRef;
let casted_array = cast(&array, &DataType::Duration(TimeUnit::Second)).unwrap();
let casted_array = casted_array
.as_any()
.downcast_ref::<DurationSecondArray>()
.unwrap();

assert_eq!(
casted_array.data_type(),
&DataType::Duration(TimeUnit::Second)
);
assert_eq!(casted_array.value(0), 0);

// from interval month day nano to duration millisecond
let array = vec![1234567];
let array = IntervalMonthDayNanoArray::from(array);
let array = Arc::new(array) as ArrayRef;
let casted_array =
cast(&array, &DataType::Duration(TimeUnit::Millisecond)).unwrap();
let casted_array = casted_array
.as_any()
.downcast_ref::<DurationMillisecondArray>()
.unwrap();
assert_eq!(
casted_array.data_type(),
&DataType::Duration(TimeUnit::Millisecond)
);
assert_eq!(casted_array.value(0), 1);

// from interval month day nano to duration microsecond
let array = vec![1234567];
let array = IntervalMonthDayNanoArray::from(array);
let array = Arc::new(array) as ArrayRef;
let casted_array =
cast(&array, &DataType::Duration(TimeUnit::Microsecond)).unwrap();
let casted_array = casted_array
.as_any()
.downcast_ref::<DurationMicrosecondArray>()
.unwrap();
assert_eq!(
casted_array.data_type(),
&DataType::Duration(TimeUnit::Microsecond)
);
assert_eq!(casted_array.value(0), 1234);
// from interval month day nano to duration nanosecond
let array = vec![1234567];
let array = IntervalMonthDayNanoArray::from(array);
let array = Arc::new(array) as ArrayRef;
let casted_array =
cast(&array, &DataType::Duration(TimeUnit::Nanosecond)).unwrap();
let casted_array = casted_array
.as_any()
.downcast_ref::<DurationNanosecondArray>()
.unwrap();
assert_eq!(
casted_array.data_type(),
&DataType::Duration(TimeUnit::Nanosecond)
);
assert_eq!(casted_array.value(0), 1234567);
}
}

0 comments on commit 6d58d1c

Please sign in to comment.