diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index b1eb2a19e31d..fd0765236ea2 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -25,7 +25,9 @@ use crate::arrow::util::pretty; use crate::datasource::file_format::csv::CsvFormatFactory; use crate::datasource::file_format::format_as_file_type; use crate::datasource::file_format::json::JsonFormatFactory; -use crate::datasource::{provider_as_source, MemTable, TableProvider}; +use crate::datasource::{ + provider_as_source, DefaultTableSource, MemTable, TableProvider, +}; use crate::error::Result; use crate::execution::context::{SessionState, TaskContext}; use crate::execution::FunctionRegistry; @@ -62,6 +64,7 @@ use datafusion_functions_aggregate::expr_fn::{ use async_trait::async_trait; use datafusion_catalog::Session; +use datafusion_sql::TableReference; /// Contains options that control how data is /// written out from a DataFrame @@ -1526,8 +1529,6 @@ impl DataFrame { table_name: &str, write_options: DataFrameWriteOptions, ) -> Result, DataFusionError> { - let arrow_schema = Schema::from(self.schema()); - let plan = if write_options.sort_by.is_empty() { self.plan } else { @@ -1536,10 +1537,19 @@ impl DataFrame { .build()? }; + let table_ref: TableReference = table_name.into(); + let table_schema = self.session_state.schema_for_ref(table_ref.clone())?; + let target = match table_schema.table(table_ref.table()).await? { + Some(ref provider) => Ok(Arc::clone(provider)), + _ => plan_err!("No table named '{table_name}'"), + }?; + + let target = Arc::new(DefaultTableSource::new(target)); + let plan = LogicalPlanBuilder::insert_into( plan, table_name.to_owned(), - &arrow_schema, + target, write_options.insert_op, )? .build()?; diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 642ec93f3671..2f517e397ebe 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -1195,7 +1195,7 @@ mod tests { use crate::datasource::file_format::json::JsonFormat; #[cfg(feature = "parquet")] use crate::datasource::file_format::parquet::ParquetFormat; - use crate::datasource::{provider_as_source, MemTable}; + use crate::datasource::{provider_as_source, DefaultTableSource, MemTable}; use crate::execution::options::ArrowReadOptions; use crate::prelude::*; use crate::{ @@ -2065,6 +2065,8 @@ mod tests { session_ctx.register_table("source", source_table.clone())?; // Convert the source table into a provider so that it can be used in a query let source = provider_as_source(source_table); + let target = session_ctx.table_provider("t").await?; + let target = Arc::new(DefaultTableSource::new(target)); // Create a table scan logical plan to read from the source table let scan_plan = LogicalPlanBuilder::scan("source", source, None)? .filter(filter_predicate)? @@ -2073,7 +2075,7 @@ mod tests { // Therefore, we will have 8 partitions in the final plan. // Create an insert plan to insert the source data into the initial table let insert_into_table = - LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, InsertOp::Append)? + LogicalPlanBuilder::insert_into(scan_plan, "t", target, InsertOp::Append)? .build()?; // Create a physical plan from the insert plan let plan = session_ctx diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index a996990105b3..f1bda1a66f28 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -390,7 +390,7 @@ impl DataSink for MemSink { mod tests { use super::*; - use crate::datasource::provider_as_source; + use crate::datasource::{provider_as_source, DefaultTableSource}; use crate::physical_plan::collect; use crate::prelude::SessionContext; @@ -640,6 +640,7 @@ mod tests { // Create and register the initial table with the provided schema and data let initial_table = Arc::new(MemTable::try_new(schema.clone(), initial_data)?); session_ctx.register_table("t", initial_table.clone())?; + let target = Arc::new(DefaultTableSource::new(initial_table.clone())); // Create and register the source table with the provided schema and inserted data let source_table = Arc::new(MemTable::try_new(schema.clone(), inserted_data)?); session_ctx.register_table("source", source_table.clone())?; @@ -649,7 +650,7 @@ mod tests { let scan_plan = LogicalPlanBuilder::scan("source", source, None)?.build()?; // Create an insert plan to insert the source data into the initial table let insert_into_table = - LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, InsertOp::Append)? + LogicalPlanBuilder::insert_into(scan_plan, "t", target, InsertOp::Append)? .build()?; // Create a physical plan from the insert plan let plan = session_ctx diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 47dee391c751..2303574e88af 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -24,7 +24,7 @@ use std::sync::Arc; use crate::datasource::file_format::file_type_to_format; use crate::datasource::listing::ListingTableUrl; use crate::datasource::physical_plan::FileSinkConfig; -use crate::datasource::source_as_provider; +use crate::datasource::{source_as_provider, DefaultTableSource}; use crate::error::{DataFusionError, Result}; use crate::execution::context::{ExecutionProps, SessionState}; use crate::logical_expr::utils::generate_sort_key; @@ -541,19 +541,22 @@ impl DefaultPhysicalPlanner { .await? } LogicalPlan::Dml(DmlStatement { - table_name, + target, op: WriteOp::Insert(insert_op), .. }) => { - let name = table_name.table(); - let schema = session_state.schema_for_ref(table_name.clone())?; - if let Some(provider) = schema.table(name).await? { + if let Some(provider) = + target.as_any().downcast_ref::() + { let input_exec = children.one()?; provider + .table_provider .insert_into(session_state, input_exec, *insert_op) .await? } else { - return exec_err!("Table '{table_name}' does not exist"); + return exec_err!( + "Table source can't be downcasted to DefaultTableSource" + ); } } LogicalPlan::Window(Window { window_expr, .. }) => { diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 4fdfb84aea42..ab89f752343d 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -384,14 +384,12 @@ impl LogicalPlanBuilder { pub fn insert_into( input: LogicalPlan, table_name: impl Into, - table_schema: &Schema, + target: Arc, insert_op: InsertOp, ) -> Result { - let table_schema = table_schema.clone().to_dfschema_ref()?; - Ok(Self::new(LogicalPlan::Dml(DmlStatement::new( table_name.into(), - table_schema, + target, WriteOp::Insert(insert_op), Arc::new(input), )))) diff --git a/datafusion/expr/src/logical_plan/dml.rs b/datafusion/expr/src/logical_plan/dml.rs index 669bc8e8a7d3..d4d50ac4eae4 100644 --- a/datafusion/expr/src/logical_plan/dml.rs +++ b/datafusion/expr/src/logical_plan/dml.rs @@ -25,7 +25,7 @@ use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{DFSchemaRef, TableReference}; -use crate::LogicalPlan; +use crate::{LogicalPlan, TableSource}; /// Operator that copies the contents of a database to file(s) #[derive(Clone)] @@ -91,12 +91,12 @@ impl Hash for CopyTo { /// The operator that modifies the content of a database (adapted from /// substrait WriteRel) -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Clone)] pub struct DmlStatement { /// The table name pub table_name: TableReference, - /// The schema of the table (must align with Rel input) - pub table_schema: DFSchemaRef, + /// this is target table to insert into + pub target: Arc, /// The type of operation to perform pub op: WriteOp, /// The relation that determines the tuples to add/remove/modify the schema must match with table_schema @@ -104,18 +104,51 @@ pub struct DmlStatement { /// The schema of the output relation pub output_schema: DFSchemaRef, } +impl Eq for DmlStatement {} +impl Hash for DmlStatement { + fn hash(&self, state: &mut H) { + self.table_name.hash(state); + self.target.schema().hash(state); + self.op.hash(state); + self.input.hash(state); + self.output_schema.hash(state); + } +} + +impl PartialEq for DmlStatement { + fn eq(&self, other: &Self) -> bool { + self.table_name == other.table_name + && self.target.schema() == other.target.schema() + && self.op == other.op + && self.input == other.input + && self.output_schema == other.output_schema + } +} + +impl Debug for DmlStatement { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("DmlStatement") + .field("table_name", &self.table_name) + .field("target", &"...") + .field("target_schema", &self.target.schema()) + .field("op", &self.op) + .field("input", &self.input) + .field("output_schema", &self.output_schema) + .finish() + } +} impl DmlStatement { /// Creates a new DML statement with the output schema set to a single `count` column. pub fn new( table_name: TableReference, - table_schema: DFSchemaRef, + target: Arc, op: WriteOp, input: Arc, ) -> Self { Self { table_name, - table_schema, + target, op, input, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index daf1a1375eac..100a0b7d43dd 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -784,7 +784,7 @@ impl LogicalPlan { } LogicalPlan::Dml(DmlStatement { table_name, - table_schema, + target, op, .. }) => { @@ -792,7 +792,7 @@ impl LogicalPlan { let input = self.only_input(inputs)?; Ok(LogicalPlan::Dml(DmlStatement::new( table_name.clone(), - Arc::clone(table_schema), + Arc::clone(target), op.clone(), Arc::new(input), ))) diff --git a/datafusion/expr/src/logical_plan/tree_node.rs b/datafusion/expr/src/logical_plan/tree_node.rs index 9a6103afd4b4..dfc18c74c70a 100644 --- a/datafusion/expr/src/logical_plan/tree_node.rs +++ b/datafusion/expr/src/logical_plan/tree_node.rs @@ -228,14 +228,14 @@ impl TreeNode for LogicalPlan { }), LogicalPlan::Dml(DmlStatement { table_name, - table_schema, + target, op, input, output_schema, }) => input.map_elements(f)?.update_data(|input| { LogicalPlan::Dml(DmlStatement { table_name, - table_schema, + target, op, input, output_schema, diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 3bc884257dab..1cdfe6d216e3 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -278,7 +278,7 @@ message DmlNode{ Type dml_type = 1; LogicalPlanNode input = 2; TableReference table_name = 3; - datafusion_common.DfSchema schema = 4; + LogicalPlanNode target = 5; } message UnnestNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index add72e4f777e..6e09e9a797ea 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -4764,7 +4764,7 @@ impl serde::Serialize for DmlNode { if self.table_name.is_some() { len += 1; } - if self.schema.is_some() { + if self.target.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.DmlNode", len)?; @@ -4779,8 +4779,8 @@ impl serde::Serialize for DmlNode { if let Some(v) = self.table_name.as_ref() { struct_ser.serialize_field("tableName", v)?; } - if let Some(v) = self.schema.as_ref() { - struct_ser.serialize_field("schema", v)?; + if let Some(v) = self.target.as_ref() { + struct_ser.serialize_field("target", v)?; } struct_ser.end() } @@ -4797,7 +4797,7 @@ impl<'de> serde::Deserialize<'de> for DmlNode { "input", "table_name", "tableName", - "schema", + "target", ]; #[allow(clippy::enum_variant_names)] @@ -4805,7 +4805,7 @@ impl<'de> serde::Deserialize<'de> for DmlNode { DmlType, Input, TableName, - Schema, + Target, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -4830,7 +4830,7 @@ impl<'de> serde::Deserialize<'de> for DmlNode { "dmlType" | "dml_type" => Ok(GeneratedField::DmlType), "input" => Ok(GeneratedField::Input), "tableName" | "table_name" => Ok(GeneratedField::TableName), - "schema" => Ok(GeneratedField::Schema), + "target" => Ok(GeneratedField::Target), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -4853,7 +4853,7 @@ impl<'de> serde::Deserialize<'de> for DmlNode { let mut dml_type__ = None; let mut input__ = None; let mut table_name__ = None; - let mut schema__ = None; + let mut target__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::DmlType => { @@ -4874,11 +4874,11 @@ impl<'de> serde::Deserialize<'de> for DmlNode { } table_name__ = map_.next_value()?; } - GeneratedField::Schema => { - if schema__.is_some() { - return Err(serde::de::Error::duplicate_field("schema")); + GeneratedField::Target => { + if target__.is_some() { + return Err(serde::de::Error::duplicate_field("target")); } - schema__ = map_.next_value()?; + target__ = map_.next_value()?; } } } @@ -4886,7 +4886,7 @@ impl<'de> serde::Deserialize<'de> for DmlNode { dml_type: dml_type__.unwrap_or_default(), input: input__, table_name: table_name__, - schema: schema__, + target: target__, }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index df32c1a70d61..f5ec45da48f2 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -409,8 +409,8 @@ pub struct DmlNode { pub input: ::core::option::Option<::prost::alloc::boxed::Box>, #[prost(message, optional, tag = "3")] pub table_name: ::core::option::Option, - #[prost(message, optional, tag = "4")] - pub schema: ::core::option::Option, + #[prost(message, optional, boxed, tag = "5")] + pub target: ::core::option::Option<::prost::alloc::boxed::Box>, } /// Nested message and enum types in `DmlNode`. pub mod dml_node { diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 53b683bac66a..641dfe7b5fb8 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -55,8 +55,8 @@ use datafusion::{ }; use datafusion_common::file_options::file_type::FileType; use datafusion_common::{ - context, internal_datafusion_err, internal_err, not_impl_err, DataFusionError, - Result, TableReference, + context, internal_datafusion_err, internal_err, not_impl_err, plan_err, + DataFusionError, Result, TableReference, ToDFSchema, }; use datafusion_expr::{ dml, @@ -71,7 +71,7 @@ use datafusion_expr::{ }; use datafusion_expr::{ AggregateUDF, ColumnUnnestList, DmlStatement, FetchType, RecursiveQuery, SkipType, - Unnest, + TableSource, Unnest, }; use self::to_proto::{serialize_expr, serialize_exprs}; @@ -236,6 +236,45 @@ fn from_table_reference( Ok(table_ref.clone().try_into()?) } +/// Converts [LogicalPlan::TableScan] to [TableSource] +/// method to be used to deserialize nodes +/// serialized by [from_table_source] +fn to_table_source( + node: &Option>, + ctx: &SessionContext, + extension_codec: &dyn LogicalExtensionCodec, +) -> Result> { + if let Some(node) = node { + match node.try_into_logical_plan(ctx, extension_codec)? { + LogicalPlan::TableScan(TableScan { source, .. }) => Ok(source), + _ => plan_err!("expected TableScan node"), + } + } else { + plan_err!("LogicalPlanNode should be provided") + } +} + +/// converts [TableSource] to [LogicalPlan::TableScan] +/// using [LogicalPlan::TableScan] was the best approach to +/// serialize [TableSource] to [LogicalPlan::TableScan] +fn from_table_source( + table_name: TableReference, + target: Arc, + extension_codec: &dyn LogicalExtensionCodec, +) -> Result { + let projected_schema = target.schema().to_dfschema_ref()?; + let r = LogicalPlan::TableScan(TableScan { + table_name, + source: target, + projection: None, + projected_schema, + filters: vec![], + fetch: None, + }); + + LogicalPlanNode::try_from_logical_plan(&r, extension_codec) +} + impl AsLogicalPlan for LogicalPlanNode { fn try_decode(buf: &[u8]) -> Result where @@ -454,7 +493,7 @@ impl AsLogicalPlan for LogicalPlanNode { )? .build() } - CustomScan(scan) => { + LogicalPlanType::CustomScan(scan) => { let schema: Schema = convert_required!(scan.schema)?; let schema = Arc::new(schema); let mut projection = None; @@ -942,7 +981,7 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlanType::Dml(dml_node) => Ok(LogicalPlan::Dml( datafusion::logical_expr::DmlStatement::new( from_table_reference(dml_node.table_name.as_ref(), "DML ")?, - Arc::new(convert_required!(dml_node.schema)?), + to_table_source(&dml_node.target, ctx, extension_codec)?, dml_node.dml_type().into(), Arc::new(into_logical_plan!(dml_node.input, ctx, extension_codec)?), ), @@ -1658,7 +1697,7 @@ impl AsLogicalPlan for LogicalPlanNode { )), LogicalPlan::Dml(DmlStatement { table_name, - table_schema, + target, op, input, .. @@ -1669,7 +1708,11 @@ impl AsLogicalPlan for LogicalPlanNode { Ok(LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Dml(Box::new(DmlNode { input: Some(Box::new(input)), - schema: Some(table_schema.try_into()?), + target: Some(Box::new(from_table_source( + table_name.clone(), + Arc::clone(target), + extension_codec, + )?)), table_name: Some(table_name.clone().into()), dml_type: dml_type.into(), }))), diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 209e9cc787d1..74055d979145 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -1709,10 +1709,10 @@ impl SqlToRel<'_, S> { // Do a table lookup to verify the table exists let table_ref = self.object_name_to_table_reference(table_name.clone())?; let table_source = self.context_provider.get_table_source(table_ref.clone())?; - let schema = (*table_source.schema()).clone(); - let schema = DFSchema::try_from(schema)?; + let schema = table_source.schema().to_dfschema_ref()?; let scan = - LogicalPlanBuilder::scan(table_ref.clone(), table_source, None)?.build()?; + LogicalPlanBuilder::scan(table_ref.clone(), Arc::clone(&table_source), None)? + .build()?; let mut planner_context = PlannerContext::new(); let source = match predicate_expr { @@ -1720,7 +1720,7 @@ impl SqlToRel<'_, S> { Some(predicate_expr) => { let filter_expr = self.sql_to_expr(predicate_expr, &schema, &mut planner_context)?; - let schema = Arc::new(schema.clone()); + let schema = Arc::new(schema); let mut using_columns = HashSet::new(); expr_to_columns(&filter_expr, &mut using_columns)?; let filter_expr = normalize_col_with_schemas_and_ambiguity_check( @@ -1734,7 +1734,7 @@ impl SqlToRel<'_, S> { let plan = LogicalPlan::Dml(DmlStatement::new( table_ref, - schema.into(), + table_source, WriteOp::Delete, Arc::new(source), )); @@ -1847,7 +1847,7 @@ impl SqlToRel<'_, S> { let plan = LogicalPlan::Dml(DmlStatement::new( table_name, - table_schema, + table_source, WriteOp::Update, Arc::new(source), )); @@ -1976,7 +1976,7 @@ impl SqlToRel<'_, S> { let plan = LogicalPlan::Dml(DmlStatement::new( table_name, - Arc::new(table_schema), + Arc::clone(&table_source), WriteOp::Insert(insert_op), Arc::new(source), ));