Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix is_distinct from for float NaN values #5446

Merged
merged 2 commits into from
Mar 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,24 @@ e 4 97 -13181 2047637360 6176835796788944083 158 53000 2042457019 97260165026400
e 5 -86 32514 -467659022 -8012578250188146150 254 2684 2861911482 2126626171973341689 0.12559289 0.014793053078 gxfHWUF8XgY2KdFxigxvNEXe2V2XMl
e 5 64 -26526 1689098844 8950618259486183091 224 45253 662099130 16127995415060805595 0.2897315 0.575945048386 56MZa5O1hVtX4c5sbnCfxuX5kDChqI

# distinct_from logic for floats
query BBBBBBBBBBB
select
'nan'::float is distinct from 'nan'::float v7,
'nan'::float is not distinct from 'nan'::float v8,
'nan'::float is not distinct from null v9,
'nan'::float is distinct from null v10,
null is distinct from 'nan'::float v11,
null is not distinct from 'nan'::float v12,
1::float is distinct from 2::float v13,
'nan'::float is distinct from 1::float v14,
'nan'::float is not distinct from 1::float v15,
1::float is not distinct from null v16,
1::float is distinct from null v17
;
----
false true false true true false true true false false true
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️


########
# Clean up after the test
########
Expand Down
26 changes: 26 additions & 0 deletions datafusion/core/tests/sqllogictests/test_files/select.slt
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,29 @@ select '1' from foo order by column1;
# foo distinct order by
statement error DataFusion error: Error during planning: For SELECT DISTINCT, ORDER BY expressions column1 must appear in select list
select distinct '1' from foo order by column1;

# distincts for float nan
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ive duplicated in both select.slt and pg_compat_simple.slt Reason for that is PG doesn't support double datatype

query BBBBBBBBBBBBBBBBB
select
'nan'::double is distinct from 'nan'::double v1,
'nan'::double is not distinct from 'nan'::double v2,
'nan'::double is not distinct from null v3,
'nan'::double is distinct from null v4,
null is distinct from 'nan'::double v5,
null is not distinct from 'nan'::double v6,
'nan'::float is distinct from 'nan'::float v7,
'nan'::float is not distinct from 'nan'::float v8,
'nan'::float is not distinct from null v9,
'nan'::float is distinct from null v10,
null is distinct from 'nan'::float v11,
null is not distinct from 'nan'::float v12,
1::float is distinct from 2::float v13,
'nan'::float is distinct from 1::float v14,
'nan'::float is not distinct from 1::float v15,
1::float is not distinct from null v16,
1::float is distinct from null v17
;
----
false true false true true false false true false true true false true true false false true


59 changes: 44 additions & 15 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,12 @@ use kernels::{
use kernels_arrow::{
add_decimal_dyn_scalar, add_dyn_decimal, divide_decimal_dyn_scalar,
divide_dyn_opt_decimal, is_distinct_from, is_distinct_from_bool,
is_distinct_from_decimal, is_distinct_from_null, is_distinct_from_utf8,
is_not_distinct_from, is_not_distinct_from_bool, is_not_distinct_from_decimal,
is_not_distinct_from_null, is_not_distinct_from_utf8, modulus_decimal,
modulus_decimal_scalar, multiply_decimal_dyn_scalar, multiply_dyn_decimal,
subtract_decimal_dyn_scalar, subtract_dyn_decimal,
is_distinct_from_decimal, is_distinct_from_f32, is_distinct_from_f64,
is_distinct_from_null, is_distinct_from_utf8, is_not_distinct_from,
is_not_distinct_from_bool, is_not_distinct_from_decimal, is_not_distinct_from_f32,
is_not_distinct_from_f64, is_not_distinct_from_null, is_not_distinct_from_utf8,
modulus_decimal, modulus_decimal_scalar, multiply_decimal_dyn_scalar,
multiply_dyn_decimal, subtract_decimal_dyn_scalar, subtract_dyn_decimal,
};

use arrow::datatypes::{DataType, Schema, TimeUnit};
Expand Down Expand Up @@ -184,16 +185,44 @@ macro_rules! compute_decimal_op {
}};
}

macro_rules! compute_f32_op {
($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
let ll = $LEFT
.as_any()
.downcast_ref::<$DT>()
.expect("compute_op failed to downcast left side array");
let rr = $RIGHT
.as_any()
.downcast_ref::<$DT>()
.expect("compute_op failed to downcast right side array");
Ok(Arc::new(paste::expr! {[<$OP _f32>]}(ll, rr)?))
}};
}

