Skip to content

Commit

Permalink
fix: include the list item field schema (#2950)
Browse files Browse the repository at this point in the history
closes #2947
  • Loading branch information
westonpace authored Oct 2, 2024
1 parent 6e042ac commit a7c6bac
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 36 deletions.
34 changes: 34 additions & 0 deletions python/python/tests/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,40 @@ def test_list_field_name(tmp_path):
assert round_tripped.schema.field("list_str").type == weird_string_type


def test_field_meta(tmp_path):
schema = pa.schema(
[
pa.field("primitive", pa.int64(), metadata={"foo": "bar"}),
pa.field(
"list",
pa.list_(pa.field("item", pa.int64(), metadata={"list": "yes"})),
metadata={"foo": "baz"},
),
pa.field(
"struct",
pa.struct([pa.field("a", pa.int64(), metadata={"struct": "yes"})]),
metadata={"foo": "qux"},
),
]
)
table = pa.table(
{
"primitive": [1, 2, 3],
"list": [[1, 2], [3, 4], [5, 6]],
"struct": [{"a": 1}, {"a": 2}, {"a": 3}],
},
schema=schema,
)

with LanceFileWriter(str(tmp_path / "foo.lance")) as writer:
writer.write_batch(table)

reader = LanceFileReader(str(tmp_path / "foo.lance"))
round_tripped = reader.read_all().to_table()

assert round_tripped == table


def test_dictionary(tmp_path):
# Basic round trip
dictionary = pa.array(["foo", "bar", "baz"], pa.string())
Expand Down
5 changes: 1 addition & 4 deletions rust/lance-encoding/src/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,6 @@ impl FieldDecoderStrategy for CoreFieldDecoderStrategy {
file_buffers: buffers,
positions_and_sizes: &offsets_column.buffer_offsets_and_sizes,
};
let item_field_name = items_field.name().clone();
let (chain, items_scheduler) = chain.new_child(
/*child_idx=*/ 0,
&field.children[0],
Expand Down Expand Up @@ -793,12 +792,10 @@ impl FieldDecoderStrategy for CoreFieldDecoderStrategy {
} else {
DataType::Int64
};
let items_type = items_field.data_type().clone();
let list_scheduler = Ok(Arc::new(ListFieldScheduler::new(
inner,
items_scheduler,
item_field_name.clone(),
items_type,
items_field.clone(),
offset_type,
null_offset_adjustments,
)) as Arc<dyn FieldScheduler>);
Expand Down
51 changes: 19 additions & 32 deletions rust/lance-encoding/src/encodings/logical/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ impl<'a> SchedulingJob for ListFieldSchedulingJob<'a> {
let next_offsets_decoder = next_offsets.decoders.into_iter().next().unwrap().decoder;

let items_scheduler = self.scheduler.items_scheduler.clone();
let items_type = self.scheduler.items_type.clone();
let items_type = self.scheduler.items_field.data_type().clone();
let io = context.io().clone();
let cache = context.cache().clone();

Expand All @@ -463,10 +463,9 @@ impl<'a> SchedulingJob for ListFieldSchedulingJob<'a> {
item_decoder: None,
rows_drained: 0,
rows_loaded: 0,
item_field_name: self.scheduler.item_field_name.clone(),
items_field: self.scheduler.items_field.clone(),
num_rows,
unloaded: Some(indirect_fut),
items_type: self.scheduler.items_type.clone(),
offset_type: self.scheduler.offset_type.clone(),
data_type: self.scheduler.list_type.clone(),
});
Expand Down Expand Up @@ -500,8 +499,7 @@ impl<'a> SchedulingJob for ListFieldSchedulingJob<'a> {
pub struct ListFieldScheduler {
offsets_scheduler: Arc<dyn FieldScheduler>,
items_scheduler: Arc<dyn FieldScheduler>,
item_field_name: String,
items_type: DataType,
items_field: Arc<Field>,
offset_type: DataType,
list_type: DataType,
offset_page_info: Vec<OffsetPageInfo>,
Expand All @@ -522,26 +520,20 @@ impl ListFieldScheduler {
pub fn new(
offsets_scheduler: Arc<dyn FieldScheduler>,
items_scheduler: Arc<dyn FieldScheduler>,
item_field_name: String,
items_type: DataType,
items_field: Arc<Field>,
// Should be int32 or int64
offset_type: DataType,
offset_page_info: Vec<OffsetPageInfo>,
) -> Self {
let list_type = match &offset_type {
DataType::Int32 => {
DataType::List(Arc::new(Field::new("item", items_type.clone(), true)))
}
DataType::Int64 => {
DataType::LargeList(Arc::new(Field::new("item", items_type.clone(), true)))
}
DataType::Int32 => DataType::List(items_field.clone()),
DataType::Int64 => DataType::LargeList(items_field.clone()),
_ => panic!("Unexpected offset type {}", offset_type),
};
Self {
offsets_scheduler,
items_scheduler,
item_field_name,
items_type,
items_field,
offset_type,
offset_page_info,
list_type,
Expand Down Expand Up @@ -594,8 +586,7 @@ struct ListPageDecoder {
num_rows: u64,
rows_drained: u64,
rows_loaded: u64,
item_field_name: String,
items_type: DataType,
items_field: Arc<Field>,
offset_type: DataType,
data_type: DataType,
}
Expand All @@ -605,8 +596,7 @@ struct ListDecodeTask {
validity: BooleanBuffer,
// Will be None if there are no items (all empty / null lists)
items: Option<Box<dyn DecodeArrayTask>>,
item_field_name: String,
items_type: DataType,
items_field: Arc<Field>,
offset_type: DataType,
}

Expand All @@ -620,15 +610,7 @@ impl DecodeArrayTask for ListDecodeTask {
let wrapped_items = items.decode()?;
Result::Ok(wrapped_items.as_struct().column(0).clone())
})
.unwrap_or_else(|| Ok(new_empty_array(&self.items_type)))?;

// TODO: we default to nullable true here, should probably use the nullability given to
// us from the input schema
let item_field = Arc::new(Field::new(
self.item_field_name,
self.items_type.clone(),
true,
));
.unwrap_or_else(|| Ok(new_empty_array(self.items_field.data_type())))?;

// The offsets are already decoded but they need to be shifted back to 0 and cast
// to the appropriate type
Expand All @@ -651,7 +633,10 @@ impl DecodeArrayTask for ListDecodeTask {
let offsets = OffsetBuffer::new(offsets_i32.values().clone());

Ok(Arc::new(ListArray::try_new(
item_field, offsets, items, validity,
self.items_field.clone(),
offsets,
items,
validity,
)?))
}
DataType::Int64 => {
Expand All @@ -660,7 +645,10 @@ impl DecodeArrayTask for ListDecodeTask {
let offsets = OffsetBuffer::new(offsets_i64.values().clone());

Ok(Arc::new(LargeListArray::try_new(
item_field, offsets, items, validity,
self.items_field.clone(),
offsets,
items,
validity,
)?))
}
_ => panic!("ListDecodeTask with data type that is not i32 or i64"),
Expand Down Expand Up @@ -787,9 +775,8 @@ impl LogicalPageDecoder for ListPageDecoder {
task: Box::new(ListDecodeTask {
offsets,
validity,
item_field_name: self.item_field_name.clone(),
items_field: self.items_field.clone(),
items: item_decode,
items_type: self.items_type.clone(),
offset_type: self.offset_type.clone(),
}) as Box<dyn DecodeArrayTask>,
})
Expand Down

0 comments on commit a7c6bac

Please sign in to comment.