From 2a5214e7a9298b396cfd5a111d035c5300e19916 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Wed, 21 Apr 2021 11:23:56 +0100 Subject: [PATCH] ARROW-12426: [Rust] Fix concatentation of arrow dictionaries --- arrow/src/array/transform/mod.rs | 120 ++++++++++++++++++++++--- arrow/src/array/transform/primitive.rs | 15 ++++ arrow/src/compute/kernels/concat.rs | 68 ++++++++++++++ 3 files changed, 189 insertions(+), 14 deletions(-) diff --git a/arrow/src/array/transform/mod.rs b/arrow/src/array/transform/mod.rs index 4dc7b56d1c37..e7ec41e97a75 100644 --- a/arrow/src/array/transform/mod.rs +++ b/arrow/src/array/transform/mod.rs @@ -15,7 +15,12 @@ // specific language governing permissions and limitations // under the License. -use crate::{buffer::MutableBuffer, datatypes::DataType, util::bit_util}; +use crate::{ + buffer::MutableBuffer, + datatypes::DataType, + error::{ArrowError, Result}, + util::bit_util, +}; use super::{ data::{into_buffers, new_buffers}, @@ -166,6 +171,65 @@ impl<'a> std::fmt::Debug for MutableArrayData<'a> { } } +/// Builds an extend that adds `offset` to the source primitive +/// Additionally validates that `max` fits into the +/// the underlying primitive returning None if not +fn build_extend_dictionary( + array: &ArrayData, + offset: usize, + max: usize, +) -> Option { + use crate::datatypes::*; + use std::convert::TryInto; + + match array.data_type() { + DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() { + DataType::UInt8 => { + let _: u8 = max.try_into().ok()?; + let offset: u8 = offset.try_into().ok()?; + Some(primitive::build_extend_with_offset(array, offset)) + } + DataType::UInt16 => { + let _: u16 = max.try_into().ok()?; + let offset: u16 = offset.try_into().ok()?; + Some(primitive::build_extend_with_offset(array, offset)) + } + DataType::UInt32 => { + let _: u32 = max.try_into().ok()?; + let offset: u32 = offset.try_into().ok()?; + Some(primitive::build_extend_with_offset(array, offset)) + } + DataType::UInt64 => { + let _: u64 = max.try_into().ok()?; + let offset: u64 = offset.try_into().ok()?; + Some(primitive::build_extend_with_offset(array, offset)) + } + DataType::Int8 => { + let _: i8 = max.try_into().ok()?; + let offset: i8 = offset.try_into().ok()?; + Some(primitive::build_extend_with_offset(array, offset)) + } + DataType::Int16 => { + let _: i16 = max.try_into().ok()?; + let offset: i16 = offset.try_into().ok()?; + Some(primitive::build_extend_with_offset(array, offset)) + } + DataType::Int32 => { + let _: i32 = max.try_into().ok()?; + let offset: i32 = offset.try_into().ok()?; + Some(primitive::build_extend_with_offset(array, offset)) + } + DataType::Int64 => { + let _: i64 = max.try_into().ok()?; + let offset: i64 = offset.try_into().ok()?; + Some(primitive::build_extend_with_offset(array, offset)) + } + _ => unreachable!(), + }, + _ => None, + } +} + fn build_extend(array: &ArrayData) -> Extend { use crate::datatypes::*; match array.data_type() { @@ -199,17 +263,7 @@ fn build_extend(array: &ArrayData) -> Extend { } DataType::List(_) => list::build_extend::(array), DataType::LargeList(_) => list::build_extend::(array), - DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() { - DataType::UInt8 => primitive::build_extend::(array), - DataType::UInt16 => primitive::build_extend::(array), - DataType::UInt32 => primitive::build_extend::(array), - DataType::UInt64 => primitive::build_extend::(array), - DataType::Int8 => primitive::build_extend::(array), - DataType::Int16 => primitive::build_extend::(array), - DataType::Int32 => primitive::build_extend::(array), - DataType::Int64 => primitive::build_extend::(array), - _ => unreachable!(), - }, + DataType::Dictionary(_, _) => unreachable!("should use build_extend_dictionary"), DataType::Struct(_) => structure::build_extend(array), DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array), DataType::Float16 => unreachable!(), @@ -339,7 +393,29 @@ impl<'a> MutableArrayData<'a> { }; let dictionary = match &data_type { - DataType::Dictionary(_, _) => Some(arrays[0].child_data()[0].clone()), + DataType::Dictionary(_, _) => match arrays.len() { + 0 => unreachable!(), + 1 => Some(arrays[0].child_data()[0].clone()), + _ => { + // Concat dictionaries together + let dictionaries: Vec<_> = + arrays.iter().map(|array| &array.child_data()[0]).collect(); + let lengths: Vec<_> = dictionaries + .iter() + .map(|dictionary| dictionary.len()) + .collect(); + let capacity = lengths.iter().sum(); + + let mut mutable = + MutableArrayData::new(dictionaries, false, capacity); + + for (i, len) in lengths.iter().enumerate() { + mutable.extend(i, 0, *len) + } + + Some(mutable.freeze()) + } + }, _ => None, }; @@ -353,7 +429,23 @@ impl<'a> MutableArrayData<'a> { let null_bytes = bit_util::ceil(capacity, 8); let null_buffer = MutableBuffer::from_len_zeroed(null_bytes); - let extend_values = arrays.iter().map(|array| build_extend(array)).collect(); + let extend_values = match &data_type { + DataType::Dictionary(_, _) => { + let mut next_offset = 0; + let extend_values: Result> = arrays + .iter() + .map(|array| { + let offset = next_offset; + next_offset += array.child_data()[0].len(); + build_extend_dictionary(array, offset, next_offset) + .ok_or(ArrowError::DictionaryKeyOverflowError) + }) + .collect(); + + extend_values.expect("MutableArrayData::new is infallible") + } + _ => arrays.iter().map(|array| build_extend(array)).collect(), + }; let data = _MutableArrayData { data_type: data_type.clone(), diff --git a/arrow/src/array/transform/primitive.rs b/arrow/src/array/transform/primitive.rs index 032bb4a87794..4c765c0c0d95 100644 --- a/arrow/src/array/transform/primitive.rs +++ b/arrow/src/array/transform/primitive.rs @@ -16,6 +16,7 @@ // under the License. use std::mem::size_of; +use std::ops::Add; use crate::{array::ArrayData, datatypes::ArrowNativeType}; @@ -32,6 +33,20 @@ pub(super) fn build_extend(array: &ArrayData) -> Extend { ) } +pub(super) fn build_extend_with_offset(array: &ArrayData, offset: T) -> Extend +where + T: ArrowNativeType + Add, +{ + let values = array.buffer::(0); + Box::new( + move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| { + mutable + .buffer1 + .extend(values[start..start + len].iter().map(|x| *x + offset)); + }, + ) +} + pub(super) fn extend_nulls( mutable: &mut _MutableArrayData, len: usize, diff --git a/arrow/src/compute/kernels/concat.rs b/arrow/src/compute/kernels/concat.rs index 32880286a724..35ff183ed91c 100644 --- a/arrow/src/compute/kernels/concat.rs +++ b/arrow/src/compute/kernels/concat.rs @@ -384,4 +384,72 @@ mod tests { Ok(()) } + + fn collect_string_dictionary( + dictionary: &DictionaryArray, + ) -> Vec> { + let values = dictionary.values(); + let values = values.as_any().downcast_ref::().unwrap(); + + dictionary + .keys() + .iter() + .map(|key| key.map(|key| values.value(key as _).to_string())) + .collect() + } + + fn concat_dictionary( + input_1: DictionaryArray, + input_2: DictionaryArray, + ) -> Vec> { + let concat = concat(&[&input_1 as _, &input_2 as _]).unwrap(); + let concat = concat + .as_any() + .downcast_ref::>() + .unwrap(); + + collect_string_dictionary(concat) + } + + #[test] + fn test_string_dictionary_array() { + let input_1: DictionaryArray = + vec!["hello", "A", "B", "hello", "hello", "C"] + .into_iter() + .collect(); + let input_2: DictionaryArray = + vec!["hello", "E", "E", "hello", "F", "E"] + .into_iter() + .collect(); + + let expected: Vec<_> = vec![ + "hello", "A", "B", "hello", "hello", "C", "hello", "E", "E", "hello", "F", + "E", + ] + .into_iter() + .map(|x| Some(x.to_string())) + .collect(); + + let concat = concat_dictionary(input_1, input_2); + assert_eq!(concat, expected); + } + + #[test] + fn test_string_dictionary_array_nulls() { + let input_1: DictionaryArray = + vec![Some("foo"), Some("bar"), None, Some("fiz")] + .into_iter() + .collect(); + let input_2: DictionaryArray = vec![None].into_iter().collect(); + let expected = vec![ + Some("foo".to_string()), + Some("bar".to_string()), + None, + Some("fiz".to_string()), + None, + ]; + + let concat = concat_dictionary(input_1, input_2); + assert_eq!(concat, expected); + } }