macro_rules! compute_f64_op {
($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
let ll = $LEFT
.as_any()
.downcast_ref::<$DT>()
.expect("compute_op failed to downcast left side array");
let rr = $RIGHT
.as_any()
.downcast_ref::<$DT>()
.expect("compute_op failed to downcast right side array");
Ok(Arc::new(paste::expr! {[<$OP _f64>]}(ll, rr)?))
}};
}

macro_rules! compute_null_op {
($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
let ll = $LEFT
.as_any()
.downcast_ref::<$DT>()
.expect("compute_op failed to downcast array");
.expect("compute_op failed to downcast left side array");
let rr = $RIGHT
.as_any()
.downcast_ref::<$DT>()
.expect("compute_op failed to downcast array");
.expect("compute_op failed to downcast right side array");
Ok(Arc::new(paste::expr! {[<$OP _null>]}(&ll, &rr)?))
}};
}
Expand All @@ -204,11 +233,11 @@ macro_rules! compute_utf8_op {
let ll = $LEFT
.as_any()
.downcast_ref::<$DT>()
.expect("compute_op failed to downcast array");
.expect("compute_op failed to downcast left side array");
let rr = $RIGHT
.as_any()
.downcast_ref::<$DT>()
.expect("compute_op failed to downcast array");
.expect("compute_op failed to downcast right side array");
Ok(Arc::new(paste::expr! {[<$OP _utf8>]}(&ll, &rr)?))
}};
}
Expand All @@ -219,7 +248,7 @@ macro_rules! compute_utf8_op_scalar {
let ll = $LEFT
.as_any()
.downcast_ref::<$DT>()
.expect("compute_op failed to downcast array");
.expect("compute_op failed to downcast left side array");
if let ScalarValue::Utf8(Some(string_value)) = $RIGHT {
Ok(Arc::new(paste::expr! {[<$OP _utf8_scalar>]}(
&ll,
Expand Down Expand Up @@ -318,7 +347,7 @@ macro_rules! compute_op_scalar {
let ll = $LEFT
.as_any()
.downcast_ref::<$DT>()
.expect("compute_op failed to downcast array");
.expect("compute_op failed to downcast left side array");
Ok(Arc::new(paste::expr! {[<$OP _scalar>]}(
&ll,
$RIGHT.try_into()?,
Expand Down Expand Up @@ -392,11 +421,11 @@ macro_rules! compute_op {
let ll = $LEFT
.as_any()
.downcast_ref::<$DT>()
.expect("compute_op failed to downcast array");
.expect("compute_op failed to downcast left side array");
let rr = $RIGHT
.as_any()
.downcast_ref::<$DT>()
.expect("compute_op failed to downcast array");
.expect("compute_op failed to downcast right side array");
Ok(Arc::new($OP(&ll, &rr)?))
}};
// invoke unary operator
Expand Down Expand Up @@ -540,8 +569,8 @@ macro_rules! binary_array_op {
DataType::UInt16 => compute_op!($LEFT, $RIGHT, $OP, UInt16Array),
DataType::UInt32 => compute_op!($LEFT, $RIGHT, $OP, UInt32Array),
DataType::UInt64 => compute_op!($LEFT, $RIGHT, $OP, UInt64Array),
DataType::Float32 => compute_op!($LEFT, $RIGHT, $OP, Float32Array),
DataType::Float64 => compute_op!($LEFT, $RIGHT, $OP, Float64Array),
DataType::Float32 => compute_f32_op!($LEFT, $RIGHT, $OP, Float32Array),
DataType::Float64 => compute_f64_op!($LEFT, $RIGHT, $OP, Float64Array),
DataType::Utf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, StringArray),
DataType::Timestamp(TimeUnit::Nanosecond, _) => {
compute_op!($LEFT, $RIGHT, $OP, TimestampNanosecondArray)
Expand Down
122 changes: 100 additions & 22 deletions datafusion/physical-expr/src/expressions/binary/kernels_arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ use std::sync::Arc;
// Simple (low performance) kernels until optimized kernels are added to arrow
// See https://github.com/apache/arrow-rs/issues/960

macro_rules! distinct_float {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you -- this is great

($LEFT:expr, $RIGHT:expr, $LEFT_ISNULL:expr, $RIGHT_ISNULL:expr) => {{
$LEFT_ISNULL != $RIGHT_ISNULL
|| $LEFT.is_nan() != $RIGHT.is_nan()
|| (!$LEFT.is_nan() && !$RIGHT.is_nan() && $LEFT != $RIGHT)
}};
}

pub(crate) fn is_distinct_from_bool(
left: &BooleanArray,
right: &BooleanArray,
Expand Down Expand Up @@ -62,22 +70,13 @@ pub(crate) fn is_distinct_from<T>(
where
T: ArrowNumericType,
{
let left_data = left.data();
let right_data = right.data();
let array_len = left_data.len().min(right_data.len());

let left_values = left.values();
let right_values = right.values();

let distinct = arrow_buffer::MutableBuffer::collect_bool(array_len, |i| {
left_data.is_null(i) != right_data.is_null(i) || left_values[i] != right_values[i]
});

let array_data = ArrayData::builder(arrow_schema::DataType::Boolean)
.len(array_len)
.add_buffer(distinct.into());

Ok(BooleanArray::from(unsafe { array_data.build_unchecked() }))
distinct(
left,
right,
|left_value, right_value, left_isnull, right_isnull| {
left_isnull != right_isnull || left_value != right_value
},
)
}

pub(crate) fn is_not_distinct_from<T>(
Expand All @@ -87,25 +86,104 @@ pub(crate) fn is_not_distinct_from<T>(
where
T: ArrowNumericType,
{
let left_data = left.data();
let right_data = right.data();
let array_len = left_data.len().min(right_data.len());
distinct(
left,
right,
|left_value, right_value, left_isnull, right_isnull| {
!(left_isnull != right_isnull || left_value != right_value)
},
)
}

fn distinct<
T,
F: FnMut(
<T as ArrowPrimitiveType>::Native,
<T as ArrowPrimitiveType>::Native,
bool,
bool,
) -> bool,
>(
left: &PrimitiveArray<T>,
right: &PrimitiveArray<T>,
mut op: F,
) -> Result<BooleanArray>
where
T: ArrowNumericType,
{
let left_values = left.values();
let right_values = right.values();
let left_data = left.data();
let right_data = right.data();

let array_len = left_data.len().min(right_data.len());
let distinct = arrow_buffer::MutableBuffer::collect_bool(array_len, |i| {
!(left_data.is_null(i) != right_data.is_null(i)
|| left_values[i] != right_values[i])
op(
left_values[i],
right_values[i],
left_data.is_null(i),
right_data.is_null(i),
)
});

let array_data = ArrayData::builder(arrow_schema::DataType::Boolean)
.len(array_len)
.add_buffer(distinct.into());

Ok(BooleanArray::from(unsafe { array_data.build_unchecked() }))
}

pub(crate) fn is_distinct_from_f32(
left: &Float32Array,
right: &Float32Array,
) -> Result<BooleanArray> {
distinct(
left,
right,
|left_value, right_value, left_isnull, right_isnull| {
distinct_float!(left_value, right_value, left_isnull, right_isnull)
},
)
}

pub(crate) fn is_not_distinct_from_f32(
left: &Float32Array,
right: &Float32Array,
) -> Result<BooleanArray> {
distinct(
left,
right,
|left_value, right_value, left_isnull, right_isnull| {
!(distinct_float!(left_value, right_value, left_isnull, right_isnull))
},
)
}

pub(crate) fn is_distinct_from_f64(
left: &Float64Array,
right: &Float64Array,
) -> Result<BooleanArray> {
distinct(
left,
right,
|left_value, right_value, left_isnull, right_isnull| {
distinct_float!(left_value, right_value, left_isnull, right_isnull)
},
)
}

pub(crate) fn is_not_distinct_from_f64(
left: &Float64Array,
right: &Float64Array,
) -> Result<BooleanArray> {
distinct(
left,
right,
|left_value, right_value, left_isnull, right_isnull| {
!(distinct_float!(left_value, right_value, left_isnull, right_isnull))
},
)
}

pub(crate) fn is_distinct_from_utf8<OffsetSize: OffsetSizeTrait>(
left: &GenericStringArray<OffsetSize>,
right: &GenericStringArray<OffsetSize>,
Expand Down