From aebc1146b694eeeb6083fb9b2cb710bf4381a9ff Mon Sep 17 00:00:00 2001 From: Matthew Gapp <61894094+matthewgapp@users.noreply.github.com> Date: Fri, 22 Nov 2024 14:55:01 -0800 Subject: [PATCH] Add scalar support with arrow types and overloading --- crates/duckdb/Cargo.toml | 2 + crates/duckdb/examples/hello-ext-capi/main.rs | 4 +- crates/duckdb/examples/hello-ext/main.rs | 4 +- crates/duckdb/src/core/vector.rs | 32 +- crates/duckdb/src/lib.rs | 4 + crates/duckdb/src/r2d2.rs | 18 +- crates/duckdb/src/raw_statement.rs | 7 +- crates/duckdb/src/types/mod.rs | 2 + crates/duckdb/src/types/string.rs | 28 ++ crates/duckdb/src/vscalar/arrow.rs | 335 ++++++++++++ crates/duckdb/src/vscalar/function.rs | 138 +++++ crates/duckdb/src/vscalar/mod.rs | 323 ++++++++++++ crates/duckdb/src/vtab/arrow.rs | 475 +++++++++++++++--- crates/duckdb/src/vtab/excel.rs | 4 +- crates/duckdb/src/vtab/function.rs | 21 +- crates/duckdb/src/vtab/mod.rs | 19 +- 16 files changed, 1330 insertions(+), 86 deletions(-) create mode 100644 crates/duckdb/src/types/string.rs create mode 100644 crates/duckdb/src/vscalar/arrow.rs create mode 100644 crates/duckdb/src/vscalar/function.rs create mode 100644 crates/duckdb/src/vscalar/mod.rs diff --git a/crates/duckdb/Cargo.toml b/crates/duckdb/Cargo.toml index ddea979a..f83d5ef9 100644 --- a/crates/duckdb/Cargo.toml +++ b/crates/duckdb/Cargo.toml @@ -22,6 +22,8 @@ default = [] bundled = ["libduckdb-sys/bundled"] json = ["libduckdb-sys/json", "bundled"] parquet = ["libduckdb-sys/parquet", "bundled"] +vscalar = [] +vscalar-arrow = [] vtab = [] vtab-loadable = ["vtab", "duckdb-loadable-macros"] vtab-excel = ["vtab", "calamine"] diff --git a/crates/duckdb/examples/hello-ext-capi/main.rs b/crates/duckdb/examples/hello-ext-capi/main.rs index 54a4a35f..16b2754d 100644 --- a/crates/duckdb/examples/hello-ext-capi/main.rs +++ b/crates/duckdb/examples/hello-ext-capi/main.rs @@ -4,7 +4,7 @@ extern crate libduckdb_sys; use duckdb::{ core::{DataChunkHandle, Inserter, LogicalTypeHandle, LogicalTypeId}, - vtab::{BindInfo, Free, FunctionInfo, InitInfo, VTab}, + vtab::{BindInfo, Free, TableFunctionInfo, InitInfo, VTab}, Connection, Result, }; use duckdb_loadable_macros::duckdb_entrypoint_c_api; @@ -59,7 +59,7 @@ impl VTab for HelloVTab { Ok(()) } - unsafe fn func(func: &FunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box> { + unsafe fn func(func: &TableFunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box> { let init_info = func.get_init_data::(); let bind_info = func.get_bind_data::(); diff --git a/crates/duckdb/examples/hello-ext/main.rs b/crates/duckdb/examples/hello-ext/main.rs index 6f159e9a..6f79cc60 100644 --- a/crates/duckdb/examples/hello-ext/main.rs +++ b/crates/duckdb/examples/hello-ext/main.rs @@ -4,7 +4,7 @@ extern crate libduckdb_sys; use duckdb::{ core::{DataChunkHandle, Inserter, LogicalTypeHandle, LogicalTypeId}, - vtab::{BindInfo, Free, FunctionInfo, InitInfo, VTab}, + vtab::{BindInfo, Free, TableFunctionInfo, InitInfo, VTab}, Connection, Result, }; use duckdb_loadable_macros::duckdb_entrypoint; @@ -59,7 +59,7 @@ impl VTab for HelloVTab { Ok(()) } - unsafe fn func(func: &FunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box> { + unsafe fn func(func: &TableFunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box> { let init_info = func.get_init_data::(); let bind_info = func.get_bind_data::(); diff --git a/crates/duckdb/src/core/vector.rs b/crates/duckdb/src/core/vector.rs index 92e5622a..f257052c 100644 --- a/crates/duckdb/src/core/vector.rs +++ b/crates/duckdb/src/core/vector.rs @@ -1,6 +1,9 @@ use std::{any::Any, ffi::CString, slice}; -use libduckdb_sys::{duckdb_array_type_array_size, duckdb_array_vector_get_child, DuckDbString}; +use libduckdb_sys::{ + duckdb_array_type_array_size, duckdb_array_vector_get_child, duckdb_validity_row_is_valid, + DuckDbString, +}; use super::LogicalTypeHandle; use crate::ffi::{ @@ -55,6 +58,23 @@ impl FlatVector { self.capacity } + pub fn row_is_null(&self, row: u64) -> bool { + // use idx_t entry_idx = row_idx / 64; idx_t idx_in_entry = row_idx % 64; bool is_valid = validity_mask[entry_idx] & (1 Β« idx_in_entry); + // as the row is valid function is slower + let valid = unsafe { + let validity = duckdb_vector_get_validity(self.ptr); + + // validity can return a NULL pointer if the entire vector is valid + if validity.is_null() { + return false; + } + + duckdb_validity_row_is_valid(validity, row) + }; + + !valid + } + /// Returns an unsafe mutable pointer to the vector’s pub fn as_mut_ptr(&self) -> *mut T { unsafe { duckdb_vector_get_data(self.ptr).cast() } @@ -65,11 +85,21 @@ impl FlatVector { unsafe { slice::from_raw_parts(self.as_mut_ptr(), self.capacity()) } } + /// Returns a slice of the vector up to a certain length + pub fn as_slice_with_len(&self, len: usize) -> &[T] { + unsafe { slice::from_raw_parts(self.as_mut_ptr(), len) } + } + /// Returns a mutable slice of the vector pub fn as_mut_slice(&mut self) -> &mut [T] { unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), self.capacity()) } } + /// Returns a mutable slice of the vector up to a certain length + pub fn as_mut_slice_with_len(&mut self, len: usize) -> &mut [T] { + unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), len) } + } + /// Returns the logical type of the vector pub fn logical_type(&self) -> LogicalTypeHandle { unsafe { LogicalTypeHandle::new(duckdb_vector_get_column_type(self.ptr)) } diff --git a/crates/duckdb/src/lib.rs b/crates/duckdb/src/lib.rs index d8caf81d..ae793e44 100644 --- a/crates/duckdb/src/lib.rs +++ b/crates/duckdb/src/lib.rs @@ -124,6 +124,10 @@ pub mod types; #[cfg(feature = "vtab")] pub mod vtab; +/// The duckdb table function interface +#[cfg(feature = "vscalar")] +pub mod vscalar; + #[cfg(test)] mod test_all_types; diff --git a/crates/duckdb/src/r2d2.rs b/crates/duckdb/src/r2d2.rs index 75203671..e1e754e4 100644 --- a/crates/duckdb/src/r2d2.rs +++ b/crates/duckdb/src/r2d2.rs @@ -40,8 +40,9 @@ //! .unwrap() //! } //! ``` -use crate::{Config, Connection, Error, Result}; +use crate::{vscalar::VScalar, vtab::VTab, Config, Connection, Error, Result}; use std::{ + fmt::Debug, path::Path, sync::{Arc, Mutex}, }; @@ -78,6 +79,21 @@ impl DuckdbConnectionManager { connection: Arc::new(Mutex::new(Connection::open_in_memory_with_flags(config)?)), }) } + + /// Register a table function. + pub fn register_table_function(&self, name: &str) -> Result<()> { + let conn = self.connection.lock().unwrap(); + conn.register_table_function::(name) + } + + /// Register a scalar function. + pub fn register_scalar(&self, name: &str) -> Result<()> + where + S::State: Debug, + { + let conn = self.connection.lock().unwrap(); + conn.register_scalar_function::(name) + } } impl r2d2::ManageConnection for DuckdbConnectionManager { diff --git a/crates/duckdb/src/raw_statement.rs b/crates/duckdb/src/raw_statement.rs index 280fa158..31ae32fc 100644 --- a/crates/duckdb/src/raw_statement.rs +++ b/crates/duckdb/src/raw_statement.rs @@ -81,11 +81,11 @@ impl RawStatement { #[inline] pub fn step(&self) -> Option { - self.result?; + let out = self.result?; unsafe { let mut arrays = FFI_ArrowArray::empty(); if ffi::duckdb_query_arrow_array( - self.result_unwrap(), + out, &mut std::ptr::addr_of_mut!(arrays) as *mut _ as *mut ffi::duckdb_arrow_array, ) .ne(&ffi::DuckDBSuccess) @@ -99,7 +99,7 @@ impl RawStatement { let mut schema = FFI_ArrowSchema::empty(); if ffi::duckdb_query_arrow_schema( - self.result_unwrap(), + out, &mut std::ptr::addr_of_mut!(schema) as *mut _ as *mut ffi::duckdb_arrow_schema, ) != ffi::DuckDBSuccess { @@ -260,6 +260,7 @@ impl RawStatement { unsafe { let mut out: ffi::duckdb_arrow = ptr::null_mut(); let rc = ffi::duckdb_execute_prepared_arrow(self.ptr, &mut out); + println!("error code: {}", rc); result_from_duckdb_arrow(rc, out)?; let rows_changed = ffi::duckdb_arrow_rows_changed(out); diff --git a/crates/duckdb/src/types/mod.rs b/crates/duckdb/src/types/mod.rs index f8ae5c58..4da528eb 100644 --- a/crates/duckdb/src/types/mod.rs +++ b/crates/duckdb/src/types/mod.rs @@ -5,6 +5,7 @@ pub use self::{ from_sql::{FromSql, FromSqlError, FromSqlResult}, ordered_map::OrderedMap, + string::DuckString, to_sql::{ToSql, ToSqlOutput}, value::Value, value_ref::{EnumType, ListType, TimeUnit, ValueRef}, @@ -25,6 +26,7 @@ mod value; mod value_ref; mod ordered_map; +mod string; /// Empty struct that can be used to fill in a query parameter as `NULL`. /// diff --git a/crates/duckdb/src/types/string.rs b/crates/duckdb/src/types/string.rs new file mode 100644 index 00000000..994b3fec --- /dev/null +++ b/crates/duckdb/src/types/string.rs @@ -0,0 +1,28 @@ +use libduckdb_sys::{duckdb_string_t, duckdb_string_t_data, duckdb_string_t_length}; + +/// Wrapper for underlying duck string type with a lifetime bound to a &mut duckdb_string_t +pub struct DuckString<'a> { + ptr: &'a mut duckdb_string_t, +} + +impl<'a> DuckString<'a> { + pub(crate) fn new(ptr: &'a mut duckdb_string_t) -> Self { + DuckString { ptr } + } +} + +impl<'a> DuckString<'a> { + /// convert duckdb_string_t to a copy on write string + pub fn as_str(&mut self) -> std::borrow::Cow<'a, str> { + String::from_utf8_lossy(self.as_bytes()) + } + + /// convert duckdb_string_t to a byte slice + pub fn as_bytes(&mut self) -> &'a [u8] { + unsafe { + let len = duckdb_string_t_length(*self.ptr); + let c_ptr = duckdb_string_t_data(self.ptr); + std::slice::from_raw_parts(c_ptr as *const u8, len as usize) + } + } +} diff --git a/crates/duckdb/src/vscalar/arrow.rs b/crates/duckdb/src/vscalar/arrow.rs new file mode 100644 index 00000000..e4089536 --- /dev/null +++ b/crates/duckdb/src/vscalar/arrow.rs @@ -0,0 +1,335 @@ +use std::sync::Arc; + +use arrow::{ + array::{Array, RecordBatch}, + datatypes::DataType, +}; + +use crate::{ + core::{DataChunkHandle, LogicalTypeId}, + vtab::arrow::{data_chunk_to_arrow, write_arrow_array_to_vector, WritableVector}, +}; + +use super::{ScalarFunctionSignature, ScalarParams, VScalar}; + +/// The possible parameters of a scalar function that accepts and returns arrow types +pub enum ArrowScalarParams { + /// The exact parameters of the scalar function + Exact(Vec), + /// The variadic parameter of the scalar function + Variadic(DataType), +} + +impl AsRef<[DataType]> for ArrowScalarParams { + fn as_ref(&self) -> &[DataType] { + match self { + ArrowScalarParams::Exact(params) => params.as_ref(), + ArrowScalarParams::Variadic(param) => std::slice::from_ref(param), + } + } +} + +impl From for ScalarParams { + fn from(params: ArrowScalarParams) -> Self { + match params { + ArrowScalarParams::Exact(params) => ScalarParams::Exact( + params + .into_iter() + .map(|v| LogicalTypeId::try_from(&v).expect("type should be converted").into()) + .collect(), + ), + ArrowScalarParams::Variadic(param) => ScalarParams::Variadic( + LogicalTypeId::try_from(¶m) + .expect("type should be converted") + .into(), + ), + } + } +} + +/// A signature for a scalar function that accepts and returns arrow types +pub struct ArrowFunctionSignature { + /// The parameters of the scalar function + pub parameters: Option, + /// The return type of the scalar function + pub return_type: DataType, +} + +impl ArrowFunctionSignature { + /// Create an exact function signature + pub fn exact(params: Vec, return_type: DataType) -> Self { + ArrowFunctionSignature { + parameters: Some(ArrowScalarParams::Exact(params)), + return_type, + } + } + + /// Create a variadic function signature + pub fn variadic(param: DataType, return_type: DataType) -> Self { + ArrowFunctionSignature { + parameters: Some(ArrowScalarParams::Variadic(param)), + return_type, + } + } +} + +/// A trait for scalar functions that accept and return arrow types that can be registered with DuckDB +pub trait VArrowScalar: Sized { + /// State that persists across invocations of the scalar function (the lifetime of the connection) + type State: Default; + + /// The actual function that is called by DuckDB + fn invoke(info: &Self::State, input: RecordBatch) -> Result, Box>; + + /// The possible signatures of the scalar function. These will result in DuckDB scalar function overloads. + /// The invoke method should be able to handle all of these signatures. + fn signatures() -> Vec; +} + +impl VScalar for T +where + T: VArrowScalar, +{ + type State = T::State; + + unsafe fn invoke( + info: &Self::State, + input: &mut DataChunkHandle, + out: &mut dyn WritableVector, + ) -> Result<(), Box> { + let array = T::invoke(info, data_chunk_to_arrow(input)?)?; + write_arrow_array_to_vector(&array, out) + } + + fn signatures() -> Vec { + T::signatures() + .into_iter() + .map(|sig| ScalarFunctionSignature { + parameters: sig.parameters.map(Into::into), + return_type: LogicalTypeId::try_from(&sig.return_type) + .expect("type should be converted") + .into(), + }) + .collect() + } +} + +#[cfg(test)] +mod test { + + use std::{error::Error, sync::Arc}; + + use arrow::{ + array::{Array, RecordBatch, StringArray}, + datatypes::DataType, + }; + + use crate::{vscalar::arrow::ArrowFunctionSignature, Connection}; + + use super::VArrowScalar; + + struct HelloScalarArrow {} + + impl VArrowScalar for HelloScalarArrow { + type State = (); + + fn invoke(_: &Self::State, input: RecordBatch) -> Result, Box> { + let name = input.column(0).as_any().downcast_ref::().unwrap(); + let result = name.iter().map(|v| format!("Hello {}", v.unwrap())).collect::>(); + Ok(Arc::new(StringArray::from(result))) + } + + fn signatures() -> Vec { + vec![ArrowFunctionSignature::exact(vec![DataType::Utf8], DataType::Utf8)] + } + } + + #[derive(Debug)] + struct MockState { + info: String, + } + + impl Default for MockState { + fn default() -> Self { + MockState { + info: "some meta".to_string(), + } + } + } + + impl Drop for MockState { + fn drop(&mut self) { + println!("dropped meta"); + } + } + + struct ArrowMultiplyScalar {} + + impl VArrowScalar for ArrowMultiplyScalar { + type State = MockState; + + fn invoke(_: &Self::State, input: RecordBatch) -> Result, Box> { + let a = input + .column(0) + .as_any() + .downcast_ref::<::arrow::array::Float32Array>() + .unwrap(); + + let b = input + .column(1) + .as_any() + .downcast_ref::<::arrow::array::Float32Array>() + .unwrap(); + + let result = a + .iter() + .zip(b.iter()) + .map(|(a, b)| a.unwrap() * b.unwrap()) + .collect::>(); + Ok(Arc::new(::arrow::array::Float32Array::from(result))) + } + + fn signatures() -> Vec { + vec![ArrowFunctionSignature::exact( + vec![DataType::Float32, DataType::Float32], + DataType::Float32, + )] + } + } + + // accepts a string or a number and parses to int and multiplies by 2 + struct ArrowOverloaded {} + + impl VArrowScalar for ArrowOverloaded { + type State = MockState; + + fn invoke(s: &Self::State, input: RecordBatch) -> Result, Box> { + assert_eq!("some meta", s.info); + + let a = input.column(0); + let b = input.column(1); + + let result = match a.data_type() { + DataType::Utf8 => { + let a = a + .as_any() + .downcast_ref::<::arrow::array::StringArray>() + .unwrap() + .iter() + .map(|v| v.unwrap().parse::().unwrap()) + .collect::>(); + let b = b + .as_any() + .downcast_ref::<::arrow::array::Float32Array>() + .unwrap() + .iter() + .map(|v| v.unwrap()) + .collect::>(); + a.iter().zip(b.iter()).map(|(a, b)| a * b).collect::>() + } + DataType::Float32 => { + let a = a + .as_any() + .downcast_ref::<::arrow::array::Float32Array>() + .unwrap() + .iter() + .map(|v| v.unwrap()) + .collect::>(); + let b = b + .as_any() + .downcast_ref::<::arrow::array::Float32Array>() + .unwrap() + .iter() + .map(|v| v.unwrap()) + .collect::>(); + a.iter().zip(b.iter()).map(|(a, b)| a * b).collect::>() + } + _ => panic!("unsupported type"), + }; + + Ok(Arc::new(::arrow::array::Float32Array::from(result))) + } + + fn signatures() -> Vec { + vec![ + ArrowFunctionSignature::exact(vec![DataType::Utf8, DataType::Float32], DataType::Float32), + ArrowFunctionSignature::exact(vec![DataType::Float32, DataType::Float32], DataType::Float32), + ] + } + } + + #[test] + fn test_arrow_scalar() -> Result<(), Box> { + let conn = Connection::open_in_memory()?; + conn.register_scalar_function::("hello")?; + + let batches = conn + .prepare("select hello('foo') as hello from range(10)")? + .query_arrow([])? + .collect::>(); + + for batch in batches.iter() { + let array = batch.column(0); + let array = array.as_any().downcast_ref::<::arrow::array::StringArray>().unwrap(); + for i in 0..array.len() { + assert_eq!(array.value(i), format!("Hello foo")); + } + } + + Ok(()) + } + + #[test] + fn test_arrow_scalar_multiply() -> Result<(), Box> { + let conn = Connection::open_in_memory()?; + conn.register_scalar_function::("multiply_udf")?; + + let batches = conn + .prepare("select multiply_udf(3.0, 2.0) as mult_result from range(10)")? + .query_arrow([])? + .collect::>(); + + for batch in batches.iter() { + let array = batch.column(0); + let array = array.as_any().downcast_ref::<::arrow::array::Float32Array>().unwrap(); + for i in 0..array.len() { + assert_eq!(array.value(i), 6.0); + } + } + Ok(()) + } + + #[test] + fn test_multiple_signatures_scalar() -> Result<(), Box> { + let conn = Connection::open_in_memory()?; + conn.register_scalar_function::("multi_sig_udf")?; + + let batches = conn + .prepare("select multi_sig_udf('3', 5) as message from range(2)")? + .query_arrow([])? + .collect::>(); + + for batch in batches.iter() { + let array = batch.column(0); + let array = array.as_any().downcast_ref::<::arrow::array::Float32Array>().unwrap(); + for i in 0..array.len() { + assert_eq!(array.value(i), 15.0); + } + } + + let batches = conn + .prepare("select multi_sig_udf(12, 10) as message from range(2)")? + .query_arrow([])? + .collect::>(); + + for batch in batches.iter() { + let array = batch.column(0); + let array = array.as_any().downcast_ref::<::arrow::array::Float32Array>().unwrap(); + for i in 0..array.len() { + assert_eq!(array.value(i), 120.0); + } + } + + Ok(()) + } +} diff --git a/crates/duckdb/src/vscalar/function.rs b/crates/duckdb/src/vscalar/function.rs new file mode 100644 index 00000000..c08b3574 --- /dev/null +++ b/crates/duckdb/src/vscalar/function.rs @@ -0,0 +1,138 @@ +pub struct ScalarFunctionSet { + ptr: duckdb_scalar_function_set, +} + +impl ScalarFunctionSet { + pub fn new(name: &str) -> Self { + let c_name = CString::new(name).expect("name should contain valid utf-8"); + Self { + ptr: unsafe { duckdb_create_scalar_function_set(c_name.as_ptr()) }, + } + } + + pub fn add_function(&self, func: ScalarFunction) -> crate::Result<()> { + unsafe { + let rc = duckdb_add_scalar_function_to_set(self.ptr, func.ptr); + if rc != DuckDBSuccess { + return Err(Error::DuckDBFailure(ffi::Error::new(rc), None)); + } + } + + Ok(()) + } + + pub(crate) fn register_with_connection(&self, con: duckdb_connection) -> crate::Result<()> { + unsafe { + let rc = ffi::duckdb_register_scalar_function_set(con, self.ptr); + if rc != ffi::DuckDBSuccess { + return Err(Error::DuckDBFailure(ffi::Error::new(rc), None)); + } + } + Ok(()) + } +} + +/// A function that returns a queryable scalar function +#[derive(Debug)] +pub struct ScalarFunction { + ptr: duckdb_scalar_function, +} + +impl Drop for ScalarFunction { + fn drop(&mut self) { + unsafe { + duckdb_destroy_scalar_function(&mut self.ptr); + } + } +} + +use std::ffi::{c_void, CString}; + +use libduckdb_sys::{ + self as ffi, duckdb_add_scalar_function_to_set, duckdb_connection, duckdb_create_scalar_function, + duckdb_create_scalar_function_set, duckdb_data_chunk, duckdb_delete_callback_t, duckdb_destroy_scalar_function, + duckdb_function_info, duckdb_scalar_function, duckdb_scalar_function_add_parameter, duckdb_scalar_function_set, + duckdb_scalar_function_set_extra_info, duckdb_scalar_function_set_function, duckdb_scalar_function_set_name, + duckdb_scalar_function_set_return_type, duckdb_scalar_function_set_varargs, duckdb_vector, DuckDBSuccess, +}; + +use crate::{core::LogicalTypeHandle, Error}; + +impl ScalarFunction { + /// Creates a new empty scalar function. + pub fn new(name: impl Into) -> Result { + let name: String = name.into(); + let f_ptr = unsafe { duckdb_create_scalar_function() }; + let c_name = CString::new(name).expect("name should contain valid utf-8"); + unsafe { duckdb_scalar_function_set_name(f_ptr, c_name.as_ptr()) }; + + Ok(Self { ptr: f_ptr }) + } + + /// Adds a parameter to the scalar function. + /// + /// # Arguments + /// * `logical_type`: The type of the parameter to add. + pub fn add_parameter(&self, logical_type: &LogicalTypeHandle) -> &Self { + unsafe { + duckdb_scalar_function_add_parameter(self.ptr, logical_type.ptr); + } + self + } + + pub fn add_variadic_parameter(&self, logical_type: &LogicalTypeHandle) -> &Self { + unsafe { + duckdb_scalar_function_set_varargs(self.ptr, logical_type.ptr); + } + self + } + + /// Sets the return type of the scalar function. + /// + /// # Arguments + /// * `logical_type`: The return type of the scalar function. + pub fn set_return_type(&self, logical_type: &LogicalTypeHandle) -> &Self { + unsafe { + duckdb_scalar_function_set_return_type(self.ptr, logical_type.ptr); + } + self + } + + /// Sets the main function of the scalar function + /// + /// # Arguments + /// * `function`: The function + pub fn set_function( + &self, + func: Option, + ) -> &Self { + unsafe { + duckdb_scalar_function_set_function(self.ptr, func); + } + self + } + + /// Assigns extra information to the scalar function that can be fetched during binding, etc. + /// + /// # Arguments + /// * `extra_info`: The extra information + /// * `destroy`: The callback that will be called to destroy the bind data (if any) + /// + /// # Safety + unsafe fn set_extra_info_impl(&self, extra_info: *mut c_void, destroy: duckdb_delete_callback_t) { + duckdb_scalar_function_set_extra_info(self.ptr, extra_info, destroy); + } + + pub fn set_extra_info(&self) -> &ScalarFunction { + unsafe { + let t = Box::new(T::default()); + let c_void = Box::into_raw(t) as *mut c_void; + self.set_extra_info_impl(c_void, Some(drop_ptr::)); + } + self + } +} + +unsafe extern "C" fn drop_ptr(ptr: *mut c_void) { + let _ = Box::from_raw(ptr as *mut T); +} diff --git a/crates/duckdb/src/vscalar/mod.rs b/crates/duckdb/src/vscalar/mod.rs new file mode 100644 index 00000000..54a63815 --- /dev/null +++ b/crates/duckdb/src/vscalar/mod.rs @@ -0,0 +1,323 @@ +use std::ffi::CString; + +use function::{ScalarFunction, ScalarFunctionSet}; +use libduckdb_sys::{ + duckdb_data_chunk, duckdb_function_info, duckdb_scalar_function_get_extra_info, duckdb_scalar_function_set_error, + duckdb_vector, +}; + +use crate::{ + core::{DataChunkHandle, LogicalTypeHandle}, + inner_connection::InnerConnection, + vtab::arrow::WritableVector, + Connection, +}; +mod function; + +/// The duckdb Arrow table function interface +#[cfg(feature = "vscalar-arrow")] +pub mod arrow; + +/// Duckdb scalar function trait +pub trait VScalar: Sized { + /// State that persists across invocations of the scalar function (the lifetime of the connection) + type State: Default; + /// The actual function + /// + /// # Safety + /// + /// This function is unsafe because it: + /// + /// - Dereferences multiple raw pointers (`func``). + /// + unsafe fn invoke( + state: &Self::State, + input: &mut DataChunkHandle, + output: &mut dyn WritableVector, + ) -> Result<(), Box>; + + /// The possible signatures of the scalar function. + /// These will result in DuckDB scalar function overloads. + /// The invoke method should be able to handle all of these signatures. + fn signatures() -> Vec; +} + +/// Duckdb scalar function parameters +pub enum ScalarParams { + /// Exact parameters + Exact(Vec), + /// Variadic parameters + Variadic(LogicalTypeHandle), +} + +/// Duckdb scalar function signature +pub struct ScalarFunctionSignature { + parameters: Option, + return_type: LogicalTypeHandle, +} + +impl ScalarFunctionSignature { + /// Create an exact function signature + pub fn exact(params: Vec, return_type: LogicalTypeHandle) -> Self { + ScalarFunctionSignature { + parameters: Some(ScalarParams::Exact(params)), + return_type, + } + } + + /// Create a variadic function signature + pub fn variadic(param: LogicalTypeHandle, return_type: LogicalTypeHandle) -> Self { + ScalarFunctionSignature { + parameters: Some(ScalarParams::Variadic(param)), + return_type, + } + } +} + +impl ScalarFunctionSignature { + pub(crate) fn register_with_scalar(&self, f: &ScalarFunction) { + f.set_return_type(&self.return_type); + + match &self.parameters { + Some(ScalarParams::Exact(params)) => { + for param in params.iter() { + f.add_parameter(param); + } + } + Some(ScalarParams::Variadic(param)) => { + f.add_variadic_parameter(param); + } + None => { + // do nothing + } + } + } +} + +/// An interface to store and retrieve data during the function execution stage +#[derive(Debug)] +struct ScalarFunctionInfo(duckdb_function_info); + +impl From for ScalarFunctionInfo { + fn from(ptr: duckdb_function_info) -> Self { + Self(ptr) + } +} + +impl ScalarFunctionInfo { + pub unsafe fn get_scalar_extra_info(&self) -> &T { + &*(duckdb_scalar_function_get_extra_info(self.0).cast()) + } + + pub unsafe fn set_error(&self, error: &str) { + let c_str = CString::new(error).unwrap(); + duckdb_scalar_function_set_error(self.0, c_str.as_ptr()); + } +} + +unsafe extern "C" fn scalar_func(info: duckdb_function_info, input: duckdb_data_chunk, mut output: duckdb_vector) +where + T: VScalar, +{ + let info = ScalarFunctionInfo::from(info); + let mut input = DataChunkHandle::new_unowned(input); + let result = T::invoke(info.get_scalar_extra_info(), &mut input, &mut output); + if let Err(e) = result { + info.set_error(&e.to_string()); + } +} + +impl Connection { + /// Register the given ScalarFunction with the current db + #[inline] + pub fn register_scalar_function(&self, name: &str) -> crate::Result<()> { + let set = ScalarFunctionSet::new(name); + for signature in S::signatures() { + let scalar_function = ScalarFunction::new(name)?; + signature.register_with_scalar(&scalar_function); + scalar_function.set_function(Some(scalar_func::)); + scalar_function.set_extra_info::(); + set.add_function(scalar_function)?; + } + self.db.borrow_mut().register_scalar_function_set(set) + } +} + +impl InnerConnection { + /// Register the given ScalarFunction with the current db + pub fn register_scalar_function_set(&mut self, f: ScalarFunctionSet) -> crate::Result<()> { + f.register_with_connection(self.con) + } +} + +#[cfg(test)] +mod test { + use std::error::Error; + + use arrow::array::Array; + use libduckdb_sys::duckdb_string_t; + + use crate::{ + core::{DataChunkHandle, Inserter, LogicalTypeHandle, LogicalTypeId}, + types::DuckString, + vtab::arrow::WritableVector, + Connection, + }; + + use super::{ScalarFunctionSignature, VScalar}; + + struct ErrorScalar {} + + impl VScalar for ErrorScalar { + type State = (); + + unsafe fn invoke( + _: &Self::State, + input: &mut DataChunkHandle, + _: &mut dyn WritableVector, + ) -> Result<(), Box> { + let mut msg = input.flat_vector(0).as_slice_with_len::(input.len())[0]; + let string = DuckString::new(&mut msg).as_str(); + Err(format!("Error: {}", string).into()) + } + + fn signatures() -> Vec { + vec![ScalarFunctionSignature::exact( + vec![LogicalTypeId::Varchar.into()], + LogicalTypeId::Varchar.into(), + )] + } + } + + #[derive(Debug)] + struct TestState { + #[allow(dead_code)] + inner: i32, + } + + impl Default for TestState { + fn default() -> Self { + TestState { inner: 42 } + } + } + + struct EchoScalar {} + + impl VScalar for EchoScalar { + type State = TestState; + + unsafe fn invoke( + s: &Self::State, + input: &mut DataChunkHandle, + output: &mut dyn WritableVector, + ) -> Result<(), Box> { + assert_eq!(s.inner, 42); + let values = input.flat_vector(0); + let values = values.as_slice_with_len::(input.len()); + let strings = values + .iter() + .map(|ptr| DuckString::new(&mut { *ptr }).as_str().to_string()) + .take(input.len()); + let output = output.flat_vector(); + for s in strings { + output.insert(0, s.to_string().as_str()); + } + Ok(()) + } + + fn signatures() -> Vec { + vec![ScalarFunctionSignature::exact( + vec![LogicalTypeId::Varchar.into()], + LogicalTypeId::Varchar.into(), + )] + } + } + + struct Repeat {} + + impl VScalar for Repeat { + type State = (); + + unsafe fn invoke( + _: &Self::State, + input: &mut DataChunkHandle, + output: &mut dyn WritableVector, + ) -> Result<(), Box> { + let output = output.flat_vector(); + let counts = input.flat_vector(1); + let values = input.flat_vector(0); + let values = values.as_slice_with_len::(input.len()); + let strings = values + .iter() + .map(|ptr| DuckString::new(&mut { *ptr }).as_str().to_string()); + let counts = counts.as_slice_with_len::(input.len()); + for (count, value) in counts.iter().zip(strings).take(input.len()) { + output.insert(0, value.repeat((*count) as usize).as_str()); + } + + Ok(()) + } + + fn signatures() -> Vec { + vec![ScalarFunctionSignature::exact( + vec![ + LogicalTypeHandle::from(LogicalTypeId::Varchar), + LogicalTypeHandle::from(LogicalTypeId::Integer), + ], + LogicalTypeHandle::from(LogicalTypeId::Varchar), + )] + } + } + + #[test] + fn test_scalar() -> Result<(), Box> { + let conn = Connection::open_in_memory()?; + conn.register_scalar_function::("echo")?; + + let mut stmt = conn.prepare("select echo('hi') as hello")?; + let mut rows = stmt.query([])?; + + while let Some(row) = rows.next()? { + let hello: String = row.get(0)?; + assert_eq!(hello, "hi"); + } + + Ok(()) + } + + #[test] + fn test_scalar_error() -> Result<(), Box> { + let conn = Connection::open_in_memory()?; + conn.register_scalar_function::("error_udf")?; + + let mut stmt = conn.prepare("select error_udf('blurg') as hello")?; + if let Err(err) = stmt.query([]) { + assert!(err.to_string().contains("Error: blurg")); + } else { + panic!("Expected an error"); + } + + Ok(()) + } + + #[test] + fn test_repeat_scalar() -> Result<(), Box> { + let conn = Connection::open_in_memory()?; + conn.register_scalar_function::("nobie_repeat")?; + + let batches = conn + .prepare("select nobie_repeat('Ho ho ho πŸŽ…πŸŽ„', 3) as message from range(5)")? + .query_arrow([])? + .collect::>(); + + for batch in batches.iter() { + let array = batch.column(0); + let array = array.as_any().downcast_ref::<::arrow::array::StringArray>().unwrap(); + for i in 0..array.len() { + assert_eq!(array.value(i), "Ho ho ho πŸŽ…πŸŽ„Ho ho ho πŸŽ…πŸŽ„Ho ho ho πŸŽ…πŸŽ„"); + } + } + + Ok(()) + } +} diff --git a/crates/duckdb/src/vtab/arrow.rs b/crates/duckdb/src/vtab/arrow.rs index 219f6f71..a3a68db4 100644 --- a/crates/duckdb/src/vtab/arrow.rs +++ b/crates/duckdb/src/vtab/arrow.rs @@ -1,14 +1,26 @@ -use super::{BindInfo, DataChunkHandle, Free, FunctionInfo, InitInfo, LogicalTypeHandle, LogicalTypeId, VTab}; -use std::ptr::null_mut; +use super::{BindInfo, DataChunkHandle, Free, InitInfo, LogicalTypeHandle, LogicalTypeId, TableFunctionInfo, VTab}; +use std::{ + borrow::Cow, + ffi::{c_char, CStr}, + marker::PhantomData, + ptr::null_mut, + sync::Arc, +}; -use crate::core::{ArrayVector, FlatVector, Inserter, ListVector, StructVector, Vector}; +use crate::{ + core::{ArrayVector, FlatVector, Inserter, ListVector, StructVector, Vector}, + types::DuckString, +}; use arrow::{ array::{ as_boolean_array, as_generic_binary_array, as_large_list_array, as_list_array, as_primitive_array, - as_string_array, as_struct_array, Array, ArrayData, AsArray, BinaryArray, BooleanArray, Decimal128Array, - FixedSizeBinaryArray, FixedSizeListArray, GenericListArray, GenericStringArray, IntervalMonthDayNanoArray, - LargeBinaryArray, LargeStringArray, OffsetSizeTrait, PrimitiveArray, StructArray, + as_string_array, as_struct_array, Array, ArrayData, AsArray, BinaryArray, BooleanArray, Date32Array, + Decimal128Array, Decimal256Array, FixedSizeBinaryArray, FixedSizeListArray, GenericBinaryBuilder, + GenericByteBuilder, GenericListArray, GenericStringArray, IntervalMonthDayNanoArray, LargeBinaryArray, + LargeStringArray, OffsetSizeTrait, PrimitiveArray, StringArray, StructArray, TimestampMicrosecondArray, + TimestampNanosecondArray, }, + buffer::{BooleanBuffer, Buffer, NullBuffer}, compute::cast, }; @@ -18,6 +30,10 @@ use arrow::{ record_batch::RecordBatch, }; +use libduckdb_sys::{ + duckdb_date, duckdb_from_timestamp, duckdb_hugeint, duckdb_interval, duckdb_string_t, duckdb_string_t_data, + duckdb_string_t_length, duckdb_time, duckdb_timestamp, duckdb_vector, +}; use num::{cast::AsPrimitive, ToPrimitive}; /// A pointer to the Arrow record batch for the table function. @@ -103,7 +119,7 @@ impl VTab for ArrowVTab { Ok(()) } - unsafe fn func(func: &FunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box> { + unsafe fn func(func: &TableFunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box> { let init_info = func.get_init_data::(); let bind_info = func.get_bind_data::(); unsafe { @@ -163,9 +179,6 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result Struct, DataType::Union(_, _) => Union, // DataType::Dictionary(_, _) => todo!(), - // duckdb/src/main/capi/helper-c.cpp does not support decimal - // DataType::Decimal128(_, _) => Decimal, - // DataType::Decimal256(_, _) => Decimal, DataType::Decimal128(_, _) => Decimal, DataType::Decimal256(_, _) => Double, DataType::Map(_, _) => Map, @@ -176,6 +189,22 @@ pub fn to_duckdb_type_id(data_type: &DataType) -> Result for LogicalTypeId { + type Error = Box; + + fn try_from(data_type: &DataType) -> Result { + to_duckdb_type_id(data_type) + } +} + +impl TryFrom for LogicalTypeId { + type Error = Box; + + fn try_from(data_type: DataType) -> Result { + to_duckdb_type_id(&data_type) + } +} + /// Convert arrow DataType to duckdb logical type pub fn to_duckdb_logical_type(data_type: &DataType) -> Result> { match data_type { @@ -212,6 +241,378 @@ pub fn to_duckdb_logical_type(data_type: &DataType) -> Result Result, Box> { + let type_id = vector.logical_type().id(); + match type_id { + LogicalTypeId::Integer => { + let data = vector.as_slice_with_len::(len); + + Ok(Arc::new(PrimitiveArray::::from_iter_values_with_nulls( + data.iter().copied(), + Some(NullBuffer::new(BooleanBuffer::collect_bool(data.len(), |row| { + !vector.row_is_null(row as u64) + }))), + ))) + } + LogicalTypeId::Timestamp + | LogicalTypeId::TimestampMs + | LogicalTypeId::TimestampS + | LogicalTypeId::TimestampTZ => { + let data = vector.as_slice_with_len::(len); + let micros = data.iter().map(|duckdb_timestamp { micros }| *micros); + let structs = TimestampMicrosecondArray::from_iter_values_with_nulls( + micros, + Some(NullBuffer::new(BooleanBuffer::collect_bool(data.len(), |row| { + !vector.row_is_null(row as u64) + }))), + ); + + Ok(Arc::new(structs)) + } + LogicalTypeId::Varchar => { + let data = vector.as_slice_with_len::(len); + + let duck_strings = data.iter().enumerate().map(|(i, s)| { + if vector.row_is_null(i as u64) { + None + } else { + let mut ptr = *s; + Some(DuckString::new(&mut ptr).as_str().to_string()) + } + }); + + let values = duck_strings.collect::>(); + + Ok(Arc::new(StringArray::from(values))) + } + LogicalTypeId::Boolean => { + let data = vector.as_slice_with_len::(len); + + Ok(Arc::new(BooleanArray::new( + BooleanBuffer::from_iter(data.iter().copied()), + Some(NullBuffer::new(BooleanBuffer::collect_bool(data.len(), |row| { + !vector.row_is_null(row as u64) + }))), + ))) + } + LogicalTypeId::Float => { + let data = vector.as_slice_with_len::(len); + + Ok(Arc::new(PrimitiveArray::::from_iter_values_with_nulls( + data.iter().copied(), + Some(NullBuffer::new(BooleanBuffer::collect_bool(data.len(), |row| { + !vector.row_is_null(row as u64) + }))), + ))) + } + LogicalTypeId::Double => { + let data = vector.as_slice_with_len::(len); + + Ok(Arc::new(PrimitiveArray::::from_iter_values_with_nulls( + data.iter().copied(), + Some(NullBuffer::new(BooleanBuffer::collect_bool(data.len(), |row| { + !vector.row_is_null(row as u64) + }))), + ))) + } + LogicalTypeId::Date => { + let data = vector.as_slice_with_len::(len); + + Ok(Arc::new(Date32Array::from_iter_values_with_nulls( + data.iter().map(|duckdb_date { days }| *days), + Some(NullBuffer::new(BooleanBuffer::collect_bool(data.len(), |row| { + !vector.row_is_null(row as u64) + }))), + ))) + } + LogicalTypeId::Time => { + let data = vector.as_slice_with_len::(len); + + Ok(Arc::new( + PrimitiveArray::::from_iter_values_with_nulls( + data.iter().map(|duckdb_time { micros }| *micros), + Some(NullBuffer::new(BooleanBuffer::collect_bool(data.len(), |row| { + !vector.row_is_null(row as u64) + }))), + ), + )) + } + LogicalTypeId::Smallint => { + let data = vector.as_slice_with_len::(len); + + Ok(Arc::new(PrimitiveArray::::from_iter_values_with_nulls( + data.iter().copied(), + Some(NullBuffer::new(BooleanBuffer::collect_bool(data.len(), |row| { + !vector.row_is_null(row as u64) + }))), + ))) + } + LogicalTypeId::USmallint => { + let data = vector.as_slice_with_len::(len); + + Ok(Arc::new(PrimitiveArray::::from_iter_values_with_nulls( + data.iter().copied(), + Some(NullBuffer::new(BooleanBuffer::collect_bool(data.len(), |row| { + !vector.row_is_null(row as u64) + }))), + ))) + } + LogicalTypeId::Blob => { + let mut data = vector.as_slice_with_len::(len).to_vec(); + + let duck_strings = data.iter_mut().enumerate().map(|(i, ptr)| { + if vector.row_is_null(i as u64) { + None + } else { + Some(DuckString::new(ptr)) + } + }); + + let mut builder = GenericBinaryBuilder::::new(); + for s in duck_strings { + if let Some(mut s) = s { + builder.append_value(s.as_bytes()); + } else { + builder.append_null(); + } + } + + Ok(Arc::new(builder.finish())) + } + LogicalTypeId::Tinyint => { + let data = vector.as_slice_with_len::(len); + + Ok(Arc::new(PrimitiveArray::::from_iter_values_with_nulls( + data.iter().copied(), + Some(NullBuffer::new(BooleanBuffer::collect_bool(data.len(), |row| { + !vector.row_is_null(row as u64) + }))), + ))) + } + LogicalTypeId::Bigint => { + let data = vector.as_slice_with_len::(len); + + Ok(Arc::new(PrimitiveArray::::from_iter_values_with_nulls( + data.iter().copied(), + Some(NullBuffer::new(BooleanBuffer::collect_bool(data.len(), |row| { + !vector.row_is_null(row as u64) + }))), + ))) + } + LogicalTypeId::UBigint => { + let data = vector.as_slice_with_len::(len); + + Ok(Arc::new(PrimitiveArray::::from_iter_values_with_nulls( + data.iter().copied(), + Some(NullBuffer::new(BooleanBuffer::collect_bool(data.len(), |row| { + !vector.row_is_null(row as u64) + }))), + ))) + } + LogicalTypeId::UTinyint => { + let data = vector.as_slice_with_len::(len); + + Ok(Arc::new(PrimitiveArray::::from_iter_values_with_nulls( + data.iter().copied(), + Some(NullBuffer::new(BooleanBuffer::collect_bool(data.len(), |row| { + !vector.row_is_null(row as u64) + }))), + ))) + } + LogicalTypeId::UInteger => { + let data = vector.as_slice_with_len::(len); + + Ok(Arc::new(PrimitiveArray::::from_iter_values_with_nulls( + data.iter().copied(), + Some(NullBuffer::new(BooleanBuffer::collect_bool(data.len(), |row| { + !vector.row_is_null(row as u64) + }))), + ))) + } + LogicalTypeId::TimestampNs => { + // even nano second precision is stored in micros when using the c api + let data = vector.as_slice_with_len::(len); + let nanos = data.iter().map(|duckdb_timestamp { micros }| *micros * 1000); + let structs = TimestampNanosecondArray::from_iter_values_with_nulls( + nanos, + Some(NullBuffer::new(BooleanBuffer::collect_bool(data.len(), |row| { + !vector.row_is_null(row as u64) + }))), + ); + + Ok(Arc::new(structs)) + } + LogicalTypeId::Struct => { + todo!() + } + LogicalTypeId::Decimal => { + todo!() + } + LogicalTypeId::Map => { + todo!() + } + LogicalTypeId::List => { + todo!() + } + LogicalTypeId::Union => { + todo!() + } + LogicalTypeId::Interval => { + let _data = vector.as_slice_with_len::(len); + todo!() + } + LogicalTypeId::Hugeint => { + let _data = vector.as_slice_with_len::(len); + todo!() + } + LogicalTypeId::Enum => { + todo!() + } + LogicalTypeId::Uuid => { + todo!() + } + } +} + +/// converts a `DataChunk` to arrow `RecordBatch` +pub fn data_chunk_to_arrow(chunk: &DataChunkHandle) -> Result> { + let len = chunk.len(); + + let columns = (0..chunk.num_columns()) + .map(|i| { + let mut vector = chunk.flat_vector(i); + flat_vector_to_arrow_array(&mut vector, len).map(|array_data| { + assert_eq!(array_data.len(), chunk.len()); + let array: Arc = Arc::new(array_data); + (i.to_string(), array) + }) + }) + .collect::, _>>()?; + + Ok(RecordBatch::try_from_iter(columns.into_iter())?) +} + +struct DataChunkHandleSlice<'a> { + chunk: &'a mut DataChunkHandle, + column_index: usize, +} + +impl<'a> DataChunkHandleSlice<'a> { + fn new(chunk: &'a mut DataChunkHandle, column_index: usize) -> Self { + Self { chunk, column_index } + } +} + +impl<'a> WritableVector for DataChunkHandleSlice<'a> { + fn array_vector(&mut self) -> ArrayVector { + self.chunk.array_vector(self.column_index) + } + + fn flat_vector(&mut self) -> FlatVector { + self.chunk.flat_vector(self.column_index) + } + + fn struct_vector(&mut self) -> StructVector { + self.chunk.struct_vector(self.column_index) + } + + fn list_vector(&mut self) -> ListVector { + self.chunk.list_vector(self.column_index) + } +} + +pub trait WritableVector { + fn flat_vector(&mut self) -> FlatVector; + fn list_vector(&mut self) -> ListVector; + fn array_vector(&mut self) -> ArrayVector; + fn struct_vector(&mut self) -> StructVector; +} + +/// Writes an Arrow array to a `WritableVector`. +pub fn write_arrow_array_to_vector( + col: &Arc, + chunk: &mut dyn WritableVector, +) -> Result<(), Box> { + match col.data_type() { + dt if dt.is_primitive() || matches!(dt, DataType::Boolean) => { + primitive_array_to_vector(col, &mut chunk.flat_vector())?; + } + DataType::Utf8 => { + string_array_to_vector(as_string_array(col.as_ref()), &mut chunk.flat_vector()); + } + DataType::LargeUtf8 => { + string_array_to_vector( + col.as_ref() + .as_any() + .downcast_ref::() + .ok_or_else(|| Box::::from("Unable to downcast to LargeStringArray"))?, + &mut chunk.flat_vector(), + ); + } + DataType::Binary => { + binary_array_to_vector(as_generic_binary_array(col.as_ref()), &mut chunk.flat_vector()); + } + DataType::FixedSizeBinary(_) => { + fixed_size_binary_array_to_vector(col.as_ref().as_fixed_size_binary(), &mut chunk.flat_vector()); + } + DataType::LargeBinary => { + large_binary_array_to_vector( + col.as_ref() + .as_any() + .downcast_ref::() + .ok_or_else(|| Box::::from("Unable to downcast to LargeBinaryArray"))?, + &mut chunk.flat_vector(), + ); + } + DataType::List(_) => { + list_array_to_vector(as_list_array(col.as_ref()), &mut chunk.list_vector())?; + } + DataType::LargeList(_) => { + list_array_to_vector(as_large_list_array(col.as_ref()), &mut chunk.list_vector())?; + } + DataType::FixedSizeList(_, _) => { + fixed_size_list_array_to_vector(as_fixed_size_list_array(col.as_ref()), &mut chunk.array_vector())?; + } + DataType::Struct(_) => { + let struct_array = as_struct_array(col.as_ref()); + let mut struct_vector = chunk.struct_vector(); + struct_array_to_vector(struct_array, &mut struct_vector)?; + } + dt => { + return Err(format!( + "column with data_type {} is not supported yet, please file an issue https://github.com/wangfenjin/duckdb-rs", + dt + ) + .into()); + } + } + + Ok(()) +} + +impl WritableVector for duckdb_vector { + fn array_vector(&mut self) -> ArrayVector { + ArrayVector::from(*self) + } + + fn flat_vector(&mut self) -> FlatVector { + FlatVector::from(*self) + } + + fn list_vector(&mut self) -> ListVector { + ListVector::from(*self) + } + + fn struct_vector(&mut self) -> StructVector { + StructVector::from(*self) + } +} + /// Converts a `RecordBatch` to a `DataChunk` in the DuckDB format. /// /// # Arguments @@ -226,59 +627,7 @@ pub fn record_batch_to_duckdb_data_chunk( assert_eq!(batch.num_columns(), chunk.num_columns()); for i in 0..batch.num_columns() { let col = batch.column(i); - match col.data_type() { - dt if dt.is_primitive() || matches!(dt, DataType::Boolean) => { - primitive_array_to_vector(col, &mut chunk.flat_vector(i))?; - } - DataType::Utf8 => { - string_array_to_vector(as_string_array(col.as_ref()), &mut chunk.flat_vector(i)); - } - DataType::LargeUtf8 => { - string_array_to_vector( - col.as_ref() - .as_any() - .downcast_ref::() - .ok_or_else(|| Box::::from("Unable to downcast to LargeStringArray"))?, - &mut chunk.flat_vector(i), - ); - } - DataType::Binary => { - binary_array_to_vector(as_generic_binary_array(col.as_ref()), &mut chunk.flat_vector(i)); - } - DataType::FixedSizeBinary(_) => { - fixed_size_binary_array_to_vector(col.as_ref().as_fixed_size_binary(), &mut chunk.flat_vector(i)); - } - DataType::LargeBinary => { - large_binary_array_to_vector( - col.as_ref() - .as_any() - .downcast_ref::() - .ok_or_else(|| Box::::from("Unable to downcast to LargeBinaryArray"))?, - &mut chunk.flat_vector(i), - ); - } - DataType::List(_) => { - list_array_to_vector(as_list_array(col.as_ref()), &mut chunk.list_vector(i))?; - } - DataType::LargeList(_) => { - list_array_to_vector(as_large_list_array(col.as_ref()), &mut chunk.list_vector(i))?; - } - DataType::FixedSizeList(_, _) => { - fixed_size_list_array_to_vector(as_fixed_size_list_array(col.as_ref()), &mut chunk.array_vector(i))?; - } - DataType::Struct(_) => { - let struct_array = as_struct_array(col.as_ref()); - let mut struct_vector = chunk.struct_vector(i); - struct_array_to_vector(struct_array, &mut struct_vector)?; - } - _ => { - return Err(format!( - "column {} is not supported yet, please file an issue https://github.com/wangfenjin/duckdb-rs", - batch.schema().field(i) - ) - .into()); - } - } + write_arrow_array_to_vector(col, &mut DataChunkHandleSlice::new(chunk, i))?; } chunk.set_len(batch.num_rows()); Ok(()) diff --git a/crates/duckdb/src/vtab/excel.rs b/crates/duckdb/src/vtab/excel.rs index b9bfad59..ae476e4a 100644 --- a/crates/duckdb/src/vtab/excel.rs +++ b/crates/duckdb/src/vtab/excel.rs @@ -1,4 +1,4 @@ -use super::{BindInfo, DataChunkHandle, Free, FunctionInfo, InitInfo, LogicalTypeHandle, LogicalTypeId, VTab}; +use super::{BindInfo, DataChunkHandle, Free, TableFunctionInfo, InitInfo, LogicalTypeHandle, LogicalTypeId, VTab}; use crate::core::Inserter; use calamine::{open_workbook_auto, DataType, Range, Reader}; @@ -132,7 +132,7 @@ impl VTab for ExcelVTab { Ok(()) } - unsafe fn func(func: &FunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box> { + unsafe fn func(func: &TableFunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box> { let init_info = func.get_init_data::(); let bind_info = func.get_bind_data::(); unsafe { diff --git a/crates/duckdb/src/vtab/function.rs b/crates/duckdb/src/vtab/function.rs index 9d14b510..4a5882f3 100644 --- a/crates/duckdb/src/vtab/function.rs +++ b/crates/duckdb/src/vtab/function.rs @@ -1,3 +1,14 @@ +use libduckdb_sys::{ + duckdb_add_scalar_function_to_set, duckdb_connection, duckdb_create_scalar_function, + duckdb_create_scalar_function_set, duckdb_destroy_scalar_function, duckdb_scalar_function, + duckdb_scalar_function_add_parameter, duckdb_scalar_function_get_extra_info, duckdb_scalar_function_set, + duckdb_scalar_function_set_error, duckdb_scalar_function_set_extra_info, duckdb_scalar_function_set_function, + duckdb_scalar_function_set_name, duckdb_scalar_function_set_return_type, duckdb_scalar_function_set_varargs, + duckdb_vector, DuckDBSuccess, +}; + +use crate::Error; + use super::{ ffi::{ duckdb_bind_add_result_column, duckdb_bind_get_extra_info, duckdb_bind_get_named_parameter, @@ -13,6 +24,7 @@ use super::{ }; use std::{ ffi::{c_void, CString}, + fmt::Debug, os::raw::c_char, }; @@ -334,9 +346,9 @@ use super::ffi::{ /// An interface to store and retrieve data during the function execution stage #[derive(Debug)] -pub struct FunctionInfo(duckdb_function_info); +pub struct TableFunctionInfo(duckdb_function_info); -impl FunctionInfo { +impl TableFunctionInfo { /// Report that an error has occurred while executing the function. /// /// # Arguments @@ -344,9 +356,10 @@ impl FunctionInfo { pub fn set_error(&self, error: &str) { let c_str = CString::new(error).unwrap(); unsafe { - duckdb_function_set_error(self.0, c_str.as_ptr() as *const c_char); + duckdb_function_set_error(self.0, c_str.as_ptr()); } } + /// Gets the bind data set by [`BindInfo::set_bind_data`] during the bind. /// /// Note that the bind data should be considered as read-only. @@ -380,7 +393,7 @@ impl FunctionInfo { } } -impl From for FunctionInfo { +impl From for TableFunctionInfo { fn from(ptr: duckdb_function_info) -> Self { Self(ptr) } diff --git a/crates/duckdb/src/vtab/mod.rs b/crates/duckdb/src/vtab/mod.rs index 9249fb1e..b06c85cf 100644 --- a/crates/duckdb/src/vtab/mod.rs +++ b/crates/duckdb/src/vtab/mod.rs @@ -17,7 +17,7 @@ pub use self::arrow::{ #[cfg(feature = "vtab-excel")] mod excel; -pub use function::{BindInfo, FunctionInfo, InitInfo, TableFunction}; +pub use function::{BindInfo, InitInfo, TableFunction, TableFunctionInfo}; pub use value::Value; use crate::core::{DataChunkHandle, LogicalTypeHandle, LogicalTypeId}; @@ -30,7 +30,7 @@ use std::mem::size_of; /// used for the bind_info and init_info /// # Safety /// This function is obviously unsafe -unsafe fn malloc_data_c() -> *mut T { +pub unsafe fn malloc_data_c() -> *mut T { duckdb_malloc(size_of::()).cast() } @@ -39,7 +39,7 @@ unsafe fn malloc_data_c() -> *mut T { /// # Safety /// This function is obviously unsafe /// TODO: maybe we should use a Free trait here -unsafe extern "C" fn drop_data_c(v: *mut c_void) { +pub unsafe extern "C" fn drop_data_c(v: *mut c_void) { let actual = v.cast::(); (*actual).free(); duckdb_free(v); @@ -100,7 +100,7 @@ pub trait VTab: Sized { /// - The `init_info` and `bind_info` data pointed to remains valid and is not freed until after this function completes. /// - No other threads are concurrently mutating the data pointed to by `init_info` and `bind_info` without proper synchronization. /// - The `output` parameter is correctly initialized and can safely be written to. - unsafe fn func(func: &FunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box>; + unsafe fn func(func: &TableFunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box>; /// Does the table function support pushdown /// default is false fn supports_pushdown() -> bool { @@ -122,7 +122,7 @@ unsafe extern "C" fn func(info: duckdb_function_info, output: duckdb_data_chu where T: VTab, { - let info = FunctionInfo::from(info); + let info = TableFunctionInfo::from(info); let mut data_chunk_handle = DataChunkHandle::new_unowned(output); let result = T::func(&info, &mut data_chunk_handle); if result.is_err() { @@ -193,7 +193,7 @@ impl InnerConnection { #[cfg(test)] mod test { use super::*; - use crate::core::Inserter; + use crate::{core::Inserter, types::DuckString}; use std::{ error::Error, ffi::{c_char, CString}, @@ -244,7 +244,10 @@ mod test { Ok(()) } - unsafe fn func(func: &FunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box> { + unsafe fn func( + func: &TableFunctionInfo, + output: &mut DataChunkHandle, + ) -> Result<(), Box> { let init_info = func.get_init_data::(); let bind_info = func.get_bind_data::(); @@ -289,7 +292,7 @@ mod test { HelloVTab::init(init_info, data) } - unsafe fn func(func: &FunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box> { + unsafe fn func(func: &TableFunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box> { HelloVTab::func(func, output) }