From 1861bdfae5bced97d19d33d3840a85015896fcdc Mon Sep 17 00:00:00 2001 From: zhyass <34016424+zhyass@users.noreply.github.com> Date: Thu, 26 Aug 2021 22:16:03 +0800 Subject: [PATCH] Add support for binary comparison --- src/compute/comparison/binary.rs | 295 ++++++++++++++++++++++++++++++ src/compute/comparison/mod.rs | 26 ++- src/io/parquet/read/binary/mod.rs | 2 +- 3 files changed, 321 insertions(+), 2 deletions(-) create mode 100644 src/compute/comparison/binary.rs diff --git a/src/compute/comparison/binary.rs b/src/compute/comparison/binary.rs new file mode 100644 index 00000000000..e039434f8f3 --- /dev/null +++ b/src/compute/comparison/binary.rs @@ -0,0 +1,295 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::error::{ArrowError, Result}; +use crate::scalar::{BinaryScalar, Scalar}; +use crate::{array::*, bitmap::Bitmap}; + +use super::{super::utils::combine_validities, Operator}; + +/// Evaluate `op(lhs, rhs)` for [`BinaryArray`]s using a specified +/// comparison function. +fn compare_op(lhs: &BinaryArray, rhs: &BinaryArray, op: F) -> Result +where + O: Offset, + F: Fn(&[u8], &[u8]) -> bool, +{ + if lhs.len() != rhs.len() { + return Err(ArrowError::InvalidArgumentError( + "Cannot perform comparison operation on arrays of different length".to_string(), + )); + } + + let validity = combine_validities(lhs.validity(), rhs.validity()); + + let values = lhs + .values_iter() + .zip(rhs.values_iter()) + .map(|(lhs, rhs)| op(lhs, rhs)); + let values = Bitmap::from_trusted_len_iter(values); + + Ok(BooleanArray::from_data(values, validity)) +} + +/// Evaluate `op(lhs, rhs)` for [`BinaryArray`] and scalar using +/// a specified comparison function. +fn compare_op_scalar(lhs: &BinaryArray, rhs: &[u8], op: F) -> BooleanArray +where + O: Offset, + F: Fn(&[u8], &[u8]) -> bool, +{ + let validity = lhs.validity().clone(); + + let values = lhs.values_iter().map(|lhs| op(lhs, rhs)); + let values = Bitmap::from_trusted_len_iter(values); + + BooleanArray::from_data(values, validity) +} + +/// Perform `lhs == rhs` operation on [`BinaryArray`]. +fn eq(lhs: &BinaryArray, rhs: &BinaryArray) -> Result { + compare_op(lhs, rhs, |a, b| a == b) +} + +/// Perform `lhs == rhs` operation on [`BinaryArray`] and a scalar. +fn eq_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a == b) +} + +/// Perform `lhs != rhs` operation on [`BinaryArray`]. +fn neq(lhs: &BinaryArray, rhs: &BinaryArray) -> Result { + compare_op(lhs, rhs, |a, b| a != b) +} + +/// Perform `lhs != rhs` operation on [`BinaryArray`] and a scalar. +fn neq_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a != b) +} + +/// Perform `lhs < rhs` operation on [`BinaryArray`]. +fn lt(lhs: &BinaryArray, rhs: &BinaryArray) -> Result { + compare_op(lhs, rhs, |a, b| a < b) +} + +/// Perform `lhs < rhs` operation on [`BinaryArray`] and a scalar. +fn lt_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a < b) +} + +/// Perform `lhs <= rhs` operation on [`BinaryArray`]. +fn lt_eq(lhs: &BinaryArray, rhs: &BinaryArray) -> Result { + compare_op(lhs, rhs, |a, b| a <= b) +} + +/// Perform `lhs <= rhs` operation on [`BinaryArray`] and a scalar. +fn lt_eq_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a <= b) +} + +/// Perform `lhs > rhs` operation on [`BinaryArray`]. +fn gt(lhs: &BinaryArray, rhs: &BinaryArray) -> Result { + compare_op(lhs, rhs, |a, b| a > b) +} + +/// Perform `lhs > rhs` operation on [`BinaryArray`] and a scalar. +fn gt_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a > b) +} + +/// Perform `lhs >= rhs` operation on [`BinaryArray`]. +fn gt_eq(lhs: &BinaryArray, rhs: &BinaryArray) -> Result { + compare_op(lhs, rhs, |a, b| a >= b) +} + +/// Perform `lhs >= rhs` operation on [`BinaryArray`] and a scalar. +fn gt_eq_scalar(lhs: &BinaryArray, rhs: &[u8]) -> BooleanArray { + compare_op_scalar(lhs, rhs, |a, b| a >= b) +} + +pub fn compare( + lhs: &BinaryArray, + rhs: &BinaryArray, + op: Operator, +) -> Result { + match op { + Operator::Eq => eq(lhs, rhs), + Operator::Neq => neq(lhs, rhs), + Operator::Gt => gt(lhs, rhs), + Operator::GtEq => gt_eq(lhs, rhs), + Operator::Lt => lt(lhs, rhs), + Operator::LtEq => lt_eq(lhs, rhs), + } +} + +pub fn compare_scalar( + lhs: &BinaryArray, + rhs: &BinaryScalar, + op: Operator, +) -> BooleanArray { + if !rhs.is_valid() { + return BooleanArray::new_null(lhs.len()); + } + compare_scalar_non_null(lhs, rhs.value(), op) +} + +pub fn compare_scalar_non_null( + lhs: &BinaryArray, + rhs: &[u8], + op: Operator, +) -> BooleanArray { + match op { + Operator::Eq => eq_scalar(lhs, rhs), + Operator::Neq => neq_scalar(lhs, rhs), + Operator::Gt => gt_scalar(lhs, rhs), + Operator::GtEq => gt_eq_scalar(lhs, rhs), + Operator::Lt => lt_scalar(lhs, rhs), + Operator::LtEq => lt_eq_scalar(lhs, rhs), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_generic, &BinaryArray) -> Result>( + lhs: Vec<&[u8]>, + rhs: Vec<&[u8]>, + op: F, + expected: Vec, + ) { + let lhs = BinaryArray::::from_slice(lhs); + let rhs = BinaryArray::::from_slice(rhs); + let expected = BooleanArray::from_slice(expected); + assert_eq!(op(&lhs, &rhs).unwrap(), expected); + } + + fn test_generic_scalar, &[u8]) -> BooleanArray>( + lhs: Vec<&[u8]>, + rhs: &[u8], + op: F, + expected: Vec, + ) { + let lhs = BinaryArray::::from_slice(lhs); + let expected = BooleanArray::from_slice(expected); + assert_eq!(op(&lhs, rhs), expected); + } + + #[test] + fn test_gt_eq() { + test_generic::( + vec![ + "arrow".as_bytes(), + "datafusion".as_bytes(), + "flight".as_bytes(), + "parquet".as_bytes(), + ], + vec![ + "flight".as_bytes(), + "flight".as_bytes(), + "flight".as_bytes(), + "flight".as_bytes(), + ], + gt_eq, + vec![false, false, true, true], + ) + } + + #[test] + fn test_gt_eq_scalar() { + test_generic_scalar::( + vec![ + "arrow".as_bytes(), + "datafusion".as_bytes(), + "flight".as_bytes(), + "parquet".as_bytes(), + ], + "flight".as_bytes(), + gt_eq_scalar, + vec![false, false, true, true], + ) + } + + #[test] + fn test_eq() { + test_generic::( + vec![ + "arrow".as_bytes(), + "arrow".as_bytes(), + "arrow".as_bytes(), + "arrow".as_bytes(), + ], + vec![ + "arrow".as_bytes(), + "parquet".as_bytes(), + "datafusion".as_bytes(), + "flight".as_bytes(), + ], + eq, + vec![true, false, false, false], + ) + } + + #[test] + fn test_eq_scalar() { + test_generic_scalar::( + vec![ + "arrow".as_bytes(), + "parquet".as_bytes(), + "datafusion".as_bytes(), + "flight".as_bytes(), + ], + "arrow".as_bytes(), + eq_scalar, + vec![true, false, false, false], + ) + } + + #[test] + fn test_neq() { + test_generic::( + vec![ + "arrow".as_bytes(), + "arrow".as_bytes(), + "arrow".as_bytes(), + "arrow".as_bytes(), + ], + vec![ + "arrow".as_bytes(), + "parquet".as_bytes(), + "datafusion".as_bytes(), + "flight".as_bytes(), + ], + neq, + vec![false, true, true, true], + ) + } + + #[test] + fn test_neq_scalar() { + test_generic_scalar::( + vec![ + "arrow".as_bytes(), + "parquet".as_bytes(), + "datafusion".as_bytes(), + "flight".as_bytes(), + ], + "arrow".as_bytes(), + neq_scalar, + vec![false, true, true, true], + ) + } +} diff --git a/src/compute/comparison/mod.rs b/src/compute/comparison/mod.rs index ceaf548ca68..5c3514145c4 100644 --- a/src/compute/comparison/mod.rs +++ b/src/compute/comparison/mod.rs @@ -22,6 +22,7 @@ use crate::datatypes::{DataType, IntervalUnit}; use crate::error::{ArrowError, Result}; use crate::scalar::Scalar; +mod binary; mod boolean; mod primitive; mod utf8; @@ -131,6 +132,16 @@ pub fn compare(lhs: &dyn Array, rhs: &dyn Array, operator: Operator) -> Result(lhs, rhs, operator) } + DataType::Binary => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + binary::compare::(lhs, rhs, operator) + } + DataType::LargeBinary => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + binary::compare::(lhs, rhs, operator) + } _ => Err(ArrowError::NotYetImplemented(format!( "Comparison between {:?} is not supported", data_type @@ -233,6 +244,16 @@ pub fn compare_scalar( let rhs = rhs.as_any().downcast_ref().unwrap(); utf8::compare_scalar::(lhs, rhs, operator) } + DataType::Binary => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + binary::compare_scalar::(lhs, rhs, operator) + } + DataType::LargeBinary => { + let lhs = lhs.as_any().downcast_ref().unwrap(); + let rhs = rhs.as_any().downcast_ref().unwrap(); + binary::compare_scalar::(lhs, rhs, operator) + } _ => { return Err(ArrowError::NotYetImplemented(format!( "Comparison between {:?} is not supported", @@ -242,6 +263,7 @@ pub fn compare_scalar( }) } +pub use binary::compare_scalar_non_null as binary_compare_scalar; pub use boolean::compare_scalar_non_null as boolean_compare_scalar; pub use primitive::compare_scalar_non_null as primitive_compare_scalar; pub(crate) use primitive::compare_values_op as primitive_compare_values_op; @@ -259,7 +281,7 @@ pub use utf8::compare_scalar_non_null as utf8_compare_scalar; /// assert_eq!(can_compare(&data_type), true); /// /// let data_type = DataType::LargeBinary; -/// assert_eq!(can_compare(&data_type), false) +/// assert_eq!(can_compare(&data_type), true) /// ``` pub fn can_compare(data_type: &DataType) -> bool { matches!( @@ -285,6 +307,8 @@ pub fn can_compare(data_type: &DataType) -> bool { | DataType::Utf8 | DataType::LargeUtf8 | DataType::Decimal(_, _) + | DataType::Binary + | DataType::LargeBinary ) } diff --git a/src/io/parquet/read/binary/mod.rs b/src/io/parquet/read/binary/mod.rs index 54da7b3f1db..1bfd2b04235 100644 --- a/src/io/parquet/read/binary/mod.rs +++ b/src/io/parquet/read/binary/mod.rs @@ -3,6 +3,6 @@ mod dictionary; mod nested; pub use basic::iter_to_array; -pub use dictionary::iter_to_array as iter_to_dict_array; pub use basic::stream_to_array; +pub use dictionary::iter_to_array as iter_to_dict_array; pub use nested::iter_to_array as iter_to_array_nested;