Skip to content

Commit

Permalink
streaming arrow data support (#373)
Browse files Browse the repository at this point in the history
* streaming arrow (#3)

* streaming

* remove invalid schema assign

* use ArrowStream type to represent streaming

* clippy

* doc

* import

* typo

* export arrow stream

* fix: Missing semicolon in docs test

---------

Co-authored-by: peasee <98815791+peasee@users.noreply.github.com>
  • Loading branch information
y-f-u and peasee authored Sep 26, 2024
1 parent 36b83bc commit e12fdb6
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 5 deletions.
33 changes: 32 additions & 1 deletion crates/duckdb/src/arrow_batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use super::{
Statement,
};

/// An handle for the resulting RecordBatch of a query.
/// A handle for the resulting RecordBatch of a query.
#[must_use = "Arrow is lazy and will do nothing unless consumed"]
pub struct Arrow<'stmt> {
pub(crate) stmt: Option<&'stmt Statement<'stmt>>,
Expand All @@ -29,3 +29,34 @@ impl<'stmt> Iterator for Arrow<'stmt> {
Some(RecordBatch::from(&self.stmt?.step()?))
}
}

/// A handle for the resulting RecordBatch of a query in streaming
#[must_use = "Arrow stream is lazy and will not fetch data unless consumed"]
pub struct ArrowStream<'stmt> {
pub(crate) stmt: Option<&'stmt Statement<'stmt>>,
pub(crate) schema: SchemaRef,
}

impl<'stmt> ArrowStream<'stmt> {
#[inline]
pub(crate) fn new(stmt: &'stmt Statement<'stmt>, schema: SchemaRef) -> ArrowStream<'stmt> {
ArrowStream {
stmt: Some(stmt),
schema,
}
}

/// return arrow schema
#[inline]
pub fn get_schema(&self) -> SchemaRef {
self.schema.clone()
}
}

