@@ -23,8 +23,8 @@ use std::num::NonZeroUsize;
2323
2424use crate :: type_coercion:: aggregates:: NUMERICS ;
2525use arrow:: datatypes:: { DataType , IntervalUnit , TimeUnit } ;
26- use datafusion_common:: Result ;
2726use datafusion_common:: types:: { LogicalType , LogicalTypeRef , NativeType } ;
27+ use datafusion_common:: { HashSet , Result } ;
2828use itertools:: Itertools ;
2929
3030/// Constant that is used as a placeholder for any valid timezone.
@@ -232,14 +232,18 @@ impl TypeSignatureClass {
232232 /// We return the largest common type for the given `TypeSignatureClass`
233233 pub fn default_casted_type ( & self , data_type : & DataType ) -> Result < DataType > {
234234 Ok ( match self {
235- TypeSignatureClass :: Native ( logical_type) => return logical_type. native ( ) . default_cast_for ( data_type) ,
236- TypeSignatureClass :: Timestamp => DataType :: Timestamp ( TimeUnit :: Nanosecond , None ) ,
237- TypeSignatureClass :: Date => DataType :: Date64 ,
238- TypeSignatureClass :: Time => DataType :: Time64 ( TimeUnit :: Nanosecond ) ,
239- TypeSignatureClass :: Interval => DataType :: Interval ( IntervalUnit :: DayTime ) ,
240- TypeSignatureClass :: Duration => DataType :: Duration ( TimeUnit :: Nanosecond ) ,
241- TypeSignatureClass :: Integer => DataType :: Int64 ,
242- } )
235+ TypeSignatureClass :: Native ( logical_type) => {
236+ return logical_type. native ( ) . default_cast_for ( data_type)
237+ }
238+ TypeSignatureClass :: Timestamp => {
239+ DataType :: Timestamp ( TimeUnit :: Nanosecond , None )
240+ }
241+ TypeSignatureClass :: Date => DataType :: Date64 ,
242+ TypeSignatureClass :: Time => DataType :: Time64 ( TimeUnit :: Nanosecond ) ,
243+ TypeSignatureClass :: Interval => DataType :: Interval ( IntervalUnit :: DayTime ) ,
244+ TypeSignatureClass :: Duration => DataType :: Duration ( TimeUnit :: Nanosecond ) ,
245+ TypeSignatureClass :: Integer => DataType :: Int64 ,
246+ } )
243247 }
244248}
245249
@@ -400,6 +404,23 @@ impl TypeSignature {
400404 . cloned ( )
401405 . map ( |data_type| vec ! [ data_type; * arg_count] )
402406 . collect ( ) ,
407+ TypeSignature :: CoercibleV2 ( coercions) => coercions
408+ . iter ( )
409+ . map ( |c| {
410+ let mut all_types: HashSet < DataType > =
411+ get_possible_types_from_signature_classes ( & c. desired_type )
412+ . into_iter ( )
413+ . collect ( ) ;
414+ let allowed_casts: Vec < DataType > = c
415+ . allowed_casts
416+ . iter ( )
417+ . flat_map ( |t| get_possible_types_from_signature_classes ( t) )
418+ . collect ( ) ;
419+ all_types. extend ( allowed_casts. into_iter ( ) ) ;
420+ all_types. into_iter ( ) . collect :: < Vec < _ > > ( )
421+ } )
422+ . multi_cartesian_product ( )
423+ . collect ( ) ,
403424 TypeSignature :: Coercible ( types) => types
404425 . iter ( )
405426 . map ( |logical_type| match logical_type {
@@ -451,12 +472,40 @@ impl TypeSignature {
451472 | TypeSignature :: Nullary
452473 | TypeSignature :: VariadicAny
453474 | TypeSignature :: ArraySignature ( _)
454- | TypeSignature :: CoercibleV2 ( _)
455475 | TypeSignature :: UserDefined => vec ! [ ] ,
456476 }
457477 }
458478}
459479
480+ fn get_possible_types_from_signature_classes (
481+ signature_classes : & TypeSignatureClass ,
482+ ) -> Vec < DataType > {
483+ match signature_classes {
484+ TypeSignatureClass :: Native ( l) => get_data_types ( l. native ( ) ) ,
485+ TypeSignatureClass :: Timestamp => {
486+ vec ! [
487+ DataType :: Timestamp ( TimeUnit :: Nanosecond , None ) ,
488+ DataType :: Timestamp ( TimeUnit :: Nanosecond , Some ( TIMEZONE_WILDCARD . into( ) ) ) ,
489+ ]
490+ }
491+ TypeSignatureClass :: Date => {
492+ vec ! [ DataType :: Date64 ]
493+ }
494+ TypeSignatureClass :: Time => {
495+ vec ! [ DataType :: Time64 ( TimeUnit :: Nanosecond ) ]
496+ }
497+ TypeSignatureClass :: Interval => {
498+ vec ! [ DataType :: Interval ( IntervalUnit :: DayTime ) ]
499+ }
500+ TypeSignatureClass :: Duration => {
501+ vec ! [ DataType :: Duration ( TimeUnit :: Nanosecond ) ]
502+ }
503+ TypeSignatureClass :: Integer => {
504+ vec ! [ DataType :: Int64 ]
505+ }
506+ }
507+ }
508+
460509fn get_data_types ( native_type : & NativeType ) -> Vec < DataType > {
461510 match native_type {
462511 NativeType :: Null => vec ! [ DataType :: Null ] ,
0 commit comments