Skip to content

Commit

Permalink
VTab::bind and init are now safe
Browse files Browse the repository at this point in the history
Rather than passing a pointer to a block of uninitialized memory, which can easily lead to UB, these functions now just return Rust objects.

This improves duckdb#414 by reducing the amount of unsafe code needed from extensions.
  • Loading branch information
sourcefrog committed Dec 21, 2024
1 parent 80c7c0c commit 5947f6d
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 76 deletions.
15 changes: 4 additions & 11 deletions crates/duckdb/examples/hello-ext/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ use libduckdb_sys as ffi;
use std::{
error::Error,
ffi::{c_char, c_void, CString},
ptr,
};

struct HelloBindData {
Expand All @@ -35,20 +34,14 @@ impl VTab for HelloVTab {
type InitData = HelloInitData;
type BindData = HelloBindData;

unsafe fn bind(bind: &BindInfo, data: *mut HelloBindData) -> Result<(), Box<dyn std::error::Error>> {
fn bind(bind: &BindInfo) -> Result<Self::BindData, Box<dyn std::error::Error>> {
bind.add_result_column("column0", LogicalTypeHandle::from(LogicalTypeId::Varchar));
let name = bind.get_parameter(0).to_string();
unsafe {
ptr::write(data, HelloBindData { name });
}
Ok(())
Ok(HelloBindData { name })
}

unsafe fn init(_: &InitInfo, data: *mut HelloInitData) -> Result<(), Box<dyn std::error::Error>> {
unsafe {
ptr::write(data, HelloInitData { done: false });
}
Ok(())
fn init(_: &InitInfo) -> Result<Self::InitData, Box<dyn std::error::Error>> {
Ok(HelloInitData { done: false })
}

unsafe fn func(func: &FunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box<dyn std::error::Error>> {
Expand Down
107 changes: 42 additions & 65 deletions crates/duckdb/src/vtab/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use crate::{error::Error, inner_connection::InnerConnection, Connection, Result};
// #![warn(unsafe_op_in_unsafe_fn)]

use super::{ffi, ffi::duckdb_free};
use std::ffi::c_void;

use crate::{error::Error, inner_connection::InnerConnection, Connection, Result};

use super::ffi;

mod function;
mod value;

Expand All @@ -23,26 +26,12 @@ pub use value::Value;
use crate::core::{DataChunkHandle, LogicalTypeHandle};
use ffi::{duckdb_bind_info, duckdb_data_chunk, duckdb_function_info, duckdb_init_info};

use ffi::duckdb_malloc;
use std::mem::size_of;

/// duckdb_malloc a struct of type T
/// used for the bind_info and init_info
/// # Safety
/// This function is obviously unsafe
unsafe fn malloc_data_c<T>() -> *mut T {
duckdb_malloc(size_of::<T>()).cast()
}

/// free bind or info data
/// Given a raw pointer to a box, free the box and the data contained within it.
///
/// # Safety
/// This function is obviously unsafe
/// TODO: maybe we should use a Free trait here
unsafe extern "C" fn drop_data_c<T: Free>(v: *mut c_void) {
let actual = v.cast::<T>();
(*actual).free();
duckdb_free(v);
/// The pointer must be a valid pointer to a `Box<T>` created by `Box::into_raw`.
unsafe extern "C" fn drop_boxed<T>(v: *mut c_void) {
drop(unsafe { Box::from_raw(v.cast::<T>()) });
}

/// Free trait for the bind and init data
Expand All @@ -59,27 +48,15 @@ pub trait VTab: Sized {
/// The data type of the bind data
type InitData: Sized + Free;
/// The data type of the init data
type BindData: Sized + Free;
type BindData: Sized + Free; // TODO: and maybe Send + Sync as this might be called from multiple threads?

// TODO: Get rid of Free, just use Drop?

/// Bind data to the table function
///
/// # Safety
///
/// `data` points to an *uninitialized* block of memory large enough to hold a `Self::BindData` value.
/// The implementation should initialize this memory with the appropriate data for the table function,
/// without reading the existing memory,
/// typically using [`std::ptr::write`] or similar.
unsafe fn bind(bind: &BindInfo, data: *mut Self::BindData) -> Result<(), Box<dyn std::error::Error>>;
fn bind(bind: &BindInfo) -> Result<Self::BindData, Box<dyn std::error::Error>>;

/// Initialize the table function
///
/// # Safety
///
/// `data` points to an *uninitialized* block of memory large enough to hold a `Self::InitData` value.
/// The implementation should initialize this memory with the appropriate data for the table function,
/// without reading the existing memory,
/// typically using [`std::ptr::write`] or similar.
unsafe fn init(init: &InitInfo, data: *mut Self::InitData) -> Result<(), Box<dyn std::error::Error>>;
fn init(init: &InitInfo) -> Result<Self::InitData, Box<dyn std::error::Error>>;

/// The actual function
///
Expand Down Expand Up @@ -130,11 +107,16 @@ where
T: VTab,
{
let info = InitInfo::from(info);
let data = malloc_data_c::<T::InitData>();
let result = T::init(&info, data);
info.set_init_data(data.cast(), Some(drop_data_c::<T::InitData>));
if result.is_err() {
info.set_error(&result.err().unwrap().to_string());
match T::init(&info) {
Ok(init_data) => {
info.set_init_data(
Box::into_raw(Box::new(init_data)) as *mut c_void,
Some(drop_boxed::<T::InitData>),
);
}
Err(e) => {
info.set_error(&e.to_string());
}
}
}

Expand All @@ -143,11 +125,16 @@ where
T: VTab,
{
let info = BindInfo::from(info);
let data = malloc_data_c::<T::BindData>();
let result = T::bind(&info, data);
info.set_bind_data(data.cast(), Some(drop_data_c::<T::BindData>));
if result.is_err() {
info.set_error(&result.err().unwrap().to_string());
match T::bind(&info) {
Ok(bind_data) => {
info.set_bind_data(
Box::into_raw(Box::new(bind_data)) as *mut c_void,
Some(drop_boxed::<T::BindData>),
);
}
Err(e) => {
info.set_error(&e.to_string());
}
}
}

Expand Down Expand Up @@ -193,7 +180,6 @@ mod test {
use std::{
error::Error,
ffi::{c_char, CString},
ptr,
};

struct HelloBindData {
Expand All @@ -214,20 +200,14 @@ mod test {
type InitData = HelloInitData;
type BindData = HelloBindData;

unsafe fn bind(bind: &BindInfo, data: *mut HelloBindData) -> Result<(), Box<dyn std::error::Error>> {
fn bind(bind: &BindInfo) -> Result<Self::BindData, Box<dyn std::error::Error>> {
bind.add_result_column("column0", LogicalTypeHandle::from(LogicalTypeId::Varchar));
let name = bind.get_parameter(0).to_string();
unsafe {
ptr::write(data, HelloBindData { name });
}
Ok(())
Ok(HelloBindData { name })
}

unsafe fn init(_: &InitInfo, data: *mut HelloInitData) -> Result<(), Box<dyn std::error::Error>> {
unsafe {
ptr::write(data, HelloInitData { done: false });
}
Ok(())
fn init(_: &InitInfo) -> Result<Self::InitData, Box<dyn std::error::Error>> {
Ok(HelloInitData { done: false })
}

unsafe fn func(func: &FunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box<dyn std::error::Error>> {
Expand Down Expand Up @@ -256,18 +236,15 @@ mod test {
type InitData = HelloInitData;
type BindData = HelloBindData;

unsafe fn bind(bind: &BindInfo, data: *mut HelloBindData) -> Result<(), Box<dyn Error>> {
fn bind(bind: &BindInfo) -> Result<Self::BindData, Box<dyn Error>> {
bind.add_result_column("column0", LogicalTypeHandle::from(LogicalTypeId::Varchar));
let name = bind.get_named_parameter("name").unwrap().to_string();
assert!(bind.get_named_parameter("unknown_name").is_none());
unsafe {
ptr::write(data, HelloBindData { name });
}
Ok(())
Ok(HelloBindData { name })
}

unsafe fn init(init_info: &InitInfo, data: *mut HelloInitData) -> Result<(), Box<dyn Error>> {
HelloVTab::init(init_info, data)
fn init(init_info: &InitInfo) -> Result<Self::InitData, Box<dyn Error>> {
HelloVTab::init(init_info)
}

unsafe fn func(func: &FunctionInfo, output: &mut DataChunkHandle) -> Result<(), Box<dyn Error>> {
Expand Down

0 comments on commit 5947f6d

Please sign in to comment.