Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions arrow-array/src/array/fixed_size_list_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,8 @@ impl From<ArrayData> for FixedSizeListArray {
fn from(data: ArrayData) -> Self {
let value_length = match data.data_type() {
DataType::FixedSizeList(_, len) => *len,
_ => {
panic!("FixedSizeListArray data should contain a FixedSizeList data type")
data_type => {
panic!("FixedSizeListArray data should contain a FixedSizeList data type, got {data_type:?}")
}
};

Expand Down
137 changes: 134 additions & 3 deletions arrow-row/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ use arrow_schema::*;
use variable::{decode_binary_view, decode_string_view};

use crate::fixed::{decode_bool, decode_fixed_size_binary, decode_primitive};
use crate::list::{compute_lengths_fixed_size_list, encode_fixed_size_list};
use crate::variable::{decode_binary, decode_string};

mod fixed;
Expand Down Expand Up @@ -409,6 +410,11 @@ impl Codec {
let converter = RowConverter::new(vec![field])?;
Ok(Self::List(converter))
}
DataType::FixedSizeList(f, _) => {
let field = SortField::new_with_options(f.data_type().clone(), sort_field.options);
let converter = RowConverter::new(vec![field])?;
Ok(Self::List(converter))
}
DataType::Struct(f) => {
let sort_fields = f
.iter()
Expand Down Expand Up @@ -450,6 +456,7 @@ impl Codec {
let values = match array.data_type() {
DataType::List(_) => as_list_array(array).values(),
DataType::LargeList(_) => as_large_list_array(array).values(),
DataType::FixedSizeList(_, _) => as_fixed_size_list_array(array).values(),
_ => unreachable!(),
};
let rows = converter.convert_columns(&[values.clone()])?;
Expand Down Expand Up @@ -536,9 +543,10 @@ impl RowConverter {
fn supports_datatype(d: &DataType) -> bool {
match d {
_ if !d.is_nested() => true,
DataType::List(f) | DataType::LargeList(f) | DataType::Map(f, _) => {
Self::supports_datatype(f.data_type())
}
DataType::List(f)
| DataType::LargeList(f)
| DataType::FixedSizeList(f, _)
| DataType::Map(f, _) => Self::supports_datatype(f.data_type()),
DataType::Struct(f) => f.iter().all(|x| Self::supports_datatype(x.data_type())),
_ => false,
}
Expand Down Expand Up @@ -1244,6 +1252,11 @@ fn row_lengths(cols: &[ArrayRef], encoders: &[Encoder]) -> Vec<usize> {
DataType::LargeList(_) => {
list::compute_lengths(&mut lengths, rows, as_large_list_array(array))
}
DataType::FixedSizeList(_, _) => compute_lengths_fixed_size_list(
&mut lengths,
rows,
as_fixed_size_list_array(array),
),
_ => unreachable!(),
},
}
Expand Down Expand Up @@ -1340,6 +1353,9 @@ fn encode_column(
DataType::LargeList(_) => {
list::encode(data, offsets, rows, opts, as_large_list_array(column))
}
DataType::FixedSizeList(_, _) => {
encode_fixed_size_list(data, offsets, rows, opts, as_fixed_size_list_array(column))
}
_ => unreachable!(),
},
}
Expand Down Expand Up @@ -1425,6 +1441,13 @@ unsafe fn decode_column(
DataType::LargeList(_) => {
Arc::new(list::decode::<i64>(converter, rows, field, validate_utf8)?)
}
DataType::FixedSizeList(_, value_length) => Arc::new(list::decode_fixed_size_list(
converter,
rows,
field,
validate_utf8,
value_length.as_usize(),
)?),
_ => unreachable!(),
},
};
Expand Down Expand Up @@ -2190,6 +2213,114 @@ mod tests {
test_nested_list::<i64>();
}

#[test]
fn test_fixed_size_list() {
let mut builder = FixedSizeListBuilder::new(Int32Builder::new(), 3);
builder.values().append_value(32);
builder.values().append_value(52);
builder.values().append_value(32);
builder.append(true);
builder.values().append_value(32);
builder.values().append_value(52);
builder.values().append_value(12);
builder.append(true);
builder.values().append_value(32);
builder.values().append_value(52);
builder.values().append_null();
builder.append(true);
builder.values().append_value(32); // MASKED
builder.values().append_value(52); // MASKED
builder.values().append_value(13); // MASKED
builder.append(false);
builder.values().append_value(32);
builder.values().append_null();
builder.values().append_null();
builder.append(true);
builder.values().append_null();
builder.values().append_null();
builder.values().append_null();
builder.append(true);
builder.values().append_value(17); // MASKED
builder.values().append_null(); // MASKED
builder.values().append_value(77); // MASKED
builder.append(false);

let list = Arc::new(builder.finish()) as ArrayRef;
let d = list.data_type().clone();

// Default sorting (ascending, nulls first)
let converter = RowConverter::new(vec![SortField::new(d.clone())]).unwrap();

let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap();
assert!(rows.row(0) > rows.row(1)); // [32, 52, 32] > [32, 52, 12]
assert!(rows.row(2) < rows.row(1)); // [32, 52, null] < [32, 52, 12]
assert!(rows.row(3) < rows.row(2)); // null < [32, 52, null]
assert!(rows.row(4) < rows.row(2)); // [32, null, null] < [32, 52, null]
assert!(rows.row(5) < rows.row(2)); // [null, null, null] < [32, 52, null]
assert!(rows.row(3) < rows.row(5)); // null < [null, null, null]
assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values)

let back = converter.convert_rows(&rows).unwrap();
assert_eq!(back.len(), 1);
back[0].to_data().validate_full().unwrap();
assert_eq!(&back[0], &list);

// Ascending, null last
let options = SortOptions::default().asc().with_nulls_first(false);
let field = SortField::new_with_options(d.clone(), options);
let converter = RowConverter::new(vec![field]).unwrap();
let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap();
assert!(rows.row(0) > rows.row(1)); // [32, 52, 32] > [32, 52, 12]
assert!(rows.row(2) > rows.row(1)); // [32, 52, null] > [32, 52, 12]
assert!(rows.row(3) > rows.row(2)); // null > [32, 52, null]
assert!(rows.row(4) > rows.row(2)); // [32, null, null] > [32, 52, null]
assert!(rows.row(5) > rows.row(2)); // [null, null, null] > [32, 52, null]
assert!(rows.row(3) > rows.row(5)); // null > [null, null, null]
assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values)

let back = converter.convert_rows(&rows).unwrap();
assert_eq!(back.len(), 1);
back[0].to_data().validate_full().unwrap();
assert_eq!(&back[0], &list);

// Descending, nulls last
let options = SortOptions::default().desc().with_nulls_first(false);
let field = SortField::new_with_options(d.clone(), options);
let converter = RowConverter::new(vec![field]).unwrap();
let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap();
assert!(rows.row(0) < rows.row(1)); // [32, 52, 32] < [32, 52, 12]
assert!(rows.row(2) > rows.row(1)); // [32, 52, null] > [32, 52, 12]
assert!(rows.row(3) > rows.row(2)); // null > [32, 52, null]
assert!(rows.row(4) > rows.row(2)); // [32, null, null] > [32, 52, null]
assert!(rows.row(5) > rows.row(2)); // [null, null, null] > [32, 52, null]
assert!(rows.row(3) > rows.row(5)); // null > [null, null, null]
assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values)

let back = converter.convert_rows(&rows).unwrap();
assert_eq!(back.len(), 1);
back[0].to_data().validate_full().unwrap();
assert_eq!(&back[0], &list);

// Descending, nulls first
let options = SortOptions::default().desc().with_nulls_first(true);
let field = SortField::new_with_options(d, options);
let converter = RowConverter::new(vec![field]).unwrap();
let rows = converter.convert_columns(&[Arc::clone(&list)]).unwrap();

assert!(rows.row(0) < rows.row(1)); // [32, 52, 32] < [32, 52, 12]
assert!(rows.row(2) < rows.row(1)); // [32, 52, null] > [32, 52, 12]
assert!(rows.row(3) < rows.row(2)); // null < [32, 52, null]
assert!(rows.row(4) < rows.row(2)); // [32, null, null] < [32, 52, null]
assert!(rows.row(5) < rows.row(2)); // [null, null, null] > [32, 52, null]
assert!(rows.row(3) < rows.row(5)); // null < [null, null, null]
assert_eq!(rows.row(3), rows.row(6)); // null = null (different masked values)

let back = converter.convert_rows(&rows).unwrap();
assert_eq!(back.len(), 1);
back[0].to_data().validate_full().unwrap();
assert_eq!(&back[0], &list);
}

fn generate_primitive_array<K>(len: usize, valid_percent: f64) -> PrimitiveArray<K>
where
K: ArrowPrimitiveType,
Expand Down
130 changes: 125 additions & 5 deletions arrow-row/src/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
// specific language governing permissions and limitations
// under the License.

use crate::{null_sentinel, RowConverter, Rows, SortField};
use arrow_array::{Array, GenericListArray, OffsetSizeTrait};
use arrow_buffer::{Buffer, MutableBuffer};
use crate::{fixed, null_sentinel, RowConverter, Rows, SortField};
use arrow_array::{new_null_array, Array, FixedSizeListArray, GenericListArray, OffsetSizeTrait};
use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer};
use arrow_data::ArrayDataBuilder;
use arrow_schema::{ArrowError, SortOptions};
use arrow_schema::{ArrowError, DataType, SortOptions};
use std::ops::Range;

