@@ -401,25 +401,35 @@ fn get_valid_types(
401401 let mut fixed_size = array_coercion != Some ( & ListCoercion :: FixedSizedListToList ) ;
402402 let mut list_sizes = Vec :: with_capacity ( arguments. len ( ) ) ;
403403 let mut element_types = Vec :: with_capacity ( arguments. len ( ) ) ;
404+ let mut nested_item_nullability = Vec :: with_capacity ( arguments. len ( ) ) ;
404405 for ( argument, current_type) in arguments. iter ( ) . zip ( current_types. iter ( ) ) {
405406 match argument {
406- ArrayFunctionArgument :: Index | ArrayFunctionArgument :: String => ( ) ,
407+ ArrayFunctionArgument :: Index | ArrayFunctionArgument :: String => {
408+ nested_item_nullability. push ( None ) ;
409+ }
407410 ArrayFunctionArgument :: Element => {
408- element_types. push ( current_type. clone ( ) )
411+ element_types. push ( current_type. clone ( ) ) ;
412+ nested_item_nullability. push ( None ) ;
409413 }
410414 ArrayFunctionArgument :: Array => match current_type {
411- DataType :: Null => element_types. push ( DataType :: Null ) ,
415+ DataType :: Null => {
416+ element_types. push ( DataType :: Null ) ;
417+ nested_item_nullability. push ( None ) ;
418+ }
412419 DataType :: List ( field) => {
413420 element_types. push ( field. data_type ( ) . clone ( ) ) ;
421+ nested_item_nullability. push ( Some ( field. is_nullable ( ) ) ) ;
414422 fixed_size = false ;
415423 }
416424 DataType :: LargeList ( field) => {
417425 element_types. push ( field. data_type ( ) . clone ( ) ) ;
426+ nested_item_nullability. push ( Some ( field. is_nullable ( ) ) ) ;
418427 large_list = true ;
419428 fixed_size = false ;
420429 }
421430 DataType :: FixedSizeList ( field, size) => {
422431 element_types. push ( field. data_type ( ) . clone ( ) ) ;
432+ nested_item_nullability. push ( Some ( field. is_nullable ( ) ) ) ;
423433 list_sizes. push ( * size)
424434 }
425435 arg_type => {
@@ -429,33 +439,49 @@ fn get_valid_types(
429439 }
430440 }
431441
442+ debug_assert_eq ! ( nested_item_nullability. len( ) , arguments. len( ) ) ;
443+
432444 let Some ( element_type) = type_union_resolution ( & element_types) else {
433445 return Ok ( vec ! [ vec![ ] ] ) ;
434446 } ;
435447
436448 if !fixed_size {
437449 list_sizes. clear ( )
438- }
450+ } ;
439451
440452 let mut list_sizes = list_sizes. into_iter ( ) ;
441- let valid_types = arguments. iter ( ) . zip ( current_types. iter ( ) ) . map (
442- |( argument_type, current_type) | match argument_type {
443- ArrayFunctionArgument :: Index => DataType :: Int64 ,
444- ArrayFunctionArgument :: String => DataType :: Utf8 ,
445- ArrayFunctionArgument :: Element => element_type. clone ( ) ,
446- ArrayFunctionArgument :: Array => {
447- if current_type. is_null ( ) {
448- DataType :: Null
449- } else if large_list {
450- DataType :: new_large_list ( element_type. clone ( ) , true )
451- } else if let Some ( size) = list_sizes. next ( ) {
452- DataType :: new_fixed_size_list ( element_type. clone ( ) , size, true )
453- } else {
454- DataType :: new_list ( element_type. clone ( ) , true )
453+ let valid_types = arguments
454+ . iter ( )
455+ . zip ( current_types. iter ( ) )
456+ . zip ( nested_item_nullability. into_iter ( ) )
457+ . map ( |( ( argument_type, current_type) , is_nested_item_nullable) | {
458+ match argument_type {
459+ ArrayFunctionArgument :: Index => DataType :: Int64 ,
460+ ArrayFunctionArgument :: String => DataType :: Utf8 ,
461+ ArrayFunctionArgument :: Element => element_type. clone ( ) ,
462+ ArrayFunctionArgument :: Array => {
463+ if current_type. is_null ( ) {
464+ DataType :: Null
465+ } else if large_list {
466+ DataType :: new_large_list (
467+ element_type. clone ( ) ,
468+ is_nested_item_nullable. unwrap_or ( true ) ,
469+ )
470+ } else if let Some ( size) = list_sizes. next ( ) {
471+ DataType :: new_fixed_size_list (
472+ element_type. clone ( ) ,
473+ size,
474+ is_nested_item_nullable. unwrap_or ( true ) ,
475+ )
476+ } else {
477+ DataType :: new_list (
478+ element_type. clone ( ) ,
479+ is_nested_item_nullable. unwrap_or ( true ) ,
480+ )
481+ }
455482 }
456483 }
457- } ,
458- ) ;
484+ } ) ;
459485
460486 Ok ( vec ! [ valid_types. collect( ) ] )
461487 }
@@ -1343,6 +1369,18 @@ mod tests {
13431369 vec![ vec![ ] ]
13441370 ) ;
13451371
1372+ let data_types = vec ! [
1373+ DataType :: new_fixed_size_list( DataType :: Int64 , 3 , false ) ,
1374+ DataType :: new_list( DataType :: Int32 , false ) ,
1375+ ] ;
1376+ assert_eq ! (
1377+ get_valid_types( function, & signature. type_signature, & data_types) ?,
1378+ vec![ vec![
1379+ DataType :: new_list( DataType :: Int64 , false ) ,
1380+ DataType :: new_list( DataType :: Int64 , false ) ,
1381+ ] ]
1382+ ) ;
1383+
13461384 Ok ( ( ) )
13471385 }
13481386}
0 commit comments