diff --git a/arrow-array/src/array/fixed_size_list_array.rs b/arrow-array/src/array/fixed_size_list_array.rs index 44be442c9f85..af814cc61414 100644 --- a/arrow-array/src/array/fixed_size_list_array.rs +++ b/arrow-array/src/array/fixed_size_list_array.rs @@ -343,8 +343,8 @@ impl From for FixedSizeListArray { fn from(data: ArrayData) -> Self { let value_length = match data.data_type() { DataType::FixedSizeList(_, len) => *len, - _ => { - panic!("FixedSizeListArray data should contain a FixedSizeList data type") + data_type => { + panic!("FixedSizeListArray data should contain a FixedSizeList data type, got {data_type:?}") } }; diff --git a/arrow-row/src/lib.rs b/arrow-row/src/lib.rs index 7f8d2cd97cbe..81320420dbe5 100644 --- a/arrow-row/src/lib.rs +++ b/arrow-row/src/lib.rs @@ -144,6 +144,7 @@ use arrow_schema::*; use variable::{decode_binary_view, decode_string_view}; use crate::fixed::{decode_bool, decode_fixed_size_binary, decode_primitive}; +use crate::list::{compute_lengths_fixed_size_list, encode_fixed_size_list}; use crate::variable::{decode_binary, decode_string}; use arrow_array::types::{Int16Type, Int32Type, Int64Type}; @@ -346,6 +347,46 @@ mod variable; /// /// With `[]` represented by an empty byte array, and `null` a null byte array. /// +/// ## Fixed Size List Encoding +/// +/// Fixed Size Lists are encoded by first encoding all child elements to the row format. +/// +/// A non-null list value is then encoded as 0x01 followed by the concatenation of each +/// of the child elements. A null list value is encoded as a null marker. +/// +/// For example given: +/// +/// ```text +/// [1_u8, 2_u8] +/// [3_u8, null] +/// null +/// ``` +/// +/// The elements would be converted to: +/// +/// ```text +/// ┌──┬──┐ ┌──┬──┐ ┌──┬──┐ ┌──┬──┐ +/// 1 │01│01│ 2 │01│02│ 3 │01│03│ null │00│00│ +/// └──┴──┘ └──┴──┘ └──┴──┘ └──┴──┘ +///``` +/// +/// Which would be encoded as +/// +/// ```text +/// ┌──┬──┬──┬──┬──┐ +/// [1_u8, 2_u8] │01│01│01│01│02│ +/// └──┴──┴──┴──┴──┘ +/// └ 1 ┘ └ 2 ┘ +/// ┌──┬──┬──┬──┬──┐ +/// [3_u8, null] │01│01│03│00│00│ +/// └──┴──┴──┴──┴──┘ +/// └ 1 ┘ └null┘ +/// ┌──┐ +/// null │00│ +/// └──┘ +/// +///``` +/// /// # Ordering /// /// ## Float Ordering @@ -433,6 +474,11 @@ impl Codec { let converter = RowConverter::new(vec![field])?; Ok(Self::List(converter)) } + DataType::FixedSizeList(f, _) => { + let field = SortField::new_with_options(f.data_type().clone(), sort_field.options); + let converter = RowConverter::new(vec![field])?; + Ok(Self::List(converter)) + } DataType::Struct(f) => { let sort_fields = f .iter() @@ -474,6 +520,7 @@ impl Codec { let values = match array.data_type() { DataType::List(_) => as_list_array(array).values(), DataType::LargeList(_) => as_large_list_array(array).values(), + DataType::FixedSizeList(_, _) => as_fixed_size_list_array(array).values(), _ => unreachable!(), }; let rows = converter.convert_columns(&[values.clone()])?; @@ -576,9 +623,10 @@ impl RowConverter { fn supports_datatype(d: &DataType) -> bool { match d { _ if !d.is_nested() => true, - DataType::List(f) | DataType::LargeList(f) | DataType::Map(f, _) => { - Self::supports_datatype(f.data_type()) - } + DataType::List(f) + | DataType::LargeList(f) + | DataType::FixedSizeList(f, _) + | DataType::Map(f, _) => Self::supports_datatype(f.data_type()), DataType::Struct(f) => f.iter().all(|x| Self::supports_datatype(x.data_type())), DataType::RunEndEncoded(_, values) => Self::supports_datatype(values.data_type()), _ => false, @@ -1365,6 +1413,11 @@ fn row_lengths(cols: &[ArrayRef], encoders: &[Encoder]) -> LengthTracker { DataType::LargeList(_) => { list::compute_lengths(tracker.materialized(), rows, as_large_list_array(array)) } + DataType::FixedSizeList(_, _) => compute_lengths_fixed_size_list( + &mut tracker, + rows, + as_fixed_size_list_array(array), + ), _ => unreachable!(), }, Encoder::RunEndEncoded(rows) => match array.data_type() { @@ -1482,6 +1535,9 @@ fn encode_column( DataType::LargeList(_) => { list::encode(data, offsets, rows, opts, as_large_list_array(column)) } + DataType::FixedSizeList(_, _) => { + encode_fixed_size_list(data, offsets, rows, opts, as_fixed_size_list_array(column)) + } _ => unreachable!(), }, Encoder::RunEndEncoded(rows) => match column.data_type() { @@ -1582,6 +1638,13 @@ unsafe fn decode_column( DataType::LargeList(_) => { Arc::new(list::decode::(converter, rows, field, validate_utf8)?) } + DataType::FixedSizeList(_, value_length) => Arc::new(list::decode_fixed_size_list( + converter, + rows, + field, + validate_utf8, + value_length.as_usize(), + )?), _ => unreachable!(), }, Codec::RunEndEncoded(converter) => match &field.data_type { @@ -2197,6 +2260,9 @@ mod tests { builder.values().append_null(); builder.append(true); builder.append(true); + builder.values().append_value(17); // MASKED + builder.values().append_null(); // MASKED + builder.append(false); let list = Arc::new(builder.finish()) as ArrayRef; let d = list.data_type().clone(); @@ -2205,11 +2271,12 @@ mod tests { let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap(); assert!(rows.row(0) > rows.row(1)); // [32, 52, 32] > [32, 52, 12] - assert!(rows.row(2) < rows.row(1)); // [32, 42] < [32, 52, 12] - assert!(rows.row(3) < rows.row(2)); // null < [32, 42] - assert!(rows.row(4) < rows.row(2)); // [32, null] < [32, 42] - assert!(rows.row(5) < rows.row(2)); // [] < [32, 42] + assert!(rows.row(2) < rows.row(1)); // [32, 52] < [32, 52, 12] + assert!(rows.row(3) < rows.row(2)); // null < [32, 52] + assert!(rows.row(4) < rows.row(2)); // [32, null] < [32, 52] + assert!(rows.row(5) < rows.row(2)); // [] < [32, 52] assert!(rows.row(3) < rows.row(5)); // null < [] + assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values) let back = converter.convert_rows(&rows).unwrap(); assert_eq!(back.len(), 1); @@ -2222,11 +2289,12 @@ mod tests { let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap(); assert!(rows.row(0) > rows.row(1)); // [32, 52, 32] > [32, 52, 12] - assert!(rows.row(2) < rows.row(1)); // [32, 42] < [32, 52, 12] - assert!(rows.row(3) > rows.row(2)); // null > [32, 42] - assert!(rows.row(4) > rows.row(2)); // [32, null] > [32, 42] - assert!(rows.row(5) < rows.row(2)); // [] < [32, 42] + assert!(rows.row(2) < rows.row(1)); // [32, 52] < [32, 52, 12] + assert!(rows.row(3) > rows.row(2)); // null > [32, 52] + assert!(rows.row(4) > rows.row(2)); // [32, null] > [32, 52] + assert!(rows.row(5) < rows.row(2)); // [] < [32, 52] assert!(rows.row(3) > rows.row(5)); // null > [] + assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values) let back = converter.convert_rows(&rows).unwrap(); assert_eq!(back.len(), 1); @@ -2239,11 +2307,12 @@ mod tests { let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap(); assert!(rows.row(0) < rows.row(1)); // [32, 52, 32] < [32, 52, 12] - assert!(rows.row(2) > rows.row(1)); // [32, 42] > [32, 52, 12] - assert!(rows.row(3) > rows.row(2)); // null > [32, 42] - assert!(rows.row(4) > rows.row(2)); // [32, null] > [32, 42] - assert!(rows.row(5) > rows.row(2)); // [] > [32, 42] + assert!(rows.row(2) > rows.row(1)); // [32, 52] > [32, 52, 12] + assert!(rows.row(3) > rows.row(2)); // null > [32, 52] + assert!(rows.row(4) > rows.row(2)); // [32, null] > [32, 52] + assert!(rows.row(5) > rows.row(2)); // [] > [32, 52] assert!(rows.row(3) > rows.row(5)); // null > [] + assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values) let back = converter.convert_rows(&rows).unwrap(); assert_eq!(back.len(), 1); @@ -2256,11 +2325,12 @@ mod tests { let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap(); assert!(rows.row(0) < rows.row(1)); // [32, 52, 32] < [32, 52, 12] - assert!(rows.row(2) > rows.row(1)); // [32, 42] > [32, 52, 12] - assert!(rows.row(3) < rows.row(2)); // null < [32, 42] - assert!(rows.row(4) < rows.row(2)); // [32, null] < [32, 42] - assert!(rows.row(5) > rows.row(2)); // [] > [32, 42] + assert!(rows.row(2) > rows.row(1)); // [32, 52] > [32, 52, 12] + assert!(rows.row(3) < rows.row(2)); // null < [32, 52] + assert!(rows.row(4) < rows.row(2)); // [32, null] < [32, 52] + assert!(rows.row(5) > rows.row(2)); // [] > [32, 52] assert!(rows.row(3) < rows.row(5)); // null < [] + assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values) let back = converter.convert_rows(&rows).unwrap(); assert_eq!(back.len(), 1); @@ -2371,6 +2441,114 @@ mod tests { test_nested_list::(); } + #[test] + fn test_fixed_size_list() { + let mut builder = FixedSizeListBuilder::new(Int32Builder::new(), 3); + builder.values().append_value(32); + builder.values().append_value(52); + builder.values().append_value(32); + builder.append(true); + builder.values().append_value(32); + builder.values().append_value(52); + builder.values().append_value(12); + builder.append(true); + builder.values().append_value(32); + builder.values().append_value(52); + builder.values().append_null(); + builder.append(true); + builder.values().append_value(32); // MASKED + builder.values().append_value(52); // MASKED + builder.values().append_value(13); // MASKED + builder.append(false); + builder.values().append_value(32); + builder.values().append_null(); + builder.values().append_null(); + builder.append(true); + builder.values().append_null(); + builder.values().append_null(); + builder.values().append_null(); + builder.append(true); + builder.values().append_value(17); // MASKED + builder.values().append_null(); // MASKED + builder.values().append_value(77); // MASKED + builder.append(false); + + let list = Arc::new(builder.finish()) as ArrayRef; + let d = list.data_type().clone(); + + // Default sorting (ascending, nulls first) + let converter = RowConverter::new(vec![SortField::new(d.clone())]).unwrap(); + + let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap(); + assert!(rows.row(0) > rows.row(1)); // [32, 52, 32] > [32, 52, 12] + assert!(rows.row(2) < rows.row(1)); // [32, 52, null] < [32, 52, 12] + assert!(rows.row(3) < rows.row(2)); // null < [32, 52, null] + assert!(rows.row(4) < rows.row(2)); // [32, null, null] < [32, 52, null] + assert!(rows.row(5) < rows.row(2)); // [null, null, null] < [32, 52, null] + assert!(rows.row(3) < rows.row(5)); // null < [null, null, null] + assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values) + + let back = converter.convert_rows(&rows).unwrap(); + assert_eq!(back.len(), 1); + back[0].to_data().validate_full().unwrap(); + assert_eq!(&back[0], &list); + + // Ascending, null last + let options = SortOptions::default().asc().with_nulls_first(false); + let field = SortField::new_with_options(d.clone(), options); + let converter = RowConverter::new(vec![field]).unwrap(); + let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap(); + assert!(rows.row(0) > rows.row(1)); // [32, 52, 32] > [32, 52, 12] + assert!(rows.row(2) > rows.row(1)); // [32, 52, null] > [32, 52, 12] + assert!(rows.row(3) > rows.row(2)); // null > [32, 52, null] + assert!(rows.row(4) > rows.row(2)); // [32, null, null] > [32, 52, null] + assert!(rows.row(5) > rows.row(2)); // [null, null, null] > [32, 52, null] + assert!(rows.row(3) > rows.row(5)); // null > [null, null, null] + assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values) + + let back = converter.convert_rows(&rows).unwrap(); + assert_eq!(back.len(), 1); + back[0].to_data().validate_full().unwrap(); + assert_eq!(&back[0], &list); + + // Descending, nulls last + let options = SortOptions::default().desc().with_nulls_first(false); + let field = SortField::new_with_options(d.clone(), options); + let converter = RowConverter::new(vec![field]).unwrap(); + let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap(); + assert!(rows.row(0) < rows.row(1)); // [32, 52, 32] < [32, 52, 12] + assert!(rows.row(2) > rows.row(1)); // [32, 52, null] > [32, 52, 12] + assert!(rows.row(3) > rows.row(2)); // null > [32, 52, null] + assert!(rows.row(4) > rows.row(2)); // [32, null, null] > [32, 52, null] + assert!(rows.row(5) > rows.row(2)); // [null, null, null] > [32, 52, null] + assert!(rows.row(3) > rows.row(5)); // null > [null, null, null] + assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values) + + let back = converter.convert_rows(&rows).unwrap(); + assert_eq!(back.len(), 1); + back[0].to_data().validate_full().unwrap(); + assert_eq!(&back[0], &list); + + // Descending, nulls first + let options = SortOptions::default().desc().with_nulls_first(true); + let field = SortField::new_with_options(d, options); + let converter = RowConverter::new(vec![field]).unwrap(); + let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap(); + + assert!(rows.row(0) < rows.row(1)); // [32, 52, 32] < [32, 52, 12] + assert!(rows.row(2) < rows.row(1)); // [32, 52, null] > [32, 52, 12] + assert!(rows.row(3) < rows.row(2)); // null < [32, 52, null] + assert!(rows.row(4) < rows.row(2)); // [32, null, null] < [32, 52, null] + assert!(rows.row(5) < rows.row(2)); // [null, null, null] > [32, 52, null] + assert!(rows.row(3) < rows.row(5)); // null < [null, null, null] + assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values) + + let back = converter.convert_rows(&rows).unwrap(); + assert_eq!(back.len(), 1); + back[0].to_data().validate_full().unwrap(); + assert_eq!(&back[0], &list); + } + fn generate_primitive_array(len: usize, valid_percent: f64) -> PrimitiveArray where K: ArrowPrimitiveType, diff --git a/arrow-row/src/list.rs b/arrow-row/src/list.rs index 46cd0f3d3d81..627214dc9c46 100644 --- a/arrow-row/src/list.rs +++ b/arrow-row/src/list.rs @@ -15,11 +15,11 @@ // specific language governing permissions and limitations // under the License. -use crate::{null_sentinel, RowConverter, Rows, SortField}; -use arrow_array::{Array, GenericListArray, OffsetSizeTrait}; -use arrow_buffer::{Buffer, MutableBuffer}; +use crate::{fixed, null_sentinel, LengthTracker, RowConverter, Rows, SortField}; +use arrow_array::{new_null_array, Array, FixedSizeListArray, GenericListArray, OffsetSizeTrait}; +use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer}; use arrow_data::ArrayDataBuilder; -use arrow_schema::{ArrowError, SortOptions}; +use arrow_schema::{ArrowError, DataType, SortOptions}; use std::ops::Range; pub fn compute_lengths( @@ -97,7 +97,7 @@ fn encode_one( } } -/// Decodes a string array from `rows` with the provided `options` +/// Decodes an array from `rows` with the provided `options` /// /// # Safety /// @@ -184,3 +184,123 @@ pub unsafe fn decode( Ok(GenericListArray::from(unsafe { builder.build_unchecked() })) } + +pub fn compute_lengths_fixed_size_list( + tracker: &mut LengthTracker, + rows: &Rows, + array: &FixedSizeListArray, +) { + let value_length = array.value_length().as_usize(); + tracker.push_variable((0..array.len()).map(|idx| { + match array.is_valid(idx) { + true => { + 1 + ((idx * value_length)..(idx + 1) * value_length) + .map(|child_idx| rows.row(child_idx).as_ref().len()) + .sum::() + } + false => 1, + } + })) +} + +/// Encodes the provided `FixedSizeListArray` to `out` with the provided `SortOptions` +/// +/// `rows` should contain the encoded child elements +pub fn encode_fixed_size_list( + data: &mut [u8], + offsets: &mut [usize], + rows: &Rows, + opts: SortOptions, + array: &FixedSizeListArray, +) { + let null_sentinel = null_sentinel(opts); + offsets + .iter_mut() + .skip(1) + .enumerate() + .for_each(|(idx, offset)| { + let value_length = array.value_length().as_usize(); + match array.is_valid(idx) { + true => { + data[*offset] = 0x01; + *offset += 1; + for child_idx in (idx * value_length)..(idx + 1) * value_length { + //dbg!(child_idx); + let row = rows.row(child_idx); + let end_offset = *offset + row.as_ref().len(); + data[*offset..end_offset].copy_from_slice(row.as_ref()); + *offset = end_offset; + } + } + false => { + let null_sentinels = 1; + //+ value_length; // 1 for self + for values too + for i in 0..null_sentinels { + data[*offset + i] = null_sentinel; + } + *offset += null_sentinels; + } + }; + }) +} + +/// Decodes a fixed size list array from `rows` with the provided `options` +/// +/// # Safety +/// +/// `rows` must contain valid data for the provided `converter` +pub unsafe fn decode_fixed_size_list( + converter: &RowConverter, + rows: &mut [&[u8]], + field: &SortField, + validate_utf8: bool, + value_length: usize, +) -> Result { + let list_type = &field.data_type; + let element_type = match list_type { + DataType::FixedSizeList(element_field, _) => element_field.data_type(), + _ => { + return Err(ArrowError::InvalidArgumentError(format!( + "Expected FixedSizeListArray, found: {:?}", + list_type + ))) + } + }; + + let len = rows.len(); + let (null_count, nulls) = fixed::decode_nulls(rows); + + let null_element_encoded = converter.convert_columns(&[new_null_array(element_type, 1)])?; + let null_element_encoded = null_element_encoded.row(0); + let null_element_slice = null_element_encoded.as_ref(); + + let mut child_rows = Vec::new(); + for row in rows { + let valid = row[0] == 1; + let mut row_offset = 1; + if !valid { + for _ in 0..value_length { + child_rows.push(null_element_slice); + } + } else { + for _ in 0..value_length { + let mut temp_child_rows = vec![&row[row_offset..]]; + converter.convert_raw(&mut temp_child_rows, validate_utf8)?; + let decoded_bytes = row.len() - row_offset - temp_child_rows[0].len(); + let next_offset = row_offset + decoded_bytes; + child_rows.push(&row[row_offset..next_offset]); + row_offset = next_offset; + } + } + } + + let children = converter.convert_raw(&mut child_rows, validate_utf8)?; + let child_data = children.iter().map(|c| c.to_data()).collect(); + let builder = ArrayDataBuilder::new(list_type.clone()) + .len(len) + .null_count(null_count) + .null_bit_buffer(Some(nulls)) + .child_data(child_data); + + Ok(FixedSizeListArray::from(builder.build_unchecked())) +}