Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 21 additions & 20 deletions encodings/dict/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
use std::fmt::Debug;

use arrow_buffer::BooleanBuffer;
use vortex_array::compute::{cast, take};
use vortex_array::compute::take;
use vortex_array::stats::{ArrayStats, StatsSetRef};
use vortex_array::vtable::{ArrayVTable, CanonicalVTable, NotSupported, VTable, ValidityVTable};
use vortex_array::{
Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable,
};
use vortex_dtype::{DType, match_each_integer_ptype};
use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_ensure};
use vortex_error::{VortexExpect as _, VortexResult, vortex_bail};
use vortex_mask::{AllOr, Mask};

vtable!(Dict);
Expand Down Expand Up @@ -83,28 +83,11 @@ impl DictArray {
/// of the `values` array. Otherwise, this constructor returns an error.
///
/// It is an error to provide a nullable `codes` with non-nullable `values`.
pub fn try_new(mut codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
pub fn try_new(codes: ArrayRef, values: ArrayRef) -> VortexResult<Self> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add a unit test checking for diff nullability of codes and values?

if !codes.dtype().is_unsigned_int() {
vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype());
}

let dtype = values.dtype();
if dtype.is_nullable() {
// If the values are nullable, we force codes to be nullable as well.
codes = cast(&codes, &codes.dtype().as_nullable())?;
} else {
// If the values are non-nullable, we assert the codes are non-nullable as well.
vortex_ensure!(
!codes.dtype().is_nullable(),
"Cannot have nullable codes for non-nullable dict array"
);
}

vortex_ensure!(
codes.dtype().nullability() == values.dtype().nullability(),
"Mismatched nullability between codes and values"
);

Ok(Self {
codes,
values,
Expand Down Expand Up @@ -289,6 +272,24 @@ mod test {
assert_eq!(indices, [2, 4]);
}

#[test]
fn nullable_codes_and_non_null_values() {
let dict = DictArray::try_new(
PrimitiveArray::new(
buffer![0u32, 1, 2, 2, 1],
Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])),
)
.into_array(),
PrimitiveArray::new(buffer![3, 6, 9], Validity::NonNullable).into_array(),
)
.unwrap();
let mask = dict.validity_mask();
let AllOr::Some(indices) = mask.indices() else {
vortex_panic!("Expected indices from mask")
};
assert_eq!(indices, [0, 2, 4]);
}

fn make_dict_primitive_chunks<T: NativePType, U: NativePType>(
len: usize,
unique_values: usize,
Expand Down
15 changes: 13 additions & 2 deletions encodings/dict/src/serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use vortex_array::{
Array, ArrayBufferVisitor, ArrayChildVisitor, Canonical, DeserializeMetadata, ProstMetadata,
};
use vortex_buffer::ByteBuffer;
use vortex_dtype::{DType, PType};
use vortex_dtype::{DType, Nullability, PType};
use vortex_error::{VortexResult, vortex_bail, vortex_err};

use crate::builders::dict_encode;
Expand All @@ -19,6 +19,9 @@ pub struct DictMetadata {
values_len: u32,
#[prost(enumeration = "PType", tag = "2")]
codes_ptype: i32,
// nullable codes are optional since they were added after stabilisation
#[prost(optional, bool, tag = "3")]
is_nullable_codes: Option<bool>,
}

impl SerdeVTable<DictVTable> for DictVTable {
Expand All @@ -33,6 +36,7 @@ impl SerdeVTable<DictVTable> for DictVTable {
array.values().len()
)
})?,
is_nullable_codes: Some(array.codes().dtype().is_nullable()),
})))
}

Expand All @@ -50,7 +54,13 @@ impl SerdeVTable<DictVTable> for DictVTable {
children.len()
)
}
let codes_dtype = DType::Primitive(metadata.codes_ptype(), dtype.nullability());
let codes_nullable: Nullability = metadata
.is_nullable_codes
// The old behaviour of (without `is_nullable_codes` metadata) used the nullability
// of the values (and whole array).
.unwrap_or_else(|| dtype.is_nullable())
.into();
let codes_dtype = DType::Primitive(metadata.codes_ptype(), codes_nullable);
let codes = children.get(0, &codes_dtype, len)?;
let values = children.get(1, dtype, metadata.values_len as usize)?;

Expand Down Expand Up @@ -93,6 +103,7 @@ mod test {
ProstMetadata(DictMetadata {
codes_ptype: PType::U64 as i32,
values_len: u32::MAX,
is_nullable_codes: None,
}),
);
}
Expand Down
Loading