Skip to content

Commit

Permalink
Add scalar support with arrow types and overloading
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewgapp committed Dec 1, 2024
1 parent 2bd811e commit aebc114
Show file tree
Hide file tree
Showing 16 changed files with 1,330 additions and 86 deletions.
2 changes: 2 additions & 0 deletions crates/duckdb/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ default = []
bundled = ["libduckdb-sys/bundled"]
json = ["libduckdb-sys/json", "bundled"]
parquet = ["libduckdb-sys/parquet", "bundled"]
vscalar = []
vscalar-arrow = []
vtab = []
vtab-loadable = ["vtab", "duckdb-loadable-macros"]
vtab-excel = ["vtab", "calamine"]
Expand Down
4 changes: 2 additions & 2 deletions crates/duckdb/examples/hello-ext-capi/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ extern crate libduckdb_sys;

use duckdb::{
core::{DataChunkHandle, Inserter, LogicalTypeHandle, LogicalTypeId},
vtab::{BindInfo, Free, FunctionInfo, InitInfo, VTab},
vtab::{BindInfo, Free, TableFunctionInfo, InitInfo, VTab},
Connection, Result,
};
use duckdb_loadable_macros::duckdb_entrypoint_c_api;
Expand Down Expand Up @@ -59,7 +59,7 @@ impl VTab for HelloVTab {
Ok(())
}

