Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance of DictionaryArray::try_new()  #1435

Merged
merged 5 commits into from
Mar 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions arrow/src/array/array_dictionary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ use super::{
make_array, Array, ArrayData, ArrayRef, PrimitiveArray, PrimitiveBuilder,
StringArray, StringBuilder, StringDictionaryBuilder,
};
use crate::datatypes::ArrowNativeType;
use crate::datatypes::{ArrowDictionaryKeyType, ArrowPrimitiveType, DataType};
use crate::datatypes::{
ArrowDictionaryKeyType, ArrowNativeType, ArrowPrimitiveType, DataType,
};
use crate::error::Result;

/// A dictionary array where each element is a single value indexed by an integer key.
Expand Down Expand Up @@ -96,8 +97,8 @@ impl<'a, K: ArrowPrimitiveType> DictionaryArray<K> {
Box::new(values.data_type().clone()),
);

// Note: This does more work than necessary by rebuilding /
// revalidating all the data
// Note: This use the ArrayDataBuilder::build_unchecked and afterwards
// call the new function which only validates that the keys are in bounds.
let mut data = ArrayData::builder(dict_data_type)
.len(keys.len())
.add_buffer(keys.data().buffers()[0].clone())
Expand All @@ -112,7 +113,14 @@ impl<'a, K: ArrowPrimitiveType> DictionaryArray<K> {
_ => data = data.null_count(0),
}

Ok(data.build()?.into())
// Safety: `validate` ensures key type is correct, and
// `validate_dictionary_offset` ensures all offsets are within range
let array = unsafe { data.build_unchecked() };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let array = unsafe { data.build_unchecked() };
// Safety: `validate` ensures key type is correct, and
// `validate_dictionary_offset` ensures all offsets are within range
let array = unsafe { data.build_unchecked() };


array.validate()?;
array.validate_dictionary_offset()?;

Ok(array.into())
}

/// Return an array view of the keys of this dictionary as a PrimitiveArray.
Expand Down Expand Up @@ -308,8 +316,8 @@ impl<T: ArrowPrimitiveType> fmt::Debug for DictionaryArray<T> {
mod tests {
use super::*;

use crate::array::Int8Array;
use crate::datatypes::Int16Type;
use crate::array::{Float32Array, Int8Array};
use crate::datatypes::{Float32Type, Int16Type};
use crate::{
array::Int16DictionaryArray, array::PrimitiveDictionaryBuilder,
datatypes::DataType,
Expand Down Expand Up @@ -574,4 +582,12 @@ mod tests {
let keys: Int32Array = [Some(-100)].into_iter().collect();
DictionaryArray::<Int32Type>::try_new(&keys, &values).unwrap();
}

#[test]
#[should_panic(expected = "Dictionary key type must be integer, but was Float32")]
fn test_try_wrong_dictionary_key_type() {
let values: StringArray = [Some("foo"), Some("bar")].into_iter().collect();
let keys: Float32Array = [Some(0_f32), None, Some(3_f32)].into_iter().collect();
DictionaryArray::<Float32Type>::try_new(&keys, &values).unwrap();
}
}
80 changes: 40 additions & 40 deletions arrow/src/array/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ impl ArrayData {
// At the moment, constructing a DictionaryArray will also check this
if !DataType::is_dictionary_key_type(key_type) {
return Err(ArrowError::InvalidArgumentError(format!(
"Dictionary values must be integer, but was {}",
"Dictionary key type must be integer, but was {}",
key_type
)));
}
Expand Down Expand Up @@ -926,8 +926,8 @@ impl ArrayData {
///
/// 1. Null count is correct
/// 2. All offsets are valid
/// 3. All String data is valid UTF-8
/// 3. All dictionary offsets are valid
/// 3. All String data is valid UTF-8
/// 4. All dictionary offsets are valid
///
/// Does not (yet) check
/// 1. Union type_ids are valid see [#85](https://github.com/apache/arrow-rs/issues/85)
Expand All @@ -949,68 +949,68 @@ impl ArrayData {
)));
}

self.validate_dictionary_offset()?;

// validate all children recursively
self.child_data
.iter()
.enumerate()
.try_for_each(|(i, child_data)| {
child_data.validate_full().map_err(|e| {
ArrowError::InvalidArgumentError(format!(
"{} child #{} invalid: {}",
self.data_type, i, e
))
})
})?;

Ok(())
}

