From d1109244d88bb19e7b8c66140dc52d78ae19aadc Mon Sep 17 00:00:00 2001 From: Elliana May Date: Thu, 18 Apr 2024 01:48:11 +0800 Subject: [PATCH] fix: arrow vtab panic (#293) --- libduckdb-sys/Cargo.toml | 2 +- src/vtab/arrow.rs | 112 ++++++++++++++++++++++++++------------- 2 files changed, 77 insertions(+), 37 deletions(-) diff --git a/libduckdb-sys/Cargo.toml b/libduckdb-sys/Cargo.toml index f407a1b1..bae62ded 100644 --- a/libduckdb-sys/Cargo.toml +++ b/libduckdb-sys/Cargo.toml @@ -39,4 +39,4 @@ serde_json = { version = "1.0" } tar = "0.4.38" [dev-dependencies] -arrow = { version = "49", default-features = false, features = ["ffi"] } +arrow = { version = "51", default-features = false, features = ["ffi"] } diff --git a/src/vtab/arrow.rs b/src/vtab/arrow.rs index 9e6e85be..fa378c9c 100644 --- a/src/vtab/arrow.rs +++ b/src/vtab/arrow.rs @@ -2,6 +2,7 @@ use super::{ vector::{FlatVector, ListVector, Vector}, BindInfo, DataChunk, Free, FunctionInfo, InitInfo, LogicalType, LogicalTypeId, StructVector, VTab, }; +use std::ptr::null_mut; use crate::vtab::vector::Inserter; use arrow::array::{ @@ -74,8 +75,11 @@ impl VTab for ArrowVTab { type InitData = ArrowInitData; unsafe fn bind(bind: &BindInfo, data: *mut ArrowBindData) -> Result<(), Box> { + (*data).rb = null_mut(); let param_count = bind.get_parameter_count(); - assert!(param_count == 2); + if param_count != 2 { + return Err(format!("Bad param count: {param_count}, expected 2").into()); + } let array = bind.get_parameter(0).to_int64(); let schema = bind.get_parameter(1).to_int64(); unsafe { @@ -106,6 +110,7 @@ impl VTab for ArrowVTab { output.set_len(0); } else { let rb = Box::from_raw((*bind_info).rb); + (*bind_info).rb = null_mut(); // erase ref in case of failure in record_batch_to_duckdb_data_chunk record_batch_to_duckdb_data_chunk(&rb, output)?; (*bind_info).rb = Box::into_raw(rb); (*init_info).done = true; @@ -156,7 +161,7 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result List, DataType::Struct(_) => Struct, DataType::Union(_, _) => Union, - DataType::Dictionary(_, _) => todo!(), + // DataType::Dictionary(_, _) => todo!(), // duckdb/src/main/capi/helper-c.cpp does not support decimal // DataType::Decimal128(_, _) => Decimal, // DataType::Decimal256(_, _) => Decimal, @@ -194,8 +199,9 @@ pub fn to_duckdb_logical_type(data_type: &DataType) -> Result { - primitive_array_to_vector(col, &mut chunk.flat_vector(i)); + primitive_array_to_vector(col, &mut chunk.flat_vector(i))?; } DataType::Utf8 => { string_array_to_vector(as_string_array(col.as_ref()), &mut chunk.flat_vector(i)); } DataType::List(_) => { - list_array_to_vector(as_list_array(col.as_ref()), &mut chunk.list_vector(i)); + list_array_to_vector(as_list_array(col.as_ref()), &mut chunk.list_vector(i))?; } DataType::LargeList(_) => { - list_array_to_vector(as_large_list_array(col.as_ref()), &mut chunk.list_vector(i)); + list_array_to_vector(as_large_list_array(col.as_ref()), &mut chunk.list_vector(i))?; } DataType::FixedSizeList(_, _) => { - fixed_size_list_array_to_vector(as_fixed_size_list_array(col.as_ref()), &mut chunk.list_vector(i)); + fixed_size_list_array_to_vector(as_fixed_size_list_array(col.as_ref()), &mut chunk.list_vector(i))?; } DataType::Struct(_) => { let struct_array = as_struct_array(col.as_ref()); let mut struct_vector = chunk.struct_vector(i); - struct_array_to_vector(struct_array, &mut struct_vector); + struct_array_to_vector(struct_array, &mut struct_vector)?; } _ => { - unimplemented!( + return Err(format!( "column {} is not supported yet, please file an issue https://github.com/wangfenjin/duckdb-rs", batch.schema().field(i) - ); + ) + .into()); } } } @@ -262,7 +269,7 @@ fn primitive_array_to_flat_vector_cast( out_vector.copy::(array.as_primitive::().values()); } -fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) { +fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) -> Result<(), Box> { match array.data_type() { DataType::Boolean => { boolean_array_to_vector(as_boolean_array(array), out.as_mut_any().downcast_mut().unwrap()); @@ -315,7 +322,6 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) { out.as_mut_any().downcast_mut().unwrap(), ); } - DataType::Float16 => todo!("Float16 is not supported yet"), DataType::Float32 => { primitive_array_to_flat_vector::( as_primitive_array(array), @@ -337,7 +343,6 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) { out.as_mut_any().downcast_mut().unwrap(), ); } - DataType::Decimal256(_, _) => todo!("Decimal256 is not supported yet"), // DuckDB Only supports timetamp_tz in microsecond precision DataType::Timestamp(_, Some(tz)) => primitive_array_to_flat_vector_cast::( @@ -376,11 +381,9 @@ fn primitive_array_to_vector(array: &dyn Array, out: &mut dyn Vector) { DataType::Time64(_) => { primitive_array_to_flat_vector_cast::(Time64MicrosecondType::DATA_TYPE, array, out) } - _ => todo!( - "Converting '{dtype:#?}' to primitive flat vector is not supported", - dtype = array.data_type() - ), + datatype => return Err(format!("Data type \"{datatype}\" not yet supported by ArrowVTab").into()), } + Ok(()) } /// Convert Arrow [Decimal128Array] to a duckdb vector. @@ -410,12 +413,15 @@ fn string_array_to_vector(array: &StringArray, out: &mut FlatVector) { } } -fn list_array_to_vector>(array: &GenericListArray, out: &mut ListVector) { +fn list_array_to_vector>( + array: &GenericListArray, + out: &mut ListVector, +) -> Result<(), Box> { let value_array = array.values(); let mut child = out.child(value_array.len()); match value_array.data_type() { dt if dt.is_primitive() => { - primitive_array_to_vector(value_array.as_ref(), &mut child); + primitive_array_to_vector(value_array.as_ref(), &mut child)?; for i in 0..array.len() { let offset = array.value_offsets()[i]; let length = array.value_length(i); @@ -423,18 +429,22 @@ fn list_array_to_vector>(array: &Generic } } _ => { - println!("Nested list is not supported yet."); - todo!() + return Err("Nested list is not supported yet.".into()); } } + + Ok(()) } -fn fixed_size_list_array_to_vector(array: &FixedSizeListArray, out: &mut ListVector) { +fn fixed_size_list_array_to_vector( + array: &FixedSizeListArray, + out: &mut ListVector, +) -> Result<(), Box> { let value_array = array.values(); let mut child = out.child(value_array.len()); match value_array.data_type() { dt if dt.is_primitive() => { - primitive_array_to_vector(value_array.as_ref(), &mut child); + primitive_array_to_vector(value_array.as_ref(), &mut child)?; for i in 0..array.len() { let offset = array.value_offset(i); let length = array.value_length(); @@ -443,10 +453,11 @@ fn fixed_size_list_array_to_vector(array: &FixedSizeListArray, out: &mut ListVec out.set_len(value_array.len()); } _ => { - println!("Nested list is not supported yet."); - todo!() + return Err("Nested list is not supported yet.".into()); } } + + Ok(()) } /// Force downcast of an [`Array`], such as an [`ArrayRef`], to @@ -455,32 +466,32 @@ fn as_fixed_size_list_array(arr: &dyn Array) -> &FixedSizeListArray { arr.as_any().downcast_ref::().unwrap() } -fn struct_array_to_vector(array: &StructArray, out: &mut StructVector) { +fn struct_array_to_vector(array: &StructArray, out: &mut StructVector) -> Result<(), Box> { for i in 0..array.num_columns() { let column = array.column(i); match column.data_type() { dt if dt.is_primitive() || matches!(dt, DataType::Boolean) => { - primitive_array_to_vector(column, &mut out.child(i)); + primitive_array_to_vector(column, &mut out.child(i))?; } DataType::Utf8 => { string_array_to_vector(as_string_array(column.as_ref()), &mut out.child(i)); } DataType::List(_) => { - list_array_to_vector(as_list_array(column.as_ref()), &mut out.list_vector_child(i)); + list_array_to_vector(as_list_array(column.as_ref()), &mut out.list_vector_child(i))?; } DataType::LargeList(_) => { - list_array_to_vector(as_large_list_array(column.as_ref()), &mut out.list_vector_child(i)); + list_array_to_vector(as_large_list_array(column.as_ref()), &mut out.list_vector_child(i))?; } DataType::FixedSizeList(_, _) => { fixed_size_list_array_to_vector( as_fixed_size_list_array(column.as_ref()), &mut out.list_vector_child(i), - ); + )?; } DataType::Struct(_) => { let struct_array = as_struct_array(column.as_ref()); let mut struct_vector = out.struct_vector_child(i); - struct_array_to_vector(struct_array, &mut struct_vector); + struct_array_to_vector(struct_array, &mut struct_vector)?; } _ => { unimplemented!( @@ -490,6 +501,7 @@ fn struct_array_to_vector(array: &StructArray, out: &mut StructVector) { } } } + Ok(()) } /// Pass RecordBatch to duckdb. @@ -531,11 +543,11 @@ mod test { use crate::{Connection, Result}; use arrow::{ array::{ - Array, ArrayRef, AsArray, Date32Array, Date64Array, Float64Array, Int32Array, PrimitiveArray, StringArray, - StructArray, Time32SecondArray, Time64MicrosecondArray, TimestampMicrosecondArray, - TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, + Array, ArrayRef, AsArray, Date32Array, Date64Array, Decimal256Array, Float64Array, Int32Array, + PrimitiveArray, StringArray, StructArray, Time32SecondArray, Time64MicrosecondArray, + TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, }, - datatypes::{ArrowPrimitiveType, DataType, Field, Fields, Schema}, + datatypes::{i256, ArrowPrimitiveType, DataType, Field, Fields, Schema}, record_batch::RecordBatch, }; use std::{error::Error, sync::Arc}; @@ -749,4 +761,32 @@ mod test { assert_eq!(column.value(0), "TIMESTAMP WITH TIME ZONE"); Ok(()) } + + #[test] + fn test_arrow_error() { + let arc: ArrayRef = Arc::new(Decimal256Array::from(vec![i256::from(1), i256::from(2), i256::from(3)])); + let batch = RecordBatch::try_from_iter(vec![("x", arc)]).unwrap(); + + let db = Connection::open_in_memory().unwrap(); + db.register_table_function::("arrow").unwrap(); + + let mut stmt = db.prepare("SELECT * FROM arrow(?, ?)").unwrap(); + + let res = match stmt.execute(arrow_recordbatch_to_query_params(batch)) { + Ok(..) => None, + Err(e) => Some(e), + } + .unwrap(); + + assert_eq!( + res, + crate::error::Error::DuckDBFailure( + crate::ffi::Error { + code: crate::ffi::ErrorCode::Unknown, + extended_code: 1 + }, + Some("Invalid Input Error: Data type \"Decimal256(76, 10)\" not yet supported by ArrowVTab".to_owned()) + ) + ); + } }