unsafe fn func(func: &FunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box<dyn std::error::Error>> {
unsafe fn func(func: &TableFunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box<dyn std::error::Error>> {
let init_info = func.get_init_data::<HelloInitData>();
let bind_info = func.get_bind_data::<HelloBindData>();

Expand Down
4 changes: 2 additions & 2 deletions crates/duckdb/examples/hello-ext/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ extern crate libduckdb_sys;

use duckdb::{
core::{DataChunkHandle, Inserter, LogicalTypeHandle, LogicalTypeId},
vtab::{BindInfo, Free, FunctionInfo, InitInfo, VTab},
vtab::{BindInfo, Free, TableFunctionInfo, InitInfo, VTab},
Connection, Result,
};
use duckdb_loadable_macros::duckdb_entrypoint;
Expand Down Expand Up @@ -59,7 +59,7 @@ impl VTab for HelloVTab {
Ok(())
}

unsafe fn func(func: &FunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box<dyn std::error::Error>> {
unsafe fn func(func: &TableFunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box<dyn std::error::Error>> {
let init_info = func.get_init_data::<HelloInitData>();
let bind_info = func.get_bind_data::<HelloBindData>();

Expand Down
32 changes: 31 additions & 1 deletion crates/duckdb/src/core/vector.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::{any::Any, ffi::CString, slice};

use libduckdb_sys::{duckdb_array_type_array_size, duckdb_array_vector_get_child, DuckDbString};
use libduckdb_sys::{
duckdb_array_type_array_size, duckdb_array_vector_get_child, duckdb_validity_row_is_valid,
DuckDbString,
};

use super::LogicalTypeHandle;
use crate::ffi::{
Expand Down Expand Up @@ -55,6 +58,23 @@ impl FlatVector {
self.capacity
}

pub fn row_is_null(&self, row: u64) -> bool {
// use idx_t entry_idx = row_idx / 64; idx_t idx_in_entry = row_idx % 64; bool is_valid = validity_mask[entry_idx] & (1 « idx_in_entry);
// as the row is valid function is slower
let valid = unsafe {
let validity = duckdb_vector_get_validity(self.ptr);

// validity can return a NULL pointer if the entire vector is valid
if validity.is_null() {
return false;
}

duckdb_validity_row_is_valid(validity, row)
};

!valid
}

/// Returns an unsafe mutable pointer to the vector’s
pub fn as_mut_ptr<T>(&self) -> *mut T {
unsafe { duckdb_vector_get_data(self.ptr).cast() }
Expand All @@ -65,11 +85,21 @@ impl FlatVector {
unsafe { slice::from_raw_parts(self.as_mut_ptr(), self.capacity()) }
}

/// Returns a slice of the vector up to a certain length
pub fn as_slice_with_len<T>(&self, len: usize) -> &[T] {
unsafe { slice::from_raw_parts(self.as_mut_ptr(), len) }
}

/// Returns a mutable slice of the vector
pub fn as_mut_slice<T>(&mut self) -> &mut [T] {
unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), self.capacity()) }
}

/// Returns a mutable slice of the vector up to a certain length
pub fn as_mut_slice_with_len<T>(&mut self, len: usize) -> &mut [T] {
unsafe { slice::from_raw_parts_mut(self.as_mut_ptr(), len) }
}

/// Returns the logical type of the vector
pub fn logical_type(&self) -> LogicalTypeHandle {
unsafe { LogicalTypeHandle::new(duckdb_vector_get_column_type(self.ptr)) }
Expand Down
4 changes: 4 additions & 0 deletions crates/duckdb/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ pub mod types;
#[cfg(feature = "vtab")]
pub mod vtab;

/// The duckdb table function interface
#[cfg(feature = "vscalar")]
pub mod vscalar;

#[cfg(test)]
mod test_all_types;

Expand Down
18 changes: 17 additions & 1 deletion crates/duckdb/src/r2d2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@
//! .unwrap()
//! }
//! ```
use crate::{Config, Connection, Error, Result};
use crate::{vscalar::VScalar, vtab::VTab, Config, Connection, Error, Result};
use std::{
fmt::Debug,
path::Path,
sync::{Arc, Mutex},
};
Expand Down Expand Up @@ -78,6 +79,21 @@ impl DuckdbConnectionManager {
connection: Arc::new(Mutex::new(Connection::open_in_memory_with_flags(config)?)),
})
}

/// Register a table function.
pub fn register_table_function<T: VTab>(&self, name: &str) -> Result<()> {
let conn = self.connection.lock().unwrap();
conn.register_table_function::<T>(name)
}

/// Register a scalar function.
pub fn register_scalar<S: VScalar>(&self, name: &str) -> Result<()>
where
S::State: Debug,
{
let conn = self.connection.lock().unwrap();
conn.register_scalar_function::<S>(name)
}
}

impl r2d2::ManageConnection for DuckdbConnectionManager {
Expand Down
7 changes: 4 additions & 3 deletions crates/duckdb/src/raw_statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,11 @@ impl RawStatement {

#[inline]
pub fn step(&self) -> Option<StructArray> {
self.result?;
let out = self.result?;
unsafe {
let mut arrays = FFI_ArrowArray::empty();
if ffi::duckdb_query_arrow_array(
self.result_unwrap(),
out,
&mut std::ptr::addr_of_mut!(arrays) as *mut _ as *mut ffi::duckdb_arrow_array,
)
.ne(&ffi::DuckDBSuccess)
Expand All @@ -99,7 +99,7 @@ impl RawStatement {

let mut schema = FFI_ArrowSchema::empty();
if ffi::duckdb_query_arrow_schema(
self.result_unwrap(),
out,
&mut std::ptr::addr_of_mut!(schema) as *mut _ as *mut ffi::duckdb_arrow_schema,
) != ffi::DuckDBSuccess
{
Expand Down Expand Up @@ -260,6 +260,7 @@ impl RawStatement {
unsafe {
let mut out: ffi::duckdb_arrow = ptr::null_mut();
let rc = ffi::duckdb_execute_prepared_arrow(self.ptr, &mut out);
println!("error code: {}", rc);
result_from_duckdb_arrow(rc, out)?;

let rows_changed = ffi::duckdb_arrow_rows_changed(out);
Expand Down
2 changes: 2 additions & 0 deletions crates/duckdb/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
pub use self::{
from_sql::{FromSql, FromSqlError, FromSqlResult},
ordered_map::OrderedMap,
string::DuckString,
to_sql::{ToSql, ToSqlOutput},
value::Value,
value_ref::{EnumType, ListType, TimeUnit, ValueRef},
Expand All @@ -25,6 +26,7 @@ mod value;
mod value_ref;

mod ordered_map;
mod string;

/// Empty struct that can be used to fill in a query parameter as `NULL`.
///
Expand Down
28 changes: 28 additions & 0 deletions crates/duckdb/src/types/string.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
use libduckdb_sys::{duckdb_string_t, duckdb_string_t_data, duckdb_string_t_length};

/// Wrapper for underlying duck string type with a lifetime bound to a &mut duckdb_string_t
pub struct DuckString<'a> {
ptr: &'a mut duckdb_string_t,
}

impl<'a> DuckString<'a> {
pub(crate) fn new(ptr: &'a mut duckdb_string_t) -> Self {
DuckString { ptr }
}
}

impl<'a> DuckString<'a> {
/// convert duckdb_string_t to a copy on write string
pub fn as_str(&mut self) -> std::borrow::Cow<'a, str> {
String::from_utf8_lossy(self.as_bytes())
}

/// convert duckdb_string_t to a byte slice
pub fn as_bytes(&mut self) -> &'a [u8] {
unsafe {
let len = duckdb_string_t_length(*self.ptr);
let c_ptr = duckdb_string_t_data(self.ptr);
std::slice::from_raw_parts(c_ptr as *const u8, len as usize)
}
}
}
Loading

0 comments on commit aebc114

Please sign in to comment.