Skip to content

Commit

Permalink
string coercible
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
  • Loading branch information
jayzhan211 committed Oct 31, 2024
1 parent 5f99515 commit 840139e
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 13 deletions.
35 changes: 35 additions & 0 deletions datafusion/common/src/types/native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,41 @@ impl From<DataType> for NativeType {
}
}

impl NativeType {
#[inline]
pub fn is_numeric(&self) -> bool {
use NativeType::*;
matches!(
self,
UInt8
| UInt16
| UInt32
| UInt64
| Int8
| Int16
| Int32
| Int64
| Float16
| Float32
| Float64
)
}

/// This function is the NativeType version of `can_cast_types`.
/// It handles general coercion rules that are widely applicable.
/// Avoid adding specific coercion cases here.
/// Aim to keep this logic as SIMPLE as possible!
pub fn can_cast_to(&self, target_type: &Self) -> bool {
// In Postgres, most functions coerce numeric strings to numeric inputs,
// but they do not accept numeric inputs as strings.
if self.is_numeric() && target_type == &NativeType::String {
return false;
}

true
}
}

// Singleton instances
// TODO: Replace with LazyLock
// pub static LOGICAL_STRING: OnceLock<LogicalTypeRef> = OnceLock::new();
Expand Down
3 changes: 0 additions & 3 deletions datafusion/expr-common/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@

use arrow::datatypes::DataType;
use datafusion_common::types::LogicalTypeRef;
// use datafusion_common::types::LogicalType;

// use crate::logical_type::LogicalTypeRef;

/// Constant that is used as a placeholder for any valid timezone.
/// This is used where a function can accept a timestamp type with any
Expand Down
16 changes: 13 additions & 3 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use arrow::{
};
use datafusion_common::{
exec_err, internal_datafusion_err, internal_err, plan_err,
types::NativeType,
types::{logical_string, NativeType},
utils::{coerced_fixed_size_list_to_list, list_ndims},
Result,
};
Expand Down Expand Up @@ -402,6 +402,10 @@ fn get_valid_types(
.map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
.collect(),
TypeSignature::String(number) => {
// TODO: we can switch to coercible after all the string functions support utf8view since it is choosen as the default string type.
//
// let data_types = get_valid_types(&TypeSignature::Coercible(vec![logical_string(); *number]), current_types)?.swap_remove(0);

if *number < 1 {
return plan_err!(
"The signature expected at least one argument but received {}",
Expand All @@ -425,7 +429,6 @@ fn get_valid_types(
new_types.push(data_type.to_owned());
}
NativeType::Null => {
// TODO: we can switch to Utf8 if all the related string function supports ut8view
new_types.push(DataType::Utf8);
}
_ => {
Expand All @@ -437,6 +440,7 @@ fn get_valid_types(
}

let data_types = new_types;

// Find the common string type for the given types
fn find_common_type(
lhs_type: &DataType,
Expand Down Expand Up @@ -529,9 +533,15 @@ fn get_valid_types(
let logical_data_type: NativeType = data_type.into();
if logical_data_type == *target_type.native() {
new_types.push(data_type.to_owned());
} else {
} else if logical_data_type.can_cast_to(target_type.native()) {
let casted_type = target_type.default_cast_for(data_type)?;
new_types.push(casted_type);
} else {
return plan_err!(
"The signature expected {:?} but received {:?}",
target_type.native(),
logical_data_type
);
}
}

Expand Down
1 change: 0 additions & 1 deletion datafusion/functions-aggregate/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ arrow-schema = { workspace = true }
datafusion-common = { workspace = true }
datafusion-execution = { workspace = true }
datafusion-expr = { workspace = true }
datafusion-expr-common = { workspace = true }
datafusion-functions-aggregate-common = { workspace = true }
datafusion-physical-expr = { workspace = true }
datafusion-physical-expr-common = { workspace = true }
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/string/bit_length.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ impl ScalarUDFImpl for BitLengthFunc {
ScalarValue::LargeUtf8(v) => Ok(ColumnarValue::Scalar(
ScalarValue::Int64(v.as_ref().map(|x| (x.len() * 8) as i64)),
)),
_ => unreachable!(),
_ => unreachable!("bit length"),
},
}
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/string/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ impl ScalarUDFImpl for ConcatFunc {
}
};
}
_ => unreachable!(),
_ => unreachable!("concat"),
}
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/string/concat_ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ impl ScalarUDFImpl for ConcatWsFunc {
ColumnarValueRef::NonNullableArray(string_array)
}
}
_ => unreachable!(),
_ => unreachable!("concat ws"),
};

let mut columns = Vec::with_capacity(args.len() - 1);
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/string/lower.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ mod tests {
let args = vec![ColumnarValue::Array(input)];
let result = match func.invoke(&args)? {
ColumnarValue::Array(result) => result,
_ => unreachable!(),
_ => unreachable!("lower"),
};
assert_eq!(&expected, &result);
Ok(())
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/string/upper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ mod tests {
let args = vec![ColumnarValue::Array(input)];
let result = match func.invoke(&args)? {
ColumnarValue::Array(result) => result,
_ => unreachable!(),
_ => unreachable!("upper"),
};
assert_eq!(&expected, &result);
Ok(())
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions/src/unicode/lpad.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ pub fn lpad<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
length_array,
&args[2],
),
(_, _) => unreachable!(),
(_, _) => unreachable!("lpad"),
}
}

Expand Down

0 comments on commit 840139e

Please sign in to comment.