Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Added support to compare intervals (#746)
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao authored Jan 9, 2022
1 parent a4383b1 commit 4c29966
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 56 deletions.
105 changes: 69 additions & 36 deletions src/compute/comparison/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Basic comparison kernels.
//! Contains comparison operators
//!
//! The module contains functions that compare either an array and a scalar
//! or two arrays of the same [`DataType`]. The scalar-oriented functions are
//! The module contains functions that compare either an [`Array`] and a [`Scalar`]
//! or two [`Array`]s (of the same [`DataType`]). The scalar-oriented functions are
//! suffixed with `_scalar`.
//!
//! The functions are organized in two variants:
Expand Down Expand Up @@ -45,7 +45,7 @@
//! ```
use crate::array::*;
use crate::datatypes::DataType;
use crate::datatypes::{DataType, IntervalUnit};
use crate::scalar::*;

pub mod binary;
Expand All @@ -54,11 +54,11 @@ pub mod primitive;
pub mod utf8;

mod simd;
pub use simd::{Simd8, Simd8Lanes};
pub use simd::{Simd8, Simd8Lanes, Simd8PartialEq, Simd8PartialOrd};

pub(crate) use primitive::compare_values_op as primitive_compare_values_op;

macro_rules! with_match_primitive_cmp {(
macro_rules! match_eq_ord {(
$key_type:expr, | $_:tt $T:ident | $($body:tt)*
) => ({
macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
Expand All @@ -80,8 +80,31 @@ macro_rules! with_match_primitive_cmp {(
}
})}

macro_rules! match_eq {(
$key_type:expr, | $_:tt $T:ident | $($body:tt)*
) => ({
macro_rules! __with_ty__ {( $_ $T:ident ) => ( $($body)* )}
use crate::datatypes::PrimitiveType::*;
use crate::types::{days_ms, months_days_ns};
match $key_type {
Int8 => __with_ty__! { i8 },
Int16 => __with_ty__! { i16 },
Int32 => __with_ty__! { i32 },
Int64 => __with_ty__! { i64 },
Int128 => __with_ty__! { i128 },
DaysMs => __with_ty__! { days_ms },
MonthDayNano => __with_ty__! { months_days_ns },
UInt8 => __with_ty__! { u8 },
UInt16 => __with_ty__! { u16 },
UInt32 => __with_ty__! { u32 },
UInt64 => __with_ty__! { u64 },
Float32 => __with_ty__! { f32 },
Float64 => __with_ty__! { f64 },
}
})}

macro_rules! compare {
($lhs:expr, $rhs:expr, $op:tt) => {{
($lhs:expr, $rhs:expr, $op:tt, $p:tt) => {{
let lhs = $lhs;
let rhs = $rhs;
assert_eq!(
Expand All @@ -96,7 +119,7 @@ macro_rules! compare {
let rhs = rhs.as_any().downcast_ref().unwrap();
boolean::$op(lhs, rhs)
}
Primitive(primitive) => with_match_primitive_cmp!(primitive, |$T| {
Primitive(primitive) => $p!(primitive, |$T| {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::$op::<$T>(lhs, rhs)
Expand Down Expand Up @@ -137,7 +160,7 @@ macro_rules! compare {
/// * the arrays do not have the same length
/// * the operation is not supported for the logical type
pub fn eq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray {
compare!(lhs, rhs, eq)
compare!(lhs, rhs, eq, match_eq)
}

/// `!=` between two [`Array`]s.
Expand All @@ -148,7 +171,7 @@ pub fn eq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray {
/// * the arrays do not have the same length
/// * the operation is not supported for the logical type
pub fn neq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray {
compare!(lhs, rhs, neq)
compare!(lhs, rhs, neq, match_eq)
}

/// `<` between two [`Array`]s.
Expand All @@ -159,7 +182,7 @@ pub fn neq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray {
/// * the arrays do not have the same length
/// * the operation is not supported for the logical type
pub fn lt(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray {
compare!(lhs, rhs, lt)
compare!(lhs, rhs, lt, match_eq_ord)
}

/// `<=` between two [`Array`]s.
Expand All @@ -170,7 +193,7 @@ pub fn lt(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray {
/// * the arrays do not have the same length
/// * the operation is not supported for the logical type
pub fn lt_eq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray {
compare!(lhs, rhs, lt_eq)
compare!(lhs, rhs, lt_eq, match_eq_ord)
}

/// `>` between two [`Array`]s.
Expand All @@ -181,7 +204,7 @@ pub fn lt_eq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray {
/// * the arrays do not have the same length
/// * the operation is not supported for the logical type
pub fn gt(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray {
compare!(lhs, rhs, gt)
compare!(lhs, rhs, gt, match_eq_ord)
}

/// `>=` between two [`Array`]s.
Expand All @@ -192,11 +215,11 @@ pub fn gt(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray {
/// * the arrays do not have the same length
/// * the operation is not supported for the logical type
pub fn gt_eq(lhs: &dyn Array, rhs: &dyn Array) -> BooleanArray {
compare!(lhs, rhs, gt_eq)
compare!(lhs, rhs, gt_eq, match_eq_ord)
}

macro_rules! compare_scalar {
($lhs:expr, $rhs:expr, $op:tt) => {{
($lhs:expr, $rhs:expr, $op:tt, $p:tt) => {{
let lhs = $lhs;
let rhs = $rhs;
assert_eq!(
Expand All @@ -215,7 +238,7 @@ macro_rules! compare_scalar {
// validity checked above
boolean::$op(lhs, rhs.value().unwrap())
}
Primitive(primitive) => with_match_primitive_cmp!(primitive, |$T| {
Primitive(primitive) => $p!(primitive, |$T| {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref::<PrimitiveScalar<$T>>().unwrap();
primitive::$op::<$T>(lhs, rhs.value().unwrap())
Expand Down Expand Up @@ -252,7 +275,7 @@ macro_rules! compare_scalar {
/// * they do not have have the same logical type
/// * the operation is not supported for the logical type
pub fn eq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray {
compare_scalar!(lhs, rhs, eq_scalar)
compare_scalar!(lhs, rhs, eq_scalar, match_eq)
}

/// `!=` between an [`Array`] and a [`Scalar`].
Expand All @@ -262,7 +285,7 @@ pub fn eq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray {
/// * they do not have have the same logical type
/// * the operation is not supported for the logical type
pub fn neq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray {
compare_scalar!(lhs, rhs, neq_scalar)
compare_scalar!(lhs, rhs, neq_scalar, match_eq)
}

/// `<` between an [`Array`] and a [`Scalar`].
Expand All @@ -272,7 +295,7 @@ pub fn neq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray {
/// * they do not have have the same logical type
/// * the operation is not supported for the logical type
pub fn lt_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray {
compare_scalar!(lhs, rhs, lt_scalar)
compare_scalar!(lhs, rhs, lt_scalar, match_eq_ord)
}

/// `<=` between an [`Array`] and a [`Scalar`].
Expand All @@ -282,7 +305,7 @@ pub fn lt_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray {
/// * they do not have have the same logical type
/// * the operation is not supported for the logical type
pub fn lt_eq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray {
compare_scalar!(lhs, rhs, lt_eq_scalar)
compare_scalar!(lhs, rhs, lt_eq_scalar, match_eq_ord)
}

/// `>` between an [`Array`] and a [`Scalar`].
Expand All @@ -292,7 +315,7 @@ pub fn lt_eq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray {
/// * they do not have have the same logical type
/// * the operation is not supported for the logical type
pub fn gt_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray {
compare_scalar!(lhs, rhs, gt_scalar)
compare_scalar!(lhs, rhs, gt_scalar, match_eq_ord)
}

/// `>=` between an [`Array`] and a [`Scalar`].
Expand All @@ -302,41 +325,41 @@ pub fn gt_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray {
/// * they do not have have the same logical type
/// * the operation is not supported for the logical type
pub fn gt_eq_scalar(lhs: &dyn Array, rhs: &dyn Scalar) -> BooleanArray {
compare_scalar!(lhs, rhs, gt_eq_scalar)
compare_scalar!(lhs, rhs, gt_eq_scalar, match_eq_ord)
}

/// Returns whether a [`DataType`] is comparable (either array or scalar) comparison.
/// Returns whether a [`DataType`] is comparable (either array or scalar).
pub fn can_eq(data_type: &DataType) -> bool {
can_compare(data_type)
can_partial_eq(data_type)
}

/// Returns whether a [`DataType`] is comparable (either array or scalar) comparison.
/// Returns whether a [`DataType`] is comparable (either array or scalar).
pub fn can_neq(data_type: &DataType) -> bool {
can_compare(data_type)
can_partial_eq(data_type)
}

/// Returns whether a [`DataType`] is comparable (either array or scalar) comparison.
/// Returns whether a [`DataType`] is comparable (either array or scalar).
pub fn can_lt(data_type: &DataType) -> bool {
can_compare(data_type)
can_partial_eq_and_ord(data_type)
}

/// Returns whether a [`DataType`] is comparable (either array or scalar) comparison.
/// Returns whether a [`DataType`] is comparable (either array or scalar).
pub fn can_lt_eq(data_type: &DataType) -> bool {
can_compare(data_type)
can_partial_eq_and_ord(data_type)
}

/// Returns whether a [`DataType`] is comparable (either array or scalar) comparison.
/// Returns whether a [`DataType`] is comparable (either array or scalar).
pub fn can_gt(data_type: &DataType) -> bool {
can_compare(data_type)
can_partial_eq_and_ord(data_type)
}

/// Returns whether a [`DataType`] is comparable (either array or scalar) comparison.
/// Returns whether a [`DataType`] is comparable (either array or scalar).
pub fn can_gt_eq(data_type: &DataType) -> bool {
can_compare(data_type)
can_partial_eq_and_ord(data_type)
}

// The list of operations currently supported.
fn can_compare(data_type: &DataType) -> bool {
fn can_partial_eq_and_ord(data_type: &DataType) -> bool {
matches!(
data_type,
DataType::Boolean
Expand All @@ -345,7 +368,7 @@ fn can_compare(data_type: &DataType) -> bool {
| DataType::Int32
| DataType::Date32
| DataType::Time32(_)
| DataType::Interval(_)
| DataType::Interval(IntervalUnit::YearMonth)
| DataType::Int64
| DataType::Timestamp(_, _)
| DataType::Date64
Expand All @@ -364,3 +387,13 @@ fn can_compare(data_type: &DataType) -> bool {
| DataType::LargeBinary
)
}

// The list of operations currently supported.
fn can_partial_eq(data_type: &DataType) -> bool {
can_partial_eq_and_ord(data_type)
|| matches!(
data_type.to_logical_type(),
DataType::Interval(IntervalUnit::DayTime)
| DataType::Interval(IntervalUnit::MonthDayNano)
)
}
14 changes: 13 additions & 1 deletion src/compute/comparison/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
};

use super::super::utils::combine_validities;
use super::simd::{Simd8, Simd8Lanes};
use super::simd::{Simd8, Simd8Lanes, Simd8PartialEq, Simd8PartialOrd};

pub(crate) fn compare_values_op<T, F>(lhs: &[T], rhs: &[T], op: F) -> MutableBitmap
where
Expand Down Expand Up @@ -87,6 +87,7 @@ where
pub fn eq<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> BooleanArray
where
T: NativeType + Simd8,
T::Simd: Simd8PartialEq,
{
compare_op(lhs, rhs, |a, b| a.eq(b))
}
Expand All @@ -95,6 +96,7 @@ where
pub fn eq_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> BooleanArray
where
T: NativeType + Simd8,
T::Simd: Simd8PartialEq,
{
compare_op_scalar(lhs, rhs, |a, b| a.eq(b))
}
Expand All @@ -103,6 +105,7 @@ where
pub fn neq<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> BooleanArray
where
T: NativeType + Simd8,
T::Simd: Simd8PartialEq,
{
compare_op(lhs, rhs, |a, b| a.neq(b))
}
Expand All @@ -111,6 +114,7 @@ where
pub fn neq_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> BooleanArray
where
T: NativeType + Simd8,
T::Simd: Simd8PartialEq,
{
compare_op_scalar(lhs, rhs, |a, b| a.neq(b))
}
Expand All @@ -119,6 +123,7 @@ where
pub fn lt<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> BooleanArray
where
T: NativeType + Simd8,
T::Simd: Simd8PartialOrd,
{
compare_op(lhs, rhs, |a, b| a.lt(b))
}
Expand All @@ -127,6 +132,7 @@ where
pub fn lt_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> BooleanArray
where
T: NativeType + Simd8,
T::Simd: Simd8PartialOrd,
{
compare_op_scalar(lhs, rhs, |a, b| a.lt(b))
}
Expand All @@ -135,6 +141,7 @@ where
pub fn lt_eq<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> BooleanArray
where
T: NativeType + Simd8,
T::Simd: Simd8PartialOrd,
{
compare_op(lhs, rhs, |a, b| a.lt_eq(b))
}
Expand All @@ -144,6 +151,7 @@ where
pub fn lt_eq_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> BooleanArray
where
T: NativeType + Simd8,
T::Simd: Simd8PartialOrd,
{
compare_op_scalar(lhs, rhs, |a, b| a.lt_eq(b))
}
Expand All @@ -153,6 +161,7 @@ where
pub fn gt<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> BooleanArray
where
T: NativeType + Simd8,
T::Simd: Simd8PartialOrd,
{
compare_op(lhs, rhs, |a, b| a.gt(b))
}
Expand All @@ -162,6 +171,7 @@ where
pub fn gt_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> BooleanArray
where
T: NativeType + Simd8,
T::Simd: Simd8PartialOrd,
{
compare_op_scalar(lhs, rhs, |a, b| a.gt(b))
}
Expand All @@ -171,6 +181,7 @@ where
pub fn gt_eq<T>(lhs: &PrimitiveArray<T>, rhs: &PrimitiveArray<T>) -> BooleanArray
where
T: NativeType + Simd8,
T::Simd: Simd8PartialOrd,
{
compare_op(lhs, rhs, |a, b| a.gt_eq(b))
}
Expand All @@ -180,6 +191,7 @@ where
pub fn gt_eq_scalar<T>(lhs: &PrimitiveArray<T>, rhs: T) -> BooleanArray
where
T: NativeType + Simd8,
T::Simd: Simd8PartialOrd,
{
compare_op_scalar(lhs, rhs, |a, b| a.gt_eq(b))
}
Expand Down
Loading

0 comments on commit 4c29966

Please sign in to comment.