From 5a1729e1243006af56b29caa20678f4ce6d60209 Mon Sep 17 00:00:00 2001 From: Jack Eadie Date: Thu, 25 Apr 2024 08:00:06 +1000 Subject: [PATCH] Support UTF8 in nested Apache Arrow data types (e.g. List) (#300) * support UTF8[] * add tests * fix test * format * clippy * bump cause github is broken --------- Co-authored-by: Max Gabrielsson --- Cargo.toml | 1 + src/vtab/arrow.rs | 86 ++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 79 insertions(+), 8 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c5674e9b..d4223645 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -84,6 +84,7 @@ pretty_assertions = "1.4.0" path = "libduckdb-sys" version = "0.10.1" + [package.metadata.docs.rs] features = ['vtab', 'chrono'] all-features = false diff --git a/src/vtab/arrow.rs b/src/vtab/arrow.rs index fa378c9c..3d6c23c2 100644 --- a/src/vtab/arrow.rs +++ b/src/vtab/arrow.rs @@ -422,17 +422,20 @@ fn list_array_to_vector>( match value_array.data_type() { dt if dt.is_primitive() => { primitive_array_to_vector(value_array.as_ref(), &mut child)?; - for i in 0..array.len() { - let offset = array.value_offsets()[i]; - let length = array.value_length(i); - out.set_entry(i, offset.as_(), length.as_()); - } + } + DataType::Utf8 => { + string_array_to_vector(as_string_array(value_array.as_ref()), &mut child); } _ => { return Err("Nested list is not supported yet.".into()); } } + for i in 0..array.len() { + let offset = array.value_offsets()[i]; + let length = array.value_length(i); + out.set_entry(i, offset.as_(), length.as_()); + } Ok(()) } @@ -452,10 +455,19 @@ fn fixed_size_list_array_to_vector( } out.set_len(value_array.len()); } + DataType::Utf8 => { + string_array_to_vector(as_string_array(value_array.as_ref()), &mut child); + } _ => { return Err("Nested list is not supported yet.".into()); } } + for i in 0..array.len() { + let offset = array.value_offset(i); + let length = array.value_length(); + out.set_entry(i, offset as usize, length as usize); + } + out.set_len(value_array.len()); Ok(()) } @@ -543,10 +555,12 @@ mod test { use crate::{Connection, Result}; use arrow::{ array::{ - Array, ArrayRef, AsArray, Date32Array, Date64Array, Decimal256Array, Float64Array, Int32Array, - PrimitiveArray, StringArray, StructArray, Time32SecondArray, Time64MicrosecondArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, + Array, ArrayRef, AsArray, Date32Array, Date64Array, Decimal256Array, Float64Array, GenericListArray, + Int32Array, ListArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, Time32SecondArray, + Time64MicrosecondArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, + TimestampSecondArray, }, + buffer::{OffsetBuffer, ScalarBuffer}, datatypes::{i256, ArrowPrimitiveType, DataType, Field, Fields, Schema}, record_batch::RecordBatch, }; @@ -676,6 +690,62 @@ mod test { Ok(()) } + fn check_generic_array_roundtrip(arry: GenericListArray) -> Result<(), Box> + where + T: OffsetSizeTrait, + { + let expected_output_array = arry.clone(); + + let db = Connection::open_in_memory()?; + db.register_table_function::("arrow")?; + + // Roundtrip a record batch from Rust to DuckDB and back to Rust + let schema = Schema::new(vec![Field::new("a", arry.data_type().clone(), false)]); + + let rb = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arry.clone())])?; + let param = arrow_recordbatch_to_query_params(rb); + let mut stmt = db.prepare("select a from arrow(?, ?)")?; + let rb = stmt.query_arrow(param)?.next().expect("no record batch"); + + let output_any_array = rb.column(0); + assert!(output_any_array + .data_type() + .equals_datatype(expected_output_array.data_type())); + + match output_any_array.as_list_opt::() { + Some(output_array) => { + assert_eq!(output_array.len(), expected_output_array.len()); + for i in 0..output_array.len() { + assert_eq!(output_array.is_valid(i), expected_output_array.is_valid(i)); + if output_array.is_valid(i) { + assert!(expected_output_array.value(i).eq(&output_array.value(i))); + } + } + } + None => panic!("Expected GenericListArray"), + } + + Ok(()) + } + + #[test] + fn test_array_roundtrip() -> Result<(), Box> { + check_generic_array_roundtrip(ListArray::new( + Arc::new(Field::new("item", DataType::Utf8, true)), + OffsetBuffer::new(ScalarBuffer::from(vec![0, 2, 4, 5])), + Arc::new(StringArray::from(vec![ + Some("foo"), + Some("baz"), + Some("bar"), + Some("foo"), + Some("baz"), + ])), + None, + ))?; + + Ok(()) + } + #[test] fn test_timestamp_roundtrip() -> Result<(), Box> { check_rust_primitive_array_roundtrip(Int32Array::from(vec![1, 2, 3]), Int32Array::from(vec![1, 2, 3]))?;