Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TableProvider::insert_into into FFI Bindings #14391

Merged
merged 3 commits into from
Feb 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions datafusion/ffi/src/execution_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use datafusion::{
execution::{SendableRecordBatchStream, TaskContext},
physical_plan::{DisplayAs, ExecutionPlan, PlanProperties},
};
use tokio::runtime::Runtime;
use tokio::runtime::Handle;

use crate::{
plan_properties::FFI_PlanProperties, record_batch_stream::FFI_RecordBatchStream,
Expand Down Expand Up @@ -72,7 +72,7 @@ unsafe impl Sync for FFI_ExecutionPlan {}
pub struct ExecutionPlanPrivateData {
pub plan: Arc<dyn ExecutionPlan>,
pub context: Arc<TaskContext>,
pub runtime: Option<Arc<Runtime>>,
pub runtime: Option<Handle>,
}

unsafe extern "C" fn properties_fn_wrapper(
Expand Down Expand Up @@ -110,7 +110,7 @@ unsafe extern "C" fn execute_fn_wrapper(
let private_data = plan.private_data as *const ExecutionPlanPrivateData;
let plan = &(*private_data).plan;
let ctx = &(*private_data).context;
let runtime = (*private_data).runtime.as_ref().map(Arc::clone);
let runtime = (*private_data).runtime.clone();

match plan.execute(partition, Arc::clone(ctx)) {
Ok(rbs) => RResult::ROk(FFI_RecordBatchStream::new(rbs, runtime)),
Expand Down Expand Up @@ -153,7 +153,7 @@ impl FFI_ExecutionPlan {
pub fn new(
plan: Arc<dyn ExecutionPlan>,
context: Arc<TaskContext>,
runtime: Option<Arc<Runtime>>,
runtime: Option<Handle>,
) -> Self {
let private_data = Box::new(ExecutionPlanPrivateData {
plan,
Expand Down
49 changes: 49 additions & 0 deletions datafusion/ffi/src/insert_op.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use abi_stable::StableAbi;
use datafusion::logical_expr::logical_plan::dml::InsertOp;

/// FFI safe version of [`InsertOp`].
#[repr(C)]
#[derive(StableAbi)]
#[allow(non_camel_case_types)]
pub enum FFI_InsertOp {
Append,
Overwrite,
Replace,
}

impl From<FFI_InsertOp> for InsertOp {
fn from(value: FFI_InsertOp) -> Self {
match value {
FFI_InsertOp::Append => InsertOp::Append,
FFI_InsertOp::Overwrite => InsertOp::Overwrite,
FFI_InsertOp::Replace => InsertOp::Replace,
}
}
}

impl From<InsertOp> for FFI_InsertOp {
fn from(value: InsertOp) -> Self {
match value {
InsertOp::Append => FFI_InsertOp::Append,
InsertOp::Overwrite => FFI_InsertOp::Overwrite,
InsertOp::Replace => FFI_InsertOp::Replace,
}
}
}
1 change: 1 addition & 0 deletions datafusion/ffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

pub mod arrow_wrappers;
pub mod execution_plan;
pub mod insert_op;
pub mod plan_properties;
pub mod record_batch_stream;
pub mod session_config;
Expand Down
8 changes: 4 additions & 4 deletions datafusion/ffi/src/record_batch_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

use std::{ffi::c_void, sync::Arc, task::Poll};
use std::{ffi::c_void, task::Poll};

use abi_stable::{
std_types::{ROption, RResult, RString},
Expand All @@ -33,7 +33,7 @@ use datafusion::{
execution::{RecordBatchStream, SendableRecordBatchStream},
};
use futures::{Stream, TryStreamExt};
use tokio::runtime::Runtime;
use tokio::runtime::Handle;

use crate::arrow_wrappers::{WrappedArray, WrappedSchema};

Expand Down Expand Up @@ -61,7 +61,7 @@ pub struct FFI_RecordBatchStream {

pub struct RecordBatchStreamPrivateData {
pub rbs: SendableRecordBatchStream,
pub runtime: Option<Arc<Runtime>>,
pub runtime: Option<Handle>,
}

impl From<SendableRecordBatchStream> for FFI_RecordBatchStream {
Expand All @@ -71,7 +71,7 @@ impl From<SendableRecordBatchStream> for FFI_RecordBatchStream {
}

impl FFI_RecordBatchStream {
pub fn new(stream: SendableRecordBatchStream, runtime: Option<Arc<Runtime>>) -> Self {
pub fn new(stream: SendableRecordBatchStream, runtime: Option<Handle>) -> Self {
let private_data = Box::into_raw(Box::new(RecordBatchStreamPrivateData {
rbs: stream,
runtime,
Expand Down
151 changes: 145 additions & 6 deletions datafusion/ffi/src/table_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ use datafusion::{
catalog::{Session, TableProvider},
datasource::TableType,
error::DataFusionError,
execution::session_state::SessionStateBuilder,
logical_expr::TableProviderFilterPushDown,
execution::{session_state::SessionStateBuilder, TaskContext},
logical_expr::{logical_plan::dml::InsertOp, TableProviderFilterPushDown},
physical_plan::ExecutionPlan,
prelude::{Expr, SessionContext},
};
Expand All @@ -40,7 +40,7 @@ use datafusion_proto::{
protobuf::LogicalExprList,
};
use prost::Message;
use tokio::runtime::Runtime;
use tokio::runtime::Handle;

use crate::{
arrow_wrappers::WrappedSchema,
Expand All @@ -50,6 +50,7 @@ use crate::{

use super::{
execution_plan::{FFI_ExecutionPlan, ForeignExecutionPlan},
insert_op::FFI_InsertOp,
session_config::FFI_SessionConfig,
};
use datafusion::error::Result;
Expand Down Expand Up @@ -133,6 +134,14 @@ pub struct FFI_TableProvider {
-> RResult<RVec<FFI_TableProviderFilterPushDown>, RString>,
>,

pub insert_into:
unsafe extern "C" fn(
provider: &Self,
session_config: &FFI_SessionConfig,
input: &FFI_ExecutionPlan,
insert_op: FFI_InsertOp,
) -> FfiFuture<RResult<FFI_ExecutionPlan, RString>>,

/// Used to create a clone on the provider of the execution plan. This should
/// only need to be called by the receiver of the plan.
pub clone: unsafe extern "C" fn(plan: &Self) -> Self,
Expand All @@ -153,7 +162,7 @@ unsafe impl Sync for FFI_TableProvider {}

struct ProviderPrivateData {
provider: Arc<dyn TableProvider + Send>,
runtime: Option<Arc<Runtime>>,
runtime: Option<Handle>,
}

unsafe extern "C" fn schema_fn_wrapper(provider: &FFI_TableProvider) -> WrappedSchema {
Expand Down Expand Up @@ -276,6 +285,53 @@ unsafe extern "C" fn scan_fn_wrapper(
.into_ffi()
}

unsafe extern "C" fn insert_into_fn_wrapper(
provider: &FFI_TableProvider,
session_config: &FFI_SessionConfig,
input: &FFI_ExecutionPlan,
insert_op: FFI_InsertOp,
) -> FfiFuture<RResult<FFI_ExecutionPlan, RString>> {
let private_data = provider.private_data as *mut ProviderPrivateData;
let internal_provider = &(*private_data).provider;
let session_config = session_config.clone();
let input = input.clone();
let runtime = &(*private_data).runtime;

async move {
let config = match ForeignSessionConfig::try_from(&session_config) {
Ok(c) => c,
Err(e) => return RResult::RErr(e.to_string().into()),
};
let session = SessionStateBuilder::new()
.with_default_features()
.with_config(config.0)
.build();
let ctx = SessionContext::new_with_state(session);

let input = match ForeignExecutionPlan::try_from(&input) {
Ok(input) => Arc::new(input),
Err(e) => return RResult::RErr(e.to_string().into()),
};

let insert_op = InsertOp::from(insert_op);

let plan = match internal_provider
.insert_into(&ctx.state(), input, insert_op)
.await
{
Ok(p) => p,
Err(e) => return RResult::RErr(e.to_string().into()),
};

RResult::ROk(FFI_ExecutionPlan::new(
plan,
ctx.task_ctx(),
runtime.clone(),
))
}
.into_ffi()
}

unsafe extern "C" fn release_fn_wrapper(provider: &mut FFI_TableProvider) {
let private_data = Box::from_raw(provider.private_data as *mut ProviderPrivateData);
drop(private_data);
Expand All @@ -295,6 +351,7 @@ unsafe extern "C" fn clone_fn_wrapper(provider: &FFI_TableProvider) -> FFI_Table
scan: scan_fn_wrapper,
table_type: table_type_fn_wrapper,
supports_filters_pushdown: provider.supports_filters_pushdown,
insert_into: provider.insert_into,
clone: clone_fn_wrapper,
release: release_fn_wrapper,
version: super::version,
Expand All @@ -313,7 +370,7 @@ impl FFI_TableProvider {
pub fn new(
provider: Arc<dyn TableProvider + Send>,
can_support_pushdown_filters: bool,
runtime: Option<Arc<Runtime>>,
runtime: Option<Handle>,
) -> Self {
let private_data = Box::new(ProviderPrivateData { provider, runtime });

Expand All @@ -325,6 +382,7 @@ impl FFI_TableProvider {
true => Some(supports_filters_pushdown_fn_wrapper),
false => None,
},
insert_into: insert_into_fn_wrapper,
clone: clone_fn_wrapper,
release: release_fn_wrapper,
version: super::version,
Expand Down Expand Up @@ -443,6 +501,37 @@ impl TableProvider for ForeignTableProvider {
}
}
}

async fn insert_into(
&self,
session: &dyn Session,
input: Arc<dyn ExecutionPlan>,
insert_op: InsertOp,
) -> Result<Arc<dyn ExecutionPlan>> {
let session_config: FFI_SessionConfig = session.config().into();

let rc = Handle::try_current().ok();
let input =
FFI_ExecutionPlan::new(input, Arc::new(TaskContext::from(session)), rc);
let insert_op: FFI_InsertOp = insert_op.into();

let plan = unsafe {
let maybe_plan =
(self.0.insert_into)(&self.0, &session_config, &input, insert_op).await;

match maybe_plan {
RResult::ROk(p) => ForeignExecutionPlan::try_from(&p)?,
RResult::RErr(e) => {
return Err(DataFusionError::Internal(format!(
"Unable to perform insert_into via FFI: {}",
e
)))
}
}
};

Ok(Arc::new(plan))
}
}

#[cfg(test)]
Expand All @@ -453,7 +542,7 @@ mod tests {
use super::*;

#[tokio::test]
async fn test_round_trip_ffi_table_provider() -> Result<()> {
async fn test_round_trip_ffi_table_provider_scan() -> Result<()> {
use arrow::datatypes::Field;
use datafusion::arrow::{
array::Float32Array, datatypes::DataType, record_batch::RecordBatch,
Expand Down Expand Up @@ -493,4 +582,54 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn test_round_trip_ffi_table_provider_insert_into() -> Result<()> {
use arrow::datatypes::Field;
use datafusion::arrow::{
array::Float32Array, datatypes::DataType, record_batch::RecordBatch,
};
use datafusion::datasource::MemTable;

let schema =
Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)]));

// define data in two partitions
let batch1 = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0]))],
)?;
let batch2 = RecordBatch::try_new(
Arc::clone(&schema),
vec![Arc::new(Float32Array::from(vec![64.0]))],
)?;

let ctx = SessionContext::new();

let provider =
Arc::new(MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?);

let ffi_provider = FFI_TableProvider::new(provider, true, None);

let foreign_table_provider: ForeignTableProvider = (&ffi_provider).into();

ctx.register_table("t", Arc::new(foreign_table_provider))?;

let result = ctx
.sql("INSERT INTO t VALUES (128.0);")
.await?
.collect()
.await?;

assert!(result.len() == 1 && result[0].num_rows() == 1);

ctx.table("t")
.await?
.select(vec![col("a")])?
.filter(col("a").gt(lit(3.0)))?
.show()
.await?;

Ok(())
}
}
8 changes: 4 additions & 4 deletions datafusion/ffi/src/tests/async_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use datafusion::{
};
use futures::Stream;
use tokio::{
runtime::Runtime,
runtime::Handle,
sync::{broadcast, mpsc},
};

Expand All @@ -59,7 +59,7 @@ fn async_table_provider_thread(
mut shutdown: mpsc::Receiver<bool>,
mut batch_request: mpsc::Receiver<bool>,
batch_sender: broadcast::Sender<Option<RecordBatch>>,
tokio_rt: mpsc::Sender<Arc<Runtime>>,
tokio_rt: mpsc::Sender<Handle>,
) {
let runtime = Arc::new(
tokio::runtime::Builder::new_current_thread()
Expand All @@ -68,7 +68,7 @@ fn async_table_provider_thread(
);
let _runtime_guard = runtime.enter();
tokio_rt
.blocking_send(Arc::clone(&runtime))
.blocking_send(runtime.handle().clone())
.expect("Unable to send tokio runtime back to main thread");

runtime.block_on(async move {
Expand All @@ -91,7 +91,7 @@ fn async_table_provider_thread(
let _ = shutdown.blocking_recv();
}

pub fn start_async_provider() -> (AsyncTableProvider, Arc<Runtime>) {
pub fn start_async_provider() -> (AsyncTableProvider, Handle) {
let (batch_request_tx, batch_request_rx) = mpsc::channel(10);
let (record_batch_tx, record_batch_rx) = broadcast::channel(10);
let (tokio_rt_tx, mut tokio_rt_rx) = mpsc::channel(10);
Expand Down