From a6993eb693a2c22b2c90f4d63bfdb742e8835fbf Mon Sep 17 00:00:00 2001 From: Georgi Krastev Date: Sat, 13 Apr 2024 23:47:56 +0200 Subject: [PATCH 1/4] Use PhysicalExtensionCodec consistently --- .../proto/src/physical_plan/from_proto.rs | 88 ++++++++++++--- datafusion/proto/src/physical_plan/mod.rs | 51 ++++----- .../tests/cases/roundtrip_physical_plan.rs | 100 +++++++++--------- 3 files changed, 151 insertions(+), 88 deletions(-) diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index aaca4dc48236..f2c5b4b080b2 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -76,9 +76,10 @@ impl From<&protobuf::PhysicalColumn> for Column { /// # Arguments /// /// * `proto` - Input proto with physical sort expression node -/// * `registry` - A registry knows how to build logical expressions out of user-defined function' names +/// * `registry` - A registry knows how to build logical expressions out of user-defined function names /// * `input_schema` - The Arrow schema for the input, used for determining expression data types /// when performing type coercion. +/// * `codec` - An extension codec used to decode custom UDFs. pub fn parse_physical_sort_expr( proto: &protobuf::PhysicalSortExprNode, registry: &dyn FunctionRegistry, @@ -102,9 +103,10 @@ pub fn parse_physical_sort_expr( /// # Arguments /// /// * `proto` - Input proto with vector of physical sort expression node -/// * `registry` - A registry knows how to build logical expressions out of user-defined function' names +/// * `registry` - A registry knows how to build logical expressions out of user-defined function names /// * `input_schema` - The Arrow schema for the input, used for determining expression data types /// when performing type coercion. +/// * `codec` - An extension codec used to decode custom UDFs. pub fn parse_physical_sort_exprs( proto: &[protobuf::PhysicalSortExprNode], registry: &dyn FunctionRegistry, @@ -123,9 +125,9 @@ pub fn parse_physical_sort_exprs( /// /// # Arguments /// -/// * `proto` - Input proto with physical window exprression node. +/// * `proto` - Input proto with physical window expression node. /// * `name` - Name of the window expression. -/// * `registry` - A registry knows how to build logical expressions out of user-defined function' names +/// * `registry` - A registry knows how to build logical expressions out of user-defined function names /// * `input_schema` - The Arrow schema for the input, used for determining expression data types /// when performing type coercion. pub fn parse_physical_window_expr( @@ -133,15 +135,29 @@ pub fn parse_physical_window_expr( registry: &dyn FunctionRegistry, input_schema: &Schema, ) -> Result> { - let codec = DefaultPhysicalExtensionCodec {}; + parse_physical_window_expr_ext( + proto, + registry, + input_schema, + &DefaultPhysicalExtensionCodec {}, + ) +} + +// TODO: Make this the public function on next major release. +pub(crate) fn parse_physical_window_expr_ext( + proto: &protobuf::PhysicalWindowExprNode, + registry: &dyn FunctionRegistry, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, +) -> Result> { let window_node_expr = - parse_physical_exprs(&proto.args, registry, input_schema, &codec)?; + parse_physical_exprs(&proto.args, registry, input_schema, codec)?; let partition_by = - parse_physical_exprs(&proto.partition_by, registry, input_schema, &codec)?; + parse_physical_exprs(&proto.partition_by, registry, input_schema, codec)?; let order_by = - parse_physical_sort_exprs(&proto.order_by, registry, input_schema, &codec)?; + parse_physical_sort_exprs(&proto.order_by, registry, input_schema, codec)?; let window_frame = proto .window_frame @@ -187,9 +203,10 @@ where /// # Arguments /// /// * `proto` - Input proto with physical expression node -/// * `registry` - A registry knows how to build logical expressions out of user-defined function' names +/// * `registry` - A registry knows how to build logical expressions out of user-defined function names /// * `input_schema` - The Arrow schema for the input, used for determining expression data types /// when performing type coercion. +/// * `codec` - An extension codec used to decode custom UDFs. pub fn parse_physical_expr( proto: &protobuf::PhysicalExprNode, registry: &dyn FunctionRegistry, @@ -213,6 +230,7 @@ pub fn parse_physical_expr( registry, "left", input_schema, + codec, )?, logical_plan::from_proto::from_proto_binary_op(&binary_expr.op)?, parse_required_physical_expr( @@ -220,6 +238,7 @@ pub fn parse_physical_expr( registry, "right", input_schema, + codec, )?, )), ExprType::AggregateExpr(_) => { @@ -241,6 +260,7 @@ pub fn parse_physical_expr( registry, "expr", input_schema, + codec, )?)) } ExprType::IsNotNullExpr(e) => { @@ -249,6 +269,7 @@ pub fn parse_physical_expr( registry, "expr", input_schema, + codec, )?)) } ExprType::NotExpr(e) => Arc::new(NotExpr::new(parse_required_physical_expr( @@ -256,6 +277,7 @@ pub fn parse_physical_expr( registry, "expr", input_schema, + codec, )?)), ExprType::Negative(e) => { Arc::new(NegativeExpr::new(parse_required_physical_expr( @@ -263,6 +285,7 @@ pub fn parse_physical_expr( registry, "expr", input_schema, + codec, )?)) } ExprType::InList(e) => in_list( @@ -271,6 +294,7 @@ pub fn parse_physical_expr( registry, "expr", input_schema, + codec, )?, parse_physical_exprs(&e.list, registry, input_schema, codec)?, &e.negated, @@ -290,12 +314,14 @@ pub fn parse_physical_expr( registry, "when_expr", input_schema, + codec, )?, parse_required_physical_expr( e.then_expr.as_ref(), registry, "then_expr", input_schema, + codec, )?, )) }) @@ -311,6 +337,7 @@ pub fn parse_physical_expr( registry, "expr", input_schema, + codec, )?, convert_required!(e.arrow_type)?, None, @@ -321,6 +348,7 @@ pub fn parse_physical_expr( registry, "expr", input_schema, + codec, )?, convert_required!(e.arrow_type)?, )), @@ -371,12 +399,14 @@ pub fn parse_physical_expr( registry, "expr", input_schema, + codec, )?, parse_required_physical_expr( like_expr.pattern.as_deref(), registry, "pattern", input_schema, + codec, )?, )), }; @@ -389,9 +419,9 @@ fn parse_required_physical_expr( registry: &dyn FunctionRegistry, field: &str, input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, ) -> Result> { - let codec = DefaultPhysicalExtensionCodec {}; - expr.map(|e| parse_physical_expr(e, registry, input_schema, &codec)) + expr.map(|e| parse_physical_expr(e, registry, input_schema, codec)) .transpose()? .ok_or_else(|| { DataFusionError::Internal(format!("Missing required field {field:?}")) @@ -433,15 +463,29 @@ pub fn parse_protobuf_hash_partitioning( partitioning: Option<&protobuf::PhysicalHashRepartition>, registry: &dyn FunctionRegistry, input_schema: &Schema, +) -> Result> { + parse_protobuf_hash_partitioning_ext( + partitioning, + registry, + input_schema, + &DefaultPhysicalExtensionCodec {}, + ) +} + +// TODO: Make this the public function on next major release. +fn parse_protobuf_hash_partitioning_ext( + partitioning: Option<&protobuf::PhysicalHashRepartition>, + registry: &dyn FunctionRegistry, + input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, ) -> Result> { match partitioning { Some(hash_part) => { - let codec = DefaultPhysicalExtensionCodec {}; let expr = parse_physical_exprs( &hash_part.hash_expr, registry, input_schema, - &codec, + codec, )?; Ok(Some(Partitioning::Hash( @@ -456,6 +500,19 @@ pub fn parse_protobuf_hash_partitioning( pub fn parse_protobuf_file_scan_config( proto: &protobuf::FileScanExecConf, registry: &dyn FunctionRegistry, +) -> Result { + parse_protobuf_file_scan_config_ext( + proto, + registry, + &DefaultPhysicalExtensionCodec {}, + ) +} + +// TODO: Make this the public function on next major release. +pub(crate) fn parse_protobuf_file_scan_config_ext( + proto: &protobuf::FileScanExecConf, + registry: &dyn FunctionRegistry, + codec: &dyn PhysicalExtensionCodec, ) -> Result { let schema: Arc = Arc::new(convert_required!(proto.schema)?); let projection = proto @@ -489,7 +546,7 @@ pub fn parse_protobuf_file_scan_config( .collect::>>()?; // Remove partition columns from the schema after recreating table_partition_cols - // because the partition columns are not in the file. They are present to allow the + // because the partition columns are not in the file. They are present to allow // the partition column types to be reconstructed after serde. let file_schema = Arc::new(Schema::new( schema @@ -502,12 +559,11 @@ pub fn parse_protobuf_file_scan_config( let mut output_ordering = vec![]; for node_collection in &proto.output_ordering { - let codec = DefaultPhysicalExtensionCodec {}; let sort_expr = parse_physical_sort_exprs( &node_collection.physical_sort_expr_nodes, registry, &schema, - &codec, + codec, )?; output_ordering.push(sort_expr); } diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 4d95c847bf99..0890cf0a7e60 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -19,22 +19,8 @@ use std::convert::TryInto; use std::fmt::Debug; use std::sync::Arc; -use self::from_proto::parse_physical_window_expr; -use self::to_proto::serialize_physical_expr; - -use crate::common::{byte_to_string, proto_error, str_to_byte}; -use crate::convert_required; -use crate::physical_plan::from_proto::{ - parse_physical_expr, parse_physical_sort_expr, parse_physical_sort_exprs, - parse_protobuf_file_scan_config, -}; -use crate::protobuf::physical_aggregate_expr_node::AggregateFunction; -use crate::protobuf::physical_expr_node::ExprType; -use crate::protobuf::physical_plan_node::PhysicalPlanType; -use crate::protobuf::repartition_exec_node::PartitionMethod; -use crate::protobuf::{ - self, window_agg_exec_node, PhysicalPlanNode, PhysicalSortExprNodeCollection, -}; +use prost::bytes::BufMut; +use prost::Message; use datafusion::arrow::compute::SortOptions; use datafusion::arrow::datatypes::SchemaRef; @@ -79,8 +65,22 @@ use datafusion::physical_plan::{ use datafusion_common::{internal_err, not_impl_err, DataFusionError, Result}; use datafusion_expr::ScalarUDF; -use prost::bytes::BufMut; -use prost::Message; +use crate::common::{byte_to_string, proto_error, str_to_byte}; +use crate::convert_required; +use crate::physical_plan::from_proto::{ + parse_physical_expr, parse_physical_sort_expr, parse_physical_sort_exprs, + parse_physical_window_expr_ext, parse_protobuf_file_scan_config, + parse_protobuf_file_scan_config_ext, +}; +use crate::protobuf::physical_aggregate_expr_node::AggregateFunction; +use crate::protobuf::physical_expr_node::ExprType; +use crate::protobuf::physical_plan_node::PhysicalPlanType; +use crate::protobuf::repartition_exec_node::PartitionMethod; +use crate::protobuf::{ + self, window_agg_exec_node, PhysicalPlanNode, PhysicalSortExprNodeCollection, +}; + +use self::to_proto::serialize_physical_expr; pub mod from_proto; pub mod to_proto; @@ -188,9 +188,10 @@ impl AsExecutionPlan for PhysicalPlanNode { } } PhysicalPlanType::CsvScan(scan) => Ok(Arc::new(CsvExec::new( - parse_protobuf_file_scan_config( + parse_protobuf_file_scan_config_ext( scan.base_conf.as_ref().unwrap(), registry, + extension_codec, )?, scan.has_header, str_to_byte(&scan.delimiter, "delimiter")?, @@ -230,12 +231,13 @@ impl AsExecutionPlan for PhysicalPlanNode { Default::default(), ))) } - PhysicalPlanType::AvroScan(scan) => { - Ok(Arc::new(AvroExec::new(parse_protobuf_file_scan_config( + PhysicalPlanType::AvroScan(scan) => Ok(Arc::new(AvroExec::new( + parse_protobuf_file_scan_config_ext( scan.base_conf.as_ref().unwrap(), registry, - )?))) - } + extension_codec, + )?, + ))), PhysicalPlanType::CoalesceBatches(coalesce_batches) => { let input: Arc = into_physical_plan( &coalesce_batches.input, @@ -334,10 +336,11 @@ impl AsExecutionPlan for PhysicalPlanNode { .window_expr .iter() .map(|window_expr| { - parse_physical_window_expr( + parse_physical_window_expr_ext( window_expr, registry, input_schema.as_ref(), + extension_codec, ) }) .collect::, _>>()?; diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index f97cfea765bf..9cc63d5ce1b5 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -21,6 +21,8 @@ use std::sync::Arc; use std::vec; use arrow::csv::WriterBuilder; +use prost::Message; + use datafusion::arrow::array::ArrayRef; use datafusion::arrow::compute::kernels::sort::SortOptions; use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema}; @@ -77,19 +79,18 @@ use datafusion_expr::{ ScalarFunctionDefinition, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, WindowFrame, WindowFrameBound, }; -use datafusion_proto::physical_plan::from_proto::parse_physical_expr; -use datafusion_proto::physical_plan::to_proto::serialize_physical_expr; use datafusion_proto::physical_plan::{ AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, }; use datafusion_proto::protobuf; -use prost::Message; /// Perform a serde roundtrip and assert that the string representation of the before and after plans /// are identical. Note that this often isn't sufficient to guarantee that no information is /// lost during serde because the string representation of a plan often only shows a subset of state. fn roundtrip_test(exec_plan: Arc) -> Result<()> { - let _ = roundtrip_test_and_return(exec_plan); + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + roundtrip_test_and_return(exec_plan, &ctx, &codec)?; Ok(()) } @@ -101,15 +102,15 @@ fn roundtrip_test(exec_plan: Arc) -> Result<()> { /// farther in tests. fn roundtrip_test_and_return( exec_plan: Arc, + ctx: &SessionContext, + codec: &dyn PhysicalExtensionCodec, ) -> Result> { - let ctx = SessionContext::new(); - let codec = DefaultPhysicalExtensionCodec {}; let proto: protobuf::PhysicalPlanNode = - protobuf::PhysicalPlanNode::try_from_physical_plan(exec_plan.clone(), &codec) + protobuf::PhysicalPlanNode::try_from_physical_plan(exec_plan.clone(), codec) .expect("to proto"); let runtime = ctx.runtime_env(); let result_exec_plan: Arc = proto - .try_into_physical_plan(&ctx, runtime.deref(), &codec) + .try_into_physical_plan(ctx, runtime.deref(), codec) .expect("from proto"); assert_eq!(format!("{exec_plan:?}"), format!("{result_exec_plan:?}")); Ok(result_exec_plan) @@ -123,17 +124,10 @@ fn roundtrip_test_and_return( /// performing serde on some plans. fn roundtrip_test_with_context( exec_plan: Arc, - ctx: SessionContext, + ctx: &SessionContext, ) -> Result<()> { let codec = DefaultPhysicalExtensionCodec {}; - let proto: protobuf::PhysicalPlanNode = - protobuf::PhysicalPlanNode::try_from_physical_plan(exec_plan.clone(), &codec) - .expect("to proto"); - let runtime = ctx.runtime_env(); - let result_exec_plan: Arc = proto - .try_into_physical_plan(&ctx, runtime.deref(), &codec) - .expect("from proto"); - assert_eq!(format!("{exec_plan:?}"), format!("{result_exec_plan:?}")); + roundtrip_test_and_return(exec_plan, ctx, &codec)?; Ok(()) } @@ -444,7 +438,7 @@ fn roundtrip_aggregate_udaf() -> Result<()> { Arc::new(EmptyExec::new(schema.clone())), schema, )?), - ctx, + &ctx, ) } @@ -642,7 +636,7 @@ fn roundtrip_scalar_udf() -> Result<()> { ctx.register_udf(udf); - roundtrip_test_with_context(Arc::new(project), ctx) + roundtrip_test_with_context(Arc::new(project), &ctx) } #[test] @@ -657,11 +651,7 @@ fn roundtrip_scalar_udf_extension_codec() { impl MyRegexUdf { fn new(pattern: String) -> Self { Self { - signature: Signature::uniform( - 1, - vec![DataType::Int32], - Volatility::Immutable, - ), + signature: Signature::exact(vec![DataType::Utf8], Volatility::Immutable), pattern, } } @@ -682,7 +672,7 @@ fn roundtrip_scalar_udf_extension_codec() { if !matches!(args.first(), Some(&DataType::Utf8)) { return plan_err!("regex_udf only accepts Utf8 arguments"); } - Ok(DataType::Int32) + Ok(DataType::Boolean) } fn invoke(&self, _args: &[ColumnarValue]) -> Result { unimplemented!() @@ -747,32 +737,40 @@ fn roundtrip_scalar_udf_extension_codec() { } } + let field_text = Field::new("text", DataType::Utf8, true); + let field_published = Field::new("published", DataType::Boolean, false); + let schema = Arc::new(Schema::new(vec![field_text, field_published])); + let input = Arc::new(EmptyExec::new(schema.clone())); + let pattern = ".*"; let udf = ScalarUDF::from(MyRegexUdf::new(pattern.to_string())); - let test_expr = ScalarFunctionExpr::new( + let udf_expr = Arc::new(ScalarFunctionExpr::new( udf.name(), ScalarFunctionDefinition::UDF(Arc::new(udf.clone())), - vec![], - DataType::Int32, + vec![col("text", &schema).expect("text")], + DataType::Boolean, None, false, + )); + + let filter = Arc::new( + FilterExec::try_new( + Arc::new(BinaryExpr::new( + col("published", &schema).expect("published"), + Operator::And, + udf_expr, + )), + input, + ) + .expect("filter"), ); - let fmt_expr = format!("{test_expr:?}"); - let ctx = SessionContext::new(); - ctx.register_udf(udf.clone()); - let extension_codec = ScalarUDFExtensionCodec {}; - let proto: protobuf::PhysicalExprNode = - match serialize_physical_expr(Arc::new(test_expr), &extension_codec) { - Ok(proto) => proto, - Err(e) => panic!("failed to serialize expr: {e:?}"), - }; - let field_a = Field::new("a", DataType::Int32, false); - let schema = Arc::new(Schema::new(vec![field_a])); - let round_trip = - parse_physical_expr(&proto, &ctx, &schema, &extension_codec).unwrap(); - assert_eq!(fmt_expr, format!("{round_trip:?}")); + let ctx = SessionContext::new(); + let codec = ScalarUDFExtensionCodec {}; + ctx.register_udf(udf); + roundtrip_test_and_return(filter, &ctx, &codec).unwrap(); } + #[test] fn roundtrip_distinct_count() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false); @@ -896,12 +894,18 @@ fn roundtrip_csv_sink() -> Result<()> { }), )]; - let roundtrip_plan = roundtrip_test_and_return(Arc::new(DataSinkExec::new( - input, - data_sink, - schema.clone(), - Some(sort_order), - ))) + let ctx = SessionContext::new(); + let codec = DefaultPhysicalExtensionCodec {}; + let roundtrip_plan = roundtrip_test_and_return( + Arc::new(DataSinkExec::new( + input, + data_sink, + schema.clone(), + Some(sort_order), + )), + &ctx, + &codec, + ) .unwrap(); let roundtrip_plan = roundtrip_plan From 7d16e388c377cbb5d108f89c41c7a1f393290ebc Mon Sep 17 00:00:00 2001 From: Georgi Krastev Date: Sun, 14 Apr 2024 17:05:55 +0200 Subject: [PATCH 2/4] Use PhysicalExtensionCodec consisdently also when serializing --- datafusion/proto/src/physical_plan/mod.rs | 60 ++- .../proto/src/physical_plan/to_proto.rs | 504 +++++++++--------- .../tests/cases/roundtrip_physical_plan.rs | 22 +- 3 files changed, 307 insertions(+), 279 deletions(-) diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 0890cf0a7e60..b8c0354bfd6a 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -72,20 +72,22 @@ use crate::physical_plan::from_proto::{ parse_physical_window_expr_ext, parse_protobuf_file_scan_config, parse_protobuf_file_scan_config_ext, }; +use crate::physical_plan::to_proto::{ + serialize_file_scan_config, serialize_maybe_filter, serialize_physical_aggr_expr, + serialize_physical_window_expr, +}; use crate::protobuf::physical_aggregate_expr_node::AggregateFunction; use crate::protobuf::physical_expr_node::ExprType; use crate::protobuf::physical_plan_node::PhysicalPlanType; use crate::protobuf::repartition_exec_node::PartitionMethod; -use crate::protobuf::{ - self, window_agg_exec_node, PhysicalPlanNode, PhysicalSortExprNodeCollection, -}; +use crate::protobuf::{self, window_agg_exec_node}; use self::to_proto::serialize_physical_expr; pub mod from_proto; pub mod to_proto; -impl AsExecutionPlan for PhysicalPlanNode { +impl AsExecutionPlan for protobuf::PhysicalPlanNode { fn try_decode(buf: &[u8]) -> Result where Self: Sized, @@ -1452,14 +1454,17 @@ impl AsExecutionPlan for PhysicalPlanNode { let filter = exec .filter_expr() .iter() - .map(|expr| expr.to_owned().try_into()) + .map(|expr| serialize_maybe_filter(expr.to_owned(), extension_codec)) .collect::>>()?; let agg = exec .aggr_expr() .iter() - .map(|expr| expr.to_owned().try_into()) + .map(|expr| { + serialize_physical_aggr_expr(expr.to_owned(), extension_codec) + }) .collect::>>()?; + let agg_names = exec .aggr_expr() .iter() @@ -1559,7 +1564,10 @@ impl AsExecutionPlan for PhysicalPlanNode { return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::CsvScan( protobuf::CsvScanExecNode { - base_conf: Some(exec.base_config().try_into()?), + base_conf: Some(serialize_file_scan_config( + exec.base_config(), + extension_codec, + )?), has_header: exec.has_header(), delimiter: byte_to_string(exec.delimiter(), "delimiter")?, quote: byte_to_string(exec.quote(), "quote")?, @@ -1584,7 +1592,10 @@ impl AsExecutionPlan for PhysicalPlanNode { return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::ParquetScan( protobuf::ParquetScanExecNode { - base_conf: Some(exec.base_config().try_into()?), + base_conf: Some(serialize_file_scan_config( + exec.base_config(), + extension_codec, + )?), predicate, }, )), @@ -1595,7 +1606,10 @@ impl AsExecutionPlan for PhysicalPlanNode { return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::AvroScan( protobuf::AvroScanExecNode { - base_conf: Some(exec.base_config().try_into()?), + base_conf: Some(serialize_file_scan_config( + exec.base_config(), + extension_codec, + )?), }, )), }); @@ -1691,7 +1705,7 @@ impl AsExecutionPlan for PhysicalPlanNode { } if let Some(union) = plan.downcast_ref::() { - let mut inputs: Vec = vec![]; + let mut inputs: Vec = vec![]; for input in union.inputs() { inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan( input.to_owned(), @@ -1706,7 +1720,7 @@ impl AsExecutionPlan for PhysicalPlanNode { } if let Some(interleave) = plan.downcast_ref::() { - let mut inputs: Vec = vec![]; + let mut inputs: Vec = vec![]; for input in interleave.inputs() { inputs.push(protobuf::PhysicalPlanNode::try_from_physical_plan( input.to_owned(), @@ -1812,11 +1826,11 @@ impl AsExecutionPlan for PhysicalPlanNode { extension_codec, )?; - let window_expr = - exec.window_expr() - .iter() - .map(|e| e.clone().try_into()) - .collect::>>()?; + let window_expr = exec + .window_expr() + .iter() + .map(|e| serialize_physical_window_expr(e.clone(), extension_codec)) + .collect::>>()?; let partition_keys = exec .partition_keys @@ -1842,11 +1856,11 @@ impl AsExecutionPlan for PhysicalPlanNode { extension_codec, )?; - let window_expr = - exec.window_expr() - .iter() - .map(|e| e.clone().try_into()) - .collect::>>()?; + let window_expr = exec + .window_expr() + .iter() + .map(|e| serialize_physical_window_expr(e.clone(), extension_codec)) + .collect::>>()?; let partition_keys = exec .partition_keys @@ -1904,7 +1918,7 @@ impl AsExecutionPlan for PhysicalPlanNode { Ok(sort_expr) }) .collect::>>()?; - Some(PhysicalSortExprNodeCollection { + Some(protobuf::PhysicalSortExprNodeCollection { physical_sort_expr_nodes: expr, }) } @@ -2047,7 +2061,7 @@ impl PhysicalExtensionCodec for DefaultPhysicalExtensionCodec { } fn into_physical_plan( - node: &Option>, + node: &Option>, registry: &dyn FunctionRegistry, runtime: &RuntimeEnv, extension_codec: &dyn PhysicalExtensionCodec, diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index e1574f48fb8e..5728fe45d9bb 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -22,18 +22,8 @@ use std::{ sync::Arc, }; -use crate::protobuf::{ - self, copy_to_node, physical_aggregate_expr_node, physical_window_expr_node, - scalar_value::Value, ArrowOptions, AvroOptions, PhysicalSortExprNode, - PhysicalSortExprNodeCollection, ScalarValue, -}; - #[cfg(feature = "parquet")] use datafusion::datasource::file_format::parquet::ParquetSink; - -use datafusion_expr::ScalarFunctionDefinition; - -use crate::logical_plan::csv_writer_options_to_proto; use datafusion::logical_expr::BuiltinScalarFunction; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; @@ -71,195 +61,203 @@ use datafusion_common::{ stats::Precision, DataFusionError, JoinSide, Result, }; +use datafusion_expr::ScalarFunctionDefinition; + +use crate::logical_plan::csv_writer_options_to_proto; +use crate::protobuf::{ + self, copy_to_node, physical_aggregate_expr_node, physical_window_expr_node, + scalar_value::Value, ArrowOptions, AvroOptions, PhysicalSortExprNode, + PhysicalSortExprNodeCollection, ScalarValue, +}; use super::{DefaultPhysicalExtensionCodec, PhysicalExtensionCodec}; impl TryFrom> for protobuf::PhysicalExprNode { type Error = DataFusionError; - fn try_from(a: Arc) -> Result { - let codec = DefaultPhysicalExtensionCodec {}; - let expressions = serialize_physical_exprs(a.expressions(), &codec)?; - - let ordering_req = a.order_bys().unwrap_or(&[]).to_vec(); - let ordering_req = serialize_physical_sort_exprs(ordering_req, &codec)?; - - if let Some(a) = a.as_any().downcast_ref::() { - let name = a.fun().name().to_string(); - return Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( - protobuf::PhysicalAggregateExprNode { - aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(name)), - expr: expressions, - ordering_req, - distinct: false, - }, - )), - }); - } - - let AggrFn { - inner: aggr_function, - distinct, - } = aggr_expr_to_aggr_fn(a.as_ref())?; + fn try_from(a: Arc) -> Result { + serialize_physical_aggr_expr(a, &DefaultPhysicalExtensionCodec {}) + } +} - Ok(protobuf::PhysicalExprNode { +pub(crate) fn serialize_physical_aggr_expr( + aggr_expr: Arc, + codec: &dyn PhysicalExtensionCodec, +) -> Result { + let expressions = serialize_physical_exprs(aggr_expr.expressions(), codec)?; + let ordering_req = aggr_expr.order_bys().unwrap_or(&[]).to_vec(); + let ordering_req = serialize_physical_sort_exprs(ordering_req, codec)?; + + if let Some(a) = aggr_expr.as_any().downcast_ref::() { + let name = a.fun().name().to_string(); + return Ok(protobuf::PhysicalExprNode { expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( protobuf::PhysicalAggregateExprNode { - aggregate_function: Some( - physical_aggregate_expr_node::AggregateFunction::AggrFunction( - aggr_function as i32, - ), - ), + aggregate_function: Some(physical_aggregate_expr_node::AggregateFunction::UserDefinedAggrFunction(name)), expr: expressions, ordering_req, - distinct, + distinct: false, }, )), - }) + }); } + + let AggrFn { + inner: aggr_function, + distinct, + } = aggr_expr_to_aggr_fn(aggr_expr.as_ref())?; + + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::AggregateExpr( + protobuf::PhysicalAggregateExprNode { + aggregate_function: Some( + physical_aggregate_expr_node::AggregateFunction::AggrFunction( + aggr_function as i32, + ), + ), + expr: expressions, + ordering_req, + distinct, + }, + )), + }) } impl TryFrom> for protobuf::PhysicalWindowExprNode { type Error = DataFusionError; - fn try_from( - window_expr: Arc, - ) -> std::result::Result { - let expr = window_expr.as_any(); + fn try_from(window_expr: Arc) -> Result { + serialize_physical_window_expr(window_expr, &DefaultPhysicalExtensionCodec {}) + } +} - let mut args = window_expr.expressions().to_vec(); - let window_frame = window_expr.get_window_frame(); +pub(crate) fn serialize_physical_window_expr( + window_expr: Arc, + codec: &dyn PhysicalExtensionCodec, +) -> Result { + let expr = window_expr.as_any(); + let mut args = window_expr.expressions().to_vec(); + let window_frame = window_expr.get_window_frame(); - let window_function = if let Some(built_in_window_expr) = - expr.downcast_ref::() + let window_function = if let Some(built_in_window_expr) = + expr.downcast_ref::() + { + let expr = built_in_window_expr.get_built_in_func_expr(); + let built_in_fn_expr = expr.as_any(); + + let builtin_fn = if built_in_fn_expr.downcast_ref::().is_some() { + protobuf::BuiltInWindowFunction::RowNumber + } else if let Some(rank_expr) = built_in_fn_expr.downcast_ref::() { + match rank_expr.get_type() { + RankType::Basic => protobuf::BuiltInWindowFunction::Rank, + RankType::Dense => protobuf::BuiltInWindowFunction::DenseRank, + RankType::Percent => protobuf::BuiltInWindowFunction::PercentRank, + } + } else if built_in_fn_expr.downcast_ref::().is_some() { + protobuf::BuiltInWindowFunction::CumeDist + } else if let Some(ntile_expr) = built_in_fn_expr.downcast_ref::() { + args.insert( + 0, + Arc::new(Literal::new(datafusion_common::ScalarValue::Int64(Some( + ntile_expr.get_n() as i64, + )))), + ); + protobuf::BuiltInWindowFunction::Ntile + } else if let Some(window_shift_expr) = + built_in_fn_expr.downcast_ref::() { - let expr = built_in_window_expr.get_built_in_func_expr(); - let built_in_fn_expr = expr.as_any(); - - let builtin_fn = if built_in_fn_expr.downcast_ref::().is_some() { - protobuf::BuiltInWindowFunction::RowNumber - } else if let Some(rank_expr) = built_in_fn_expr.downcast_ref::() { - match rank_expr.get_type() { - RankType::Basic => protobuf::BuiltInWindowFunction::Rank, - RankType::Dense => protobuf::BuiltInWindowFunction::DenseRank, - RankType::Percent => protobuf::BuiltInWindowFunction::PercentRank, - } - } else if built_in_fn_expr.downcast_ref::().is_some() { - protobuf::BuiltInWindowFunction::CumeDist - } else if let Some(ntile_expr) = built_in_fn_expr.downcast_ref::() { - args.insert( - 0, - Arc::new(Literal::new(datafusion_common::ScalarValue::Int64(Some( - ntile_expr.get_n() as i64, - )))), - ); - protobuf::BuiltInWindowFunction::Ntile - } else if let Some(window_shift_expr) = - built_in_fn_expr.downcast_ref::() - { - args.insert( - 1, - Arc::new(Literal::new(datafusion_common::ScalarValue::Int64(Some( - window_shift_expr.get_shift_offset(), - )))), - ); - args.insert( - 2, - Arc::new(Literal::new(window_shift_expr.get_default_value())), - ); - - if window_shift_expr.get_shift_offset() >= 0 { - protobuf::BuiltInWindowFunction::Lag - } else { - protobuf::BuiltInWindowFunction::Lead - } - } else if let Some(nth_value_expr) = - built_in_fn_expr.downcast_ref::() - { - match nth_value_expr.get_kind() { - NthValueKind::First => protobuf::BuiltInWindowFunction::FirstValue, - NthValueKind::Last => protobuf::BuiltInWindowFunction::LastValue, - NthValueKind::Nth(n) => { - args.insert( - 1, - Arc::new(Literal::new( - datafusion_common::ScalarValue::Int64(Some(n)), - )), - ); - protobuf::BuiltInWindowFunction::NthValue - } - } + args.insert( + 1, + Arc::new(Literal::new(datafusion_common::ScalarValue::Int64(Some( + window_shift_expr.get_shift_offset(), + )))), + ); + args.insert( + 2, + Arc::new(Literal::new(window_shift_expr.get_default_value())), + ); + + if window_shift_expr.get_shift_offset() >= 0 { + protobuf::BuiltInWindowFunction::Lag } else { - return not_impl_err!("BuiltIn function not supported: {expr:?}"); - }; - - physical_window_expr_node::WindowFunction::BuiltInFunction(builtin_fn as i32) - } else if let Some(plain_aggr_window_expr) = - expr.downcast_ref::() - { - let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn( - plain_aggr_window_expr.get_aggregate_expr().as_ref(), - )?; - - if distinct { - // TODO - return not_impl_err!( - "Distinct aggregate functions not supported in window expressions" - ); + protobuf::BuiltInWindowFunction::Lead } - - if !window_frame.start_bound.is_unbounded() { - return Err(DataFusionError::Internal(format!("Invalid PlainAggregateWindowExpr = {window_expr:?} with WindowFrame = {window_frame:?}"))); + } else if let Some(nth_value_expr) = built_in_fn_expr.downcast_ref::() { + match nth_value_expr.get_kind() { + NthValueKind::First => protobuf::BuiltInWindowFunction::FirstValue, + NthValueKind::Last => protobuf::BuiltInWindowFunction::LastValue, + NthValueKind::Nth(n) => { + args.insert( + 1, + Arc::new(Literal::new(datafusion_common::ScalarValue::Int64( + Some(n), + ))), + ); + protobuf::BuiltInWindowFunction::NthValue + } } + } else { + return not_impl_err!("BuiltIn function not supported: {expr:?}"); + }; - physical_window_expr_node::WindowFunction::AggrFunction(inner as i32) - } else if let Some(sliding_aggr_window_expr) = - expr.downcast_ref::() - { - let AggrFn { inner, distinct } = aggr_expr_to_aggr_fn( - sliding_aggr_window_expr.get_aggregate_expr().as_ref(), - )?; - - if distinct { - // TODO - return not_impl_err!( - "Distinct aggregate functions not supported in window expressions" - ); - } + physical_window_expr_node::WindowFunction::BuiltInFunction(builtin_fn as i32) + } else if let Some(plain_aggr_window_expr) = + expr.downcast_ref::() + { + let AggrFn { inner, distinct } = + aggr_expr_to_aggr_fn(plain_aggr_window_expr.get_aggregate_expr().as_ref())?; + + if distinct { + // TODO + return not_impl_err!( + "Distinct aggregate functions not supported in window expressions" + ); + } - if window_frame.start_bound.is_unbounded() { - return Err(DataFusionError::Internal(format!("Invalid SlidingAggregateWindowExpr = {window_expr:?} with WindowFrame = {window_frame:?}"))); - } + if !window_frame.start_bound.is_unbounded() { + return Err(DataFusionError::Internal(format!("Invalid PlainAggregateWindowExpr = {window_expr:?} with WindowFrame = {window_frame:?}"))); + } - physical_window_expr_node::WindowFunction::AggrFunction(inner as i32) - } else { - return not_impl_err!("WindowExpr not supported: {window_expr:?}"); - }; - let codec = DefaultPhysicalExtensionCodec {}; - let args = serialize_physical_exprs(args, &codec)?; - let partition_by = - serialize_physical_exprs(window_expr.partition_by().to_vec(), &codec)?; + physical_window_expr_node::WindowFunction::AggrFunction(inner as i32) + } else if let Some(sliding_aggr_window_expr) = + expr.downcast_ref::() + { + let AggrFn { inner, distinct } = + aggr_expr_to_aggr_fn(sliding_aggr_window_expr.get_aggregate_expr().as_ref())?; + + if distinct { + // TODO + return not_impl_err!( + "Distinct aggregate functions not supported in window expressions" + ); + } - let order_by = - serialize_physical_sort_exprs(window_expr.order_by().to_vec(), &codec)?; + if window_frame.start_bound.is_unbounded() { + return Err(DataFusionError::Internal(format!("Invalid SlidingAggregateWindowExpr = {window_expr:?} with WindowFrame = {window_frame:?}"))); + } - let window_frame: protobuf::WindowFrame = window_frame - .as_ref() - .try_into() - .map_err(|e| DataFusionError::Internal(format!("{e}")))?; - - let name = window_expr.name().to_string(); - - Ok(protobuf::PhysicalWindowExprNode { - args, - partition_by, - order_by, - window_frame: Some(window_frame), - window_function: Some(window_function), - name, - }) - } + physical_window_expr_node::WindowFunction::AggrFunction(inner as i32) + } else { + return not_impl_err!("WindowExpr not supported: {window_expr:?}"); + }; + + let args = serialize_physical_exprs(args, codec)?; + let partition_by = + serialize_physical_exprs(window_expr.partition_by().to_vec(), codec)?; + let order_by = serialize_physical_sort_exprs(window_expr.order_by().to_vec(), codec)?; + let window_frame: protobuf::WindowFrame = window_frame + .as_ref() + .try_into() + .map_err(|e| DataFusionError::Internal(format!("{e}")))?; + + Ok(protobuf::PhysicalWindowExprNode { + args, + partition_by, + order_by, + window_frame: Some(window_frame), + window_function: Some(window_function), + name: window_expr.name().to_string(), + }) } struct AggrFn { @@ -366,7 +364,7 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { pub fn serialize_physical_sort_exprs( sort_exprs: I, codec: &dyn PhysicalExtensionCodec, -) -> Result, DataFusionError> +) -> Result> where I: IntoIterator, { @@ -379,7 +377,7 @@ where pub fn serialize_physical_sort_expr( sort_expr: PhysicalSortExpr, codec: &dyn PhysicalExtensionCodec, -) -> Result { +) -> Result { let PhysicalSortExpr { expr, options } = sort_expr; let expr = serialize_physical_expr(expr, codec)?; Ok(PhysicalSortExprNode { @@ -392,7 +390,7 @@ pub fn serialize_physical_sort_expr( pub fn serialize_physical_exprs( values: I, codec: &dyn PhysicalExtensionCodec, -) -> Result, DataFusionError> +) -> Result> where I: IntoIterator>, { @@ -409,7 +407,7 @@ where pub fn serialize_physical_expr( value: Arc, codec: &dyn PhysicalExtensionCodec, -) -> Result { +) -> Result { let expr = value.as_any(); if let Some(expr) = expr.downcast_ref::() { @@ -637,7 +635,7 @@ fn try_parse_when_then_expr( impl TryFrom<&PartitionedFile> for protobuf::PartitionedFile { type Error = DataFusionError; - fn try_from(pf: &PartitionedFile) -> Result { + fn try_from(pf: &PartitionedFile) -> Result { let last_modified = pf.object_meta.last_modified; let last_modified_ns = last_modified.timestamp_nanos_opt().ok_or_else(|| { DataFusionError::Plan(format!( @@ -661,7 +659,7 @@ impl TryFrom<&PartitionedFile> for protobuf::PartitionedFile { impl TryFrom<&FileRange> for protobuf::FileRange { type Error = DataFusionError; - fn try_from(value: &FileRange) -> Result { + fn try_from(value: &FileRange) -> Result { Ok(protobuf::FileRange { start: value.start, end: value.end, @@ -748,59 +746,64 @@ impl From<&ColumnStatistics> for protobuf::ColumnStats { impl TryFrom<&FileScanConfig> for protobuf::FileScanExecConf { type Error = DataFusionError; - fn try_from( - conf: &FileScanConfig, - ) -> Result { - let codec = DefaultPhysicalExtensionCodec {}; - let file_groups = conf - .file_groups - .iter() - .map(|p| p.as_slice().try_into()) - .collect::, _>>()?; - let mut output_orderings = vec![]; - for order in &conf.output_ordering { - let ordering = serialize_physical_sort_exprs(order.to_vec(), &codec)?; - output_orderings.push(ordering) - } - - // Fields must be added to the schema so that they can persist in the protobuf - // and then they are to be removed from the schema in `parse_protobuf_file_scan_config` - let mut fields = conf - .file_schema - .fields() - .iter() - .cloned() - .collect::>(); - fields.extend(conf.table_partition_cols.iter().cloned().map(Arc::new)); - let schema = Arc::new(datafusion::arrow::datatypes::Schema::new(fields.clone())); + fn try_from(conf: &FileScanConfig) -> Result { + serialize_file_scan_config(conf, &DefaultPhysicalExtensionCodec {}) + } +} - Ok(protobuf::FileScanExecConf { - file_groups, - statistics: Some((&conf.statistics).into()), - limit: conf.limit.map(|l| protobuf::ScanLimit { limit: l as u32 }), - projection: conf - .projection - .as_ref() - .unwrap_or(&vec![]) - .iter() - .map(|n| *n as u32) - .collect(), - schema: Some(schema.as_ref().try_into()?), - table_partition_cols: conf - .table_partition_cols - .iter() - .map(|x| x.name().clone()) - .collect::>(), - object_store_url: conf.object_store_url.to_string(), - output_ordering: output_orderings - .into_iter() - .map(|e| PhysicalSortExprNodeCollection { - physical_sort_expr_nodes: e, - }) - .collect::>(), - }) +pub(crate) fn serialize_file_scan_config( + conf: &FileScanConfig, + codec: &dyn PhysicalExtensionCodec, +) -> Result { + let file_groups = conf + .file_groups + .iter() + .map(|p| p.as_slice().try_into()) + .collect::, _>>()?; + + let mut output_orderings = vec![]; + for order in &conf.output_ordering { + let ordering = serialize_physical_sort_exprs(order.to_vec(), codec)?; + output_orderings.push(ordering) } + + // Fields must be added to the schema so that they can persist in the protobuf, + // and then they are to be removed from the schema in `parse_protobuf_file_scan_config` + let mut fields = conf + .file_schema + .fields() + .iter() + .cloned() + .collect::>(); + fields.extend(conf.table_partition_cols.iter().cloned().map(Arc::new)); + let schema = Arc::new(arrow::datatypes::Schema::new(fields.clone())); + + Ok(protobuf::FileScanExecConf { + file_groups, + statistics: Some((&conf.statistics).into()), + limit: conf.limit.map(|l| protobuf::ScanLimit { limit: l as u32 }), + projection: conf + .projection + .as_ref() + .unwrap_or(&vec![]) + .iter() + .map(|n| *n as u32) + .collect(), + schema: Some(schema.as_ref().try_into()?), + table_partition_cols: conf + .table_partition_cols + .iter() + .map(|x| x.name().clone()) + .collect::>(), + object_store_url: conf.object_store_url.to_string(), + output_ordering: output_orderings + .into_iter() + .map(|e| PhysicalSortExprNodeCollection { + physical_sort_expr_nodes: e, + }) + .collect::>(), + }) } impl From for protobuf::JoinSide { @@ -815,43 +818,42 @@ impl From for protobuf::JoinSide { impl TryFrom>> for protobuf::MaybeFilter { type Error = DataFusionError; - fn try_from(expr: Option>) -> Result { - let codec = DefaultPhysicalExtensionCodec {}; - match expr { - None => Ok(protobuf::MaybeFilter { expr: None }), - Some(expr) => Ok(protobuf::MaybeFilter { - expr: Some(serialize_physical_expr(expr, &codec)?), - }), - } + fn try_from(expr: Option>) -> Result { + serialize_maybe_filter(expr, &DefaultPhysicalExtensionCodec {}) + } +} + +pub(crate) fn serialize_maybe_filter( + expr: Option>, + codec: &dyn PhysicalExtensionCodec, +) -> Result { + match expr { + None => Ok(protobuf::MaybeFilter { expr: None }), + Some(expr) => Ok(protobuf::MaybeFilter { + expr: Some(serialize_physical_expr(expr, codec)?), + }), } } impl TryFrom>> for protobuf::MaybePhysicalSortExprs { type Error = DataFusionError; - fn try_from(sort_exprs: Option>) -> Result { - match sort_exprs { - None => Ok(protobuf::MaybePhysicalSortExprs { sort_expr: vec![] }), - Some(sort_exprs) => Ok(protobuf::MaybePhysicalSortExprs { - sort_expr: sort_exprs - .into_iter() - .map(|sort_expr| sort_expr.try_into()) - .collect::>>()?, - }), - } + fn try_from(sort_exprs: Option>) -> Result { + let codec = DefaultPhysicalExtensionCodec {}; + Ok(protobuf::MaybePhysicalSortExprs { + sort_expr: serialize_physical_sort_exprs( + sort_exprs.unwrap_or_default(), + &codec, + )?, + }) } } -impl TryFrom for protobuf::PhysicalSortExprNode { +impl TryFrom for PhysicalSortExprNode { type Error = DataFusionError; - fn try_from(sort_expr: PhysicalSortExpr) -> std::result::Result { - let codec = DefaultPhysicalExtensionCodec {}; - Ok(PhysicalSortExprNode { - expr: Some(Box::new(serialize_physical_expr(sort_expr.expr, &codec)?)), - asc: !sort_expr.options.descending, - nulls_first: sort_expr.options.nulls_first, - }) + fn try_from(sort_expr: PhysicalSortExpr) -> Result { + serialize_physical_sort_expr(sort_expr, &DefaultPhysicalExtensionCodec {}) } } diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 9cc63d5ce1b5..7d71f66d516e 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -37,7 +37,7 @@ use datafusion::datasource::physical_plan::{ }; use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility}; -use datafusion::physical_expr::expressions::NthValueAgg; +use datafusion::physical_expr::expressions::{Count, NthValueAgg}; use datafusion::physical_expr::window::SlidingAggregateWindowExpr; use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr}; use datafusion::physical_plan::aggregates::{ @@ -739,7 +739,8 @@ fn roundtrip_scalar_udf_extension_codec() { let field_text = Field::new("text", DataType::Utf8, true); let field_published = Field::new("published", DataType::Boolean, false); - let schema = Arc::new(Schema::new(vec![field_text, field_published])); + let field_author = Field::new("author", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field_text, field_published, field_author])); let input = Arc::new(EmptyExec::new(schema.clone())); let pattern = ".*"; @@ -758,17 +759,28 @@ fn roundtrip_scalar_udf_extension_codec() { Arc::new(BinaryExpr::new( col("published", &schema).expect("published"), Operator::And, - udf_expr, + udf_expr.clone(), )), input, ) .expect("filter"), ); + let aggregate = Arc::new( + AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::new(vec![], vec![], vec![]), + vec![Arc::new(Count::new(udf_expr, "count", DataType::Boolean))], + vec![None], + filter, + schema.clone(), + ) + .expect("aggregate"), + ); + let ctx = SessionContext::new(); let codec = ScalarUDFExtensionCodec {}; - ctx.register_udf(udf); - roundtrip_test_and_return(filter, &ctx, &codec).unwrap(); + roundtrip_test_and_return(aggregate, &ctx, &codec).unwrap(); } #[test] From 280feda78624c686b93692ed302ff96438bf5978 Mon Sep 17 00:00:00 2001 From: Georgi Krastev Date: Mon, 15 Apr 2024 10:01:15 +0200 Subject: [PATCH 3/4] Add a test for window aggregation with UDF codec --- .../tests/cases/roundtrip_physical_plan.rs | 66 +++++++++++-------- 1 file changed, 38 insertions(+), 28 deletions(-) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 7d71f66d516e..642860d6397b 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -37,7 +37,7 @@ use datafusion::datasource::physical_plan::{ }; use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{create_udf, JoinType, Operator, Volatility}; -use datafusion::physical_expr::expressions::{Count, NthValueAgg}; +use datafusion::physical_expr::expressions::{Count, Max, NthValueAgg}; use datafusion::physical_expr::window::SlidingAggregateWindowExpr; use datafusion::physical_expr::{PhysicalSortRequirement, ScalarFunctionExpr}; use datafusion::physical_plan::aggregates::{ @@ -640,7 +640,7 @@ fn roundtrip_scalar_udf() -> Result<()> { } #[test] -fn roundtrip_scalar_udf_extension_codec() { +fn roundtrip_scalar_udf_extension_codec() -> Result<()> { #[derive(Debug)] struct MyRegexUdf { signature: Signature, @@ -662,18 +662,22 @@ fn roundtrip_scalar_udf_extension_codec() { fn as_any(&self) -> &dyn Any { self } + fn name(&self) -> &str { "regex_udf" } + fn signature(&self) -> &Signature { &self.signature } + fn return_type(&self, args: &[DataType]) -> Result { if !matches!(args.first(), Some(&DataType::Utf8)) { return plan_err!("regex_udf only accepts Utf8 arguments"); } - Ok(DataType::Boolean) + Ok(DataType::Int64) } + fn invoke(&self, _args: &[ColumnarValue]) -> Result { unimplemented!() } @@ -748,39 +752,45 @@ fn roundtrip_scalar_udf_extension_codec() { let udf_expr = Arc::new(ScalarFunctionExpr::new( udf.name(), ScalarFunctionDefinition::UDF(Arc::new(udf.clone())), - vec![col("text", &schema).expect("text")], - DataType::Boolean, + vec![col("text", &schema)?], + DataType::Int64, None, false, )); - let filter = Arc::new( - FilterExec::try_new( - Arc::new(BinaryExpr::new( - col("published", &schema).expect("published"), - Operator::And, - udf_expr.clone(), - )), - input, - ) - .expect("filter"), - ); + let filter = Arc::new(FilterExec::try_new( + Arc::new(BinaryExpr::new( + col("published", &schema)?, + Operator::And, + Arc::new(BinaryExpr::new(udf_expr.clone(), Operator::Gt, lit(0))), + )), + input, + )?); - let aggregate = Arc::new( - AggregateExec::try_new( - AggregateMode::Final, - PhysicalGroupBy::new(vec![], vec![], vec![]), - vec![Arc::new(Count::new(udf_expr, "count", DataType::Boolean))], - vec![None], - filter, - schema.clone(), - ) - .expect("aggregate"), - ); + let window = Arc::new(WindowAggExec::try_new( + vec![Arc::new(PlainAggregateWindowExpr::new( + Arc::new(Max::new(udf_expr.clone(), "max", DataType::Int64)), + &[col("author", &schema)?], + &[], + Arc::new(WindowFrame::new(None)), + ))], + filter, + vec![col("author", &schema)?], + )?); + + let aggregate = Arc::new(AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::new(vec![], vec![], vec![]), + vec![Arc::new(Count::new(udf_expr, "count", DataType::Int64))], + vec![None], + window, + schema.clone(), + )?); let ctx = SessionContext::new(); let codec = ScalarUDFExtensionCodec {}; - roundtrip_test_and_return(aggregate, &ctx, &codec).unwrap(); + roundtrip_test_and_return(aggregate, &ctx, &codec)?; + Ok(()) } #[test] From 1f0040c5d205f6e0cf56aa20cd61f3713b37f4d4 Mon Sep 17 00:00:00 2001 From: Georgi Krastev Date: Mon, 15 Apr 2024 10:27:46 +0200 Subject: [PATCH 4/4] Commit binary incompatible changes --- .../proto/src/physical_plan/from_proto.rs | 63 ++++------------- datafusion/proto/src/physical_plan/mod.rs | 16 ++--- .../proto/src/physical_plan/to_proto.rs | 68 ++----------------- 3 files changed, 27 insertions(+), 120 deletions(-) diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index f2c5b4b080b2..81e4c92ffc68 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -21,13 +21,11 @@ use std::collections::HashMap; use std::convert::{TryFrom, TryInto}; use std::sync::Arc; -use crate::common::proto_error; -use crate::convert_required; -use crate::logical_plan::{self, csv_writer_options_from_proto}; -use crate::protobuf::physical_expr_node::ExprType; -use crate::protobuf::{self, copy_to_node}; - use arrow::compute::SortOptions; +use chrono::{TimeZone, Utc}; +use object_store::path::Path; +use object_store::ObjectMeta; + use datafusion::arrow::datatypes::Schema; use datafusion::datasource::file_format::csv::CsvSink; use datafusion::datasource::file_format::json::JsonSink; @@ -57,13 +55,15 @@ use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; use datafusion_common::{not_impl_err, DataFusionError, JoinSide, Result, ScalarValue}; - -use chrono::{TimeZone, Utc}; use datafusion_expr::ScalarFunctionDefinition; -use object_store::path::Path; -use object_store::ObjectMeta; -use super::{DefaultPhysicalExtensionCodec, PhysicalExtensionCodec}; +use crate::common::proto_error; +use crate::convert_required; +use crate::logical_plan::{self, csv_writer_options_from_proto}; +use crate::protobuf::physical_expr_node::ExprType; +use crate::protobuf::{self, copy_to_node}; + +use super::PhysicalExtensionCodec; impl From<&protobuf::PhysicalColumn> for Column { fn from(c: &protobuf::PhysicalColumn) -> Column { @@ -130,24 +130,11 @@ pub fn parse_physical_sort_exprs( /// * `registry` - A registry knows how to build logical expressions out of user-defined function names /// * `input_schema` - The Arrow schema for the input, used for determining expression data types /// when performing type coercion. +/// * `codec` - An extension codec used to decode custom UDFs. pub fn parse_physical_window_expr( proto: &protobuf::PhysicalWindowExprNode, registry: &dyn FunctionRegistry, input_schema: &Schema, -) -> Result> { - parse_physical_window_expr_ext( - proto, - registry, - input_schema, - &DefaultPhysicalExtensionCodec {}, - ) -} - -// TODO: Make this the public function on next major release. -pub(crate) fn parse_physical_window_expr_ext( - proto: &protobuf::PhysicalWindowExprNode, - registry: &dyn FunctionRegistry, - input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, ) -> Result> { let window_node_expr = @@ -463,20 +450,6 @@ pub fn parse_protobuf_hash_partitioning( partitioning: Option<&protobuf::PhysicalHashRepartition>, registry: &dyn FunctionRegistry, input_schema: &Schema, -) -> Result> { - parse_protobuf_hash_partitioning_ext( - partitioning, - registry, - input_schema, - &DefaultPhysicalExtensionCodec {}, - ) -} - -// TODO: Make this the public function on next major release. -fn parse_protobuf_hash_partitioning_ext( - partitioning: Option<&protobuf::PhysicalHashRepartition>, - registry: &dyn FunctionRegistry, - input_schema: &Schema, codec: &dyn PhysicalExtensionCodec, ) -> Result> { match partitioning { @@ -500,18 +473,6 @@ fn parse_protobuf_hash_partitioning_ext( pub fn parse_protobuf_file_scan_config( proto: &protobuf::FileScanExecConf, registry: &dyn FunctionRegistry, -) -> Result { - parse_protobuf_file_scan_config_ext( - proto, - registry, - &DefaultPhysicalExtensionCodec {}, - ) -} - -// TODO: Make this the public function on next major release. -pub(crate) fn parse_protobuf_file_scan_config_ext( - proto: &protobuf::FileScanExecConf, - registry: &dyn FunctionRegistry, codec: &dyn PhysicalExtensionCodec, ) -> Result { let schema: Arc = Arc::new(convert_required!(proto.schema)?); diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index b8c0354bfd6a..a481e7090fb3 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -69,8 +69,7 @@ use crate::common::{byte_to_string, proto_error, str_to_byte}; use crate::convert_required; use crate::physical_plan::from_proto::{ parse_physical_expr, parse_physical_sort_expr, parse_physical_sort_exprs, - parse_physical_window_expr_ext, parse_protobuf_file_scan_config, - parse_protobuf_file_scan_config_ext, + parse_physical_window_expr, parse_protobuf_file_scan_config, }; use crate::physical_plan::to_proto::{ serialize_file_scan_config, serialize_maybe_filter, serialize_physical_aggr_expr, @@ -190,7 +189,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { } } PhysicalPlanType::CsvScan(scan) => Ok(Arc::new(CsvExec::new( - parse_protobuf_file_scan_config_ext( + parse_protobuf_file_scan_config( scan.base_conf.as_ref().unwrap(), registry, extension_codec, @@ -213,6 +212,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { let base_config = parse_protobuf_file_scan_config( scan.base_conf.as_ref().unwrap(), registry, + extension_codec, )?; let predicate = scan .predicate @@ -233,13 +233,13 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { Default::default(), ))) } - PhysicalPlanType::AvroScan(scan) => Ok(Arc::new(AvroExec::new( - parse_protobuf_file_scan_config_ext( + PhysicalPlanType::AvroScan(scan) => { + Ok(Arc::new(AvroExec::new(parse_protobuf_file_scan_config( scan.base_conf.as_ref().unwrap(), registry, extension_codec, - )?, - ))), + )?))) + } PhysicalPlanType::CoalesceBatches(coalesce_batches) => { let input: Arc = into_physical_plan( &coalesce_batches.input, @@ -338,7 +338,7 @@ impl AsExecutionPlan for protobuf::PhysicalPlanNode { .window_expr .iter() .map(|window_expr| { - parse_physical_window_expr_ext( + parse_physical_window_expr( window_expr, registry, input_schema.as_ref(), diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 5728fe45d9bb..b4c23e4d0c3c 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -70,17 +70,9 @@ use crate::protobuf::{ PhysicalSortExprNodeCollection, ScalarValue, }; -use super::{DefaultPhysicalExtensionCodec, PhysicalExtensionCodec}; +use super::PhysicalExtensionCodec; -impl TryFrom> for protobuf::PhysicalExprNode { - type Error = DataFusionError; - - fn try_from(a: Arc) -> Result { - serialize_physical_aggr_expr(a, &DefaultPhysicalExtensionCodec {}) - } -} - -pub(crate) fn serialize_physical_aggr_expr( +pub fn serialize_physical_aggr_expr( aggr_expr: Arc, codec: &dyn PhysicalExtensionCodec, ) -> Result { @@ -123,15 +115,7 @@ pub(crate) fn serialize_physical_aggr_expr( }) } -impl TryFrom> for protobuf::PhysicalWindowExprNode { - type Error = DataFusionError; - - fn try_from(window_expr: Arc) -> Result { - serialize_physical_window_expr(window_expr, &DefaultPhysicalExtensionCodec {}) - } -} - -pub(crate) fn serialize_physical_window_expr( +pub fn serialize_physical_window_expr( window_expr: Arc, codec: &dyn PhysicalExtensionCodec, ) -> Result { @@ -454,7 +438,7 @@ pub fn serialize_physical_expr( .when_then_expr() .iter() .map(|(when_expr, then_expr)| { - try_parse_when_then_expr(when_expr, then_expr, codec) + serialize_when_then_expr(when_expr, then_expr, codec) }) .collect::, @@ -621,7 +605,7 @@ pub fn serialize_physical_expr( } } -fn try_parse_when_then_expr( +fn serialize_when_then_expr( when_expr: &Arc, then_expr: &Arc, codec: &dyn PhysicalExtensionCodec, @@ -744,15 +728,7 @@ impl From<&ColumnStatistics> for protobuf::ColumnStats { } } -impl TryFrom<&FileScanConfig> for protobuf::FileScanExecConf { - type Error = DataFusionError; - - fn try_from(conf: &FileScanConfig) -> Result { - serialize_file_scan_config(conf, &DefaultPhysicalExtensionCodec {}) - } -} - -pub(crate) fn serialize_file_scan_config( +pub fn serialize_file_scan_config( conf: &FileScanConfig, codec: &dyn PhysicalExtensionCodec, ) -> Result { @@ -815,15 +791,7 @@ impl From for protobuf::JoinSide { } } -impl TryFrom>> for protobuf::MaybeFilter { - type Error = DataFusionError; - - fn try_from(expr: Option>) -> Result { - serialize_maybe_filter(expr, &DefaultPhysicalExtensionCodec {}) - } -} - -pub(crate) fn serialize_maybe_filter( +pub fn serialize_maybe_filter( expr: Option>, codec: &dyn PhysicalExtensionCodec, ) -> Result { @@ -835,28 +803,6 @@ pub(crate) fn serialize_maybe_filter( } } -impl TryFrom>> for protobuf::MaybePhysicalSortExprs { - type Error = DataFusionError; - - fn try_from(sort_exprs: Option>) -> Result { - let codec = DefaultPhysicalExtensionCodec {}; - Ok(protobuf::MaybePhysicalSortExprs { - sort_expr: serialize_physical_sort_exprs( - sort_exprs.unwrap_or_default(), - &codec, - )?, - }) - } -} - -impl TryFrom for PhysicalSortExprNode { - type Error = DataFusionError; - - fn try_from(sort_expr: PhysicalSortExpr) -> Result { - serialize_physical_sort_expr(sort_expr, &DefaultPhysicalExtensionCodec {}) - } -} - impl TryFrom<&JsonSink> for protobuf::JsonSink { type Error = DataFusionError;