Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support string concat || for StringViewArray #12063

Merged
merged 10 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 15 additions & 16 deletions datafusion/expr-common/src/type_coercion/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -912,26 +912,22 @@ fn dictionary_coercion(

/// Coercion rules for string concat.
/// This is a union of string coercion rules and specified rules:
/// 1. At lease one side of lhs and rhs should be string type (Utf8 / LargeUtf8)
/// 1. At least one side of lhs and rhs should be string type (Utf8 / LargeUtf8)
/// 2. Data type of the other side should be able to cast to string type
fn string_concat_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
use arrow::datatypes::DataType::*;
match (lhs_type, rhs_type) {
// If Utf8View is in any side, we coerce to Utf8.
// Ref: https://github.com/apache/datafusion/pull/11796
(Utf8View, Utf8View | Utf8 | LargeUtf8) | (Utf8 | LargeUtf8, Utf8View) => {
Some(Utf8)
string_coercion(lhs_type, rhs_type).or(match (lhs_type, rhs_type) {
(Utf8View, from_type) | (from_type, Utf8View) => {
string_concat_internal_coercion(from_type, &Utf8View)
}
_ => string_coercion(lhs_type, rhs_type).or(match (lhs_type, rhs_type) {
(Utf8, from_type) | (from_type, Utf8) => {
string_concat_internal_coercion(from_type, &Utf8)
}
(LargeUtf8, from_type) | (from_type, LargeUtf8) => {
string_concat_internal_coercion(from_type, &LargeUtf8)
}
_ => None,
}),
}
(Utf8, from_type) | (from_type, Utf8) => {
string_concat_internal_coercion(from_type, &Utf8)
}
(LargeUtf8, from_type) | (from_type, LargeUtf8) => {
string_concat_internal_coercion(from_type, &LargeUtf8)
}
_ => None,
})
}

fn array_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType> {
Expand All @@ -942,6 +938,8 @@ fn array_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType>
}
}

/// If `from_type` can be casted to `to_type`, return `to_type`, otherwise
/// return `None`.
fn string_concat_internal_coercion(
from_type: &DataType,
to_type: &DataType,
Expand All @@ -967,6 +965,7 @@ fn string_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataType>
}
// Then, if LargeUtf8 is in any side, we coerce to LargeUtf8.
(LargeUtf8, Utf8 | LargeUtf8) | (Utf8, LargeUtf8) => Some(LargeUtf8),
// Utf8 coerces to Utf8
(Utf8, Utf8) => Some(Utf8),
_ => None,
}
Expand Down
53 changes: 24 additions & 29 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ use datafusion_expr::type_coercion::binary::get_result_type;
use datafusion_expr::{ColumnarValue, Operator};
use datafusion_physical_expr_common::datum::{apply, apply_cmp, apply_cmp_for_nested};

use crate::expressions::binary::kernels::concat_elements_utf8view;
use kernels::{
bitwise_and_dyn, bitwise_and_dyn_scalar, bitwise_or_dyn, bitwise_or_dyn_scalar,
bitwise_shift_left_dyn, bitwise_shift_left_dyn_scalar, bitwise_shift_right_dyn,
Expand Down Expand Up @@ -131,34 +132,6 @@ impl std::fmt::Display for BinaryExpr {
}
}

/// Invoke a compute kernel on a pair of binary data arrays
macro_rules! compute_utf8_op {
($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{
let ll = $LEFT
.as_any()
.downcast_ref::<$DT>()
.expect("compute_op failed to downcast left side array");
let rr = $RIGHT
.as_any()
.downcast_ref::<$DT>()
.expect("compute_op failed to downcast right side array");
Ok(Arc::new(paste::expr! {[<$OP _utf8>]}(&ll, &rr)?))
}};
}

macro_rules! binary_string_array_op {
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
match $LEFT.data_type() {
DataType::Utf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, StringArray),
DataType::LargeUtf8 => compute_utf8_op!($LEFT, $RIGHT, $OP, LargeStringArray),
other => internal_err!(
"Data type {:?} not supported for binary operation '{}' on string arrays",
other, stringify!($OP)
),
}
}};
}

