Skip to content

Commit

Permalink
ARROW-12426: [Rust] Fix concatentation of arrow dictionaries
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Apr 21, 2021
1 parent 5479e19 commit 2a5214e
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 14 deletions.
120 changes: 106 additions & 14 deletions arrow/src/array/transform/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
// specific language governing permissions and limitations
// under the License.

use crate::{buffer::MutableBuffer, datatypes::DataType, util::bit_util};
use crate::{
buffer::MutableBuffer,
datatypes::DataType,
error::{ArrowError, Result},
util::bit_util,
};

use super::{
data::{into_buffers, new_buffers},
Expand Down Expand Up @@ -166,6 +171,65 @@ impl<'a> std::fmt::Debug for MutableArrayData<'a> {
}
}

/// Builds an extend that adds `offset` to the source primitive
/// Additionally validates that `max` fits into the
/// the underlying primitive returning None if not
fn build_extend_dictionary(
array: &ArrayData,
offset: usize,
max: usize,
) -> Option<Extend> {
use crate::datatypes::*;
use std::convert::TryInto;

match array.data_type() {
DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() {
DataType::UInt8 => {
let _: u8 = max.try_into().ok()?;
let offset: u8 = offset.try_into().ok()?;
Some(primitive::build_extend_with_offset(array, offset))
}
DataType::UInt16 => {
let _: u16 = max.try_into().ok()?;
let offset: u16 = offset.try_into().ok()?;
Some(primitive::build_extend_with_offset(array, offset))
}
DataType::UInt32 => {
let _: u32 = max.try_into().ok()?;
let offset: u32 = offset.try_into().ok()?;
Some(primitive::build_extend_with_offset(array, offset))
}
DataType::UInt64 => {
let _: u64 = max.try_into().ok()?;
let offset: u64 = offset.try_into().ok()?;
Some(primitive::build_extend_with_offset(array, offset))
}
DataType::Int8 => {
let _: i8 = max.try_into().ok()?;
let offset: i8 = offset.try_into().ok()?;
Some(primitive::build_extend_with_offset(array, offset))
}
DataType::Int16 => {
let _: i16 = max.try_into().ok()?;
let offset: i16 = offset.try_into().ok()?;
Some(primitive::build_extend_with_offset(array, offset))
}
DataType::Int32 => {
let _: i32 = max.try_into().ok()?;
let offset: i32 = offset.try_into().ok()?;
Some(primitive::build_extend_with_offset(array, offset))
}
DataType::Int64 => {
let _: i64 = max.try_into().ok()?;
let offset: i64 = offset.try_into().ok()?;
Some(primitive::build_extend_with_offset(array, offset))
}
_ => unreachable!(),
},
_ => None,
}
}

