Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Safety fixes and docs for vtab #415

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions crates/duckdb-loadable-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,20 @@ pub fn duckdb_entrypoint(_attr: TokenStream, item: TokenStream) -> TokenStream {
/// Will be called by duckdb
#[no_mangle]
pub unsafe extern "C" fn #c_entrypoint(db: *mut c_void) {
let connection = Connection::open_from_raw(db.cast()).expect("can't open db connection");
#prefixed_original_function(connection).expect("init failed");
unsafe {
let connection = Connection::open_from_raw(db.cast()).expect("can't open db connection");
#prefixed_original_function(connection).expect("init failed");
}
}

/// # Safety
///
/// Predefined function, don't need to change unless you are sure
#[no_mangle]
pub unsafe extern "C" fn #c_entrypoint_version() -> *const c_char {
ffi::duckdb_library_version()
unsafe {
ffi::duckdb_library_version()
}
}


Expand Down
49 changes: 18 additions & 31 deletions crates/duckdb/examples/hello-ext/main.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#![warn(unsafe_op_in_unsafe_fn)]

extern crate duckdb;
extern crate duckdb_loadable_macros;
extern crate libduckdb_sys;
Expand All @@ -12,25 +14,15 @@ use libduckdb_sys as ffi;
use std::{
error::Error,
ffi::{c_char, c_void, CString},
ptr,
};

#[repr(C)]
struct HelloBindData {
name: *mut c_char,
name: String,
}

impl Free for HelloBindData {
fn free(&mut self) {
unsafe {
if self.name.is_null() {
return;
}
drop(CString::from_raw(self.name));
}
}
}
impl Free for HelloBindData {}