pub fn compute_lengths<O: OffsetSizeTrait>(
Expand Down Expand Up @@ -97,7 +97,7 @@ fn encode_one(
}
}

/// Decodes a string array from `rows` with the provided `options`
/// Decodes an array from `rows` with the provided `options`
///
/// # Safety
///
Expand Down Expand Up @@ -184,3 +184,123 @@ pub unsafe fn decode<O: OffsetSizeTrait>(

Ok(GenericListArray::from(unsafe { builder.build_unchecked() }))
}

pub fn compute_lengths_fixed_size_list(
lengths: &mut [usize],
rows: &Rows,
array: &FixedSizeListArray,
) {
let value_length = array.value_length().as_usize();
lengths.iter_mut().enumerate().for_each(|(idx, length)| {
*length = match array.is_valid(idx) {
true => {
1 + ((idx * value_length)..(idx + 1) * value_length)
.map(|child_idx| rows.row(child_idx).as_ref().len())
.sum::<usize>()
}
false => 1,
};
})
}

/// Encodes the provided `FixedSizeListArray` to `out` with the provided `SortOptions`
///
/// `rows` should contain the encoded child elements
pub fn encode_fixed_size_list(
data: &mut [u8],
offsets: &mut [usize],
rows: &Rows,
opts: SortOptions,
array: &FixedSizeListArray,
) {
let null_sentinel = null_sentinel(opts);
offsets
.iter_mut()
.skip(1)
.enumerate()
.for_each(|(idx, offset)| {
let value_length = array.value_length().as_usize();
match array.is_valid(idx) {
true => {
data[*offset] = 0x01;
*offset += 1;
for child_idx in (idx * value_length)..(idx + 1) * value_length {
//dbg!(child_idx);
let row = rows.row(child_idx);
let end_offset = *offset + row.as_ref().len();
data[*offset..end_offset].copy_from_slice(row.as_ref());
*offset = end_offset;
}
}
false => {
let null_sentinels = 1;
//+ value_length; // 1 for self + for values too
for i in 0..null_sentinels {
data[*offset + i] = null_sentinel;
}
*offset += null_sentinels;
}
};
})
}