fn build_extend(array: &ArrayData) -> Extend {
use crate::datatypes::*;
match array.data_type() {
Expand Down Expand Up @@ -199,17 +263,7 @@ fn build_extend(array: &ArrayData) -> Extend {
}
DataType::List(_) => list::build_extend::<i32>(array),
DataType::LargeList(_) => list::build_extend::<i64>(array),
DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() {
DataType::UInt8 => primitive::build_extend::<u8>(array),
DataType::UInt16 => primitive::build_extend::<u16>(array),
DataType::UInt32 => primitive::build_extend::<u32>(array),
DataType::UInt64 => primitive::build_extend::<u64>(array),
DataType::Int8 => primitive::build_extend::<i8>(array),
DataType::Int16 => primitive::build_extend::<i16>(array),
DataType::Int32 => primitive::build_extend::<i32>(array),
DataType::Int64 => primitive::build_extend::<i64>(array),
_ => unreachable!(),
},
DataType::Dictionary(_, _) => unreachable!("should use build_extend_dictionary"),
DataType::Struct(_) => structure::build_extend(array),
DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array),
DataType::Float16 => unreachable!(),
Expand Down Expand Up @@ -339,7 +393,29 @@ impl<'a> MutableArrayData<'a> {
};

let dictionary = match &data_type {
DataType::Dictionary(_, _) => Some(arrays[0].child_data()[0].clone()),
DataType::Dictionary(_, _) => match arrays.len() {
0 => unreachable!(),
1 => Some(arrays[0].child_data()[0].clone()),
_ => {
// Concat dictionaries together
let dictionaries: Vec<_> =
arrays.iter().map(|array| &array.child_data()[0]).collect();
let lengths: Vec<_> = dictionaries
.iter()
.map(|dictionary| dictionary.len())
.collect();
let capacity = lengths.iter().sum();

let mut mutable =
MutableArrayData::new(dictionaries, false, capacity);

for (i, len) in lengths.iter().enumerate() {
mutable.extend(i, 0, *len)
}

Some(mutable.freeze())
}
},
_ => None,
};

Expand All @@ -353,7 +429,23 @@ impl<'a> MutableArrayData<'a> {
let null_bytes = bit_util::ceil(capacity, 8);
let null_buffer = MutableBuffer::from_len_zeroed(null_bytes);

let extend_values = arrays.iter().map(|array| build_extend(array)).collect();
let extend_values = match &data_type {
DataType::Dictionary(_, _) => {
let mut next_offset = 0;
let extend_values: Result<Vec<_>> = arrays
.iter()
.map(|array| {
let offset = next_offset;
next_offset += array.child_data()[0].len();
build_extend_dictionary(array, offset, next_offset)
.ok_or(ArrowError::DictionaryKeyOverflowError)
})
.collect();

extend_values.expect("MutableArrayData::new is infallible")
}
_ => arrays.iter().map(|array| build_extend(array)).collect(),
};

let data = _MutableArrayData {
data_type: data_type.clone(),
Expand Down
15 changes: 15 additions & 0 deletions arrow/src/array/transform/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use std::mem::size_of;
use std::ops::Add;

use crate::{array::ArrayData, datatypes::ArrowNativeType};

Expand All @@ -32,6 +33,20 @@ pub(super) fn build_extend<T: ArrowNativeType>(array: &ArrayData) -> Extend {
)
}

pub(super) fn build_extend_with_offset<T>(array: &ArrayData, offset: T) -> Extend
where
T: ArrowNativeType + Add<Output = T>,
{
let values = array.buffer::<T>(0);
Box::new(
move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| {
mutable
.buffer1
.extend(values[start..start + len].iter().map(|x| *x + offset));
},
)
}

pub(super) fn extend_nulls<T: ArrowNativeType>(
mutable: &mut _MutableArrayData,
len: usize,
Expand Down
68 changes: 68 additions & 0 deletions arrow/src/compute/kernels/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,4 +384,72 @@ mod tests {

Ok(())
}

fn collect_string_dictionary(
dictionary: &DictionaryArray<Int32Type>,
) -> Vec<Option<String>> {
let values = dictionary.values();
let values = values.as_any().downcast_ref::<StringArray>().unwrap();

dictionary
.keys()
.iter()
.map(|key| key.map(|key| values.value(key as _).to_string()))
.collect()
}

fn concat_dictionary(
input_1: DictionaryArray<Int32Type>,
input_2: DictionaryArray<Int32Type>,
) -> Vec<Option<String>> {
let concat = concat(&[&input_1 as _, &input_2 as _]).unwrap();
let concat = concat
.as_any()
.downcast_ref::<DictionaryArray<Int32Type>>()
.unwrap();

collect_string_dictionary(concat)
}

#[test]
fn test_string_dictionary_array() {
let input_1: DictionaryArray<Int32Type> =
vec!["hello", "A", "B", "hello", "hello", "C"]
.into_iter()
.collect();
let input_2: DictionaryArray<Int32Type> =
vec!["hello", "E", "E", "hello", "F", "E"]
.into_iter()
.collect();

let expected: Vec<_> = vec![
"hello", "A", "B", "hello", "hello", "C", "hello", "E", "E", "hello", "F",
"E",
]
.into_iter()
.map(|x| Some(x.to_string()))
.collect();

let concat = concat_dictionary(input_1, input_2);
assert_eq!(concat, expected);
}

#[test]
fn test_string_dictionary_array_nulls() {
let input_1: DictionaryArray<Int32Type> =
vec![Some("foo"), Some("bar"), None, Some("fiz")]
.into_iter()
.collect();
let input_2: DictionaryArray<Int32Type> = vec![None].into_iter().collect();
let expected = vec![
Some("foo".to_string()),
Some("bar".to_string()),
None,
Some("fiz".to_string()),
None,
];

let concat = concat_dictionary(input_1, input_2);
assert_eq!(concat, expected);
}
}

0 comments on commit 2a5214e

Please sign in to comment.