Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Add support for binary comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
zhyass committed Aug 26, 2021
1 parent 81ec49c commit 1861bdf
Show file tree
Hide file tree
Showing 3 changed files with 321 additions and 2 deletions.
295 changes: 295 additions & 0 deletions src/compute/comparison/binary.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,295 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use crate::error::{ArrowError, Result};
use crate::scalar::{BinaryScalar, Scalar};
use crate::{array::*, bitmap::Bitmap};

use super::{super::utils::combine_validities, Operator};

/// Evaluate `op(lhs, rhs)` for [`BinaryArray`]s using a specified
/// comparison function.
fn compare_op<O, F>(lhs: &BinaryArray<O>, rhs: &BinaryArray<O>, op: F) -> Result<BooleanArray>
where
O: Offset,
F: Fn(&[u8], &[u8]) -> bool,
{
if lhs.len() != rhs.len() {
return Err(ArrowError::InvalidArgumentError(
"Cannot perform comparison operation on arrays of different length".to_string(),
));
}

let validity = combine_validities(lhs.validity(), rhs.validity());

let values = lhs
.values_iter()
.zip(rhs.values_iter())
.map(|(lhs, rhs)| op(lhs, rhs));
let values = Bitmap::from_trusted_len_iter(values);

Ok(BooleanArray::from_data(values, validity))
}

/// Evaluate `op(lhs, rhs)` for [`BinaryArray`] and scalar using
/// a specified comparison function.
fn compare_op_scalar<O, F>(lhs: &BinaryArray<O>, rhs: &[u8], op: F) -> BooleanArray
where
O: Offset,
F: Fn(&[u8], &[u8]) -> bool,
{
let validity = lhs.validity().clone();

let values = lhs.values_iter().map(|lhs| op(lhs, rhs));
let values = Bitmap::from_trusted_len_iter(values);

BooleanArray::from_data(values, validity)
}

/// Perform `lhs == rhs` operation on [`BinaryArray`].
fn eq<O: Offset>(lhs: &BinaryArray<O>, rhs: &BinaryArray<O>) -> Result<BooleanArray> {
compare_op(lhs, rhs, |a, b| a == b)
}

/// Perform `lhs == rhs` operation on [`BinaryArray`] and a scalar.
fn eq_scalar<O: Offset>(lhs: &BinaryArray<O>, rhs: &[u8]) -> BooleanArray {
compare_op_scalar(lhs, rhs, |a, b| a == b)
}

/// Perform `lhs != rhs` operation on [`BinaryArray`].
fn neq<O: Offset>(lhs: &BinaryArray<O>, rhs: &BinaryArray<O>) -> Result<BooleanArray> {
compare_op(lhs, rhs, |a, b| a != b)
}

/// Perform `lhs != rhs` operation on [`BinaryArray`] and a scalar.
fn neq_scalar<O: Offset>(lhs: &BinaryArray<O>, rhs: &[u8]) -> BooleanArray {
compare_op_scalar(lhs, rhs, |a, b| a != b)
}

/// Perform `lhs < rhs` operation on [`BinaryArray`].
fn lt<O: Offset>(lhs: &BinaryArray<O>, rhs: &BinaryArray<O>) -> Result<BooleanArray> {
compare_op(lhs, rhs, |a, b| a < b)
}

/// Perform `lhs < rhs` operation on [`BinaryArray`] and a scalar.
fn lt_scalar<O: Offset>(lhs: &BinaryArray<O>, rhs: &[u8]) -> BooleanArray {
compare_op_scalar(lhs, rhs, |a, b| a < b)
}

/// Perform `lhs <= rhs` operation on [`BinaryArray`].
fn lt_eq<O: Offset>(lhs: &BinaryArray<O>, rhs: &BinaryArray<O>) -> Result<BooleanArray> {
compare_op(lhs, rhs, |a, b| a <= b)
}

/// Perform `lhs <= rhs` operation on [`BinaryArray`] and a scalar.
fn lt_eq_scalar<O: Offset>(lhs: &BinaryArray<O>, rhs: &[u8]) -> BooleanArray {
compare_op_scalar(lhs, rhs, |a, b| a <= b)
}

/// Perform `lhs > rhs` operation on [`BinaryArray`].
fn gt<O: Offset>(lhs: &BinaryArray<O>, rhs: &BinaryArray<O>) -> Result<BooleanArray> {
compare_op(lhs, rhs, |a, b| a > b)
}

