Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: include the list item field schema #2950

Merged
merged 1 commit into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions python/python/tests/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,40 @@ def test_list_field_name(tmp_path):
assert round_tripped.schema.field("list_str").type == weird_string_type


def test_field_meta(tmp_path):
schema = pa.schema(
[
pa.field("primitive", pa.int64(), metadata={"foo": "bar"}),
pa.field(
"list",
pa.list_(pa.field("item", pa.int64(), metadata={"list": "yes"})),
metadata={"foo": "baz"},
),
pa.field(
"struct",
pa.struct([pa.field("a", pa.int64(), metadata={"struct": "yes"})]),
metadata={"foo": "qux"},
),
]
)
table = pa.table(
{
"primitive": [1, 2, 3],
"list": [[1, 2], [3, 4], [5, 6]],
"struct": [{"a": 1}, {"a": 2}, {"a": 3}],
},
schema=schema,
)

with LanceFileWriter(str(tmp_path / "foo.lance")) as writer:
writer.write_batch(table)

reader = LanceFileReader(str(tmp_path / "foo.lance"))
round_tripped = reader.read_all().to_table()

assert round_tripped == table


def test_dictionary(tmp_path):
# Basic round trip
dictionary = pa.array(["foo", "bar", "baz"], pa.string())
Expand Down
5 changes: 1 addition & 4 deletions rust/lance-encoding/src/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,6 @@ impl FieldDecoderStrategy for CoreFieldDecoderStrategy {
file_buffers: buffers,
positions_and_sizes: &offsets_column.buffer_offsets_and_sizes,
};
let item_field_name = items_field.name().clone();
let (chain, items_scheduler) = chain.new_child(
/*child_idx=*/ 0,
&field.children[0],
Expand Down Expand Up @@ -793,12 +792,10 @@ impl FieldDecoderStrategy for CoreFieldDecoderStrategy {
} else {
DataType::Int64
};
let items_type = items_field.data_type().clone();
let list_scheduler = Ok(Arc::new(ListFieldScheduler::new(
inner,
items_scheduler,
item_field_name.clone(),
items_type,
items_field.clone(),
offset_type,
null_offset_adjustments,
)) as Arc<dyn FieldScheduler>);
Expand Down
51 changes: 19 additions & 32 deletions rust/lance-encoding/src/encodings/logical/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ impl<'a> SchedulingJob for ListFieldSchedulingJob<'a> {
let next_offsets_decoder = next_offsets.decoders.into_iter().next().unwrap().decoder;

let items_scheduler = self.scheduler.items_scheduler.clone();
let items_type = self.scheduler.items_type.clone();
let items_type = self.scheduler.items_field.data_type().clone();
let io = context.io().clone();
let cache = context.cache().clone();

Expand All @@ -463,10 +463,9 @@ impl<'a> SchedulingJob for ListFieldSchedulingJob<'a> {
item_decoder: None,
rows_drained: 0,
rows_loaded: 0,
item_field_name: self.scheduler.item_field_name.clone(),
items_field: self.scheduler.items_field.clone(),
num_rows,
unloaded: Some(indirect_fut),
items_type: self.scheduler.items_type.clone(),
offset_type: self.scheduler.offset_type.clone(),
data_type: self.scheduler.list_type.clone(),
});
Expand Down Expand Up @@ -500,8 +499,7 @@ impl<'a> SchedulingJob for ListFieldSchedulingJob<'a> {
pub struct ListFieldScheduler {
offsets_scheduler: Arc<dyn FieldScheduler>,
items_scheduler: Arc<dyn FieldScheduler>,
item_field_name: String,
items_type: DataType,
items_field: Arc<Field>,
offset_type: DataType,
list_type: DataType,
offset_page_info: Vec<OffsetPageInfo>,
Expand All @@ -522,26 +520,20 @@ impl ListFieldScheduler {
pub fn new(
offsets_scheduler: Arc<dyn FieldScheduler>,
items_scheduler: Arc<dyn FieldScheduler>,
item_field_name: String,
items_type: DataType,
items_field: Arc<Field>,
// Should be int32 or int64
offset_type: DataType,
offset_page_info: Vec<OffsetPageInfo>,
) -> Self {
let list_type = match &offset_type {
DataType::Int32 => {
DataType::List(Arc::new(Field::new("item", items_type.clone(), true)))
}
DataType::Int64 => {
DataType::LargeList(Arc::new(Field::new("item", items_type.clone(), true)))
}
DataType::Int32 => DataType::List(items_field.clone()),
DataType::Int64 => DataType::LargeList(items_field.clone()),
_ => panic!("Unexpected offset type {}", offset_type),
};
Self {
offsets_scheduler,
items_scheduler,
item_field_name,
items_type,
items_field,
offset_type,
offset_page_info,
list_type,
Expand Down Expand Up @@ -594,8 +586,7 @@ struct ListPageDecoder {
num_rows: u64,
rows_drained: u64,
rows_loaded: u64,
item_field_name: String,
items_type: DataType,
items_field: Arc<Field>,
offset_type: DataType,
data_type: DataType,
}
Expand All @@ -605,8 +596,7 @@ struct ListDecodeTask {
validity: BooleanBuffer,
// Will be None if there are no items (all empty / null lists)
items: Option<Box<dyn DecodeArrayTask>>,
item_field_name: String,
items_type: DataType,
items_field: Arc<Field>,
offset_type: DataType,
}

Expand All @@ -620,15 +610,7 @@ impl DecodeArrayTask for ListDecodeTask {
let wrapped_items = items.decode()?;
Result::Ok(wrapped_items.as_struct().column(0).clone())
})
.unwrap_or_else(|| Ok(new_empty_array(&self.items_type)))?;

// TODO: we default to nullable true here, should probably use the nullability given to
// us from the input schema
let item_field = Arc::new(Field::new(
self.item_field_name,
self.items_type.clone(),
true,
));
.unwrap_or_else(|| Ok(new_empty_array(self.items_field.data_type())))?;

// The offsets are already decoded but they need to be shifted back to 0 and cast
// to the appropriate type
Expand All @@ -651,7 +633,10 @@ impl DecodeArrayTask for ListDecodeTask {
let offsets = OffsetBuffer::new(offsets_i32.values().clone());

Ok(Arc::new(ListArray::try_new(
item_field, offsets, items, validity,
self.items_field.clone(),
offsets,
items,
validity,
)?))
}
DataType::Int64 => {
Expand All @@ -660,7 +645,10 @@ impl DecodeArrayTask for ListDecodeTask {
let offsets = OffsetBuffer::new(offsets_i64.values().clone());

Ok(Arc::new(LargeListArray::try_new(
item_field, offsets, items, validity,
self.items_field.clone(),
offsets,
items,
validity,
)?))
}
_ => panic!("ListDecodeTask with data type that is not i32 or i64"),
Expand Down Expand Up @@ -787,9 +775,8 @@ impl LogicalPageDecoder for ListPageDecoder {
task: Box::new(ListDecodeTask {
offsets,
validity,
item_field_name: self.item_field_name.clone(),
items_field: self.items_field.clone(),
items: item_decode,
items_type: self.items_type.clone(),
offset_type: self.offset_type.clone(),
}) as Box<dyn DecodeArrayTask>,
})
Expand Down
Loading