Skip to content

Commit

Permalink
Fix inconsistent array type for binary numerical operators result bet…
Browse files Browse the repository at this point in the history
…ween array and scalar (#6269)

* Cast binary numerical operators result between array and scalar to primitive array

* Add order by to stablize query result

* Fix tests

* Fix clippy
  • Loading branch information
viirya authored May 9, 2023
1 parent c7f5183 commit 1dd3674
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 64 deletions.
28 changes: 20 additions & 8 deletions datafusion/core/tests/sqllogictests/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1719,14 +1719,26 @@ select max(x_dict) from value_dict where x_dict > 3;
----
5

query error DataFusion error: External error: Arrow error: Invalid argument error: RowConverter column schema mismatch, expected Int64 got Dictionary\(Int64, Int64\)
select sum(x_dict) from value_dict group by x_dict % 2;
query I
select sum(x_dict) from value_dict group by x_dict % 2 order by sum(x_dict);
----
8
13

query error DataFusion error: External error: Arrow error: Invalid argument error: RowConverter column schema mismatch, expected Int64 got Dictionary\(Int64, Int64\)
select avg(x_dict) from value_dict group by x_dict % 2;
query R
select avg(x_dict) from value_dict group by x_dict % 2 order by avg(x_dict);
----
2.6
2.666666666667

query error DataFusion error: External error: Arrow error: Invalid argument error: RowConverter column schema mismatch, expected Int64 got Dictionary\(Int64, Int64\)
select min(x_dict) from value_dict group by x_dict % 2;
query I
select min(x_dict) from value_dict group by x_dict % 2 order by min(x_dict);
----
1
2

query error DataFusion error: External error: Arrow error: Invalid argument error: RowConverter column schema mismatch, expected Int64 got Dictionary\(Int64, Int64\)
select max(x_dict) from value_dict group by x_dict % 2;
query I
select max(x_dict) from value_dict group by x_dict % 2 order by max(x_dict);
----
4
5
122 changes: 66 additions & 56 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ use arrow::compute::kernels::comparison::{
eq_dyn_utf8_scalar, gt_dyn_utf8_scalar, gt_eq_dyn_utf8_scalar, lt_dyn_utf8_scalar,
lt_eq_dyn_utf8_scalar, neq_dyn_utf8_scalar,
};
use arrow::compute::{try_unary, unary, CastOptions};
use arrow::compute::{cast, try_unary, unary, CastOptions};
use arrow::datatypes::*;

use adapter::{eq_dyn, gt_dyn, gt_eq_dyn, lt_dyn, lt_eq_dyn, neq_dyn};
Expand Down Expand Up @@ -694,6 +694,9 @@ impl PhysicalExpr for BinaryExpr {
(ColumnarValue::Array(array), ColumnarValue::Scalar(scalar)) => {
// if left is array and right is literal - use scalar operations
self.evaluate_array_scalar(array, scalar.clone(), &result_type)?
.map(|r| {
r.and_then(|a| to_result_type_array(&self.op, a, &result_type))
})
}
(ColumnarValue::Scalar(scalar), ColumnarValue::Array(array)) => {
// if right is literal and left is array - reverse operator and parameters
Expand Down Expand Up @@ -1027,6 +1030,35 @@ pub(crate) fn array_eq_scalar(lhs: &dyn Array, rhs: &ScalarValue) -> Result<Arra
)?
}

/// Casts dictionary array to result type for binary numerical operators. Such operators
/// between array and scalar produce a dictionary array other than primitive array of the
/// same operators between array and array. This leads to inconsistent result types causing
/// errors in the following query execution. For such operators between array and scalar,
/// we cast the dictionary array to primitive array.
fn to_result_type_array(
op: &Operator,
array: ArrayRef,
result_type: &DataType,
) -> Result<ArrayRef> {
if op.is_numerical_operators() {
match array.data_type() {
DataType::Dictionary(_, value_type) => {
if value_type.as_ref() == result_type {
Ok(cast(&array, result_type)?)
} else {
Err(DataFusionError::Internal(format!(
"Incompatible Dictionary value type {:?} with result type {:?} of Binary operator {:?}",
value_type, result_type, op
)))
}
}
_ => Ok(array),
}
} else {
Ok(array)
}
}

impl BinaryExpr {
/// Evaluate the expression of the left input is an array and
/// right is literal - use scalar operations
Expand Down Expand Up @@ -2699,13 +2731,8 @@ mod tests {

let a = dict_builder.finish();

let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();

dict_builder.append(2)?;
dict_builder.append_null();
dict_builder.append(3)?;
dict_builder.append(6)?;
let expected = dict_builder.finish();
let expected: PrimitiveArray<Int32Type> =
PrimitiveArray::from(vec![Some(2), None, Some(3), Some(6)]);

apply_arithmetic_scalar(
Arc::new(schema),
Expand Down Expand Up @@ -2742,13 +2769,17 @@ mod tests {
let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
let a = DictionaryArray::try_new(keys, decimal_array)?;

let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
let decimal_array = Arc::new(create_decimal_array(
&[Some(value + 1), None, Some(value), Some(value + 2)],
&[
Some(value + 1),
Some(value),
None,
Some(value + 2),
Some(value + 1),
],
11,
0,
));
let expected = DictionaryArray::try_new(keys, decimal_array)?;

apply_arithmetic_scalar(
Arc::new(schema),
Expand All @@ -2758,7 +2789,7 @@ mod tests {
Box::new(DataType::Int8),
Box::new(ScalarValue::Decimal128(Some(1), 10, 0)),
),
Arc::new(expected),
decimal_array,
)?;

Ok(())
Expand Down Expand Up @@ -2918,13 +2949,8 @@ mod tests {

let a = dict_builder.finish();

let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();

dict_builder.append(0)?;
dict_builder.append_null();
dict_builder.append(1)?;
dict_builder.append(4)?;
let expected = dict_builder.finish();
let expected: PrimitiveArray<Int32Type> =
PrimitiveArray::from(vec![Some(0), None, Some(1), Some(4)]);

apply_arithmetic_scalar(
Arc::new(schema),
Expand Down Expand Up @@ -2961,13 +2987,17 @@ mod tests {
let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
let a = DictionaryArray::try_new(keys, decimal_array)?;

let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
let decimal_array = Arc::new(create_decimal_array(
&[Some(value - 1), None, Some(value - 2), Some(value)],
&[
Some(value - 1),
Some(value - 2),
None,
Some(value),
Some(value - 1),
],
11,
0,
));
let expected = DictionaryArray::try_new(keys, decimal_array)?;

apply_arithmetic_scalar(
Arc::new(schema),
Expand All @@ -2977,7 +3007,7 @@ mod tests {
Box::new(DataType::Int8),
Box::new(ScalarValue::Decimal128(Some(1), 10, 0)),
),
Arc::new(expected),
decimal_array,
)?;

Ok(())
Expand Down Expand Up @@ -3133,13 +3163,8 @@ mod tests {

let a = dict_builder.finish();

let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();

dict_builder.append(2)?;
dict_builder.append_null();
dict_builder.append(4)?;
dict_builder.append(10)?;
let expected = dict_builder.finish();
let expected: PrimitiveArray<Int32Type> =
PrimitiveArray::from(vec![Some(2), None, Some(4), Some(10)]);

apply_arithmetic_scalar(
Arc::new(schema),
Expand Down Expand Up @@ -3176,13 +3201,11 @@ mod tests {
let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
let a = DictionaryArray::try_new(keys, decimal_array)?;

let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
let decimal_array = Arc::new(create_decimal_array(
&[Some(246), None, Some(244), Some(248)],
&[Some(246), Some(244), None, Some(248), Some(246)],
21,
0,
));
let expected = DictionaryArray::try_new(keys, decimal_array)?;

apply_arithmetic_scalar(
Arc::new(schema),
Expand All @@ -3192,7 +3215,7 @@ mod tests {
Box::new(DataType::Int8),
Box::new(ScalarValue::Decimal128(Some(2), 10, 0)),
),
Arc::new(expected),
decimal_array,
)?;

Ok(())
Expand Down Expand Up @@ -3360,13 +3383,8 @@ mod tests {

let a = dict_builder.finish();

let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();

dict_builder.append(0)?;
dict_builder.append_null();
dict_builder.append(1)?;
dict_builder.append(2)?;
let expected = dict_builder.finish();
let expected: PrimitiveArray<Int32Type> =
PrimitiveArray::from(vec![Some(0), None, Some(1), Some(2)]);

apply_arithmetic_scalar(
Arc::new(schema),
Expand Down Expand Up @@ -3403,18 +3421,17 @@ mod tests {
let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
let a = DictionaryArray::try_new(keys, decimal_array)?;

let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
let decimal_array = Arc::new(create_decimal_array(
&[
Some(6150000000000),
None,
Some(6100000000000),
None,
Some(6200000000000),
Some(6150000000000),
],
21,
11,
));
let expected = DictionaryArray::try_new(keys, decimal_array)?;

apply_arithmetic_scalar(
Arc::new(schema),
Expand All @@ -3424,7 +3441,7 @@ mod tests {
Box::new(DataType::Int8),
Box::new(ScalarValue::Decimal128(Some(2), 10, 0)),
),
Arc::new(expected),
decimal_array,
)?;

Ok(())
Expand Down Expand Up @@ -3582,13 +3599,8 @@ mod tests {

let a = dict_builder.finish();

let mut dict_builder = PrimitiveDictionaryBuilder::<Int8Type, Int32Type>::new();

dict_builder.append(1)?;
dict_builder.append_null();
dict_builder.append(0)?;
dict_builder.append(1)?;
let expected = dict_builder.finish();
let expected: PrimitiveArray<Int32Type> =
PrimitiveArray::from(vec![Some(1), None, Some(0), Some(1)]);

apply_arithmetic_scalar(
Arc::new(schema),
Expand Down Expand Up @@ -3625,13 +3637,11 @@ mod tests {
let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
let a = DictionaryArray::try_new(keys, decimal_array)?;

let keys = Int8Array::from(vec![0, 2, 1, 3, 0]);
let decimal_array = Arc::new(create_decimal_array(
&[Some(1), None, Some(0), Some(0)],
&[Some(1), Some(0), None, Some(0), Some(1)],
10,
0,
));
let expected = DictionaryArray::try_new(keys, decimal_array)?;

apply_arithmetic_scalar(
Arc::new(schema),
Expand All @@ -3641,7 +3651,7 @@ mod tests {
Box::new(DataType::Int8),
Box::new(ScalarValue::Decimal128(Some(2), 10, 0)),
),
Arc::new(expected),
decimal_array,
)?;

Ok(())
Expand Down

0 comments on commit 1dd3674

Please sign in to comment.