diff --git a/arrow/src/compute/kernels/take.rs b/arrow/src/compute/kernels/take.rs index 66bfd613fd80..4f1fe09e335c 100644 --- a/arrow/src/compute/kernels/take.rs +++ b/arrow/src/compute/kernels/take.rs @@ -231,9 +231,22 @@ where .map(|a| take_impl(a.as_ref(), indices, Some(options.clone()))) .collect(); let arrays = arrays?; - let pairs: Vec<(Field, ArrayRef)> = + let fields: Vec<(Field, ArrayRef)> = fields.clone().into_iter().zip(arrays).collect(); - Ok(Arc::new(StructArray::from(pairs)) as ArrayRef) + + // Create the null bit buffer. + let is_valid: Buffer = indices + .iter() + .map(|index| { + if let Some(index) = index { + struct_.is_valid(ArrowNativeType::to_usize(&index).unwrap()) + } else { + false + } + }) + .collect(); + + Ok(Arc::new(StructArray::from((fields, is_valid))) as ArrayRef) } DataType::Dictionary(key_type, _) => match key_type.as_ref() { DataType::Int8 => downcast_dict_take!(Int8Type, values, indices), @@ -824,20 +837,34 @@ mod tests { } // create a simple struct for testing purposes - fn create_test_struct() -> StructArray { - let boolean_data = BooleanArray::from(vec![true, false, false, true]) - .data() - .clone(); - let int_data = Int32Array::from(vec![42, 28, 19, 31]).data().clone(); - let mut field_types = vec![]; - field_types.push(Field::new("a", DataType::Boolean, true)); - field_types.push(Field::new("b", DataType::Int32, true)); - let struct_array_data = ArrayData::builder(DataType::Struct(field_types)) - .len(4) - .add_child_data(boolean_data) - .add_child_data(int_data) - .build(); - StructArray::from(struct_array_data) + fn create_test_struct( + values: Vec, Option)>>, + ) -> StructArray { + let mut struct_builder = StructBuilder::new( + vec![ + Field::new("a", DataType::Boolean, true), + Field::new("b", DataType::Int32, true), + ], + vec![ + Box::new(BooleanBuilder::new(values.len())), + Box::new(Int32Builder::new(values.len())), + ], + ); + + for value in values { + struct_builder + .field_builder::(0) + .unwrap() + .append_option(value.and_then(|v| v.0)) + .unwrap(); + struct_builder + .field_builder::(1) + .unwrap() + .append_option(value.and_then(|v| v.1)) + .unwrap(); + struct_builder.append(value.is_some()).unwrap(); + } + struct_builder.finish() } #[test] @@ -1491,61 +1518,59 @@ mod tests { #[test] fn test_take_struct() { - let array = create_test_struct(); - - let index = UInt32Array::from(vec![0, 3, 1, 0, 2]); - let a = take(&array, &index, None).unwrap(); - let a: &StructArray = a.as_any().downcast_ref::().unwrap(); - assert_eq!(index.len(), a.len()); - assert_eq!(0, a.null_count()); + let array = create_test_struct(vec![ + Some((Some(true), Some(42))), + Some((Some(false), Some(28))), + Some((Some(false), Some(19))), + Some((Some(true), Some(31))), + None, + ]); - let expected_bool_data = BooleanArray::from(vec![true, true, false, true, false]) - .data() - .clone(); - let expected_int_data = Int32Array::from(vec![42, 31, 28, 42, 19]).data().clone(); - let mut field_types = vec![]; - field_types.push(Field::new("a", DataType::Boolean, true)); - field_types.push(Field::new("b", DataType::Int32, true)); - let struct_array_data = ArrayData::builder(DataType::Struct(field_types)) - .len(5) - .add_child_data(expected_bool_data) - .add_child_data(expected_int_data) - .build(); - let struct_array = StructArray::from(struct_array_data); + let index = UInt32Array::from(vec![0, 3, 1, 0, 2, 4]); + let actual = take(&array, &index, None).unwrap(); + let actual: &StructArray = actual.as_any().downcast_ref::().unwrap(); + assert_eq!(index.len(), actual.len()); + assert_eq!(1, actual.null_count()); + + let expected = create_test_struct(vec![ + Some((Some(true), Some(42))), + Some((Some(true), Some(31))), + Some((Some(false), Some(28))), + Some((Some(true), Some(42))), + Some((Some(false), Some(19))), + None, + ]); - assert_eq!(a, &struct_array); + assert_eq!(&expected, actual); } #[test] - fn test_take_struct_with_nulls() { - let array = create_test_struct(); + fn test_take_struct_with_null_indices() { + let array = create_test_struct(vec![ + Some((Some(true), Some(42))), + Some((Some(false), Some(28))), + Some((Some(false), Some(19))), + Some((Some(true), Some(31))), + None, + ]); - let index = UInt32Array::from(vec![None, Some(3), Some(1), None, Some(0)]); - let a = take(&array, &index, None).unwrap(); - let a: &StructArray = a.as_any().downcast_ref::().unwrap(); - assert_eq!(index.len(), a.len()); - assert_eq!(0, a.null_count()); + let index = + UInt32Array::from(vec![None, Some(3), Some(1), None, Some(0), Some(4)]); + let actual = take(&array, &index, None).unwrap(); + let actual: &StructArray = actual.as_any().downcast_ref::().unwrap(); + assert_eq!(index.len(), actual.len()); + assert_eq!(3, actual.null_count()); // 2 because of indices, 1 because of struct array - let expected_bool_data = - BooleanArray::from(vec![None, Some(true), Some(false), None, Some(true)]) - .data() - .clone(); - let expected_int_data = - Int32Array::from(vec![None, Some(31), Some(28), None, Some(42)]) - .data() - .clone(); + let expected = create_test_struct(vec![ + None, + Some((Some(true), Some(31))), + Some((Some(false), Some(28))), + None, + Some((Some(true), Some(42))), + None, + ]); - let mut field_types = vec![]; - field_types.push(Field::new("a", DataType::Boolean, true)); - field_types.push(Field::new("b", DataType::Int32, true)); - let struct_array_data = ArrayData::builder(DataType::Struct(field_types)) - .len(5) - // TODO: see https://issues.apache.org/jira/browse/ARROW-5408 for why count != 2 - .add_child_data(expected_bool_data) - .add_child_data(expected_int_data) - .build(); - let struct_array = StructArray::from(struct_array_data); - assert_eq!(a, &struct_array); + assert_eq!(&expected, actual); } #[test]