Skip to content

Commit cab2eef

Browse files
committed
fix possible types
1 parent 2968ba6 commit cab2eef

File tree

2 files changed

+61
-31
lines changed

2 files changed

+61
-31
lines changed

datafusion/expr-common/src/signature.rs

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ use std::num::NonZeroUsize;
2323

2424
use crate::type_coercion::aggregates::NUMERICS;
2525
use arrow::datatypes::{DataType, IntervalUnit, TimeUnit};
26-
use datafusion_common::Result;
2726
use datafusion_common::types::{LogicalType, LogicalTypeRef, NativeType};
27+
use datafusion_common::{HashSet, Result};
2828
use itertools::Itertools;
2929

3030
/// Constant that is used as a placeholder for any valid timezone.
@@ -232,14 +232,18 @@ impl TypeSignatureClass {
232232
/// We return the largest common type for the given `TypeSignatureClass`
233233
pub fn default_casted_type(&self, data_type: &DataType) -> Result<DataType> {
234234
Ok(match self {
235-
TypeSignatureClass::Native(logical_type) => return logical_type.native().default_cast_for(data_type),
236-
TypeSignatureClass::Timestamp => DataType::Timestamp(TimeUnit::Nanosecond, None),
237-
TypeSignatureClass::Date => DataType::Date64,
238-
TypeSignatureClass::Time => DataType::Time64(TimeUnit::Nanosecond),
239-
TypeSignatureClass::Interval => DataType::Interval(IntervalUnit::DayTime),
240-
TypeSignatureClass::Duration => DataType::Duration(TimeUnit::Nanosecond),
241-
TypeSignatureClass::Integer => DataType::Int64,
242-
})
235+
TypeSignatureClass::Native(logical_type) => {
236+
return logical_type.native().default_cast_for(data_type)
237+
}
238+
TypeSignatureClass::Timestamp => {
239+
DataType::Timestamp(TimeUnit::Nanosecond, None)
240+
}
241+
TypeSignatureClass::Date => DataType::Date64,
242+
TypeSignatureClass::Time => DataType::Time64(TimeUnit::Nanosecond),
243+
TypeSignatureClass::Interval => DataType::Interval(IntervalUnit::DayTime),
244+
TypeSignatureClass::Duration => DataType::Duration(TimeUnit::Nanosecond),
245+
TypeSignatureClass::Integer => DataType::Int64,
246+
})
243247
}
244248
}
245249

@@ -400,6 +404,23 @@ impl TypeSignature {
400404
.cloned()
401405
.map(|data_type| vec![data_type; *arg_count])
402406
.collect(),
407+
TypeSignature::CoercibleV2(coercions) => coercions
408+
.iter()
409+
.map(|c| {
410+
let mut all_types: HashSet<DataType> =
411+
get_possible_types_from_signature_classes(&c.desired_type)
412+
.into_iter()
413+
.collect();
414+
let allowed_casts: Vec<DataType> = c
415+
.allowed_casts
416+
.iter()
417+
.flat_map(|t| get_possible_types_from_signature_classes(t))
418+
.collect();
419+
all_types.extend(allowed_casts.into_iter());
420+
all_types.into_iter().collect::<Vec<_>>()
421+
})
422+
.multi_cartesian_product()
423+
.collect(),
403424
TypeSignature::Coercible(types) => types
404425
.iter()
405426
.map(|logical_type| match logical_type {
@@ -451,12 +472,40 @@ impl TypeSignature {
451472
| TypeSignature::Nullary
452473
| TypeSignature::VariadicAny
453474
| TypeSignature::ArraySignature(_)
454-
| TypeSignature::CoercibleV2(_)
455475
| TypeSignature::UserDefined => vec![],
456476
}
457477
}
458478
}
459479

480+
fn get_possible_types_from_signature_classes(
481+
signature_classes: &TypeSignatureClass,
482+
) -> Vec<DataType> {
483+
match signature_classes {
484+
TypeSignatureClass::Native(l) => get_data_types(l.native()),
485+
TypeSignatureClass::Timestamp => {
486+
vec![
487+
DataType::Timestamp(TimeUnit::Nanosecond, None),
488+
DataType::Timestamp(TimeUnit::Nanosecond, Some(TIMEZONE_WILDCARD.into())),
489+
]
490+
}
491+
TypeSignatureClass::Date => {
492+
vec![DataType::Date64]
493+
}
494+
TypeSignatureClass::Time => {
495+
vec![DataType::Time64(TimeUnit::Nanosecond)]
496+
}
497+
TypeSignatureClass::Interval => {
498+
vec![DataType::Interval(IntervalUnit::DayTime)]
499+
}
500+
TypeSignatureClass::Duration => {
501+
vec![DataType::Duration(TimeUnit::Nanosecond)]
502+
}
503+
TypeSignatureClass::Integer => {
504+
vec![DataType::Int64]
505+
}
506+
}
507+
}
508+
460509
fn get_data_types(native_type: &NativeType) -> Vec<DataType> {
461510
match native_type {
462511
NativeType::Null => vec![DataType::Null],

datafusion/expr/src/type_coercion/functions.rs

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,8 @@ fn get_valid_types(
617617
.iter()
618618
.any(|t| is_matched_type(t, &current_logical_type))
619619
{
620-
let casted_type = param.desired_type.default_casted_type(current_type)?;
620+
let casted_type =
621+
param.desired_type.default_casted_type(current_type)?;
621622
new_types.push(casted_type);
622623
} else {
623624
return internal_err!(
@@ -627,26 +628,6 @@ fn get_valid_types(
627628
current_type
628629
);
629630
}
630-
631-
// if let Some(casted_type) = get_casted_type(
632-
// &param.desired_type,
633-
// &current_logical_type,
634-
// current_type,
635-
// )
636-
// .or_else(|| {
637-
// param.allowed_casts.iter().find_map(|t| {
638-
// get_casted_type(t, &current_logical_type, current_type)
639-
// })
640-
// }) {
641-
// new_types.push(casted_type);
642-
// } else {
643-
// return internal_err!(
644-
// "Expect {} but received NativeType: {}, DataType: {}",
645-
// param.desired_type,
646-
// current_logical_type,
647-
// current_type
648-
// );
649-
// }
650631
}
651632

652633
vec![new_types]

0 commit comments

Comments
 (0)