Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ipc): support for reading union arrays through IPC #1140

Merged
merged 1 commit into from
Jan 6, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 108 additions & 16 deletions arrow/src/ipc/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use std::sync::Arc;
use crate::array::*;
use crate::buffer::Buffer;
use crate::compute::cast;
use crate::datatypes::{DataType, Field, IntervalUnit, Schema, SchemaRef};
use crate::datatypes::{DataType, Field, IntervalUnit, Schema, SchemaRef, UnionMode};
use crate::error::{ArrowError, Result};
use crate::ipc;
use crate::record_batch::{RecordBatch, RecordBatchReader};
Expand Down Expand Up @@ -60,7 +60,7 @@ fn create_array(
dictionaries: &[Option<ArrayRef>],
mut node_index: usize,
mut buffer_index: usize,
) -> (ArrayRef, usize, usize) {
) -> Result<(ArrayRef, usize, usize)> {
use DataType::*;
let array = match data_type {
Utf8 | Binary | LargeBinary | LargeUtf8 => {
Expand Down Expand Up @@ -105,7 +105,7 @@ fn create_array(
dictionaries,
node_index,
buffer_index,
);
)?;
node_index = triple.1;
buffer_index = triple.2;

Expand All @@ -127,7 +127,7 @@ fn create_array(
dictionaries,
node_index,
buffer_index,
);
)?;
node_index = triple.1;
buffer_index = triple.2;

Expand All @@ -152,7 +152,7 @@ fn create_array(
dictionaries,
node_index,
buffer_index,
);
)?;
node_index = triple.1;
buffer_index = triple.2;
struct_arrays.push((struct_field.clone(), triple.0));
Expand Down Expand Up @@ -184,6 +184,55 @@ fn create_array(
value_array,
)
}
Union(fields, mode) => {
let union_node = nodes[node_index];
node_index += 1;

let len = union_node.length() as usize;

let null_buffer: Buffer = read_buffer(&buffers[buffer_index], data);
let type_ids: Buffer =
read_buffer(&buffers[buffer_index + 1], data)[..len].into();

buffer_index += 2;

let value_offsets = match mode {
UnionMode::Dense => {
let buffer = read_buffer(&buffers[buffer_index], data);
buffer_index += 1;
Some(buffer[..len * 4].into())
}
UnionMode::Sparse => None,
};

let mut children = vec![];

for field in fields {
let triple = create_array(
nodes,
field.data_type(),
data,
buffers,
dictionaries,
node_index,
buffer_index,
)?;

node_index = triple.1;
buffer_index = triple.2;

children.push((field.clone(), triple.0));
}

let array = UnionArray::try_new(
type_ids,
value_offsets,
children,
Some(null_buffer),
)?;

Arc::new(array)
}
Null => {
let length = nodes[node_index].length() as usize;
let data = ArrayData::builder(data_type.clone())
Expand All @@ -209,7 +258,7 @@ fn create_array(
array
}
};
(array, node_index, buffer_index)
Ok((array, node_index, buffer_index))
}

/// Reads the correct number of buffers based on data type and null_count, and creates a
Expand Down Expand Up @@ -438,7 +487,7 @@ pub fn read_record_batch(
dictionaries,
node_index,
buffer_index,
);
)?;
node_index = triple.1;
buffer_index = triple.2;
arrays.push(triple.0);
Expand Down Expand Up @@ -1165,6 +1214,19 @@ mod tests {
})
}

fn roundtrip_ipc(rb: &RecordBatch) -> RecordBatch {
let mut buf = Vec::new();
let mut writer =
ipc::writer::FileWriter::try_new(&mut buf, &rb.schema()).unwrap();
writer.write(rb).unwrap();
writer.finish().unwrap();
drop(writer);

let mut reader =
ipc::reader::FileReader::try_new(std::io::Cursor::new(buf)).unwrap();
reader.next().unwrap().unwrap()
}

#[test]
fn test_roundtrip_nested_dict() {
let inner: DictionaryArray<datatypes::Int32Type> =
Expand All @@ -1183,18 +1245,48 @@ mod tests {
false,
)]));

let batch = RecordBatch::try_new(schema.clone(), vec![struct_array]).unwrap();
let batch = RecordBatch::try_new(schema, vec![struct_array]).unwrap();

let mut buf = Vec::new();
let mut writer = ipc::writer::FileWriter::try_new(&mut buf, &schema).unwrap();
writer.write(&batch).unwrap();
writer.finish().unwrap();
drop(writer);
assert_eq!(batch, roundtrip_ipc(&batch));
}

fn check_union_with_builder(mut builder: UnionBuilder) {
builder.append::<datatypes::Int32Type>("a", 1).unwrap();
builder.append_null().unwrap();
builder.append::<datatypes::Float64Type>("c", 3.0).unwrap();
builder.append::<datatypes::Int32Type>("a", 4).unwrap();
builder.append::<datatypes::Int64Type>("d", 11).unwrap();
let union = builder.build().unwrap();

let schema = Arc::new(Schema::new(vec![Field::new(
"union",
union.data_type().clone(),
false,
)]));

let union_array = Arc::new(union) as ArrayRef;

let reader = ipc::reader::FileReader::try_new(std::io::Cursor::new(buf)).unwrap();
let batch2: std::result::Result<Vec<_>, _> = reader.collect();
let rb = RecordBatch::try_new(schema, vec![union_array]).unwrap();
let rb2 = roundtrip_ipc(&rb);
// TODO: equality not yet implemented for union, so we check that the length of the array is
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

// the same and that all of the buffers are the same instead.
assert_eq!(rb.schema(), rb2.schema());
assert_eq!(rb.num_columns(), rb2.num_columns());
assert_eq!(rb.num_rows(), rb2.num_rows());
let union1 = rb.column(0);
let union2 = rb2.column(0);

assert_eq!(batch, batch2.unwrap()[0]);
assert_eq!(union1.data().buffers(), union2.data().buffers());
}

#[test]
fn test_roundtrip_dense_union() {
check_union_with_builder(UnionBuilder::new_dense(6));
}

#[test]
fn test_roundtrip_sparse_union() {
check_union_with_builder(UnionBuilder::new_sparse(6));
}

/// Read gzipped JSON file
Expand Down