From fb93436a511254e169762dfbe0421df592f3eef2 Mon Sep 17 00:00:00 2001 From: Jorge Leitao Date: Sun, 19 Jun 2022 11:37:57 +0200 Subject: [PATCH] Replaced macros by generics (#1084) --- src/compute/arithmetics/mod.rs | 115 ++++++++++++++++----------------- 1 file changed, 56 insertions(+), 59 deletions(-) diff --git a/src/compute/arithmetics/mod.rs b/src/compute/arithmetics/mod.rs index 2b6bb7e6c0c..07391675ed9 100644 --- a/src/compute/arithmetics/mod.rs +++ b/src/compute/arithmetics/mod.rs @@ -21,17 +21,17 @@ use crate::{ bitmap::Bitmap, datatypes::{DataType, IntervalUnit, TimeUnit}, scalar::{PrimitiveScalar, Scalar}, + types::NativeType, }; -// Macro to evaluate match branch in arithmetic function. -macro_rules! primitive { - ($lhs:expr, $rhs:expr, $op:tt, $type:ty) => {{ - let lhs = $lhs.as_any().downcast_ref().unwrap(); - let rhs = $rhs.as_any().downcast_ref().unwrap(); - - let result = basic::$op::<$type>(lhs, rhs); - Box::new(result) as Box - }}; +fn binary_dyn, &PrimitiveArray) -> PrimitiveArray>( + lhs: &dyn Array, + rhs: &dyn Array, + op: F, +) -> Box { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + op(lhs, rhs).boxed() } // Macro to create a `match` statement with dynamic dispatch to functions based on @@ -42,18 +42,18 @@ macro_rules! arith { let rhs = $rhs; use DataType::*; match (lhs.data_type(), rhs.data_type()) { - (Int8, Int8) => primitive!(lhs, rhs, $op, i8), - (Int16, Int16) => primitive!(lhs, rhs, $op, i16), - (Int32, Int32) => primitive!(lhs, rhs, $op, i32), + (Int8, Int8) => binary_dyn::(lhs, rhs, basic::$op), + (Int16, Int16) => binary_dyn::(lhs, rhs, basic::$op), + (Int32, Int32) => binary_dyn::(lhs, rhs, basic::$op), (Int64, Int64) | (Duration(_), Duration(_)) => { - primitive!(lhs, rhs, $op, i64) + binary_dyn::(lhs, rhs, basic::$op) } - (UInt8, UInt8) => primitive!(lhs, rhs, $op, u8), - (UInt16, UInt16) => primitive!(lhs, rhs, $op, u16), - (UInt32, UInt32) => primitive!(lhs, rhs, $op, u32), - (UInt64, UInt64) => primitive!(lhs, rhs, $op, u64), - (Float32, Float32) => primitive!(lhs, rhs, $op, f32), - (Float64, Float64) => primitive!(lhs, rhs, $op, f64), + (UInt8, UInt8) => binary_dyn::(lhs, rhs, basic::$op), + (UInt16, UInt16) => binary_dyn::(lhs, rhs, basic::$op), + (UInt32, UInt32) => binary_dyn::(lhs, rhs, basic::$op), + (UInt64, UInt64) => binary_dyn::(lhs, rhs, basic::$op), + (Float32, Float32) => binary_dyn::(lhs, rhs, basic::$op), + (Float64, Float64) => binary_dyn::(lhs, rhs, basic::$op), $ ( (Decimal(_, _), Decimal(_, _)) => { let lhs = lhs.as_any().downcast_ref().unwrap(); @@ -101,30 +101,27 @@ macro_rules! arith { }}; } -// Macro to evaluate match branch in arithmetic function. -macro_rules! primitive_scalar { - ($lhs:expr, $rhs:expr, $op:tt, $type:ty) => {{ - let lhs = $lhs - .as_any() - .downcast_ref::>() - .unwrap(); - let rhs = $rhs - .as_any() - .downcast_ref::>() - .unwrap(); - - let rhs = if let Some(rhs) = rhs.value() { - rhs - } else { - return Box::new(PrimitiveArray::<$type>::new_null( - lhs.data_type().clone(), - lhs.len(), - )) as Box; - }; - - let result = basic::$op::<$type>(lhs, &rhs); - Box::new(result) as Box - }}; +fn binary_scalar, &T) -> PrimitiveArray>( + lhs: &PrimitiveArray, + rhs: &PrimitiveScalar, + op: F, +) -> PrimitiveArray { + let rhs = if let Some(rhs) = rhs.value() { + rhs + } else { + return PrimitiveArray::::new_null(lhs.data_type().clone(), lhs.len()); + }; + op(lhs, &rhs) +} + +fn binary_scalar_dyn, &T) -> PrimitiveArray>( + lhs: &dyn Array, + rhs: &dyn Scalar, + op: F, +) -> Box { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + binary_scalar(lhs, rhs, op).boxed() } // Macro to create a `match` statement with dynamic dispatch to functions based on @@ -135,23 +132,23 @@ macro_rules! arith_scalar { let rhs = $rhs; use DataType::*; match (lhs.data_type(), rhs.data_type()) { - (Int8, Int8) => primitive_scalar!(lhs, rhs, $op, i8), - (Int16, Int16) => primitive_scalar!(lhs, rhs, $op, i16), - (Int32, Int32) => primitive_scalar!(lhs, rhs, $op, i32), + (Int8, Int8) => binary_scalar_dyn::(lhs, rhs, basic::$op), + (Int16, Int16) => binary_scalar_dyn::(lhs, rhs, basic::$op), + (Int32, Int32) => binary_scalar_dyn::(lhs, rhs, basic::$op), (Int64, Int64) | (Duration(_), Duration(_)) => { - primitive_scalar!(lhs, rhs, $op, i64) + binary_scalar_dyn::(lhs, rhs, basic::$op) } - (UInt8, UInt8) => primitive_scalar!(lhs, rhs, $op, u8), - (UInt16, UInt16) => primitive_scalar!(lhs, rhs, $op, u16), - (UInt32, UInt32) => primitive_scalar!(lhs, rhs, $op, u32), - (UInt64, UInt64) => primitive_scalar!(lhs, rhs, $op, u64), - (Float32, Float32) => primitive_scalar!(lhs, rhs, $op, f32), - (Float64, Float64) => primitive_scalar!(lhs, rhs, $op, f64), + (UInt8, UInt8) => binary_scalar_dyn::(lhs, rhs, basic::$op), + (UInt16, UInt16) => binary_scalar_dyn::(lhs, rhs, basic::$op), + (UInt32, UInt32) => binary_scalar_dyn::(lhs, rhs, basic::$op), + (UInt64, UInt64) => binary_scalar_dyn::(lhs, rhs, basic::$op), + (Float32, Float32) => binary_scalar_dyn::(lhs, rhs, basic::$op), + (Float64, Float64) => binary_scalar_dyn::(lhs, rhs, basic::$op), $ ( (Decimal(_, _), Decimal(_, _)) => { let lhs = lhs.as_any().downcast_ref().unwrap(); let rhs = rhs.as_any().downcast_ref().unwrap(); - Box::new(decimal::$op_decimal(lhs, rhs)) as Box + decimal::$op_decimal(lhs, rhs).boxed() } )? $ ( @@ -160,7 +157,7 @@ macro_rules! arith_scalar { | (Date32, Duration(_)) => { let lhs = lhs.as_any().downcast_ref().unwrap(); let rhs = rhs.as_any().downcast_ref().unwrap(); - Box::new(time::$op_duration::(lhs, rhs)) as Box + time::$op_duration::(lhs, rhs).boxed() } (Time64(TimeUnit::Microsecond), Duration(_)) | (Time64(TimeUnit::Nanosecond), Duration(_)) @@ -168,21 +165,21 @@ macro_rules! arith_scalar { | (Timestamp(_, _), Duration(_)) => { let lhs = lhs.as_any().downcast_ref().unwrap(); let rhs = rhs.as_any().downcast_ref().unwrap(); - Box::new(time::$op_duration::(lhs, rhs)) as Box + time::$op_duration::(lhs, rhs).boxed() } )? $ ( (Timestamp(_, _), Interval(IntervalUnit::MonthDayNano)) => { let lhs = lhs.as_any().downcast_ref().unwrap(); let rhs = rhs.as_any().downcast_ref().unwrap(); - time::$op_interval(lhs, rhs).map(|x| Box::new(x) as Box).unwrap() + time::$op_interval(lhs, rhs).unwrap().boxed() } )? $ ( (Timestamp(_, None), Timestamp(_, None)) => { let lhs = lhs.as_any().downcast_ref().unwrap(); let rhs = rhs.as_any().downcast_ref().unwrap(); - time::$op_timestamp(lhs, rhs).map(|x| Box::new(x) as Box).unwrap() + time::$op_timestamp(lhs, rhs).unwrap().boxed() } )? _ => todo!( @@ -197,7 +194,7 @@ macro_rules! arith_scalar { /// Adds two [`Array`]s. /// # Panic /// This function panics iff -/// * the opertion is not supported for the logical types (use [`can_add`] to check) +/// * the operation is not supported for the logical types (use [`can_add`] to check) /// * the arrays have a different length /// * one of the arrays is a timestamp with timezone and the timezone is not valid. pub fn add(lhs: &dyn Array, rhs: &dyn Array) -> Box {