Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Replaced some macros by generics #1084

Merged
merged 1 commit into from
Jun 19, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 56 additions & 59 deletions src/compute/arithmetics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@ use crate::{
bitmap::Bitmap,
datatypes::{DataType, IntervalUnit, TimeUnit},
scalar::{PrimitiveScalar, Scalar},
types::NativeType,
};

// Macro to evaluate match branch in arithmetic function.
macro_rules! primitive {
($lhs:expr, $rhs:expr, $op:tt, $type:ty) => {{
let lhs = $lhs.as_any().downcast_ref().unwrap();
let rhs = $rhs.as_any().downcast_ref().unwrap();

let result = basic::$op::<$type>(lhs, rhs);
Box::new(result) as Box<dyn Array>
}};
fn binary_dyn<T: NativeType, F: Fn(&PrimitiveArray<T>, &PrimitiveArray<T>) -> PrimitiveArray<T>>(
lhs: &dyn Array,
rhs: &dyn Array,
op: F,
) -> Box<dyn Array> {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
op(lhs, rhs).boxed()
}

// Macro to create a `match` statement with dynamic dispatch to functions based on
Expand All @@ -42,18 +42,18 @@ macro_rules! arith {
let rhs = $rhs;
use DataType::*;
match (lhs.data_type(), rhs.data_type()) {
(Int8, Int8) => primitive!(lhs, rhs, $op, i8),
(Int16, Int16) => primitive!(lhs, rhs, $op, i16),
(Int32, Int32) => primitive!(lhs, rhs, $op, i32),
(Int8, Int8) => binary_dyn::<i8, _>(lhs, rhs, basic::$op),
(Int16, Int16) => binary_dyn::<i16, _>(lhs, rhs, basic::$op),
(Int32, Int32) => binary_dyn::<i32, _>(lhs, rhs, basic::$op),
(Int64, Int64) | (Duration(_), Duration(_)) => {
primitive!(lhs, rhs, $op, i64)
binary_dyn::<i64, _>(lhs, rhs, basic::$op)
}
(UInt8, UInt8) => primitive!(lhs, rhs, $op, u8),
(UInt16, UInt16) => primitive!(lhs, rhs, $op, u16),
(UInt32, UInt32) => primitive!(lhs, rhs, $op, u32),
(UInt64, UInt64) => primitive!(lhs, rhs, $op, u64),
(Float32, Float32) => primitive!(lhs, rhs, $op, f32),
(Float64, Float64) => primitive!(lhs, rhs, $op, f64),
(UInt8, UInt8) => binary_dyn::<u8, _>(lhs, rhs, basic::$op),
(UInt16, UInt16) => binary_dyn::<u16, _>(lhs, rhs, basic::$op),
(UInt32, UInt32) => binary_dyn::<u32, _>(lhs, rhs, basic::$op),
(UInt64, UInt64) => binary_dyn::<u64, _>(lhs, rhs, basic::$op),
(Float32, Float32) => binary_dyn::<f32, _>(lhs, rhs, basic::$op),
(Float64, Float64) => binary_dyn::<f64, _>(lhs, rhs, basic::$op),
$ (
(Decimal(_, _), Decimal(_, _)) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
Expand Down Expand Up @@ -101,30 +101,27 @@ macro_rules! arith {
}};
}

// Macro to evaluate match branch in arithmetic function.
macro_rules! primitive_scalar {
($lhs:expr, $rhs:expr, $op:tt, $type:ty) => {{
let lhs = $lhs
.as_any()
.downcast_ref::<PrimitiveArray<$type>>()
.unwrap();
let rhs = $rhs
.as_any()
.downcast_ref::<PrimitiveScalar<$type>>()
.unwrap();

let rhs = if let Some(rhs) = rhs.value() {
rhs
} else {
return Box::new(PrimitiveArray::<$type>::new_null(
lhs.data_type().clone(),
lhs.len(),
)) as Box<dyn Array>;
};

let result = basic::$op::<$type>(lhs, &rhs);
Box::new(result) as Box<dyn Array>
}};
fn binary_scalar<T: NativeType, F: Fn(&PrimitiveArray<T>, &T) -> PrimitiveArray<T>>(
lhs: &PrimitiveArray<T>,
rhs: &PrimitiveScalar<T>,
op: F,
) -> PrimitiveArray<T> {
let rhs = if let Some(rhs) = rhs.value() {
rhs
} else {
return PrimitiveArray::<T>::new_null(lhs.data_type().clone(), lhs.len());
};
op(lhs, &rhs)
}

fn binary_scalar_dyn<T: NativeType, F: Fn(&PrimitiveArray<T>, &T) -> PrimitiveArray<T>>(
lhs: &dyn Array,
rhs: &dyn Scalar,
op: F,
) -> Box<dyn Array> {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
binary_scalar(lhs, rhs, op).boxed()
}

// Macro to create a `match` statement with dynamic dispatch to functions based on
Expand All @@ -135,23 +132,23 @@ macro_rules! arith_scalar {
let rhs = $rhs;
use DataType::*;
match (lhs.data_type(), rhs.data_type()) {
(Int8, Int8) => primitive_scalar!(lhs, rhs, $op, i8),
(Int16, Int16) => primitive_scalar!(lhs, rhs, $op, i16),
(Int32, Int32) => primitive_scalar!(lhs, rhs, $op, i32),
(Int8, Int8) => binary_scalar_dyn::<i8, _>(lhs, rhs, basic::$op),
(Int16, Int16) => binary_scalar_dyn::<i16, _>(lhs, rhs, basic::$op),
(Int32, Int32) => binary_scalar_dyn::<i32, _>(lhs, rhs, basic::$op),
(Int64, Int64) | (Duration(_), Duration(_)) => {
primitive_scalar!(lhs, rhs, $op, i64)
binary_scalar_dyn::<i64, _>(lhs, rhs, basic::$op)
}
(UInt8, UInt8) => primitive_scalar!(lhs, rhs, $op, u8),
(UInt16, UInt16) => primitive_scalar!(lhs, rhs, $op, u16),
(UInt32, UInt32) => primitive_scalar!(lhs, rhs, $op, u32),
(UInt64, UInt64) => primitive_scalar!(lhs, rhs, $op, u64),
(Float32, Float32) => primitive_scalar!(lhs, rhs, $op, f32),
(Float64, Float64) => primitive_scalar!(lhs, rhs, $op, f64),
(UInt8, UInt8) => binary_scalar_dyn::<u8, _>(lhs, rhs, basic::$op),
(UInt16, UInt16) => binary_scalar_dyn::<u16, _>(lhs, rhs, basic::$op),
(UInt32, UInt32) => binary_scalar_dyn::<u32, _>(lhs, rhs, basic::$op),
(UInt64, UInt64) => binary_scalar_dyn::<u64, _>(lhs, rhs, basic::$op),
(Float32, Float32) => binary_scalar_dyn::<f32, _>(lhs, rhs, basic::$op),
(Float64, Float64) => binary_scalar_dyn::<f64, _>(lhs, rhs, basic::$op),
$ (
(Decimal(_, _), Decimal(_, _)) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
Box::new(decimal::$op_decimal(lhs, rhs)) as Box<dyn Array>
decimal::$op_decimal(lhs, rhs).boxed()
}
)?
$ (
Expand All @@ -160,29 +157,29 @@ macro_rules! arith_scalar {
| (Date32, Duration(_)) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
Box::new(time::$op_duration::<i32>(lhs, rhs)) as Box<dyn Array>
time::$op_duration::<i32>(lhs, rhs).boxed()
}
(Time64(TimeUnit::Microsecond), Duration(_))
| (Time64(TimeUnit::Nanosecond), Duration(_))
| (Date64, Duration(_))
| (Timestamp(_, _), Duration(_)) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
Box::new(time::$op_duration::<i64>(lhs, rhs)) as Box<dyn Array>
time::$op_duration::<i64>(lhs, rhs).boxed()
}
)?
$ (
(Timestamp(_, _), Interval(IntervalUnit::MonthDayNano)) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
time::$op_interval(lhs, rhs).map(|x| Box::new(x) as Box<dyn Array>).unwrap()
time::$op_interval(lhs, rhs).unwrap().boxed()
}
)?
$ (
(Timestamp(_, None), Timestamp(_, None)) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
time::$op_timestamp(lhs, rhs).map(|x| Box::new(x) as Box<dyn Array>).unwrap()
time::$op_timestamp(lhs, rhs).unwrap().boxed()
}
)?
_ => todo!(
Expand All @@ -197,7 +194,7 @@ macro_rules! arith_scalar {
/// Adds two [`Array`]s.
/// # Panic
/// This function panics iff
/// * the opertion is not supported for the logical types (use [`can_add`] to check)
/// * the operation is not supported for the logical types (use [`can_add`] to check)
/// * the arrays have a different length
/// * one of the arrays is a timestamp with timezone and the timezone is not valid.
pub fn add(lhs: &dyn Array, rhs: &dyn Array) -> Box<dyn Array> {
Expand Down