diff --git a/encodings/dict/src/array.rs b/encodings/dict/src/array.rs index 193c8bfc539..718216bb8ee 100644 --- a/encodings/dict/src/array.rs +++ b/encodings/dict/src/array.rs @@ -4,14 +4,14 @@ use std::fmt::Debug; use arrow_buffer::BooleanBuffer; -use vortex_array::compute::{cast, take}; +use vortex_array::compute::take; use vortex_array::stats::{ArrayStats, StatsSetRef}; use vortex_array::vtable::{ArrayVTable, CanonicalVTable, NotSupported, VTable, ValidityVTable}; use vortex_array::{ Array, ArrayRef, Canonical, EncodingId, EncodingRef, IntoArray, ToCanonical, vtable, }; use vortex_dtype::{DType, match_each_integer_ptype}; -use vortex_error::{VortexExpect as _, VortexResult, vortex_bail, vortex_ensure}; +use vortex_error::{VortexExpect as _, VortexResult, vortex_bail}; use vortex_mask::{AllOr, Mask}; vtable!(Dict); @@ -83,28 +83,11 @@ impl DictArray { /// of the `values` array. Otherwise, this constructor returns an error. /// /// It is an error to provide a nullable `codes` with non-nullable `values`. - pub fn try_new(mut codes: ArrayRef, values: ArrayRef) -> VortexResult { + pub fn try_new(codes: ArrayRef, values: ArrayRef) -> VortexResult { if !codes.dtype().is_unsigned_int() { vortex_bail!(MismatchedTypes: "unsigned int", codes.dtype()); } - let dtype = values.dtype(); - if dtype.is_nullable() { - // If the values are nullable, we force codes to be nullable as well. - codes = cast(&codes, &codes.dtype().as_nullable())?; - } else { - // If the values are non-nullable, we assert the codes are non-nullable as well. - vortex_ensure!( - !codes.dtype().is_nullable(), - "Cannot have nullable codes for non-nullable dict array" - ); - } - - vortex_ensure!( - codes.dtype().nullability() == values.dtype().nullability(), - "Mismatched nullability between codes and values" - ); - Ok(Self { codes, values, @@ -289,6 +272,24 @@ mod test { assert_eq!(indices, [2, 4]); } + #[test] + fn nullable_codes_and_non_null_values() { + let dict = DictArray::try_new( + PrimitiveArray::new( + buffer![0u32, 1, 2, 2, 1], + Validity::from(BooleanBuffer::from(vec![true, false, true, false, true])), + ) + .into_array(), + PrimitiveArray::new(buffer![3, 6, 9], Validity::NonNullable).into_array(), + ) + .unwrap(); + let mask = dict.validity_mask(); + let AllOr::Some(indices) = mask.indices() else { + vortex_panic!("Expected indices from mask") + }; + assert_eq!(indices, [0, 2, 4]); + } + fn make_dict_primitive_chunks( len: usize, unique_values: usize, diff --git a/encodings/dict/src/serde.rs b/encodings/dict/src/serde.rs index fb7e05bbefd..eb33131fcfa 100644 --- a/encodings/dict/src/serde.rs +++ b/encodings/dict/src/serde.rs @@ -7,7 +7,7 @@ use vortex_array::{ Array, ArrayBufferVisitor, ArrayChildVisitor, Canonical, DeserializeMetadata, ProstMetadata, }; use vortex_buffer::ByteBuffer; -use vortex_dtype::{DType, PType}; +use vortex_dtype::{DType, Nullability, PType}; use vortex_error::{VortexResult, vortex_bail, vortex_err}; use crate::builders::dict_encode; @@ -19,6 +19,9 @@ pub struct DictMetadata { values_len: u32, #[prost(enumeration = "PType", tag = "2")] codes_ptype: i32, + // nullable codes are optional since they were added after stabilisation + #[prost(optional, bool, tag = "3")] + is_nullable_codes: Option, } impl SerdeVTable for DictVTable { @@ -33,6 +36,7 @@ impl SerdeVTable for DictVTable { array.values().len() ) })?, + is_nullable_codes: Some(array.codes().dtype().is_nullable()), }))) } @@ -50,7 +54,13 @@ impl SerdeVTable for DictVTable { children.len() ) } - let codes_dtype = DType::Primitive(metadata.codes_ptype(), dtype.nullability()); + let codes_nullable: Nullability = metadata + .is_nullable_codes + // The old behaviour of (without `is_nullable_codes` metadata) used the nullability + // of the values (and whole array). + .unwrap_or_else(|| dtype.is_nullable()) + .into(); + let codes_dtype = DType::Primitive(metadata.codes_ptype(), codes_nullable); let codes = children.get(0, &codes_dtype, len)?; let values = children.get(1, dtype, metadata.values_len as usize)?; @@ -93,6 +103,7 @@ mod test { ProstMetadata(DictMetadata { codes_ptype: PType::U64 as i32, values_len: u32::MAX, + is_nullable_codes: None, }), ); }