impl<'stmt> Iterator for ArrowStream<'stmt> {
type Item = RecordBatch;

fn next(&mut self) -> Option<Self::Item> {
Some(RecordBatch::from(&self.stmt?.stream_step(self.get_schema())?))
}
}
2 changes: 1 addition & 1 deletion crates/duckdb/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ pub use crate::r2d2::DuckdbConnectionManager;
pub use crate::{
appender::Appender,
appender_params::{appender_params_from_iter, AppenderParams, AppenderParamsFromIter},
arrow_batch::Arrow,
arrow_batch::{Arrow, ArrowStream},
cache::CachedStatement,
column::Column,
config::{AccessMode, Config, DefaultNullOrder, DefaultOrder},
Expand Down
61 changes: 59 additions & 2 deletions crates/duckdb/src/raw_statement.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{ffi::CStr, ptr, rc::Rc, sync::Arc};
use std::{ffi::CStr, ops::Deref, ptr, rc::Rc, sync::Arc};

use arrow::{
array::StructArray,
Expand All @@ -9,14 +9,15 @@ use arrow::{
use super::{ffi, Result};
#[cfg(feature = "polars")]
use crate::arrow2;
use crate::error::result_from_duckdb_arrow;
use crate::{error::result_from_duckdb_arrow, Error};

// Private newtype for raw sqlite3_stmts that finalize themselves when dropped.
// TODO: destroy statement and result
#[derive(Debug)]
pub struct RawStatement {
ptr: ffi::duckdb_prepared_statement,
result: Option<ffi::duckdb_arrow>,
duckdb_result: Option<ffi::duckdb_result>,
schema: Option<SchemaRef>,
// Cached SQL (trimmed) that we use as the key when we're in the statement
// cache. This is None for statements which didn't come from the statement
Expand All @@ -38,6 +39,7 @@ impl RawStatement {
ptr: stmt,
result: None,
schema: None,
duckdb_result: None,
statement_cache_key: None,
}
}
Expand Down Expand Up @@ -110,6 +112,39 @@ impl RawStatement {
}
}

#[inline]
pub fn streaming_step(&self, schema: SchemaRef) -> Option<StructArray> {
if let Some(result) = self.duckdb_result {
unsafe {
let mut out = ffi::duckdb_stream_fetch_chunk(result);

if out.is_null() {
return None;
}

let mut arrays = FFI_ArrowArray::empty();
ffi::duckdb_result_arrow_array(
result,
out,
&mut std::ptr::addr_of_mut!(arrays) as *mut _ as *mut ffi::duckdb_arrow_array,
);

ffi::duckdb_destroy_data_chunk(&mut out);

if arrays.is_empty() {
return None;
}

let schema = FFI_ArrowSchema::try_from(schema.deref()).ok()?;
let array_data = from_ffi(arrays, &schema).expect("ok");
let struct_array = StructArray::from(array_data);
return Some(struct_array);
}
}

None
}

#[cfg(feature = "polars")]
#[inline]
pub fn step2(&self) -> Option<arrow2::array::StructArray> {
Expand Down Expand Up @@ -242,6 +277,22 @@ impl RawStatement {
}
}

pub fn execute_streaming(&mut self) -> Result<()> {
self.reset_result();
unsafe {
let mut out: ffi::duckdb_result = std::mem::zeroed();

let rc = ffi::duckdb_execute_prepared_streaming(self.ptr, &mut out);
if rc != ffi::DuckDBSuccess {
return Err(Error::DuckDBFailure(ffi::Error::new(rc), None));
}

self.duckdb_result = Some(out);

Ok(())
}
}

#[inline]
pub fn reset_result(&mut self) {
self.schema = None;
Expand All @@ -251,6 +302,12 @@ impl RawStatement {
}
self.result = None;
}
if let Some(mut result) = self.duckdb_result {
unsafe {
ffi::duckdb_destroy_result(&mut result);
}
self.duckdb_result = None;
}
}

#[inline]
Expand Down
32 changes: 31 additions & 1 deletion crates/duckdb/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use super::{ffi, AndThenRows, Connection, Error, MappedRows, Params, RawStatemen
#[cfg(feature = "polars")]
use crate::{arrow2, polars_dataframe::Polars};
use crate::{
arrow_batch::Arrow,
arrow_batch::{Arrow, ArrowStream},
error::result_from_duckdb_prepare,
types::{TimeUnit, ToSql, ToSqlOutput},
};
Expand Down Expand Up @@ -109,6 +109,30 @@ impl Statement<'_> {
Ok(Arrow::new(self))
}

/// Execute the prepared statement, returning a handle to the resulting
/// vector of arrow RecordBatch in streaming way
///
/// ## Example
///
/// ```rust,no_run
/// # use duckdb::{Result, Connection};
/// # use arrow::record_batch::RecordBatch;
/// # use arrow::datatypes::SchemaRef;
/// fn get_arrow_data(conn: &Connection, schema: SchemaRef) -> Result<Vec<RecordBatch>> {
/// Ok(conn.prepare("SELECT * FROM test")?.stream_arrow([], schema)?.collect())
/// }
/// ```
///
/// # Failure
///
/// Will return `Err` if binding parameters fails.
#[inline]
pub fn stream_arrow<P: Params>(&mut self, params: P, schema: SchemaRef) -> Result<ArrowStream<'_>> {
params.__bind_in(self)?;
self.stmt.execute_streaming()?;
Ok(ArrowStream::new(self, schema))
}

/// Execute the prepared statement, returning a handle to the resulting
/// vector of polars DataFrame.
///
Expand Down Expand Up @@ -337,6 +361,12 @@ impl Statement<'_> {
self.stmt.step()
}

/// Get next batch records in arrow-rs in a streaming way
#[inline]
pub fn stream_step(&self, schema: SchemaRef) -> Option<StructArray> {
self.stmt.streaming_step(schema)
}

#[cfg(feature = "polars")]
/// Get next batch records in arrow2
#[inline]
Expand Down

0 comments on commit e12fdb6

Please sign in to comment.