/// Perform `lhs > rhs` operation on [`BinaryArray`] and a scalar.
fn gt_scalar<O: Offset>(lhs: &BinaryArray<O>, rhs: &[u8]) -> BooleanArray {
compare_op_scalar(lhs, rhs, |a, b| a > b)
}

/// Perform `lhs >= rhs` operation on [`BinaryArray`].
fn gt_eq<O: Offset>(lhs: &BinaryArray<O>, rhs: &BinaryArray<O>) -> Result<BooleanArray> {
compare_op(lhs, rhs, |a, b| a >= b)
}

/// Perform `lhs >= rhs` operation on [`BinaryArray`] and a scalar.
fn gt_eq_scalar<O: Offset>(lhs: &BinaryArray<O>, rhs: &[u8]) -> BooleanArray {
compare_op_scalar(lhs, rhs, |a, b| a >= b)
}

pub fn compare<O: Offset>(
lhs: &BinaryArray<O>,
rhs: &BinaryArray<O>,
op: Operator,
) -> Result<BooleanArray> {
match op {
Operator::Eq => eq(lhs, rhs),
Operator::Neq => neq(lhs, rhs),
Operator::Gt => gt(lhs, rhs),
Operator::GtEq => gt_eq(lhs, rhs),
Operator::Lt => lt(lhs, rhs),
Operator::LtEq => lt_eq(lhs, rhs),
}
}

pub fn compare_scalar<O: Offset>(
lhs: &BinaryArray<O>,
rhs: &BinaryScalar<O>,
op: Operator,
) -> BooleanArray {
if !rhs.is_valid() {
return BooleanArray::new_null(lhs.len());
}
compare_scalar_non_null(lhs, rhs.value(), op)
}

pub fn compare_scalar_non_null<O: Offset>(
lhs: &BinaryArray<O>,
rhs: &[u8],
op: Operator,
) -> BooleanArray {
match op {
Operator::Eq => eq_scalar(lhs, rhs),
Operator::Neq => neq_scalar(lhs, rhs),
Operator::Gt => gt_scalar(lhs, rhs),
Operator::GtEq => gt_eq_scalar(lhs, rhs),
Operator::Lt => lt_scalar(lhs, rhs),
Operator::LtEq => lt_eq_scalar(lhs, rhs),
}
}

#[cfg(test)]
mod tests {
use super::*;

fn test_generic<O: Offset, F: Fn(&BinaryArray<O>, &BinaryArray<O>) -> Result<BooleanArray>>(
lhs: Vec<&[u8]>,
rhs: Vec<&[u8]>,
op: F,
expected: Vec<bool>,
) {
let lhs = BinaryArray::<O>::from_slice(lhs);
let rhs = BinaryArray::<O>::from_slice(rhs);
let expected = BooleanArray::from_slice(expected);
assert_eq!(op(&lhs, &rhs).unwrap(), expected);
}

fn test_generic_scalar<O: Offset, F: Fn(&BinaryArray<O>, &[u8]) -> BooleanArray>(
lhs: Vec<&[u8]>,
rhs: &[u8],
op: F,
expected: Vec<bool>,
) {
let lhs = BinaryArray::<O>::from_slice(lhs);
let expected = BooleanArray::from_slice(expected);
assert_eq!(op(&lhs, rhs), expected);
}

#[test]
fn test_gt_eq() {
test_generic::<i32, _>(
vec![
"arrow".as_bytes(),
"datafusion".as_bytes(),
"flight".as_bytes(),
"parquet".as_bytes(),
],
vec![
"flight".as_bytes(),
"flight".as_bytes(),
"flight".as_bytes(),
"flight".as_bytes(),
],
gt_eq,
vec![false, false, true, true],
)
}

#[test]
fn test_gt_eq_scalar() {
test_generic_scalar::<i32, _>(
vec![
"arrow".as_bytes(),
"datafusion".as_bytes(),
"flight".as_bytes(),
"parquet".as_bytes(),
],
"flight".as_bytes(),
gt_eq_scalar,
vec![false, false, true, true],
)
}

#[test]
fn test_eq() {
test_generic::<i32, _>(
vec![
"arrow".as_bytes(),
"arrow".as_bytes(),
"arrow".as_bytes(),
"arrow".as_bytes(),
],
vec![
"arrow".as_bytes(),
"parquet".as_bytes(),
"datafusion".as_bytes(),
"flight".as_bytes(),
],
eq,
vec![true, false, false, false],
)
}

#[test]
fn test_eq_scalar() {
test_generic_scalar::<i32, _>(
vec![
"arrow".as_bytes(),
"parquet".as_bytes(),
"datafusion".as_bytes(),
"flight".as_bytes(),
],
"arrow".as_bytes(),
eq_scalar,
vec![true, false, false, false],
)
}

#[test]
fn test_neq() {
test_generic::<i32, _>(
vec![
"arrow".as_bytes(),
"arrow".as_bytes(),
"arrow".as_bytes(),
"arrow".as_bytes(),
],
vec![
"arrow".as_bytes(),
"parquet".as_bytes(),
"datafusion".as_bytes(),
"flight".as_bytes(),
],
neq,
vec![false, true, true, true],
)
}

#[test]
fn test_neq_scalar() {
test_generic_scalar::<i32, _>(
vec![
"arrow".as_bytes(),
"parquet".as_bytes(),
"datafusion".as_bytes(),
"flight".as_bytes(),
],
"arrow".as_bytes(),
neq_scalar,
vec![false, true, true, true],
)
}
}
26 changes: 25 additions & 1 deletion src/compute/comparison/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crate::datatypes::{DataType, IntervalUnit};
use crate::error::{ArrowError, Result};
use crate::scalar::Scalar;