/// Invoke a boolean kernel on a pair of arrays
macro_rules! boolean_op {
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
Expand Down Expand Up @@ -662,14 +635,36 @@ impl BinaryExpr {
BitwiseXor => bitwise_xor_dyn(left, right),
BitwiseShiftRight => bitwise_shift_right_dyn(left, right),
BitwiseShiftLeft => bitwise_shift_left_dyn(left, right),
StringConcat => binary_string_array_op!(left, right, concat_elements),
StringConcat => concat_elements(left, right),
AtArrow | ArrowAt => {
unreachable!("ArrowAt and AtArrow should be rewritten to function")
}
}
}
}

fn concat_elements(left: Arc<dyn Array>, right: Arc<dyn Array>) -> Result<ArrayRef> {
Ok(match left.data_type() {
DataType::Utf8 => Arc::new(concat_elements_utf8(
left.as_string::<i32>(),
right.as_string::<i32>(),
)?),
DataType::LargeUtf8 => Arc::new(concat_elements_utf8(
left.as_string::<i64>(),
right.as_string::<i64>(),
)?),
DataType::Utf8View => Arc::new(concat_elements_utf8view(
left.as_string_view(),
right.as_string_view(),
)?),
other => {
return internal_err!(
"Data type {other:?} not supported for binary operation 'concat_elements' on string arrays"
);
}
})
}

/// Create a binary expression whose arguments are correctly coerced.
/// This function errors if it is not possible to coerce the arguments
/// to computational types supported by the operator.
Expand Down
33 changes: 33 additions & 0 deletions datafusion/physical-expr/src/expressions/binary/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use arrow::datatypes::DataType;
use datafusion_common::internal_err;
use datafusion_common::{Result, ScalarValue};

use arrow_schema::ArrowError;
use std::sync::Arc;

/// Downcasts $LEFT and $RIGHT to $ARRAY_TYPE and then calls $KERNEL($LEFT, $RIGHT)
Expand Down Expand Up @@ -131,3 +132,35 @@ create_dyn_scalar_kernel!(bitwise_or_dyn_scalar, bitwise_or_scalar);
create_dyn_scalar_kernel!(bitwise_xor_dyn_scalar, bitwise_xor_scalar);
create_dyn_scalar_kernel!(bitwise_shift_right_dyn_scalar, bitwise_shift_right_scalar);
create_dyn_scalar_kernel!(bitwise_shift_left_dyn_scalar, bitwise_shift_left_scalar);

pub fn concat_elements_utf8view(
left: &StringViewArray,
right: &StringViewArray,
) -> std::result::Result<StringViewArray, ArrowError> {
let capacity = left
.data_buffers()
.iter()
.zip(right.data_buffers().iter())
.map(|(b1, b2)| b1.len() + b2.len())
.sum();
let mut result = StringViewBuilder::with_capacity(capacity);

// Avoid reallocations by writing to a reused buffer (note we
// could be even more efficient r by creating the view directly
// here and avoid the buffer but that would be more complex)
let mut buffer = String::new();

for (left, right) in left.iter().zip(right.iter()) {
if let (Some(left), Some(right)) = (left, right) {
use std::fmt::Write;
buffer.clear();
write!(&mut buffer, "{left}{right}")
.expect("writing into string buffer failed");
result.append_value(&buffer);
} else {
// at least one of the values is null, so the output is also null
result.append_null()
}
}
Ok(result.finish())
}
74 changes: 69 additions & 5 deletions datafusion/sqllogictest/test_files/string_view.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,63 @@ FROM test;
0
NULL

