Skip to content

Commit

Permalink
Use trait
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Oct 6, 2022
1 parent bebfa85 commit 0ea6f98
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 95 deletions.
112 changes: 20 additions & 92 deletions arrow/src/compute/kernels/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,18 @@ use crate::array::*;
use crate::buffer::{buffer_unary_not, Buffer, MutableBuffer};
use crate::compute::util::combine_option_bitmap;
use crate::datatypes::{
ArrowNativeType, ArrowNumericType, DataType, Date32Type, Date64Type, Float32Type,
Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType,
IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, Time32MillisecondType,
Time32SecondType, Time64MicrosecondType, Time64NanosecondType, TimeUnit,
TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
native_op::ArrowNativeTypeOp, ArrowNativeType, ArrowNumericType, DataType,
Date32Type, Date64Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type,
Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit,
IntervalYearMonthType, Time32MillisecondType, Time32SecondType,
Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type,
UInt32Type, UInt64Type, UInt8Type,
};
#[allow(unused_imports)]
use crate::downcast_dictionary_array;
use crate::error::{ArrowError, Result};
use crate::util::bit_util;
use num::ToPrimitive;
use regex::Regex;
use std::collections::HashMap;

Expand Down Expand Up @@ -1336,7 +1336,7 @@ macro_rules! dyn_compare_utf8_scalar {
/// Please refer to `f32::total_cmp` and `f64::total_cmp`.
pub fn eq_dyn_scalar<T>(left: &dyn Array, right: T) -> Result<BooleanArray>
where
T: num::ToPrimitive + std::fmt::Debug,
T: ArrowNativeTypeOp,
{
match left.data_type() {
DataType::Dictionary(key_type, _value_type) => {
Expand Down Expand Up @@ -3051,24 +3051,12 @@ where
pub fn eq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
where
T: ArrowNumericType,
T::Native: num::ToPrimitive + std::fmt::Debug,
T::Native: ArrowNativeTypeOp,
{
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::eq, |a, b| a == b);
#[cfg(not(feature = "simd"))]
match left.data_type() {
DataType::Float32 => {
let left = as_primitive_array::<Float32Type>(left);
let right = try_to_type!(right, to_f32)?;
compare_op_scalar(left, |a: f32| a.total_cmp(&right).is_eq())
}
DataType::Float64 => {
let left = as_primitive_array::<Float64Type>(left);
let right = try_to_type!(right, to_f64)?;
compare_op_scalar(left, |a: f64| a.total_cmp(&right).is_eq())
}
_ => compare_op_scalar(left, |a| a == right),
}
return compare_op_scalar(left, |a| a.is_eq(right));
}

/// Applies an unary and infallible comparison function to a primitive array.
Expand Down Expand Up @@ -3099,24 +3087,12 @@ where
pub fn neq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
where
T: ArrowNumericType,
T::Native: num::ToPrimitive + std::fmt::Debug,
T::Native: ArrowNativeTypeOp,
{
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::ne, |a, b| a != b);
#[cfg(not(feature = "simd"))]
match left.data_type() {
DataType::Float32 => {
let left = as_primitive_array::<Float32Type>(left);
let right = try_to_type!(right, to_f32)?;
compare_op_scalar(left, |a: f32| a.total_cmp(&right).is_ne())
}
DataType::Float64 => {
let left = as_primitive_array::<Float64Type>(left);
let right = try_to_type!(right, to_f64)?;
compare_op_scalar(left, |a: f64| a.total_cmp(&right).is_ne())
}
_ => compare_op_scalar(left, |a| a != right),
}
return compare_op_scalar(left, |a| a.is_ne(right));
}

/// Perform `left < right` operation on two [`PrimitiveArray`]s. Null values are less than non-null
Expand All @@ -3140,24 +3116,12 @@ where
pub fn lt_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
where
T: ArrowNumericType,
T::Native: num::ToPrimitive + std::fmt::Debug,
T::Native: ArrowNativeTypeOp,
{
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::lt, |a, b| a < b);
#[cfg(not(feature = "simd"))]
match left.data_type() {
DataType::Float32 => {
let left = as_primitive_array::<Float32Type>(left);
let right = try_to_type!(right, to_f32)?;
compare_op_scalar(left, |a: f32| a.total_cmp(&right).is_lt())
}
DataType::Float64 => {
let left = as_primitive_array::<Float64Type>(left);
let right = try_to_type!(right, to_f64)?;
compare_op_scalar(left, |a: f64| a.total_cmp(&right).is_lt())
}
_ => compare_op_scalar(left, |a| a < right),
}
return compare_op_scalar(left, |a| a.is_lt(right));
}

/// Perform `left <= right` operation on two [`PrimitiveArray`]s. Null values are less than non-null
Expand All @@ -3184,24 +3148,12 @@ where
pub fn lt_eq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
where
T: ArrowNumericType,
T::Native: num::ToPrimitive + std::fmt::Debug,
T::Native: ArrowNativeTypeOp,
{
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::le, |a, b| a <= b);
#[cfg(not(feature = "simd"))]
match left.data_type() {
DataType::Float32 => {
let left = as_primitive_array::<Float32Type>(left);
let right = try_to_type!(right, to_f32)?;
compare_op_scalar(left, |a: f32| a.total_cmp(&right).is_le())
}
DataType::Float64 => {
let left = as_primitive_array::<Float64Type>(left);
let right = try_to_type!(right, to_f64)?;
compare_op_scalar(left, |a: f64| a.total_cmp(&right).is_le())
}
_ => compare_op_scalar(left, |a| a <= right),
}
return compare_op_scalar(left, |a| a.is_le(right));
}