pub fn validate_dictionary_offset(&self) -> Result<()> {
match &self.data_type {
DataType::Utf8 => {
self.validate_utf8::<i32>()?;
}
DataType::LargeUtf8 => {
self.validate_utf8::<i64>()?;
}
DataType::Binary => {
self.validate_offsets_full::<i32>(self.buffers[1].len())?;
}
DataType::Utf8 => self.validate_utf8::<i32>(),
DataType::LargeUtf8 => self.validate_utf8::<i64>(),
DataType::Binary => self.validate_offsets_full::<i32>(self.buffers[1].len()),
DataType::LargeBinary => {
self.validate_offsets_full::<i64>(self.buffers[1].len())?;
self.validate_offsets_full::<i64>(self.buffers[1].len())
}
DataType::List(_) | DataType::Map(_, _) => {
let child = &self.child_data[0];
self.validate_offsets_full::<i32>(child.len + child.offset)?;
self.validate_offsets_full::<i32>(child.len + child.offset)
}
DataType::LargeList(_) => {
let child = &self.child_data[0];
self.validate_offsets_full::<i64>(child.len + child.offset)?;
self.validate_offsets_full::<i64>(child.len + child.offset)
}
DataType::Union(_, _) => {
// Validate Union Array as part of implementing new Union semantics
// See comments in `ArrayData::validate()`
// https://github.com/apache/arrow-rs/issues/85
//
// TODO file follow on ticket for full union validation
Ok(())
}
DataType::Dictionary(key_type, _value_type) => {
let dictionary_length: i64 = self.child_data[0].len.try_into().unwrap();
let max_value = dictionary_length - 1;
match key_type.as_ref() {
DataType::UInt8 => self.check_bounds::<u8>(max_value)?,
DataType::UInt16 => self.check_bounds::<u16>(max_value)?,
DataType::UInt32 => self.check_bounds::<u32>(max_value)?,
DataType::UInt64 => self.check_bounds::<u64>(max_value)?,
DataType::Int8 => self.check_bounds::<i8>(max_value)?,
DataType::Int16 => self.check_bounds::<i16>(max_value)?,
DataType::Int32 => self.check_bounds::<i32>(max_value)?,
DataType::Int64 => self.check_bounds::<i64>(max_value)?,
DataType::UInt8 => self.check_bounds::<u8>(max_value),
DataType::UInt16 => self.check_bounds::<u16>(max_value),
DataType::UInt32 => self.check_bounds::<u32>(max_value),
DataType::UInt64 => self.check_bounds::<u64>(max_value),
DataType::Int8 => self.check_bounds::<i8>(max_value),
DataType::Int16 => self.check_bounds::<i16>(max_value),
DataType::Int32 => self.check_bounds::<i32>(max_value),
DataType::Int64 => self.check_bounds::<i64>(max_value),
_ => unreachable!(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a possible solution would be to extract the dictionary validation logic out of ArrayData::validate_full into a separate function. DictionaryArray::try_new could then use ArrayDataBuilder::build_unchecked and afterwards call the new function which only validates that the keys are in bounds.

I think "the dictionary validation logic" is only for the logic inside DataType::Dictionary pattern branch.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The validation logic for other data types are not for dictionary offset.

}
}
_ => {
// No extra validation check required for other types
Ok(())
}
};

// validate all children recursively
self.child_data
.iter()
.enumerate()
.try_for_each(|(i, child_data)| {
child_data.validate_full().map_err(|e| {
ArrowError::InvalidArgumentError(format!(
"{} child #{} invalid: {}",
self.data_type, i, e
))
})
})?;

Ok(())
}
}

/// Calls the `validate(item_index, range)` function for each of
Expand Down Expand Up @@ -1736,7 +1736,7 @@ mod tests {

// Test creating a dictionary with a non integer type
#[test]
#[should_panic(expected = "Dictionary values must be integer, but was Utf8")]
#[should_panic(expected = "Dictionary key type must be integer, but was Utf8")]
fn test_non_int_dictionary() {
let i32_buffer = Buffer::from_slice_ref(&[0i32, 2i32]);
let data_type =
Expand Down