From 78a2a63bacced679da51843ae6135f5034760764 Mon Sep 17 00:00:00 2001 From: Jorge Leitao Date: Tue, 5 Jul 2022 08:20:06 -0700 Subject: [PATCH] Improved dictionary (#1137) --- src/array/dictionary/ffi.rs | 16 +- src/array/dictionary/fmt.rs | 2 +- src/array/dictionary/iterator.rs | 17 +- src/array/dictionary/mod.rs | 235 +++++++++++++++--- src/array/dictionary/mutable.rs | 37 ++- src/array/equal/dictionary.rs | 2 +- src/array/equal/mod.rs | 12 + src/array/growable/dictionary.rs | 34 ++- src/array/ord.rs | 5 +- src/array/specification.rs | 16 ++ src/compute/arithmetics/mod.rs | 7 +- src/compute/cast/dictionary_to.rs | 46 +++- src/compute/sort/utf8.rs | 11 +- src/compute/take/dict.rs | 10 +- src/io/avro/read/nested.rs | 24 +- src/io/csv/write/serialize.rs | 6 +- src/io/ipc/read/array/dictionary.rs | 4 +- src/io/ipc/read/deserialize.rs | 1 + src/io/json/read/deserialize.rs | 12 +- src/io/json_integration/read/array.rs | 6 +- .../read/deserialize/binary/dictionary.rs | 7 +- src/io/parquet/read/deserialize/dictionary.rs | 55 +++- .../fixed_size_binary/dictionary.rs | 7 +- .../read/deserialize/primitive/dictionary.rs | 11 +- src/io/parquet/read/statistics/dictionary.rs | 2 +- src/io/parquet/write/dictionary.rs | 26 +- tests/it/array/dictionary/mod.rs | 152 +++++++++++ tests/it/array/equal/dictionary.rs | 2 +- tests/it/array/growable/dictionary.rs | 36 +-- tests/it/array/growable/mod.rs | 9 +- tests/it/compute/arithmetics/mod.rs | 14 +- tests/it/io/avro/read.rs | 3 +- tests/it/io/csv/write.rs | 6 +- tests/it/io/parquet/mod.rs | 56 ++--- tests/it/io/parquet/read_indexes.rs | 8 +- tests/it/io/print.rs | 2 +- 36 files changed, 668 insertions(+), 231 deletions(-) diff --git a/src/array/dictionary/ffi.rs b/src/array/dictionary/ffi.rs index 6bc467630a2..1e06f81e6ed 100644 --- a/src/array/dictionary/ffi.rs +++ b/src/array/dictionary/ffi.rs @@ -1,6 +1,6 @@ use crate::{ array::{FromFfi, PrimitiveArray, ToFfi}, - error::Result, + error::Error, ffi, }; @@ -25,16 +25,20 @@ unsafe impl ToFfi for DictionaryArray { } impl FromFfi for DictionaryArray { - unsafe fn try_from_ffi(array: A) -> Result { + unsafe fn try_from_ffi(array: A) -> Result { // keys: similar to PrimitiveArray, but the datatype is the inner one let validity = unsafe { array.validity() }?; let values = unsafe { array.buffer::(1) }?; - let data_type = K::PRIMITIVE.into(); - let keys = PrimitiveArray::::try_new(data_type, values, validity)?; - let values = array.dictionary()?.unwrap(); + let data_type = array.data_type().clone(); + + let keys = PrimitiveArray::::try_new(K::PRIMITIVE.into(), values, validity)?; + let values = array + .dictionary()? + .ok_or_else(|| Error::oos("Dictionary Array must contain a dictionary in ffi"))?; let values = ffi::try_from(values)?; - Ok(DictionaryArray::::from_data(keys, values)) + // the assumption of this trait + DictionaryArray::::try_new_unchecked(data_type, keys, values) } } diff --git a/src/array/dictionary/fmt.rs b/src/array/dictionary/fmt.rs index e3e3b475d8f..f5f76624001 100644 --- a/src/array/dictionary/fmt.rs +++ b/src/array/dictionary/fmt.rs @@ -15,7 +15,7 @@ pub fn write_value( let values = array.values(); if keys.is_valid(index) { - let key = keys.value(index).to_usize().unwrap(); + let key = array.key_value(index); get_display(values.as_ref(), null)(f, key) } else { write!(f, "{}", null) diff --git a/src/array/dictionary/iterator.rs b/src/array/dictionary/iterator.rs index 21b1cc7bad4..0249a1b589e 100644 --- a/src/array/dictionary/iterator.rs +++ b/src/array/dictionary/iterator.rs @@ -1,4 +1,4 @@ -use crate::bitmap::utils::{zip_validity, ZipValidity}; +use crate::bitmap::utils::ZipValidity; use crate::scalar::Scalar; use crate::trusted_len::TrustedLen; @@ -66,18 +66,3 @@ impl<'a, K: DictionaryKey> IntoIterator for &'a DictionaryArray { self.iter() } } - -impl<'a, K: DictionaryKey> DictionaryArray { - /// Returns an iterator of `Option>` - pub fn iter(&'a self) -> ZipIter<'a, K> { - zip_validity( - DictionaryValuesIter::new(self), - self.keys.validity().as_ref().map(|x| x.iter()), - ) - } - - /// Returns an iterator of `Box` - pub fn values_iter(&'a self) -> ValuesIter<'a, K> { - DictionaryValuesIter::new(self) - } -} diff --git a/src/array/dictionary/mod.rs b/src/array/dictionary/mod.rs index 0c9d9267375..0d5b4502eb2 100644 --- a/src/array/dictionary/mod.rs +++ b/src/array/dictionary/mod.rs @@ -1,7 +1,14 @@ +use std::hint::unreachable_unchecked; + use crate::{ - bitmap::Bitmap, + bitmap::{ + utils::{zip_validity, ZipValidity}, + Bitmap, + }, datatypes::{DataType, IntegerType}, + error::Error, scalar::{new_scalar, Scalar}, + trusted_len::TrustedLen, types::NativeType, }; @@ -13,12 +20,23 @@ pub use iterator::*; pub use mutable::*; use super::{new_empty_array, primitive::PrimitiveArray, Array}; -use crate::scalar::NullScalar; +use super::{new_null_array, specification::check_indexes}; /// Trait denoting [`NativeType`]s that can be used as keys of a dictionary. -pub trait DictionaryKey: NativeType + num_traits::NumCast + num_traits::FromPrimitive { +pub trait DictionaryKey: NativeType + TryInto + TryFrom { /// The corresponding [`IntegerType`] of this key const KEY_TYPE: IntegerType; + + /// Represents this key as a `usize`. + /// # Safety + /// The caller _must_ have checked that the value can be casted to `usize`. + #[inline] + unsafe fn as_usize(self) -> usize { + match self.try_into() { + Ok(v) => v, + Err(_) => unreachable_unchecked(), + } + } } impl DictionaryKey for i8 { @@ -46,8 +64,13 @@ impl DictionaryKey for u64 { const KEY_TYPE: IntegerType = IntegerType::UInt64; } -/// An [`Array`] whose values are encoded by keys. This [`Array`] is useful when the cardinality of +/// An [`Array`] whose values are stored as indices. This [`Array`] is useful when the cardinality of /// values is low compared to the length of the [`Array`]. +/// +/// # Safety +/// This struct guarantees that each item of [`DictionaryArray::keys`] is castable to `usize` and +/// its value is smaller than [`DictionaryArray::values`]`.len()`. In other words, you can safely +/// use `unchecked` calls to retrive the values #[derive(Clone)] pub struct DictionaryArray { data_type: DataType, @@ -55,38 +78,152 @@ pub struct DictionaryArray { values: Box, } +fn check_data_type( + key_type: IntegerType, + data_type: &DataType, + values_data_type: &DataType, +) -> Result<(), Error> { + if let DataType::Dictionary(key, value, _) = data_type.to_logical_type() { + if *key != key_type { + return Err(Error::oos( + "DictionaryArray must be initialized with a DataType::Dictionary whose integer is compatible to its keys", + )); + } + if value.as_ref().to_logical_type() != values_data_type.to_logical_type() { + return Err(Error::oos( + "DictionaryArray must be initialized with a DataType::Dictionary whose value is equal to its values", + )); + } + } else { + return Err(Error::oos( + "DictionaryArray must be initialized with logical DataType::Dictionary", + )); + } + Ok(()) +} + impl DictionaryArray { + /// Returns a new [`DictionaryArray`]. + /// # Implementation + /// This function is `O(N)` where `N` is the length of keys + /// # Errors + /// This function errors iff + /// * the `data_type`'s logical type is not a `DictionaryArray` + /// * the `data_type`'s keys is not compatible with `keys` + /// * the `data_type`'s values's data_type is not equal with `values.data_type()` + /// * any of the keys's values is not represented in `usize` or is `>= values.len()` + pub fn try_new( + data_type: DataType, + keys: PrimitiveArray, + values: Box, + ) -> Result { + check_data_type(K::KEY_TYPE, &data_type, values.data_type())?; + + check_indexes(keys.values(), values.len())?; + + Ok(Self { + data_type, + keys, + values, + }) + } + + /// Returns a new [`DictionaryArray`]. + /// # Implementation + /// This function is `O(N)` where `N` is the length of keys + /// # Errors + /// This function errors iff + /// * any of the keys's values is not represented in `usize` or is `>= values.len()` + pub fn try_from_keys(keys: PrimitiveArray, values: Box) -> Result { + let data_type = Self::default_data_type(values.data_type().clone()); + Self::try_new(data_type, keys, values) + } + + /// Returns a new [`DictionaryArray`]. + /// # Errors + /// This function errors iff + /// * the `data_type`'s logical type is not a `DictionaryArray` + /// * the `data_type`'s keys is not compatible with `keys` + /// * the `data_type`'s values's data_type is not equal with `values.data_type()` + /// # Safety + /// The caller must ensure that every keys's values is represented in `usize` and is `< values.len()` + pub unsafe fn try_new_unchecked( + data_type: DataType, + keys: PrimitiveArray, + values: Box, + ) -> Result { + check_data_type(K::KEY_TYPE, &data_type, values.data_type())?; + + Ok(Self { + data_type, + keys, + values, + }) + } + /// Returns a new empty [`DictionaryArray`]. pub fn new_empty(data_type: DataType) -> Self { - let values = Self::get_child(&data_type); + let values = Self::try_get_child(&data_type).unwrap(); let values = new_empty_array(values.clone()); - let data_type = K::PRIMITIVE.into(); - Self::from_data(PrimitiveArray::::new_empty(data_type), values) + Self::try_new( + data_type, + PrimitiveArray::::new_empty(K::PRIMITIVE.into()), + values, + ) + .unwrap() } /// Returns an [`DictionaryArray`] whose all elements are null #[inline] pub fn new_null(data_type: DataType, length: usize) -> Self { - let values = Self::get_child(&data_type); - let data_type = K::PRIMITIVE.into(); - Self::from_data( - PrimitiveArray::::new_null(data_type, length), - new_empty_array(values.clone()), + let values = Self::try_get_child(&data_type).unwrap(); + let values = new_null_array(values.clone(), 1); + Self::try_new( + data_type, + PrimitiveArray::::new_null(K::PRIMITIVE.into(), length), + values, ) + .unwrap() } - /// The canonical method to create a new [`DictionaryArray`]. - pub fn from_data(keys: PrimitiveArray, values: Box) -> Self { - let data_type = - DataType::Dictionary(K::KEY_TYPE, Box::new(values.data_type().clone()), false); + /// Returns an iterator of [`Option>`]. + /// # Implementation + /// This function will allocate a new [`Scalar`] per item and is usually not performant. + /// Consider calling `keys_iter` and `values`, downcasting `values`, and iterating over that. + pub fn iter(&self) -> ZipValidity, DictionaryValuesIter> { + zip_validity( + DictionaryValuesIter::new(self), + self.keys.validity().as_ref().map(|x| x.iter()), + ) + } - Self { - data_type, - keys, - values, + /// Returns an iterator of [`Box`] + /// # Implementation + /// This function will allocate a new [`Scalar`] per item and is usually not performant. + /// Consider calling `keys_iter` and `values`, downcasting `values`, and iterating over that. + pub fn values_iter(&self) -> DictionaryValuesIter { + DictionaryValuesIter::new(self) + } + + /// Returns the [`DataType`] of this [`DictionaryArray`] + #[inline] + pub fn data_type(&self) -> &DataType { + &self.data_type + } + + /// Returns whether the values of this [`DictionaryArray`] are ordered + #[inline] + pub fn is_ordered(&self) -> bool { + match self.data_type.to_logical_type() { + DataType::Dictionary(_, _, is_ordered) => *is_ordered, + _ => unreachable!(), } } + pub(crate) fn default_data_type(values_datatype: DataType) -> DataType { + DataType::Dictionary(K::KEY_TYPE, Box::new(values_datatype), false) + } + /// Creates a new [`DictionaryArray`] by slicing the existing [`DictionaryArray`]. /// # Panics /// iff `offset + length > self.len()`. @@ -124,10 +261,7 @@ impl DictionaryArray { pub fn set_validity(&mut self, validity: Option) { self.keys.set_validity(validity); } -} -// accessors -impl DictionaryArray { /// Returns the length of this array #[inline] pub fn len(&self) -> usize { @@ -147,6 +281,29 @@ impl DictionaryArray { &self.keys } + /// Returns an iterator of the keys' values of the [`DictionaryArray`] as `usize` + #[inline] + pub fn keys_values_iter(&self) -> impl TrustedLen + Clone + '_ { + // safety - invariant of the struct + self.keys.values_iter().map(|x| unsafe { x.as_usize() }) + } + + /// Returns an iterator of the keys' of the [`DictionaryArray`] as `usize` + #[inline] + pub fn keys_iter(&self) -> impl TrustedLen> + Clone + '_ { + // safety - invariant of the struct + self.keys.iter().map(|x| x.map(|x| unsafe { x.as_usize() })) + } + + /// Returns the keys' value of the [`DictionaryArray`] as `usize` + /// # Panics + /// This function panics iff `index >= self.len()` + #[inline] + pub fn key_value(&self, index: usize) -> usize { + // safety - invariant of the struct + unsafe { self.keys.values()[index].as_usize() } + } + /// Returns the values of the [`DictionaryArray`]. #[inline] pub fn values(&self) -> &Box { @@ -154,14 +311,16 @@ impl DictionaryArray { } /// Returns the value of the [`DictionaryArray`] at position `i`. + /// # Implementation + /// This function will allocate a new [`Scalar`] and is usually not performant. + /// Consider calling `keys` and `values`, downcasting `values`, and iterating over that. + /// # Panic + /// This function panics iff `index >= self.len()` #[inline] pub fn value(&self, index: usize) -> Box { - if self.keys.is_null(index) { - Box::new(NullScalar::new()) - } else { - let index = self.keys.value(index).to_usize().unwrap(); - new_scalar(self.values.as_ref(), index) - } + // safety - invariant of this struct + let index = unsafe { self.keys.value(index).as_usize() }; + new_scalar(self.values.as_ref(), index) } /// Boxes self into a [`Box`]. @@ -173,15 +332,16 @@ impl DictionaryArray { pub fn arced(self) -> std::sync::Arc { std::sync::Arc::new(self) } -} -impl DictionaryArray { - pub(crate) fn get_child(data_type: &DataType) -> &DataType { - match data_type { + pub(crate) fn try_get_child(data_type: &DataType) -> Result<&DataType, Error> { + Ok(match data_type.to_logical_type() { DataType::Dictionary(_, values, _) => values.as_ref(), - DataType::Extension(_, inner, _) => Self::get_child(inner), - _ => panic!("DictionaryArray must be initialized with DataType::Dictionary"), - } + _ => { + return Err(Error::oos( + "Dictionaries must be initialized with DataType::Dictionary", + )) + } + }) } } @@ -213,12 +373,15 @@ impl Array for DictionaryArray { fn slice(&self, offset: usize, length: usize) -> Box { Box::new(self.slice(offset, length)) } + unsafe fn slice_unchecked(&self, offset: usize, length: usize) -> Box { Box::new(self.slice_unchecked(offset, length)) } + fn with_validity(&self, validity: Option) -> Box { Box::new(self.clone().with_validity(validity)) } + fn to_boxed(&self) -> Box { Box::new(self.clone()) } diff --git a/src/array/dictionary/mutable.rs b/src/array/dictionary/mutable.rs index ba0d393f973..c5ca7efb310 100644 --- a/src/array/dictionary/mutable.rs +++ b/src/array/dictionary/mutable.rs @@ -32,12 +32,21 @@ pub struct MutableDictionaryArray { data_type: DataType, keys: MutablePrimitiveArray, map: HashedMap, + // invariant: `keys.len() <= values.len()` values: M, } impl From> for DictionaryArray { fn from(mut other: MutableDictionaryArray) -> Self { - DictionaryArray::::from_data(other.keys.into(), other.values.as_box()) + // Safety - the invariant of this struct ensures that this is up-held + unsafe { + DictionaryArray::::try_new_unchecked( + other.data_type, + other.keys.into(), + other.values.as_box(), + ) + .unwrap() + } } } @@ -91,7 +100,7 @@ impl MutableDictionaryArray { Ok(false) } None => { - let key = K::from_usize(self.map.len()).ok_or(Error::Overflow)?; + let key = K::try_from(self.map.len()).map_err(|_| Error::Overflow)?; self.map.insert(hash, key); self.keys.push(Some(key)); Ok(true) @@ -105,7 +114,7 @@ impl MutableDictionaryArray { } /// returns a mutable reference to the inner values. - pub fn mut_values(&mut self) -> &mut M { + fn mut_values(&mut self) -> &mut M { &mut self.values } @@ -141,6 +150,18 @@ impl MutableDictionaryArray { pub fn keys(&self) -> &MutablePrimitiveArray { &self.keys } + + fn take_into(&mut self) -> DictionaryArray { + // Safety - the invariant of this struct ensures that this is up-held + unsafe { + DictionaryArray::::try_new( + self.data_type.clone(), + std::mem::take(&mut self.keys).into(), + self.values.as_box(), + ) + .unwrap() + } + } } impl MutableArray for MutableDictionaryArray { @@ -153,17 +174,11 @@ impl MutableArray for MutableDictio } fn as_box(&mut self) -> Box { - Box::new(DictionaryArray::::from_data( - std::mem::take(&mut self.keys).into(), - self.values.as_box(), - )) + Box::new(self.take_into()) } fn as_arc(&mut self) -> Arc { - Arc::new(DictionaryArray::::from_data( - std::mem::take(&mut self.keys).into(), - self.values.as_box(), - )) + Arc::new(self.take_into()) } fn data_type(&self) -> &DataType { diff --git a/src/array/equal/dictionary.rs b/src/array/equal/dictionary.rs index 8c879ff8370..d65634095fb 100644 --- a/src/array/equal/dictionary.rs +++ b/src/array/equal/dictionary.rs @@ -1,4 +1,4 @@ -use crate::array::{Array, DictionaryArray, DictionaryKey}; +use crate::array::{DictionaryArray, DictionaryKey}; pub(super) fn equal(lhs: &DictionaryArray, rhs: &DictionaryArray) -> bool { if !(lhs.data_type() == rhs.data_type() && lhs.len() == rhs.len()) { diff --git a/src/array/equal/mod.rs b/src/array/equal/mod.rs index 737edb35e76..aa2ea602882 100644 --- a/src/array/equal/mod.rs +++ b/src/array/equal/mod.rs @@ -87,6 +87,12 @@ impl PartialEq<&dyn Array> for Utf8Array { } } +impl PartialEq> for &dyn Array { + fn eq(&self, other: &Utf8Array) -> bool { + equal(*self, other) + } +} + impl PartialEq> for BinaryArray { fn eq(&self, other: &Self) -> bool { binary::equal(self, other) @@ -99,6 +105,12 @@ impl PartialEq<&dyn Array> for BinaryArray { } } +impl PartialEq> for &dyn Array { + fn eq(&self, other: &BinaryArray) -> bool { + equal(*self, other) + } +} + impl PartialEq for FixedSizeBinaryArray { fn eq(&self, other: &Self) -> bool { fixed_size_binary::equal(self, other) diff --git a/src/array/growable/dictionary.rs b/src/array/growable/dictionary.rs index 908a1394fa9..0b3c432525a 100644 --- a/src/array/growable/dictionary.rs +++ b/src/array/growable/dictionary.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use crate::{ array::{Array, DictionaryArray, DictionaryKey, PrimitiveArray}, bitmap::MutableBitmap, + datatypes::DataType, }; use super::{ @@ -16,6 +17,7 @@ use super::{ /// This growable does not perform collision checks and instead concatenates /// the values of each [`DictionaryArray`] one after the other. pub struct GrowableDictionary<'a, K: DictionaryKey> { + data_type: DataType, keys_values: Vec<&'a [K]>, key_values: Vec, key_validity: MutableBitmap, @@ -44,6 +46,8 @@ impl<'a, T: DictionaryKey> GrowableDictionary<'a, T> { /// # Panics /// If `arrays` is empty. pub fn new(arrays: &[&'a DictionaryArray], mut use_validity: bool, capacity: usize) -> Self { + let data_type = arrays[0].data_type().clone(); + // if any of the arrays has nulls, insertions from any array requires setting bits // as there is at least one array with nulls. if arrays.iter().any(|array| array.null_count() > 0) { @@ -69,6 +73,7 @@ impl<'a, T: DictionaryKey> GrowableDictionary<'a, T> { let (values, offsets) = concatenate_values(&arrays_keys, &arrays_values, capacity); Self { + data_type, offsets, values, keys_values, @@ -83,10 +88,11 @@ impl<'a, T: DictionaryKey> GrowableDictionary<'a, T> { let validity = std::mem::take(&mut self.key_validity); let key_values = std::mem::take(&mut self.key_values); - let data_type = T::PRIMITIVE.into(); - let keys = PrimitiveArray::::from_data(data_type, key_values.into(), validity.into()); + let keys = + PrimitiveArray::::try_new(T::PRIMITIVE.into(), key_values.into(), validity.into()) + .unwrap(); - DictionaryArray::::from_data(keys, self.values.clone()) + DictionaryArray::::try_new(self.data_type.clone(), keys, self.values.clone()).unwrap() } } @@ -101,7 +107,17 @@ impl<'a, T: DictionaryKey> Growable<'a> for GrowableDictionary<'a, T> { values .iter() // `.unwrap_or(0)` because this operation does not check for null values, which may contain any key. - .map(|x| T::from_usize(offset + x.to_usize().unwrap_or(0)).unwrap()), + .map(|x| { + let x: usize = offset + (*x).try_into().unwrap_or(0); + let x: T = match x.try_into() { + Ok(key) => key, + // todo: convert this to an error. + Err(_) => { + panic!("The maximum key is too small") + } + }; + x + }), ); } @@ -127,12 +143,10 @@ impl<'a, T: DictionaryKey> From> for DictionaryArray) -> Self { let data_type = T::PRIMITIVE.into(); - let keys = PrimitiveArray::::from_data( - data_type, - val.key_values.into(), - val.key_validity.into(), - ); + let keys = + PrimitiveArray::::try_new(data_type, val.key_values.into(), val.key_validity.into()) + .unwrap(); - DictionaryArray::::from_data(keys, val.values) + DictionaryArray::::try_new(val.data_type.clone(), keys, val.values).unwrap() } } diff --git a/src/array/ord.rs b/src/array/ord.rs index 8fe3adec311..639317165ab 100644 --- a/src/array/ord.rs +++ b/src/array/ord.rs @@ -141,8 +141,9 @@ where let comparator = build_compare(left.values().as_ref(), right.values().as_ref())?; Ok(Box::new(move |i: usize, j: usize| { - let key_left = left_keys[i].to_usize().unwrap(); - let key_right = right_keys[j].to_usize().unwrap(); + // safety: all dictionaries keys are guaranteed to be castable to usize + let key_left = unsafe { left_keys[i].as_usize() }; + let key_right = unsafe { right_keys[j].as_usize() }; (comparator)(key_left, key_right) })) } diff --git a/src/array/specification.rs b/src/array/specification.rs index 0ede2582329..8c35194f96b 100644 --- a/src/array/specification.rs +++ b/src/array/specification.rs @@ -97,6 +97,22 @@ pub fn try_check_offsets(offsets: &[O], values_len: usize) -> Result< } } +pub fn check_indexes(keys: &[K], len: usize) -> Result<()> +where + K: Copy + TryInto, +{ + keys.iter().try_for_each(|key| { + let key: usize = (*key) + .try_into() + .map_err(|_| Error::oos("The dictionary key must fit in a `usize`"))?; + if key >= len { + Err(Error::oos("The dictionary key must be smaller")) + } else { + Ok(()) + } + }) +} + #[cfg(test)] mod tests { use proptest::prelude::*; diff --git a/src/compute/arithmetics/mod.rs b/src/compute/arithmetics/mod.rs index d0ece217956..b3a52cd7e19 100644 --- a/src/compute/arithmetics/mod.rs +++ b/src/compute/arithmetics/mod.rs @@ -434,9 +434,12 @@ pub fn neg(array: &dyn Array) -> Box { Dictionary(key) => match_integer_type!(key, |$T| { let array = array.as_any().downcast_ref::>().unwrap(); - let values = neg(array.values().as_ref()).into(); + let values = neg(array.values().as_ref()); - Box::new(DictionaryArray::<$T>::from_data(array.keys().clone(), values)) as Box + // safety - this operation only applies to values and thus preserves the dictionary's invariant + unsafe{ + DictionaryArray::<$T>::try_new_unchecked(array.data_type().clone(), array.keys().clone(), values).unwrap().boxed() + } }), _ => todo!(), } diff --git a/src/compute/cast/dictionary_to.rs b/src/compute/cast/dictionary_to.rs index 16b1c06f105..01d3325cf9d 100644 --- a/src/compute/cast/dictionary_to.rs +++ b/src/compute/cast/dictionary_to.rs @@ -15,9 +15,7 @@ macro_rules! key_cast { if cast_keys.null_count() > $keys.null_count() { return Err(Error::Overflow); } - Ok(Box::new(DictionaryArray::<$to_type>::from_data( - cast_keys, $values, - ))) + DictionaryArray::try_new($array.data_type().clone(), $keys.clone(), $values.clone()) }}; } @@ -31,9 +29,14 @@ pub fn dictionary_to_dictionary_values( ) -> Result> { let keys = from.keys(); let values = from.values(); + let length = values.len(); let values = cast(values.as_ref(), values_type, CastOptions::default())?; - Ok(DictionaryArray::from_data(keys.clone(), values)) + + assert_eq!(values.len(), length); // this is guaranteed by `cast` + unsafe { + DictionaryArray::try_new_unchecked(from.data_type().clone(), keys.clone(), values.clone()) + } } /// Similar to dictionary_to_dictionary_values, but overflowing cast is wrapped @@ -43,6 +46,7 @@ pub fn wrapping_dictionary_to_dictionary_values( ) -> Result> { let keys = from.keys(); let values = from.values(); + let length = values.len(); let values = cast( values.as_ref(), @@ -52,7 +56,10 @@ pub fn wrapping_dictionary_to_dictionary_values( partial: false, }, )?; - Ok(DictionaryArray::from_data(keys.clone(), values)) + assert_eq!(values.len(), length); // this is guaranteed by `cast` + unsafe { + DictionaryArray::try_new_unchecked(from.data_type().clone(), keys.clone(), values.clone()) + } } /// Casts a [`DictionaryArray`] to a new [`DictionaryArray`] backed by a @@ -64,18 +71,25 @@ pub fn dictionary_to_dictionary_keys( from: &DictionaryArray, ) -> Result> where - K1: DictionaryKey, - K2: DictionaryKey, + K1: DictionaryKey + num_traits::NumCast, + K2: DictionaryKey + num_traits::NumCast, { let keys = from.keys(); let values = from.values(); + let is_ordered = from.is_ordered(); let casted_keys = primitive_to_primitive::(keys, &K2::PRIMITIVE.into()); if casted_keys.null_count() > keys.null_count() { Err(Error::Overflow) } else { - Ok(DictionaryArray::from_data(casted_keys, values.clone())) + let data_type = DataType::Dictionary( + K2::KEY_TYPE, + Box::new(values.data_type().clone()), + is_ordered, + ); + // some of the values may not fit in `usize` and thus this needs to be checked + DictionaryArray::try_new(data_type, casted_keys, values.clone()) } } @@ -89,17 +103,24 @@ where { let keys = from.keys(); let values = from.values(); + let is_ordered = from.is_ordered(); let casted_keys = primitive_as_primitive::(keys, &K2::PRIMITIVE.into()); if casted_keys.null_count() > keys.null_count() { Err(Error::Overflow) } else { - Ok(DictionaryArray::from_data(casted_keys, values.clone())) + let data_type = DataType::Dictionary( + K2::KEY_TYPE, + Box::new(values.data_type().clone()), + is_ordered, + ); + // some of the values may not fit in `usize` and thus this needs to be checked + DictionaryArray::try_new(data_type, casted_keys, values.clone()) } } -pub(super) fn dictionary_cast_dyn( +pub(super) fn dictionary_cast_dyn( array: &dyn Array, to_type: &DataType, options: CastOptions, @@ -117,6 +138,7 @@ pub(super) fn dictionary_cast_dyn( match_integer_type!(to_keys_type, |$T| { key_cast!(keys, values, array, &data_type, $T) }) + .map(|x| x.boxed()) } _ => unpack_dictionary::(keys, values.as_ref(), to_type, options), } @@ -130,7 +152,7 @@ fn unpack_dictionary( options: CastOptions, ) -> Result> where - K: DictionaryKey, + K: DictionaryKey + num_traits::NumCast, { // attempt to cast the dict values to the target type // use the take kernel to expand out the dictionary @@ -146,7 +168,7 @@ where /// The resulting array has the same length. pub fn dictionary_to_values(from: &DictionaryArray) -> Box where - K: DictionaryKey, + K: DictionaryKey + num_traits::NumCast, { // take requires first casting i64 let indices = primitive_to_primitive::<_, i64>(from.keys(), &DataType::Int64); diff --git a/src/compute/sort/utf8.rs b/src/compute/sort/utf8.rs index 4d00dc42ecb..e2e2da1bc56 100644 --- a/src/compute/sort/utf8.rs +++ b/src/compute/sort/utf8.rs @@ -28,10 +28,13 @@ pub(super) fn indices_sorted_unstable_by_dictionary>() .unwrap(); - let get = |idx| unsafe { - let index = keys.value_unchecked(idx as usize); - // Note: there is no check that the keys are within bounds of the dictionary. - dict.value(index.to_usize().unwrap()) + let get = |index| unsafe { + // safety: indices_sorted_unstable_by is guaranteed to get items in bounds + let index = keys.value_unchecked(index); + // safety: dictionaries are guaranteed to have valid usize keys + let index = index.as_usize(); + // safety: dictionaries are guaranteed to have keys in bounds + dict.value_unchecked(index) }; let cmp = |lhs: &&str, rhs: &&str| lhs.cmp(rhs); diff --git a/src/compute/take/dict.rs b/src/compute/take/dict.rs index aa602850f80..1d1efc4fb3f 100644 --- a/src/compute/take/dict.rs +++ b/src/compute/take/dict.rs @@ -30,5 +30,13 @@ where I: Index, { let keys = take_primitive::(values.keys(), indices); - DictionaryArray::::from_data(keys, values.values().clone()) + // safety - this operation takes a subset of keys and thus preserves the dictionary's invariant + unsafe { + DictionaryArray::::try_new_unchecked( + values.data_type().clone(), + keys, + values.values().clone(), + ) + .unwrap() + } } diff --git a/src/io/avro/read/nested.rs b/src/io/avro/read/nested.rs index e6138e74d33..450930cad73 100644 --- a/src/io/avro/read/nested.rs +++ b/src/io/avro/read/nested.rs @@ -157,17 +157,25 @@ impl MutableArray for FixedItemsUtf8Dictionary { } fn as_box(&mut self) -> Box { - Box::new(DictionaryArray::from_data( - std::mem::take(&mut self.keys).into(), - Box::new(self.values.clone()), - )) + Box::new( + DictionaryArray::try_new( + self.data_type.clone(), + std::mem::take(&mut self.keys).into(), + Box::new(self.values.clone()), + ) + .unwrap(), + ) } fn as_arc(&mut self) -> std::sync::Arc { - std::sync::Arc::new(DictionaryArray::from_data( - std::mem::take(&mut self.keys).into(), - Box::new(self.values.clone()), - )) + std::sync::Arc::new( + DictionaryArray::try_new( + self.data_type.clone(), + std::mem::take(&mut self.keys).into(), + Box::new(self.values.clone()), + ) + .unwrap(), + ) } fn data_type(&self) -> &DataType { diff --git a/src/io/csv/write/serialize.rs b/src/io/csv/write/serialize.rs index f1031a979cb..6f704d1cc08 100644 --- a/src/io/csv/write/serialize.rs +++ b/src/io/csv/write/serialize.rs @@ -471,17 +471,15 @@ fn serialize_utf8_dict<'a, K: DictionaryKey, O: Offset>( array: &'a dyn Any, ) -> Box + 'a> { let array = array.downcast_ref::>().unwrap(); - let keys = array.keys(); let values = array .values() .as_any() .downcast_ref::>() .unwrap(); Box::new(BufStreamingIterator::new( - keys.iter(), + array.keys_iter(), move |x, buf| { - if let Some(x) = x { - let i = x.to_usize().unwrap(); + if let Some(i) = x { if !values.is_null(i) { let val = values.value(i); buf.extend_from_slice(val.as_bytes()); diff --git a/src/io/ipc/read/array/dictionary.rs b/src/io/ipc/read/array/dictionary.rs index 69c617852f6..f3d92f65a2f 100644 --- a/src/io/ipc/read/array/dictionary.rs +++ b/src/io/ipc/read/array/dictionary.rs @@ -3,6 +3,7 @@ use std::convert::TryInto; use std::io::{Read, Seek}; use crate::array::{DictionaryArray, DictionaryKey}; +use crate::datatypes::DataType; use crate::error::{Error, Result}; use super::super::Dictionaries; @@ -12,6 +13,7 @@ use super::{read_primitive, skip_primitive}; #[allow(clippy::too_many_arguments)] pub fn read_dictionary( field_nodes: &mut VecDeque, + data_type: DataType, id: Option, buffers: &mut VecDeque, reader: &mut R, @@ -53,7 +55,7 @@ where scratch, )?; - Ok(DictionaryArray::::from_data(keys, values)) + DictionaryArray::::try_new(data_type, keys, values) } pub fn skip_dictionary( diff --git a/src/io/ipc/read/deserialize.rs b/src/io/ipc/read/deserialize.rs index fec5676def5..13300033b66 100644 --- a/src/io/ipc/read/deserialize.rs +++ b/src/io/ipc/read/deserialize.rs @@ -192,6 +192,7 @@ pub fn read( match_integer_type!(key_type, |$T| { read_dictionary::<$T, _>( field_nodes, + data_type, ipc_field.dictionary_id, buffers, reader, diff --git a/src/io/json/read/deserialize.rs b/src/io/json/read/deserialize.rs index aba6cc6f01d..b7add12bc97 100644 --- a/src/io/json/read/deserialize.rs +++ b/src/io/json/read/deserialize.rs @@ -282,7 +282,7 @@ fn deserialize_dictionary<'a, K: DictionaryKey, A: Borrow>>( rows: &[A], data_type: DataType, ) -> DictionaryArray { - let child = DictionaryArray::::get_child(&data_type); + let child = DictionaryArray::::try_get_child(&data_type).unwrap(); let mut map = HashedMap::::default(); @@ -296,8 +296,11 @@ fn deserialize_dictionary<'a, K: DictionaryKey, A: Borrow>>( Some((hash, v)) => match map.get(&hash) { Some(key) => Some(*key), None => { - // todo: convert this to an error. - let key = K::from_usize(map.len()).unwrap(); + let key = match map.len().try_into() { + Ok(key) => key, + // todo: convert this to an error. + Err(_) => panic!("The maximum key is too small for this json struct"), + }; inner.push(v); map.insert(hash, key); Some(key) @@ -307,8 +310,9 @@ fn deserialize_dictionary<'a, K: DictionaryKey, A: Borrow>>( }) .collect::>(); + drop(extractor); let values = _deserialize(&inner, child.clone()); - DictionaryArray::::from_data(keys, values) + DictionaryArray::::try_new(data_type, keys, values).unwrap() } pub(crate) fn _deserialize<'a, A: Borrow>>( diff --git a/src/io/json_integration/read/array.rs b/src/io/json_integration/read/array.rs index e434cbe6918..2d7d0ec39a6 100644 --- a/src/io/json_integration/read/array.rs +++ b/src/io/json_integration/read/array.rs @@ -230,7 +230,7 @@ fn to_map( Ok(Box::new(MapArray::new(data_type, offsets, field, validity))) } -fn to_dictionary( +fn to_dictionary( data_type: DataType, field: &IpcField, json_col: &ArrowJsonColumn, @@ -244,7 +244,7 @@ fn to_dictionary( let keys = to_primitive(json_col, K::PRIMITIVE.into()); - let inner_data_type = DictionaryArray::::get_child(&data_type); + let inner_data_type = DictionaryArray::::try_get_child(&data_type)?; let values = to_array( inner_data_type.clone(), field, @@ -252,7 +252,7 @@ fn to_dictionary( dictionaries, )?; - Ok(Box::new(DictionaryArray::::from_data(keys, values))) + DictionaryArray::::try_new(data_type, keys, values).map(|a| a.boxed()) } /// Construct an [`Array`] from the JSON integration format diff --git a/src/io/parquet/read/deserialize/binary/dictionary.rs b/src/io/parquet/read/deserialize/binary/dictionary.rs index b5448a077bf..bf656929e65 100644 --- a/src/io/parquet/read/deserialize/binary/dictionary.rs +++ b/src/io/parquet/read/deserialize/binary/dictionary.rs @@ -23,6 +23,7 @@ where { iter: I, data_type: DataType, + values_data_type: DataType, values: Dict, items: VecDeque<(Vec, MutableBitmap)>, chunk_size: Option, @@ -36,13 +37,14 @@ where I: DataPages, { pub fn new(iter: I, data_type: DataType, chunk_size: Option) -> Self { - let data_type = match data_type { + let values_data_type = match &data_type { DataType::Dictionary(_, values, _) => values.as_ref().clone(), _ => unreachable!(), }; Self { iter, data_type, + values_data_type, values: Dict::Empty, items: VecDeque::new(), chunk_size, @@ -90,8 +92,9 @@ where &mut self.iter, &mut self.items, &mut self.values, + self.data_type.clone(), self.chunk_size, - |dict| read_dict::(self.data_type.clone(), dict), + |dict| read_dict::(self.values_data_type.clone(), dict), ); match maybe_state { MaybeNext::Some(Ok(dict)) => Some(Ok(dict)), diff --git a/src/io/parquet/read/deserialize/dictionary.rs b/src/io/parquet/read/deserialize/dictionary.rs index 62ef51adc9c..f89a5e06d5a 100644 --- a/src/io/parquet/read/deserialize/dictionary.rs +++ b/src/io/parquet/read/deserialize/dictionary.rs @@ -10,6 +10,7 @@ use parquet2::{ use crate::{ array::{Array, DictionaryArray, DictionaryKey, PrimitiveArray}, bitmap::MutableBitmap, + datatypes::DataType, error::{Error, Result}, }; @@ -158,13 +159,30 @@ where &mut page.validity, Some(remaining), values, - &mut page.values.by_ref().map(|x| K::from_u32(x).unwrap()), + &mut page.values.by_ref().map(|x| { + let x: usize = x.try_into().unwrap(); + match x.try_into() { + Ok(key) => key, + // todo: convert this to an error. + Err(_) => panic!("The maximum key is too small"), + } + }), ), State::Required(page) => { values.extend( page.values .by_ref() - .map(|x| K::from_u32(x).unwrap()) + .map(|x| { + let x: usize = x.try_into().unwrap(); + let x: K = match x.try_into() { + Ok(key) => key, + // todo: convert this to an error. + Err(_) => { + panic!("The maximum key is too small") + } + }; + x + }) .take(remaining), ); } @@ -173,13 +191,33 @@ where page_validity, Some(remaining), values, - &mut page_values.by_ref().map(|x| K::from_u32(x).unwrap()), + &mut page_values.by_ref().map(|x| { + let x: usize = x.try_into().unwrap(); + let x: K = match x.try_into() { + Ok(key) => key, + // todo: convert this to an error. + Err(_) => { + panic!("The maximum key is too small") + } + }; + x + }), ), State::FilteredRequired(page) => { values.extend( page.values .by_ref() - .map(|x| K::from_u32(x).unwrap()) + .map(|x| { + let x: usize = x.try_into().unwrap(); + let x: K = match x.try_into() { + Ok(key) => key, + // todo: convert this to an error. + Err(_) => { + panic!("The maximum key is too small") + } + }; + x + }) .take(remaining), ); } @@ -203,7 +241,7 @@ impl Dict { } fn finish_key(values: Vec, validity: MutableBitmap) -> PrimitiveArray { - PrimitiveArray::from_data(K::PRIMITIVE.into(), values.into(), validity.into()) + PrimitiveArray::new(K::PRIMITIVE.into(), values.into(), validity.into()) } #[inline] @@ -216,13 +254,14 @@ pub(super) fn next_dict< iter: &'a mut I, items: &mut VecDeque<(Vec, MutableBitmap)>, dict: &mut Dict, + data_type: DataType, chunk_size: Option, read_dict: F, ) -> MaybeNext>> { if items.len() > 1 { let (values, validity) = items.pop_front().unwrap(); let keys = finish_key(values, validity); - return MaybeNext::Some(Ok(DictionaryArray::from_data(keys, dict.unwrap()))); + return MaybeNext::Some(DictionaryArray::try_new(data_type, keys, dict.unwrap())); } match iter.next() { Err(e) => MaybeNext::Some(Err(e.into())), @@ -255,7 +294,7 @@ pub(super) fn next_dict< let (values, validity) = items.pop_front().unwrap(); let keys = PrimitiveArray::from_data(K::PRIMITIVE.into(), values.into(), validity.into()); - MaybeNext::Some(Ok(DictionaryArray::from_data(keys, dict.unwrap()))) + MaybeNext::Some(DictionaryArray::try_new(data_type, keys, dict.unwrap())) } } Ok(None) => { @@ -266,7 +305,7 @@ pub(super) fn next_dict< let keys = finish_key(values, validity); - MaybeNext::Some(Ok(DictionaryArray::from_data(keys, dict.unwrap()))) + MaybeNext::Some(DictionaryArray::try_new(data_type, keys, dict.unwrap())) } else { MaybeNext::None } diff --git a/src/io/parquet/read/deserialize/fixed_size_binary/dictionary.rs b/src/io/parquet/read/deserialize/fixed_size_binary/dictionary.rs index 1db6b48003e..4d44ef4f724 100644 --- a/src/io/parquet/read/deserialize/fixed_size_binary/dictionary.rs +++ b/src/io/parquet/read/deserialize/fixed_size_binary/dictionary.rs @@ -22,6 +22,7 @@ where { iter: I, data_type: DataType, + values_data_type: DataType, values: Dict, items: VecDeque<(Vec, MutableBitmap)>, chunk_size: Option, @@ -33,13 +34,14 @@ where I: DataPages, { pub fn new(iter: I, data_type: DataType, chunk_size: Option) -> Self { - let data_type = match data_type { + let values_data_type = match &data_type { DataType::Dictionary(_, values, _) => values.as_ref().clone(), _ => unreachable!(), }; Self { iter, data_type, + values_data_type, values: Dict::Empty, items: VecDeque::new(), chunk_size, @@ -73,8 +75,9 @@ where &mut self.iter, &mut self.items, &mut self.values, + self.data_type.clone(), self.chunk_size, - |dict| read_dict(self.data_type.clone(), dict), + |dict| read_dict(self.values_data_type.clone(), dict), ); match maybe_state { MaybeNext::Some(Ok(dict)) => Some(Ok(dict)), diff --git a/src/io/parquet/read/deserialize/primitive/dictionary.rs b/src/io/parquet/read/deserialize/primitive/dictionary.rs index c7293633795..2b3f0ca5491 100644 --- a/src/io/parquet/read/deserialize/primitive/dictionary.rs +++ b/src/io/parquet/read/deserialize/primitive/dictionary.rs @@ -45,6 +45,7 @@ where { iter: I, data_type: DataType, + values_data_type: DataType, values: Dict, items: VecDeque<(Vec, MutableBitmap)>, chunk_size: Option, @@ -62,13 +63,14 @@ where F: Copy + Fn(P) -> T, { pub fn new(iter: I, data_type: DataType, chunk_size: Option, op: F) -> Self { - let data_type = match data_type { - DataType::Dictionary(_, values, _) => *values, - _ => data_type, + let values_data_type = match &data_type { + DataType::Dictionary(_, values, _) => *(values.clone()), + _ => unreachable!(), }; Self { iter, data_type, + values_data_type, values: Dict::Empty, items: VecDeque::new(), chunk_size, @@ -93,8 +95,9 @@ where &mut self.iter, &mut self.items, &mut self.values, + self.data_type.clone(), self.chunk_size, - |dict| read_dict::(self.data_type.clone(), self.op, dict), + |dict| read_dict::(self.values_data_type.clone(), self.op, dict), ); match maybe_state { MaybeNext::Some(Ok(dict)) => Some(Ok(dict)), diff --git a/src/io/parquet/read/statistics/dictionary.rs b/src/io/parquet/read/statistics/dictionary.rs index d29feb46802..b2553c62f1b 100644 --- a/src/io/parquet/read/statistics/dictionary.rs +++ b/src/io/parquet/read/statistics/dictionary.rs @@ -41,7 +41,7 @@ impl MutableArray for DynMutableDictionary { match self.data_type.to_physical_type() { PhysicalType::Dictionary(key) => match_integer_type!(key, |$T| { let keys = PrimitiveArray::<$T>::from_iter((0..inner.len() as $T).map(Some)); - Box::new(DictionaryArray::<$T>::from_data(keys, inner)) + Box::new(DictionaryArray::<$T>::try_new(self.data_type.clone(), keys, inner).unwrap()) }), _ => todo!(), } diff --git a/src/io/parquet/write/dictionary.rs b/src/io/parquet/write/dictionary.rs index 05c56207bfe..f6a9bcabbb6 100644 --- a/src/io/parquet/write/dictionary.rs +++ b/src/io/parquet/write/dictionary.rs @@ -20,26 +20,25 @@ use crate::datatypes::DataType; use crate::error::{Error, Result}; use crate::io::parquet::write::utils; use crate::{ - array::{Array, DictionaryArray, DictionaryKey, PrimitiveArray}, + array::{Array, DictionaryArray, DictionaryKey}, io::parquet::read::schema::is_nullable, }; fn encode_keys( - array: &PrimitiveArray, - validity: Option<&Bitmap>, + array: &DictionaryArray, type_: PrimitiveType, statistics: ParquetStatistics, options: WriteOptions, ) -> Result { + let validity = array.values().validity(); let is_optional = is_nullable(&type_.field_info); let mut buffer = vec![]; let null_count = if let Some(validity) = validity { - let projected_validity = array.iter().map(|x| { - x.map(|x| validity.get_bit(x.to_usize().unwrap())) - .unwrap_or(false) - }); + let projected_validity = array + .keys_iter() + .map(|x| x.map(|x| validity.get_bit(x)).unwrap_or(false)); let projected_val = Bitmap::from_trusted_len_iter(projected_validity); let null_count = projected_val.unset_bits(); @@ -68,8 +67,7 @@ fn encode_keys( // encode indices // compute the required number of bits if let Some(validity) = validity { - let keys = array.iter().flatten().filter_map(|x| { - let index = x.to_usize().unwrap(); + let keys = array.keys_iter().flatten().filter_map(|index| { // discard indices whose values are null, since they are part of the def levels. if validity.get_bit(index) { Some(index as u32) @@ -87,7 +85,7 @@ fn encode_keys( // followed by the encoded indices. encode_u32(&mut buffer, keys, num_bits)?; } else { - let keys = array.iter().flatten().map(|x| x.to_usize().unwrap() as u32); + let keys = array.keys_iter().flatten().map(|x| x as u32); let num_bits = utils::get_bit_width(keys.clone().max().unwrap_or(0) as u64) as u8; let keys = utils::ExactSizedIter::new(keys, array.len() - array.null_count()); @@ -202,13 +200,7 @@ pub fn array_to_pages( let dict_page = EncodedPage::Dict(dict_page); // write DataPage pointing to DictPage - let data_page = encode_keys( - array.keys(), - array.values().validity(), - type_, - statistics, - options, - )?; + let data_page = encode_keys(array, type_, statistics, options)?; let iter = std::iter::once(Ok(dict_page)).chain(std::iter::once(Ok(data_page))); Ok(DynIter::new(Box::new(iter))) diff --git a/tests/it/array/dictionary/mod.rs b/tests/it/array/dictionary/mod.rs index f7a9b61988d..f436e2685db 100644 --- a/tests/it/array/dictionary/mod.rs +++ b/tests/it/array/dictionary/mod.rs @@ -1 +1,153 @@ mod mutable; + +use arrow2::{array::*, datatypes::DataType}; + +#[test] +fn try_new_ok() { + let values = Utf8Array::::from_slice(&["a", "aa"]); + let data_type = + DataType::Dictionary(i32::KEY_TYPE, Box::new(values.data_type().clone()), false); + let array = DictionaryArray::try_new( + data_type, + PrimitiveArray::from_vec(vec![1, 0]), + values.boxed(), + ) + .unwrap(); + + assert_eq!(array.keys(), &PrimitiveArray::from_vec(vec![1i32, 0])); + assert_eq!( + &Utf8Array::::from_slice(&["a", "aa"]) as &dyn Array, + array.values().as_ref(), + ); + assert!(!array.is_ordered()); + + assert_eq!(format!("{:?}", array), "DictionaryArray[aa, a]"); +} + +#[test] +fn try_new_incorrect_key() { + let values = Utf8Array::::from_slice(&["a", "aa"]); + let data_type = + DataType::Dictionary(i16::KEY_TYPE, Box::new(values.data_type().clone()), false); + + let r = DictionaryArray::try_new( + data_type, + PrimitiveArray::from_vec(vec![1, 0]), + values.boxed(), + ) + .is_err(); + + assert!(r); +} + +#[test] +fn try_new_incorrect_dt() { + let values = Utf8Array::::from_slice(&["a", "aa"]); + let data_type = DataType::Int32; + + let r = DictionaryArray::try_new( + data_type, + PrimitiveArray::from_vec(vec![1, 0]), + values.boxed(), + ) + .is_err(); + + assert!(r); +} + +#[test] +fn try_new_incorrect_values_dt() { + let values = Utf8Array::::from_slice(&["a", "aa"]); + let data_type = DataType::Dictionary(i32::KEY_TYPE, Box::new(DataType::LargeUtf8), false); + + let r = DictionaryArray::try_new( + data_type, + PrimitiveArray::from_vec(vec![1, 0]), + values.boxed(), + ) + .is_err(); + + assert!(r); +} + +#[test] +fn try_new_out_of_bounds() { + let values = Utf8Array::::from_slice(&["a", "aa"]); + + let r = DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![2, 0]), values.boxed()) + .is_err(); + + assert!(r); +} + +#[test] +fn try_new_out_of_bounds_neg() { + let values = Utf8Array::::from_slice(&["a", "aa"]); + + let r = DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![-1, 0]), values.boxed()) + .is_err(); + + assert!(r); +} + +#[test] +fn new_null() { + let dt = DataType::Dictionary(i16::KEY_TYPE, Box::new(DataType::Int32), false); + let array = DictionaryArray::::new_null(dt, 2); + + assert_eq!(format!("{:?}", array), "DictionaryArray[None, None]"); +} + +#[test] +fn new_empty() { + let dt = DataType::Dictionary(i16::KEY_TYPE, Box::new(DataType::Int32), false); + let array = DictionaryArray::::new_empty(dt); + + assert_eq!(format!("{:?}", array), "DictionaryArray[]"); +} + +#[test] +fn with_validity() { + let values = Utf8Array::::from_slice(&["a", "aa"]); + let array = + DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0]), values.boxed()) + .unwrap(); + + let array = array.with_validity(Some([true, false].into())); + + assert_eq!(format!("{:?}", array), "DictionaryArray[aa, None]"); +} + +#[test] +fn rev_iter() { + let values = Utf8Array::::from_slice(&["a", "aa"]); + let array = + DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0]), values.boxed()) + .unwrap(); + + let mut iter = array.into_iter(); + assert_eq!(iter.by_ref().rev().count(), 2); + assert_eq!(iter.size_hint(), (0, Some(0))); +} + +#[test] +fn iter_values() { + let values = Utf8Array::::from_slice(&["a", "aa"]); + let array = + DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0]), values.boxed()) + .unwrap(); + + let mut iter = array.values_iter(); + assert_eq!(iter.by_ref().count(), 2); + assert_eq!(iter.size_hint(), (0, Some(0))); +} + +#[test] +fn keys_values_iter() { + let values = Utf8Array::::from_slice(&["a", "aa"]); + let array = + DictionaryArray::try_from_keys(PrimitiveArray::from_vec(vec![1, 0]), values.boxed()) + .unwrap(); + + assert_eq!(array.keys_values_iter().collect::>(), vec![1, 0]); +} diff --git a/tests/it/array/equal/dictionary.rs b/tests/it/array/equal/dictionary.rs index 1bbe9627d6f..8e25d083e1a 100644 --- a/tests/it/array/equal/dictionary.rs +++ b/tests/it/array/equal/dictionary.rs @@ -6,7 +6,7 @@ fn create_dictionary_array(values: &[Option<&str>], keys: &[Option]) -> Dic let keys = Int16Array::from(keys); let values = Utf8Array::::from(values); - DictionaryArray::from_data(keys, Box::new(values)) + DictionaryArray::try_from_keys(keys, values.boxed()).unwrap() } #[test] diff --git a/tests/it/array/growable/dictionary.rs b/tests/it/array/growable/dictionary.rs index a45c0b298be..5d0eeef226d 100644 --- a/tests/it/array/growable/dictionary.rs +++ b/tests/it/array/growable/dictionary.rs @@ -1,6 +1,5 @@ use arrow2::array::growable::{Growable, GrowableDictionary}; use arrow2::array::*; -use arrow2::datatypes::DataType; use arrow2::error::Result; #[test] @@ -13,10 +12,11 @@ fn test_single() -> Result<()> { let array = array.into(); // same values, less keys - let expected = DictionaryArray::::from_data( - PrimitiveArray::from(vec![Some(1), Some(0)]), + let expected = DictionaryArray::try_from_keys( + PrimitiveArray::from_vec(vec![1, 0]), Box::new(Utf8Array::::from(&original_data)), - ); + ) + .unwrap(); let mut growable = GrowableDictionary::new(&[&array], false, 0); @@ -28,25 +28,6 @@ fn test_single() -> Result<()> { Ok(()) } -#[test] -fn test_negative_keys() { - let vals = vec![Some("a"), Some("b"), Some("c")]; - let keys = vec![0, 1, 2, -1]; - - let keys = PrimitiveArray::from_data( - DataType::Int32, - keys.into(), - Some(vec![true, true, true, false].into()), - ); - - let arr = DictionaryArray::from_data(keys, Box::new(Utf8Array::::from(vals))); - // check that we don't panic with negative keys to usize conversion - let mut growable = GrowableDictionary::new(&[&arr], false, 0); - growable.extend(0, 0, 4); - let out: DictionaryArray = growable.into(); - assert_eq!(out, arr); -} - #[test] fn test_multi() -> Result<()> { let mut original_data1 = vec![Some("a"), Some("b"), None, Some("a")]; @@ -65,10 +46,11 @@ fn test_multi() -> Result<()> { // same values, less keys original_data1.extend(original_data2.iter().cloned()); - let expected = DictionaryArray::::from_data( - PrimitiveArray::from(vec![Some(1), None, Some(3), None]), - Box::new(Utf8Array::::from_slice(&["a", "b", "c", "b", "a"])), - ); + let expected = DictionaryArray::try_from_keys( + PrimitiveArray::from(&[Some(1), None, Some(3), None]), + Utf8Array::::from_slice(&["a", "b", "c", "b", "a"]).boxed(), + ) + .unwrap(); let mut growable = GrowableDictionary::new(&[&array1, &array2], false, 0); diff --git a/tests/it/array/growable/mod.rs b/tests/it/array/growable/mod.rs index acc5da6ddc5..a6e25ecba37 100644 --- a/tests/it/array/growable/mod.rs +++ b/tests/it/array/growable/mod.rs @@ -38,9 +38,10 @@ fn test_make_growable() { FixedSizeBinaryArray::new(DataType::FixedSizeBinary(2), b"abcd".to_vec().into(), None); make_growable(&[&array], false, 2); - let array = DictionaryArray::::from_data( - Int32Array::from_slice([1, 2]), - Box::new(Int32Array::from_slice([1, 2])), - ); + let array = DictionaryArray::try_from_keys( + Int32Array::from_slice([1, 0]), + Int32Array::from_slice([1, 2]).boxed(), + ) + .unwrap(); make_growable(&[&array], false, 2); } diff --git a/tests/it/compute/arithmetics/mod.rs b/tests/it/compute/arithmetics/mod.rs index f697cd8041a..a22f146d7b2 100644 --- a/tests/it/compute/arithmetics/mod.rs +++ b/tests/it/compute/arithmetics/mod.rs @@ -95,14 +95,16 @@ fn test_neg() { #[test] fn test_neg_dict() { - let a = DictionaryArray::::from_data( + let a = DictionaryArray::try_from_keys( UInt8Array::from_slice(&[0, 0, 1]), - Box::new(Int8Array::from_slice(&[1, 2])), - ); + Int8Array::from_slice(&[1, 2]).boxed(), + ) + .unwrap(); let result = neg(&a); - let expected = DictionaryArray::::from_data( + let expected = DictionaryArray::try_from_keys( UInt8Array::from_slice(&[0, 0, 1]), - Box::new(Int8Array::from_slice(&[-1, -2])), - ); + Int8Array::from_slice(&[-1, -2]).boxed(), + ) + .unwrap(); assert_eq!(expected, result.as_ref()); } diff --git a/tests/it/io/avro/read.rs b/tests/it/io/avro/read.rs index bef9dcb252a..deb8ffd802f 100644 --- a/tests/it/io/avro/read.rs +++ b/tests/it/io/avro/read.rs @@ -106,10 +106,11 @@ pub(super) fn data() -> Chunk> { None, ) .boxed(), - DictionaryArray::::from_data( + DictionaryArray::try_from_keys( Int32Array::from_slice([1, 0]), Box::new(Utf8Array::::from_slice(["SPADES", "HEARTS"])), ) + .unwrap() .boxed(), PrimitiveArray::::from_slice([12345678i128, -12345678i128]) .to(DataType::Decimal(18, 5)) diff --git a/tests/it/io/csv/write.rs b/tests/it/io/csv/write.rs index 44bec50497f..80d7212175f 100644 --- a/tests/it/io/csv/write.rs +++ b/tests/it/io/csv/write.rs @@ -16,7 +16,7 @@ fn data() -> Chunk> { let c6 = PrimitiveArray::::from_vec(vec![1234, 24680, 85563]) .to(DataType::Time32(TimeUnit::Second)); let keys = UInt32Array::from_slice(&[2, 0, 1]); - let c7 = DictionaryArray::from_data(keys, Box::new(c1.clone())); + let c7 = DictionaryArray::try_from_keys(keys, Box::new(c1.clone())).unwrap(); Chunk::new(vec![ Box::new(c1) as Box, @@ -256,13 +256,13 @@ fn data_array(column: &str) -> (Chunk>, Vec<&'static str>) { "dictionary[u32]" => { let keys = UInt32Array::from_slice(&[2, 1, 0]); let values = Utf8Array::::from_slice(["a b", "c", "d"]).boxed(); - let array = DictionaryArray::from_data(keys, values); + let array = DictionaryArray::try_from_keys(keys, values).unwrap(); (array.boxed(), vec!["d", "c", "a b"]) } "dictionary[u64]" => { let keys = UInt64Array::from_slice(&[2, 1, 0]); let values = Utf8Array::::from_slice(["a b", "c", "d"]).boxed(); - let array = DictionaryArray::from_data(keys, values); + let array = DictionaryArray::try_from_keys(keys, values).unwrap(); (array.boxed(), vec!["d", "c", "a b"]) } _ => todo!(), diff --git a/tests/it/io/parquet/mod.rs b/tests/it/io/parquet/mod.rs index 197cf20be56..4f3ddd2b6a3 100644 --- a/tests/it/io/parquet/mod.rs +++ b/tests/it/io/parquet/mod.rs @@ -359,7 +359,7 @@ pub fn pyarrow_nullable(column: &str) -> Box { "int32_dict" => { let keys = PrimitiveArray::::from([Some(0), Some(1), None, Some(1)]); let values = Box::new(PrimitiveArray::::from_slice([10, 200])); - Box::new(DictionaryArray::::from_data(keys, values)) + Box::new(DictionaryArray::try_from_keys(keys, values).unwrap()) } "decimal_9" => { let values = i64_values @@ -440,10 +440,7 @@ pub fn pyarrow_nullable_statistics(column: &str) -> Statistics { }, "int32_dict" => { let new_dict = |array: Box| -> Box { - Box::new(DictionaryArray::::from_data( - vec![Some(0)].into(), - array, - )) + Box::new(DictionaryArray::try_from_keys(vec![Some(0)].into(), array).unwrap()) }; Statistics { @@ -999,42 +996,45 @@ fn arrow_type() -> Result<()> { let array2 = Utf8Array::::from([Some("a"), None, Some("bb")]); let indices = PrimitiveArray::from_values((0..3u64).map(|x| x % 2)); - let values = PrimitiveArray::from_slice([1.0f32, 3.0]); - let array3 = DictionaryArray::from_data(indices.clone(), Box::new(values)); + let values = PrimitiveArray::from_slice([1.0f32, 3.0]).boxed(); + let array3 = DictionaryArray::try_from_keys(indices.clone(), values).unwrap(); - let values = BinaryArray::::from_slice([b"ab", b"ac"]); - let array4 = DictionaryArray::from_data(indices.clone(), Box::new(values)); + let values = BinaryArray::::from_slice([b"ab", b"ac"]).boxed(); + let array4 = DictionaryArray::try_from_keys(indices.clone(), values).unwrap(); let values = FixedSizeBinaryArray::from_data( DataType::FixedSizeBinary(2), vec![b'a', b'b', b'a', b'c'].into(), None, - ); - let array5 = DictionaryArray::from_data(indices.clone(), Box::new(values)); + ) + .boxed(); + let array5 = DictionaryArray::try_from_keys(indices.clone(), values).unwrap(); - let values = PrimitiveArray::from_slice([1i16, 3]); - let array6 = DictionaryArray::from_data(indices.clone(), Box::new(values)); + let values = PrimitiveArray::from_slice([1i16, 3]).boxed(); + let array6 = DictionaryArray::try_from_keys(indices.clone(), values).unwrap(); - let values = PrimitiveArray::from_slice([1i64, 3]).to(DataType::Timestamp( - TimeUnit::Millisecond, - Some("UTC".to_string()), - )); - let array7 = DictionaryArray::from_data(indices.clone(), Box::new(values)); + let values = PrimitiveArray::from_slice([1i64, 3]) + .to(DataType::Timestamp( + TimeUnit::Millisecond, + Some("UTC".to_string()), + )) + .boxed(); + let array7 = DictionaryArray::try_from_keys(indices.clone(), values).unwrap(); - let values = PrimitiveArray::from_slice([1.0f64, 3.0]); - let array8 = DictionaryArray::from_data(indices.clone(), Box::new(values)); + let values = PrimitiveArray::from_slice([1.0f64, 3.0]).boxed(); + let array8 = DictionaryArray::try_from_keys(indices.clone(), values).unwrap(); - let values = PrimitiveArray::from_slice([1u8, 3]); - let array9 = DictionaryArray::from_data(indices.clone(), Box::new(values)); + let values = PrimitiveArray::from_slice([1u8, 3]).boxed(); + let array9 = DictionaryArray::try_from_keys(indices.clone(), values).unwrap(); - let values = PrimitiveArray::from_slice([1u16, 3]); - let array10 = DictionaryArray::from_data(indices.clone(), Box::new(values)); + let values = PrimitiveArray::from_slice([1u16, 3]).boxed(); + let array10 = DictionaryArray::try_from_keys(indices.clone(), values).unwrap(); - let values = PrimitiveArray::from_slice([1u32, 3]); - let array11 = DictionaryArray::from_data(indices.clone(), Box::new(values)); + let values = PrimitiveArray::from_slice([1u32, 3]).boxed(); + let array11 = DictionaryArray::try_from_keys(indices.clone(), values).unwrap(); - let values = PrimitiveArray::from_slice([1u64, 3]); - let array12 = DictionaryArray::from_data(indices, Box::new(values)); + let values = PrimitiveArray::from_slice([1u64, 3]).boxed(); + let array12 = DictionaryArray::try_from_keys(indices, values).unwrap(); let array13 = PrimitiveArray::::from_slice([1, 2, 3]) .to(DataType::Interval(IntervalUnit::YearMonth)); diff --git a/tests/it/io/parquet/read_indexes.rs b/tests/it/io/parquet/read_indexes.rs index 92b647acf1d..ee91ba846a5 100644 --- a/tests/it/io/parquet/read_indexes.rs +++ b/tests/it/io/parquet/read_indexes.rs @@ -208,12 +208,12 @@ fn indexed_optional_boolean() -> Result<()> { #[test] fn indexed_dict() -> Result<()> { let indices = PrimitiveArray::from_values((0..6u64).map(|x| x % 2)); - let values = PrimitiveArray::from_slice([4i32, 6i32]); - let array = DictionaryArray::from_data(indices, Box::new(values)); + let values = PrimitiveArray::from_slice([4i32, 6i32]).boxed(); + let array = DictionaryArray::try_from_keys(indices, values).unwrap(); let indices = PrimitiveArray::from_slice(&[0u64]); - let values = PrimitiveArray::from_slice([4i32, 6i32]); - let expected = DictionaryArray::from_data(indices, Box::new(values)); + let values = PrimitiveArray::from_slice([4i32, 6i32]).boxed(); + let expected = DictionaryArray::try_from_keys(indices, values).unwrap(); let expected = expected.boxed(); diff --git a/tests/it/io/print.rs b/tests/it/io/print.rs index a2f4eb5c9b7..6e37836d6b8 100644 --- a/tests/it/io/print.rs +++ b/tests/it/io/print.rs @@ -97,7 +97,7 @@ fn write_dictionary() -> Result<()> { fn dictionary_validities() -> Result<()> { let keys = PrimitiveArray::::from([Some(1), None, Some(0)]); let values = PrimitiveArray::::from([None, Some(10)]); - let array = DictionaryArray::::from_data(keys, Box::new(values)); + let array = DictionaryArray::try_from_keys(keys, Box::new(values)).unwrap(); let columns = Chunk::new(vec![&array as &dyn Array]);