/// Decodes a fixed size list array from `rows` with the provided `options`
///
/// # Safety
///
/// `rows` must contain valid data for the provided `converter`
pub unsafe fn decode_fixed_size_list(
converter: &RowConverter,
rows: &mut [&[u8]],
field: &SortField,
validate_utf8: bool,
value_length: usize,
) -> Result<FixedSizeListArray, ArrowError> {
let list_type = &field.data_type;
let element_type = match list_type {
DataType::FixedSizeList(element_field, _) => element_field.data_type(),
_ => {
return Err(ArrowError::InvalidArgumentError(format!(
"Expected FixedSizeListArray, found: {:?}",
list_type
)))
}
};

let len = rows.len();
let (null_count, nulls) = fixed::decode_nulls(rows);

let null_element_encoded = converter.convert_columns(&[new_null_array(element_type, 1)])?;
let null_element_encoded = null_element_encoded.row(0);
let null_element_slice = null_element_encoded.as_ref();

let mut child_rows = Vec::new();
for row in rows {
let valid = row[0] == 1;
let mut row_offset = 1;
if !valid {
for _ in 0..value_length {
child_rows.push(null_element_slice);
}
} else {
for _ in 0..value_length {
let mut temp_child_rows = vec![&row[row_offset..]];
converter.convert_raw(&mut temp_child_rows, validate_utf8)?;
let decoded_bytes = row.len() - row_offset - temp_child_rows[0].len();
let next_offset = row_offset + decoded_bytes;
child_rows.push(&row[row_offset..next_offset]);
row_offset = next_offset;
}
}
}

let children = converter.convert_raw(&mut child_rows, validate_utf8)?;
let child_data = children.iter().map(|c| c.to_data()).collect();
let builder = ArrayDataBuilder::new(list_type.clone())
.len(len)
.null_count(null_count)
.null_bit_buffer(Some(nulls))
.child_data(child_data);

Ok(FixedSizeListArray::from(builder.build_unchecked()))
}
Loading