Skip to content

Commit

Permalink
clean up func
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
  • Loading branch information
jayzhan211 committed Feb 20, 2024
1 parent 76d960f commit 28c7fc6
Showing 1 changed file with 60 additions and 87 deletions.
147 changes: 60 additions & 87 deletions arrow-ord/src/cmp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,10 @@ use arrow_buffer::bit_util::ceil;
use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer};
use arrow_schema::{ArrowError, DataType};
use arrow_select::take::take;
use std::cmp::Ordering;
use std::ops::Not;

use crate::ord::{build_compare, DynComparator};
use crate::ord::build_compare;

#[derive(Debug, Copy, Clone)]
enum Op {
Expand Down Expand Up @@ -168,6 +169,44 @@ pub fn not_distinct(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<BooleanArray, Ar
compare_op(Op::NotDistinct, lhs, rhs)
}

fn compare_list(l: &dyn Array, r: &dyn Array) -> Result<Ordering, ArrowError> {
let l_t = l.data_type();
let r_t = r.data_type();
let l_len = l.len();
let r_len = r.len();
let min_len = std::cmp::min(l_len, r_len);

if let (DataType::List(_), DataType::List(_)) = (l_t, r_t) {
let l = l.as_list::<i32>();
let r = r.as_list::<i32>();

for j in 0..min_len {
let l_v = l.value(j);
let r_v = r.value(j);
let ord = compare_list(l_v.as_ref(), r_v.as_ref())?;
if ord != Ordering::Equal {
return Ok(ord);
}
}
} else {
let cmp = build_compare(l, r)?;
for j in 0..min_len {
let ord = cmp(j, j);
if ord != Ordering::Equal {
return Ok(ord);
}
}
}

if l_len < r_len {
Ok(Ordering::Less)
} else if l_len > r_len {
Ok(Ordering::Greater)
} else {
Ok(Ordering::Equal)
}
}

fn process_nested(
l: &dyn Array,
r: &dyn Array,
Expand All @@ -178,97 +217,31 @@ fn process_nested(
) -> Result<Option<BooleanArray>, ArrowError> {
use arrow_schema::DataType::*;
if let (List(_), List(_)) = (l_t, r_t) {
// Process nested data types
fn process(
l: &dyn Array,
r: &dyn Array,
len: usize,
target_ord: Ordering,
) -> Result<BooleanArray, ArrowError> {
let mut values = BooleanArray::builder(len);
let ord = compare_list(l, r)?;
if ord == target_ord {
values.append_value(true);
} else {
values.append_value(false);
}

let values = values.finish();
Ok(values)
}

match op {
Op::Less => {
let l = l.as_list::<i32>();
let r = r.as_list::<i32>();
let mut values = BooleanArray::builder(len);
for i in 0..l.len() {
let l = l.value(i);
let r = r.value(i);
let l_t = l.data_type();
let r_t = r.data_type();
let l_len = l.len();
let r_len = r.len();
let min_len = std::cmp::min(l_len, r_len);

if !l_t.is_nested() && !r_t.is_nested() {
let cmp = build_compare(&l, &r)?;

fn post_process(len: usize, cmp: DynComparator, r_is_longer: bool) -> bool {
for j in 0..len {
let ord = cmp(j, j);
if ord == std::cmp::Ordering::Less {
return true;
}
if ord == std::cmp::Ordering::Greater {
return false;
}
}
r_is_longer
}
values.append_value(post_process(min_len, cmp, r_len > l_len));
} else {
// Since `compare_op` does not support inconsistent lengths, we compare the
// prefix with `compare_op` only, and compare the left if the prefix is equal
let l = l.slice(0, min_len);
let r = r.slice(0, min_len);

let lt_res = lt(&l, &r)?;
let eq_res = eq(&l, &r)?;

fn post_process(
lt: &BooleanArray,
eq: &BooleanArray,
r_is_longer: bool,
) -> bool {
for j in 0..lt.len() {
if lt.value(j) {
return true;
}
if !eq.value(j) {
return false;
}
}
r_is_longer
}

values.append_value(post_process(&lt_res, &eq_res, r_len > l_len));
}
}

let values = values.finish();
let values = process(l, r, len, Ordering::Less)?;
Ok(Some(values))
}
Op::Equal => {
let l = l.as_list::<i32>();
let r = r.as_list::<i32>();
let mut values = BooleanArray::builder(len);
for i in 0..l.len() {
let l = l.value(i);
let r = r.value(i);
let l_len = l.len();
let r_len = r.len();
if l_len != r_len {
values.append_value(false);
continue;
}

let eq_res = eq(&l, &r)?;
fn post_process(eq: &BooleanArray) -> bool {
for j in 0..eq.len() {
if !eq.value(j) {
return false;
}
}
true
}

values.append_value(post_process(&eq_res));
}

let values = values.finish();
let values = process(l, r, len, Ordering::Equal)?;
Ok(Some(values))
}
_ => Err(ArrowError::NotYetImplemented(format!(
Expand Down

0 comments on commit 28c7fc6

Please sign in to comment.