# || mixed types
# expect all results to be the same for each row as they all have the same values
query TTTTTTTT
SELECT
column1_utf8view || column2_utf8view,
column1_utf8 || column2_utf8view,
column1_large_utf8 || column2_utf8view,
column1_dict || column2_utf8view,
-- reverse argument order
column2_utf8view || column1_utf8view,
column2_utf8view || column1_utf8,
column2_utf8view || column1_large_utf8,
column2_utf8view || column1_dict
FROM test;
----
AndrewX AndrewX AndrewX AndrewX XAndrew XAndrew XAndrew XAndrew
XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng
RaphaelR RaphaelR RaphaelR RaphaelR RRaphael RRaphael RRaphael RRaphael
NULL NULL NULL NULL NULL NULL NULL NULL

# || constants
# expect all results to be the same for each row as they all have the same values
query TTTTTTTT
SELECT
column1_utf8view || 'foo',
column1_utf8 || 'foo',
column1_large_utf8 || 'foo',
column1_dict || 'foo',
-- reverse argument order
'foo' || column1_utf8view,
'foo' || column1_utf8,
'foo' || column1_large_utf8,
'foo' || column1_dict
FROM test;
----
Andrewfoo Andrewfoo Andrewfoo Andrewfoo fooAndrew fooAndrew fooAndrew fooAndrew
Xiangpengfoo Xiangpengfoo Xiangpengfoo Xiangpengfoo fooXiangpeng fooXiangpeng fooXiangpeng fooXiangpeng
Raphaelfoo Raphaelfoo Raphaelfoo Raphaelfoo fooRaphael fooRaphael fooRaphael fooRaphael
NULL NULL NULL NULL NULL NULL NULL NULL

# || same type (column1 has null, so also tests NULL || NULL)
# expect all results to be the same for each row as they all have the same values
query TTT
SELECT
column1_utf8view || column1_utf8view,
column1_utf8 || column1_utf8,
column1_large_utf8 || column1_large_utf8
-- Dictionary/Dictionary coercion doesn't work
-- https://github.com/apache/datafusion/issues/12101
--column1_dict || column1_dict
FROM test;
----
AndrewAndrew AndrewAndrew AndrewAndrew
XiangpengXiangpeng XiangpengXiangpeng XiangpengXiangpeng
RaphaelRaphael RaphaelRaphael RaphaelRaphael
NULL NULL NULL

statement ok
drop table test;

Expand All @@ -1168,18 +1225,25 @@ select t.dt from dates t where arrow_cast('2024-01-01', 'Utf8View') < t.dt;
statement ok
drop table dates;

### Tests for `||` with Utf8View specifically

statement ok
create table temp as values
('value1', arrow_cast('rust', 'Utf8View'), arrow_cast('fast', 'Utf8View')),
('value2', arrow_cast('datafusion', 'Utf8View'), arrow_cast('cool', 'Utf8View'));

query TTT
select arrow_typeof(column1), arrow_typeof(column2), arrow_typeof(column3) from temp;
----
Utf8 Utf8View Utf8View
Utf8 Utf8View Utf8View

query T
select column2||' is fast' from temp;
----
rust is fast
datafusion is fast


query T
select column2 || ' is ' || column3 from temp;
----
Expand All @@ -1190,15 +1254,15 @@ query TT
explain select column2 || 'is' || column3 from temp;
----
logical_plan
01)Projection: CAST(temp.column2 AS Utf8) || Utf8("is") || CAST(temp.column3 AS Utf8)
01)Projection: temp.column2 || Utf8View("is") || temp.column3 AS temp.column2 || Utf8("is") || temp.column3
02)--TableScan: temp projection=[column2, column3]


# should not cast the column2 to utf8
query TT
explain select column2||' is fast' from temp;
----
logical_plan
01)Projection: CAST(temp.column2 AS Utf8) || Utf8(" is fast")
01)Projection: temp.column2 || Utf8View(" is fast") AS temp.column2 || Utf8(" is fast")
02)--TableScan: temp projection=[column2]


Expand All @@ -1212,7 +1276,7 @@ query TT
explain select column2||column3 from temp;
----
logical_plan
01)Projection: CAST(temp.column2 AS Utf8) || CAST(temp.column3 AS Utf8)
01)Projection: temp.column2 || temp.column3
02)--TableScan: temp projection=[column2, column3]

query T
Expand Down