#[repr(C)]
struct HelloInitData {
done: bool,
}
Expand All @@ -45,37 +37,32 @@ impl VTab for HelloVTab {

unsafe fn bind(bind: &BindInfo, data: *mut HelloBindData) -> Result<(), Box<dyn std::error::Error>> {
bind.add_result_column("column0", LogicalTypeHandle::from(LogicalTypeId::Varchar));
let param = bind.get_parameter(0).to_string();
let name = bind.get_parameter(0).to_string();
unsafe {
(*data).name = CString::new(param).unwrap().into_raw();
ptr::write(data, HelloBindData { name });
}
Ok(())
}

unsafe fn init(_: &InitInfo, data: *mut HelloInitData) -> Result<(), Box<dyn std::error::Error>> {
unsafe {
(*data).done = false;
ptr::write(data, HelloInitData { done: false });
}
Ok(())
}

unsafe fn func(func: &FunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box<dyn std::error::Error>> {
let init_info = func.get_init_data::<HelloInitData>();
let bind_info = func.get_bind_data::<HelloBindData>();
let init_info = unsafe { func.get_init_data::<HelloInitData>().as_mut().unwrap() };
let bind_info = unsafe { func.get_bind_data::<HelloBindData>().as_mut().unwrap() };

unsafe {
if (*init_info).done {
output.set_len(0);
} else {
(*init_info).done = true;
let vector = output.flat_vector(0);
let name = CString::from_raw((*bind_info).name);
let result = CString::new(format!("Hello {}", name.to_str()?))?;
// Can't consume the CString
(*bind_info).name = CString::into_raw(name);
vector.insert(0, result);
output.set_len(1);
}
if init_info.done {
output.set_len(0);
} else {
init_info.done = true;
let vector = output.flat_vector(0);
let result = CString::new(format!("Hello {}", bind_info.name))?;
vector.insert(0, result);
output.set_len(1);
}
Ok(())
}
Expand Down
81 changes: 31 additions & 50 deletions crates/duckdb/src/vtab/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ mod excel;
pub use function::{BindInfo, FunctionInfo, InitInfo, TableFunction};
pub use value::Value;

use crate::core::{DataChunkHandle, LogicalTypeHandle, LogicalTypeId};
use crate::core::{DataChunkHandle, LogicalTypeHandle};
use ffi::{duckdb_bind_info, duckdb_data_chunk, duckdb_function_info, duckdb_init_info};

use ffi::duckdb_malloc;
Expand Down Expand Up @@ -65,27 +65,22 @@ pub trait VTab: Sized {
///
/// # Safety
///
/// This function is unsafe because it dereferences raw pointers (`data`) and manipulates the memory directly.
/// The caller must ensure that:
///
/// - The `data` pointer is valid and points to a properly initialized `BindData` instance.
/// - The lifetime of `data` must outlive the execution of `bind` to avoid dangling pointers, especially since
/// `bind` does not take ownership of `data`.
/// - Concurrent access to `data` (if applicable) must be properly synchronized.
/// - The `bind` object must be valid and correctly initialized.
/// `data` points to an *uninitialized* block of memory large enough to hold a `Self::BindData` value.
/// The implementation should initialize this memory with the appropriate data for the table function,
/// without reading the existing memory,
/// typically using [`std::ptr::write`] or similar.
unsafe fn bind(bind: &BindInfo, data: *mut Self::BindData) -> Result<(), Box<dyn std::error::Error>>;

/// Initialize the table function
///
/// # Safety
///
/// This function is unsafe because it performs raw pointer dereferencing on the `data` argument.
/// The caller is responsible for ensuring that:
///
/// - The `data` pointer is non-null and points to a valid `InitData` instance.
/// - There is no data race when accessing `data`, meaning if `data` is accessed from multiple threads,
/// proper synchronization is required.
/// - The lifetime of `data` extends beyond the scope of this call to avoid use-after-free errors.
/// `data` points to an *uninitialized* block of memory large enough to hold a `Self::InitData` value.
/// The implementation should initialize this memory with the appropriate data for the table function,
/// without reading the existing memory,
/// typically using [`std::ptr::write`] or similar.
unsafe fn init(init: &InitInfo, data: *mut Self::InitData) -> Result<(), Box<dyn std::error::Error>>;

/// The actual function
///
/// # Safety
Expand Down Expand Up @@ -194,28 +189,19 @@ impl InnerConnection {
mod test {
use super::*;
use crate::core::Inserter;
use crate::core::LogicalTypeId;
use std::{
error::Error,
ffi::{c_char, CString},
ptr,
};

#[repr(C)]
struct HelloBindData {
name: *mut c_char,
name: String,
}

impl Free for HelloBindData {
fn free(&mut self) {
unsafe {
if self.name.is_null() {
return;
}
drop(CString::from_raw(self.name));
}
}
}
impl Free for HelloBindData {}

#[repr(C)]
struct HelloInitData {
done: bool,
}
Expand All @@ -230,37 +216,32 @@ mod test {

unsafe fn bind(bind: &BindInfo, data: *mut HelloBindData) -> Result<(), Box<dyn std::error::Error>> {
bind.add_result_column("column0", LogicalTypeHandle::from(LogicalTypeId::Varchar));
let param = bind.get_parameter(0).to_string();
let name = bind.get_parameter(0).to_string();
unsafe {
(*data).name = CString::new(param).unwrap().into_raw();
ptr::write(data, HelloBindData { name });
}
Ok(())
}

unsafe fn init(_: &InitInfo, data: *mut HelloInitData) -> Result<(), Box<dyn std::error::Error>> {
unsafe {
(*data).done = false;
ptr::write(data, HelloInitData { done: false });
}
Ok(())
}

unsafe fn func(func: &FunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box<dyn std::error::Error>> {
let init_info = func.get_init_data::<HelloInitData>();
let bind_info = func.get_bind_data::<HelloBindData>();

unsafe {
if (*init_info).done {
output.set_len(0);
} else {
(*init_info).done = true;
let vector = output.flat_vector(0);
let name = CString::from_raw((*bind_info).name);
let result = CString::new(format!("Hello {}", name.to_str()?))?;
// Can't consume the CString
(*bind_info).name = CString::into_raw(name);
vector.insert(0, result);
output.set_len(1);
}
let init_info = unsafe { func.get_init_data::<HelloInitData>().as_mut().unwrap() };
let bind_info = unsafe { func.get_bind_data::<HelloBindData>().as_ref().unwrap() };

if init_info.done {
output.set_len(0);
} else {
init_info.done = true;
let vector = output.flat_vector(0);
let result = CString::new(format!("Hello {}", bind_info.name))?;
vector.insert(0, result);
output.set_len(1);
}
Ok(())
}
Expand All @@ -277,10 +258,10 @@ mod test {

unsafe fn bind(bind: &BindInfo, data: *mut HelloBindData) -> Result<(), Box<dyn Error>> {
bind.add_result_column("column0", LogicalTypeHandle::from(LogicalTypeId::Varchar));
let param = bind.get_named_parameter("name").unwrap().to_string();
let name = bind.get_named_parameter("name").unwrap().to_string();
assert!(bind.get_named_parameter("unknown_name").is_none());
unsafe {
(*data).name = CString::new(param).unwrap().into_raw();
ptr::write(data, HelloBindData { name });
}
Ok(())
}
Expand Down
Loading