/// Perform `left > right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null
Expand All @@ -3225,24 +3177,12 @@ where
pub fn gt_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
where
T: ArrowNumericType,
T::Native: num::ToPrimitive + std::fmt::Debug,
T::Native: ArrowNativeTypeOp,
{
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::gt, |a, b| a > b);
#[cfg(not(feature = "simd"))]
match left.data_type() {
DataType::Float32 => {
let left = as_primitive_array::<Float32Type>(left);
let right = try_to_type!(right, to_f32)?;
compare_op_scalar(left, |a: f32| a.total_cmp(&right).is_gt())
}
DataType::Float64 => {
let left = as_primitive_array::<Float64Type>(left);
let right = try_to_type!(right, to_f64)?;
compare_op_scalar(left, |a: f64| a.total_cmp(&right).is_gt())
}
_ => compare_op_scalar(left, |a| a > right),
}
return compare_op_scalar(left, |a| a.is_gt(right));
}

/// Perform `left >= right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null
Expand All @@ -3269,24 +3209,12 @@ where
pub fn gt_eq_scalar<T>(left: &PrimitiveArray<T>, right: T::Native) -> Result<BooleanArray>
where
T: ArrowNumericType,
T::Native: num::ToPrimitive + std::fmt::Debug,
T::Native: ArrowNativeTypeOp,
{
#[cfg(feature = "simd")]
return simd_compare_op_scalar(left, right, T::ge, |a, b| a >= b);
#[cfg(not(feature = "simd"))]
match left.data_type() {
DataType::Float32 => {
let left = as_primitive_array::<Float32Type>(left);
let right = try_to_type!(right, to_f32)?;
compare_op_scalar(left, |a: f32| a.total_cmp(&right).is_ge())
}
DataType::Float64 => {
let left = as_primitive_array::<Float64Type>(left);
let right = try_to_type!(right, to_f64)?;
compare_op_scalar(left, |a: f64| a.total_cmp(&right).is_ge())
}
_ => compare_op_scalar(left, |a| a >= right),
}
return compare_op_scalar(left, |a| a.is_ge(right));
}

/// Checks if a [`GenericListArray`] contains a value in the [`PrimitiveArray`]
Expand Down
61 changes: 58 additions & 3 deletions arrow/src/datatypes/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ pub(crate) mod native_op {
+ Mul<Output = Self>
+ Div<Output = Self>
+ Zero
+ num::ToPrimitive
{
fn add_checked(self, rhs: Self) -> Result<Self> {
Ok(self + rhs)
Expand Down Expand Up @@ -81,6 +82,30 @@ pub(crate) mod native_op {
fn div_wrapping(self, rhs: Self) -> Self {
self / rhs
}

fn is_eq(self, rhs: Self) -> bool {
self == rhs
}

fn is_ne(self, rhs: Self) -> bool {
self != rhs
}

fn is_lt(self, rhs: Self) -> bool {
self < rhs
}

fn is_le(self, rhs: Self) -> bool {
self <= rhs
}

fn is_gt(self, rhs: Self) -> bool {
self > rhs
}

fn is_ge(self, rhs: Self) -> bool {
self >= rhs
}
}
}

Expand Down Expand Up @@ -156,6 +181,36 @@ native_type_op!(u16);
native_type_op!(u32);
native_type_op!(u64);

impl native_op::ArrowNativeTypeOp for f16 {}
impl native_op::ArrowNativeTypeOp for f32 {}
impl native_op::ArrowNativeTypeOp for f64 {}
macro_rules! native_type_float_op {
($t:tt) => {
impl native_op::ArrowNativeTypeOp for $t {
fn is_eq(self, rhs: Self) -> bool {
self.total_cmp(&rhs).is_eq()
}

fn is_ne(self, rhs: Self) -> bool {
self.total_cmp(&rhs).is_ne()
}

fn is_lt(self, rhs: Self) -> bool {
self.total_cmp(&rhs).is_lt()
}

fn is_le(self, rhs: Self) -> bool {
self.total_cmp(&rhs).is_le()
}

fn is_gt(self, rhs: Self) -> bool {
self.total_cmp(&rhs).is_gt()
}

fn is_ge(self, rhs: Self) -> bool {
self.total_cmp(&rhs).is_ge()
}
}
};
}

native_type_float_op!(f16);
native_type_float_op!(f32);
native_type_float_op!(f64);

0 comments on commit 0ea6f98

Please sign in to comment.