mod binary;
mod boolean;
mod primitive;
mod utf8;
Expand Down Expand Up @@ -131,6 +132,16 @@ pub fn compare(lhs: &dyn Array, rhs: &dyn Array, operator: Operator) -> Result<B
let rhs = rhs.as_any().downcast_ref().unwrap();
primitive::compare::<i128>(lhs, rhs, operator)
}
DataType::Binary => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
binary::compare::<i32>(lhs, rhs, operator)
}
DataType::LargeBinary => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
binary::compare::<i64>(lhs, rhs, operator)
}
_ => Err(ArrowError::NotYetImplemented(format!(
"Comparison between {:?} is not supported",
data_type
Expand Down Expand Up @@ -233,6 +244,16 @@ pub fn compare_scalar(
let rhs = rhs.as_any().downcast_ref().unwrap();
utf8::compare_scalar::<i64>(lhs, rhs, operator)
}
DataType::Binary => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
binary::compare_scalar::<i32>(lhs, rhs, operator)
}
DataType::LargeBinary => {
let lhs = lhs.as_any().downcast_ref().unwrap();
let rhs = rhs.as_any().downcast_ref().unwrap();
binary::compare_scalar::<i64>(lhs, rhs, operator)
}
_ => {
return Err(ArrowError::NotYetImplemented(format!(
"Comparison between {:?} is not supported",
Expand All @@ -242,6 +263,7 @@ pub fn compare_scalar(
})
}

pub use binary::compare_scalar_non_null as binary_compare_scalar;
pub use boolean::compare_scalar_non_null as boolean_compare_scalar;
pub use primitive::compare_scalar_non_null as primitive_compare_scalar;
pub(crate) use primitive::compare_values_op as primitive_compare_values_op;
Expand All @@ -259,7 +281,7 @@ pub use utf8::compare_scalar_non_null as utf8_compare_scalar;
/// assert_eq!(can_compare(&data_type), true);
///
/// let data_type = DataType::LargeBinary;
/// assert_eq!(can_compare(&data_type), false)
/// assert_eq!(can_compare(&data_type), true)
/// ```
pub fn can_compare(data_type: &DataType) -> bool {
matches!(
Expand All @@ -285,6 +307,8 @@ pub fn can_compare(data_type: &DataType) -> bool {
| DataType::Utf8
| DataType::LargeUtf8
| DataType::Decimal(_, _)
| DataType::Binary
| DataType::LargeBinary
)
}

Expand Down
2 changes: 1 addition & 1 deletion src/io/parquet/read/binary/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ mod dictionary;
mod nested;

pub use basic::iter_to_array;
pub use dictionary::iter_to_array as iter_to_dict_array;
pub use basic::stream_to_array;
pub use dictionary::iter_to_array as iter_to_dict_array;
pub use nested::iter_to_array as iter_to_array_nested;

0 comments on commit 1861bdf

Please sign in to comment.