Skip to content

Commit a96dcc8

Browse files
authored
Fix array types coercion: preserve child element nullability for list types (#17306)
1 parent f99a6cf commit a96dcc8

File tree

1 file changed

+58
-20
lines changed

1 file changed

+58
-20
lines changed

datafusion/expr/src/type_coercion/functions.rs

Lines changed: 58 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -401,25 +401,35 @@ fn get_valid_types(
401401
let mut fixed_size = array_coercion != Some(&ListCoercion::FixedSizedListToList);
402402
let mut list_sizes = Vec::with_capacity(arguments.len());
403403
let mut element_types = Vec::with_capacity(arguments.len());
404+
let mut nested_item_nullability = Vec::with_capacity(arguments.len());
404405
for (argument, current_type) in arguments.iter().zip(current_types.iter()) {
405406
match argument {
406-
ArrayFunctionArgument::Index | ArrayFunctionArgument::String => (),
407+
ArrayFunctionArgument::Index | ArrayFunctionArgument::String => {
408+
nested_item_nullability.push(None);
409+
}
407410
ArrayFunctionArgument::Element => {
408-
element_types.push(current_type.clone())
411+
element_types.push(current_type.clone());
412+
nested_item_nullability.push(None);
409413
}
410414
ArrayFunctionArgument::Array => match current_type {
411-
DataType::Null => element_types.push(DataType::Null),
415+
DataType::Null => {
416+
element_types.push(DataType::Null);
417+
nested_item_nullability.push(None);
418+
}
412419
DataType::List(field) => {
413420
element_types.push(field.data_type().clone());
421+
nested_item_nullability.push(Some(field.is_nullable()));
414422
fixed_size = false;
415423
}
416424
DataType::LargeList(field) => {
417425
element_types.push(field.data_type().clone());
426+
nested_item_nullability.push(Some(field.is_nullable()));
418427
large_list = true;
419428
fixed_size = false;
420429
}
421430
DataType::FixedSizeList(field, size) => {
422431
element_types.push(field.data_type().clone());
432+
nested_item_nullability.push(Some(field.is_nullable()));
423433
list_sizes.push(*size)
424434
}
425435
arg_type => {
@@ -429,33 +439,49 @@ fn get_valid_types(
429439
}
430440
}
431441

442+
debug_assert_eq!(nested_item_nullability.len(), arguments.len());
443+
432444
let Some(element_type) = type_union_resolution(&element_types) else {
433445
return Ok(vec![vec![]]);
434446
};
435447

436448
if !fixed_size {
437449
list_sizes.clear()
438-
}
450+
};
439451

440452
let mut list_sizes = list_sizes.into_iter();
441-
let valid_types = arguments.iter().zip(current_types.iter()).map(
442-
|(argument_type, current_type)| match argument_type {
443-
ArrayFunctionArgument::Index => DataType::Int64,
444-
ArrayFunctionArgument::String => DataType::Utf8,
445-
ArrayFunctionArgument::Element => element_type.clone(),
446-
ArrayFunctionArgument::Array => {
447-
if current_type.is_null() {
448-
DataType::Null
449-
} else if large_list {
450-
DataType::new_large_list(element_type.clone(), true)
451-
} else if let Some(size) = list_sizes.next() {
452-
DataType::new_fixed_size_list(element_type.clone(), size, true)
453-
} else {
454-
DataType::new_list(element_type.clone(), true)
453+
let valid_types = arguments
454+
.iter()
455+
.zip(current_types.iter())
456+
.zip(nested_item_nullability)
457+
.map(|((argument_type, current_type), is_nested_item_nullable)| {
458+
match argument_type {
459+
ArrayFunctionArgument::Index => DataType::Int64,
460+
ArrayFunctionArgument::String => DataType::Utf8,
461+
ArrayFunctionArgument::Element => element_type.clone(),
462+
ArrayFunctionArgument::Array => {
463+
if current_type.is_null() {
464+
DataType::Null
465+
} else if large_list {
466+
DataType::new_large_list(
467+
element_type.clone(),
468+
is_nested_item_nullable.unwrap_or(true),
469+
)
470+
} else if let Some(size) = list_sizes.next() {
471+
DataType::new_fixed_size_list(
472+
element_type.clone(),
473+
size,
474+
is_nested_item_nullable.unwrap_or(true),
475+
)
476+
} else {
477+
DataType::new_list(
478+
element_type.clone(),
479+
is_nested_item_nullable.unwrap_or(true),
480+
)
481+
}
455482
}
456483
}
457-
},
458-
);
484+
});
459485

460486
Ok(vec![valid_types.collect()])
461487
}
@@ -1343,6 +1369,18 @@ mod tests {
13431369
vec![vec![]]
13441370
);
13451371

1372+
let data_types = vec![
1373+
DataType::new_fixed_size_list(DataType::Int64, 3, false),
1374+
DataType::new_list(DataType::Int32, false),
1375+
];
1376+
assert_eq!(
1377+
get_valid_types(function, &signature.type_signature, &data_types)?,
1378+
vec![vec![
1379+
DataType::new_list(DataType::Int64, false),
1380+
DataType::new_list(DataType::Int64, false),
1381+
]]
1382+
);
1383+
13461384
Ok(())
13471385
}
13481386
}

0 commit comments

Comments
 (0)