Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Replaced macros by generics (#1084)
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao authored Jun 19, 2022
1 parent 16d01ef commit fb93436
Showing 1 changed file with 56 additions and 59 deletions.
115 changes: 56 additions & 59 deletions src/compute/arithmetics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@ use crate::{
bitmap::Bitmap,
datatypes::{DataType, IntervalUnit, TimeUnit},
scalar::{PrimitiveScalar, Scalar},
types::NativeType,
};

// Macro to evaluate match branch in arithmetic function.
macro_rules! primitive {
($lhs:expr, $rhs:expr, $op:tt, $type:ty) => {{
let lhs = $lhs.as_any().downcast_ref().unwrap();
let rhs = $rhs.as_any().downcast_ref().unwrap();

let result = basic::$op::<$type>(lhs, rhs);
Box::new(result) as Box<dyn Array>
}};
fn binary_dyn<T: NativeType, F: Fn(&PrimitiveArray<T>, &PrimitiveArray<T>) -> PrimitiveArray<T>>(
lhs: &dyn Array,
rhs: &dyn Array,
op: F,
) -> Box<dyn Array> {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
op(lhs, rhs).boxed()
}

// Macro to create a `match` statement with dynamic dispatch to functions based on
Expand All @@ -42,18 +42,18 @@ macro_rules! arith {
let rhs = $rhs;
use DataType::*;
match (lhs.data_type(), rhs.data_type()) {
(Int8, Int8) => primitive!(lhs, rhs, $op, i8),
(Int16, Int16) => primitive!(lhs, rhs, $op, i16),
(Int32, Int32) => primitive!(lhs, rhs, $op, i32),
(Int8, Int8) => binary_dyn::<i8, _>(lhs, rhs, basic::$op),
(Int16, Int16) => binary_dyn::<i16, _>(lhs, rhs, basic::$op),
(Int32, Int32) => binary_dyn::<i32, _>(lhs, rhs, basic::$op),
(Int64, Int64) | (Duration(_), Duration(_)) => {
primitive!(lhs, rhs, $op, i64)
binary_dyn::<i64, _>(lhs, rhs, basic::$op)
}
(UInt8, UInt8) => primitive!(lhs, rhs, $op, u8),
(UInt16, UInt16) => primitive!(lhs, rhs, $op, u16),
(UInt32, UInt32) => primitive!(lhs, rhs, $op, u32),
(UInt64, UInt64) => primitive!(lhs, rhs, $op, u64),
(Float32, Float32) => primitive!(lhs, rhs, $op, f32),
(Float64, Float64) => primitive!(lhs, rhs, $op, f64),
(UInt8, UInt8) => binary_dyn::<u8, _>(lhs, rhs, basic::$op),
(UInt16, UInt16) => binary_dyn::<u16, _>(lhs, rhs, basic::$op),
(UInt32, UInt32) => binary_dyn::<u32, _>(lhs, rhs, basic::$op),
(UInt64, UInt64) => binary_dyn::<u64, _>(lhs, rhs, basic::$op),
(Float32, Float32) => binary_dyn::<f32, _>(lhs, rhs, basic::$op),
(Float64, Float64) => binary_dyn::<f64, _>(lhs, rhs, basic::$op),
$ (
(Decimal(_, _), Decimal(_, _)) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
Expand Down Expand Up @@ -101,30 +101,27 @@ macro_rules! arith {
}};
}

// Macro to evaluate match branch in arithmetic function.
macro_rules! primitive_scalar {
($lhs:expr, $rhs:expr, $op:tt, $type:ty) => {{
let lhs = $lhs
.as_any()
.downcast_ref::<PrimitiveArray<$type>>()
.unwrap();
let rhs = $rhs
.as_any()
.downcast_ref::<PrimitiveScalar<$type>>()
.unwrap();

let rhs = if let Some(rhs) = rhs.value() {
rhs
} else {
return Box::new(PrimitiveArray::<$type>::new_null(
lhs.data_type().clone(),
lhs.len(),
)) as Box<dyn Array>;
};

let result = basic::$op::<$type>(lhs, &rhs);
Box::new(result) as Box<dyn Array>
}};
fn binary_scalar<T: NativeType, F: Fn(&PrimitiveArray<T>, &T) -> PrimitiveArray<T>>(
lhs: &PrimitiveArray<T>,
rhs: &PrimitiveScalar<T>,
op: F,
) -> PrimitiveArray<T> {
let rhs = if let Some(rhs) = rhs.value() {
rhs
} else {
return PrimitiveArray::<T>::new_null(lhs.data_type().clone(), lhs.len());
};
op(lhs, &rhs)
}

fn binary_scalar_dyn<T: NativeType, F: Fn(&PrimitiveArray<T>, &T) -> PrimitiveArray<T>>(
lhs: &dyn Array,
rhs: &dyn Scalar,
op: F,
) -> Box<dyn Array> {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
binary_scalar(lhs, rhs, op).boxed()
}

// Macro to create a `match` statement with dynamic dispatch to functions based on
Expand All @@ -135,23 +132,23 @@ macro_rules! arith_scalar {
let rhs = $rhs;
use DataType::*;
match (lhs.data_type(), rhs.data_type()) {
(Int8, Int8) => primitive_scalar!(lhs, rhs, $op, i8),
(Int16, Int16) => primitive_scalar!(lhs, rhs, $op, i16),
(Int32, Int32) => primitive_scalar!(lhs, rhs, $op, i32),
(Int8, Int8) => binary_scalar_dyn::<i8, _>(lhs, rhs, basic::$op),
(Int16, Int16) => binary_scalar_dyn::<i16, _>(lhs, rhs, basic::$op),
(Int32, Int32) => binary_scalar_dyn::<i32, _>(lhs, rhs, basic::$op),
(Int64, Int64) | (Duration(_), Duration(_)) => {
primitive_scalar!(lhs, rhs, $op, i64)
binary_scalar_dyn::<i64, _>(lhs, rhs, basic::$op)
}
(UInt8, UInt8) => primitive_scalar!(lhs, rhs, $op, u8),
(UInt16, UInt16) => primitive_scalar!(lhs, rhs, $op, u16),
(UInt32, UInt32) => primitive_scalar!(lhs, rhs, $op, u32),
(UInt64, UInt64) => primitive_scalar!(lhs, rhs, $op, u64),
(Float32, Float32) => primitive_scalar!(lhs, rhs, $op, f32),
(Float64, Float64) => primitive_scalar!(lhs, rhs, $op, f64),
(UInt8, UInt8) => binary_scalar_dyn::<u8, _>(lhs, rhs, basic::$op),
(UInt16, UInt16) => binary_scalar_dyn::<u16, _>(lhs, rhs, basic::$op),
(UInt32, UInt32) => binary_scalar_dyn::<u32, _>(lhs, rhs, basic::$op),
(UInt64, UInt64) => binary_scalar_dyn::<u64, _>(lhs, rhs, basic::$op),
(Float32, Float32) => binary_scalar_dyn::<f32, _>(lhs, rhs, basic::$op),
(Float64, Float64) => binary_scalar_dyn::<f64, _>(lhs, rhs, basic::$op),
$ (
(Decimal(_, _), Decimal(_, _)) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
Box::new(decimal::$op_decimal(lhs, rhs)) as Box<dyn Array>
decimal::$op_decimal(lhs, rhs).boxed()
}
)?
$ (
Expand All @@ -160,29 +157,29 @@ macro_rules! arith_scalar {
| (Date32, Duration(_)) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
Box::new(time::$op_duration::<i32>(lhs, rhs)) as Box<dyn Array>
time::$op_duration::<i32>(lhs, rhs).boxed()
}
(Time64(TimeUnit::Microsecond), Duration(_))
| (Time64(TimeUnit::Nanosecond), Duration(_))
| (Date64, Duration(_))
| (Timestamp(_, _), Duration(_)) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
Box::new(time::$op_duration::<i64>(lhs, rhs)) as Box<dyn Array>
time::$op_duration::<i64>(lhs, rhs).boxed()
}
)?
$ (
(Timestamp(_, _), Interval(IntervalUnit::MonthDayNano)) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
time::$op_interval(lhs, rhs).map(|x| Box::new(x) as Box<dyn Array>).unwrap()
time::$op_interval(lhs, rhs).unwrap().boxed()
}
)?
$ (
(Timestamp(_, None), Timestamp(_, None)) => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
time::$op_timestamp(lhs, rhs).map(|x| Box::new(x) as Box<dyn Array>).unwrap()
time::$op_timestamp(lhs, rhs).unwrap().boxed()
}
)?
_ => todo!(
Expand All @@ -197,7 +194,7 @@ macro_rules! arith_scalar {
/// Adds two [`Array`]s.
/// # Panic
/// This function panics iff
/// * the opertion is not supported for the logical types (use [`can_add`] to check)
/// * the operation is not supported for the logical types (use [`can_add`] to check)
/// * the arrays have a different length
/// * one of the arrays is a timestamp with timezone and the timezone is not valid.
pub fn add(lhs: &dyn Array, rhs: &dyn Array) -> Box<dyn Array> {
Expand Down

0 comments on commit fb93436

Please sign in to comment.