diff --git a/datafusion/core/src/physical_optimizer/projection_pushdown.rs b/datafusion/core/src/physical_optimizer/projection_pushdown.rs index e8f3bf01ecaa..08a78013b298 100644 --- a/datafusion/core/src/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/src/physical_optimizer/projection_pushdown.rs @@ -1287,6 +1287,7 @@ fn new_join_children( #[cfg(test)] mod tests { use super::*; + use std::any::Any; use std::sync::Arc; use crate::datasource::file_format::file_compression_type::FileCompressionType; @@ -1313,7 +1314,10 @@ mod tests { use datafusion_common::{JoinSide, JoinType, Result, ScalarValue, Statistics}; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; - use datafusion_expr::{ColumnarValue, Operator}; + use datafusion_expr::{ + ColumnarValue, Operator, ScalarFunctionDefinition, ScalarUDF, ScalarUDFImpl, + Signature, Volatility, + }; use datafusion_physical_expr::expressions::{ BinaryExpr, CaseExpr, CastExpr, Column, Literal, NegativeExpr, }; @@ -1329,6 +1333,42 @@ mod tests { use itertools::Itertools; + /// Mocked UDF + #[derive(Debug)] + struct DummyUDF { + signature: Signature, + } + + impl DummyUDF { + fn new() -> Self { + Self { + signature: Signature::variadic_any(Volatility::Immutable), + } + } + } + + impl ScalarUDFImpl for DummyUDF { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "dummy_udf" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Int32) + } + + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!("DummyUDF::invoke") + } + } + #[test] fn test_update_matching_exprs() -> Result<()> { let exprs: Vec> = vec![ @@ -1345,7 +1385,9 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( + DummyUDF::new(), + ))), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b", 1)), @@ -1412,7 +1454,9 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 5)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( + DummyUDF::new(), + ))), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b", 1)), @@ -1482,7 +1526,9 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f", 4)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( + DummyUDF::new(), + ))), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b", 1)), @@ -1549,7 +1595,9 @@ mod tests { Arc::new(NegativeExpr::new(Arc::new(Column::new("f_new", 5)))), Arc::new(ScalarFunctionExpr::new( "scalar_expr", - Arc::new(|_: &[ColumnarValue]| unimplemented!("not implemented")), + ScalarFunctionDefinition::UDF(Arc::new(ScalarUDF::new_from_impl( + DummyUDF::new(), + ))), vec![ Arc::new(BinaryExpr::new( Arc::new(Column::new("b_new", 1)), diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index ff653192c02a..3c1c0663235a 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -44,6 +44,7 @@ use arrow_array::Array; use datafusion_common::{exec_err, Result, ScalarValue}; use datafusion_expr::execution_props::ExecutionProps; pub use datafusion_expr::FuncMonotonicity; +use datafusion_expr::ScalarFunctionDefinition; use datafusion_expr::{ type_coercion::functions::data_types, BuiltinScalarFunction, ColumnarValue, ScalarFunctionImplementation, @@ -57,7 +58,7 @@ pub fn create_physical_expr( fun: &BuiltinScalarFunction, input_phy_exprs: &[Arc], input_schema: &Schema, - execution_props: &ExecutionProps, + _execution_props: &ExecutionProps, ) -> Result> { let input_expr_types = input_phy_exprs .iter() @@ -69,14 +70,12 @@ pub fn create_physical_expr( let data_type = fun.return_type(&input_expr_types)?; - let fun_expr: ScalarFunctionImplementation = - create_physical_fun(fun, execution_props)?; - let monotonicity = fun.monotonicity(); + let fun_def = ScalarFunctionDefinition::BuiltIn(*fun); Ok(Arc::new(ScalarFunctionExpr::new( &format!("{fun}"), - fun_expr, + fun_def, input_phy_exprs.to_vec(), data_type, monotonicity, @@ -195,7 +194,6 @@ where /// Create a physical scalar function. pub fn create_physical_fun( fun: &BuiltinScalarFunction, - _execution_props: &ExecutionProps, ) -> Result { Ok(match fun { // math functions diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 1c9f0e609c3c..d34084236690 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -34,22 +34,22 @@ use std::fmt::{self, Debug, Formatter}; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::functions::out_ordering; +use crate::functions::{create_physical_fun, out_ordering}; use crate::physical_expr::{down_cast_any_ref, physical_exprs_equal}; use crate::sort_properties::SortProperties; use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; -use datafusion_common::Result; +use datafusion_common::{internal_err, Result}; use datafusion_expr::{ expr_vec_fmt, BuiltinScalarFunction, ColumnarValue, FuncMonotonicity, - ScalarFunctionImplementation, + ScalarFunctionDefinition, }; /// Physical expression of a scalar function pub struct ScalarFunctionExpr { - fun: ScalarFunctionImplementation, + fun: ScalarFunctionDefinition, name: String, args: Vec>, return_type: DataType, @@ -79,7 +79,7 @@ impl ScalarFunctionExpr { /// Create a new Scalar function pub fn new( name: &str, - fun: ScalarFunctionImplementation, + fun: ScalarFunctionDefinition, args: Vec>, return_type: DataType, monotonicity: Option, @@ -96,7 +96,7 @@ impl ScalarFunctionExpr { } /// Get the scalar function implementation - pub fn fun(&self) -> &ScalarFunctionImplementation { + pub fn fun(&self) -> &ScalarFunctionDefinition { &self.fun } @@ -172,8 +172,18 @@ impl PhysicalExpr for ScalarFunctionExpr { }; // evaluate the function - let fun = self.fun.as_ref(); - (fun)(&inputs) + match self.fun { + ScalarFunctionDefinition::BuiltIn(ref fun) => { + let fun = create_physical_fun(fun)?; + (fun)(&inputs) + } + ScalarFunctionDefinition::UDF(ref fun) => fun.invoke(&inputs), + ScalarFunctionDefinition::Name(_) => { + internal_err!( + "Name function must be resolved to one of the other variants prior to physical planning" + ) + } + } } fn children(&self) -> Vec> { diff --git a/datafusion/physical-expr/src/udf.rs b/datafusion/physical-expr/src/udf.rs index ede3e5badbb1..4fc94bfa15ec 100644 --- a/datafusion/physical-expr/src/udf.rs +++ b/datafusion/physical-expr/src/udf.rs @@ -20,7 +20,9 @@ use crate::{PhysicalExpr, ScalarFunctionExpr}; use arrow_schema::Schema; use datafusion_common::{DFSchema, Result}; pub use datafusion_expr::ScalarUDF; -use datafusion_expr::{type_coercion::functions::data_types, Expr}; +use datafusion_expr::{ + type_coercion::functions::data_types, Expr, ScalarFunctionDefinition, +}; use std::sync::Arc; /// Create a physical expression of the UDF. @@ -45,9 +47,10 @@ pub fn create_physical_expr( let return_type = fun.return_type_from_exprs(args, input_dfschema, &input_expr_types)?; + let fun_def = ScalarFunctionDefinition::UDF(Arc::new(fun.clone())); Ok(Arc::new(ScalarFunctionExpr::new( fun.name(), - fun.fun(), + fun_def, input_phy_exprs.to_vec(), return_type, fun.monotonicity()?, diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index b5683dc1425e..1d25463cdb85 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1458,6 +1458,7 @@ message PhysicalExprNode { message PhysicalScalarUdfNode { string name = 1; repeated PhysicalExprNode args = 2; + optional bytes fun_definition = 3; ArrowType return_type = 4; } diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index f5be49dc9de7..d93a49be8e1e 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -20391,6 +20391,9 @@ impl serde::Serialize for PhysicalScalarUdfNode { if !self.args.is_empty() { len += 1; } + if self.fun_definition.is_some() { + len += 1; + } if self.return_type.is_some() { len += 1; } @@ -20401,6 +20404,10 @@ impl serde::Serialize for PhysicalScalarUdfNode { if !self.args.is_empty() { struct_ser.serialize_field("args", &self.args)?; } + if let Some(v) = self.fun_definition.as_ref() { + #[allow(clippy::needless_borrow)] + struct_ser.serialize_field("funDefinition", pbjson::private::base64::encode(&v).as_str())?; + } if let Some(v) = self.return_type.as_ref() { struct_ser.serialize_field("returnType", v)?; } @@ -20416,6 +20423,8 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { const FIELDS: &[&str] = &[ "name", "args", + "fun_definition", + "funDefinition", "return_type", "returnType", ]; @@ -20424,6 +20433,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { enum GeneratedField { Name, Args, + FunDefinition, ReturnType, } impl<'de> serde::Deserialize<'de> for GeneratedField { @@ -20448,6 +20458,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { match value { "name" => Ok(GeneratedField::Name), "args" => Ok(GeneratedField::Args), + "funDefinition" | "fun_definition" => Ok(GeneratedField::FunDefinition), "returnType" | "return_type" => Ok(GeneratedField::ReturnType), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } @@ -20470,6 +20481,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { { let mut name__ = None; let mut args__ = None; + let mut fun_definition__ = None; let mut return_type__ = None; while let Some(k) = map_.next_key()? { match k { @@ -20485,6 +20497,14 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { } args__ = Some(map_.next_value()?); } + GeneratedField::FunDefinition => { + if fun_definition__.is_some() { + return Err(serde::de::Error::duplicate_field("funDefinition")); + } + fun_definition__ = + map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| x.0) + ; + } GeneratedField::ReturnType => { if return_type__.is_some() { return Err(serde::de::Error::duplicate_field("returnType")); @@ -20496,6 +20516,7 @@ impl<'de> serde::Deserialize<'de> for PhysicalScalarUdfNode { Ok(PhysicalScalarUdfNode { name: name__.unwrap_or_default(), args: args__.unwrap_or_default(), + fun_definition: fun_definition__, return_type: return_type__, }) } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index e1c9af105bbd..8b025028dc6b 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2092,6 +2092,8 @@ pub struct PhysicalScalarUdfNode { pub name: ::prost::alloc::string::String, #[prost(message, repeated, tag = "2")] pub args: ::prost::alloc::vec::Vec, + #[prost(bytes = "vec", optional, tag = "3")] + pub fun_definition: ::core::option::Option<::prost::alloc::vec::Vec>, #[prost(message, optional, tag = "4")] pub return_type: ::core::option::Option, } diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index 184c048c1bdd..ca54d4e803ca 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -59,9 +59,12 @@ use datafusion_common::stats::Precision; use datafusion_common::{not_impl_err, DataFusionError, JoinSide, Result, ScalarValue}; use chrono::{TimeZone, Utc}; +use datafusion_expr::ScalarFunctionDefinition; use object_store::path::Path; use object_store::ObjectMeta; +use super::{DefaultPhysicalExtensionCodec, PhysicalExtensionCodec}; + impl From<&protobuf::PhysicalColumn> for Column { fn from(c: &protobuf::PhysicalColumn) -> Column { Column::new(&c.name, c.index as usize) @@ -82,7 +85,8 @@ pub fn parse_physical_sort_expr( input_schema: &Schema, ) -> Result { if let Some(expr) = &proto.expr { - let expr = parse_physical_expr(expr.as_ref(), registry, input_schema)?; + let codec = DefaultPhysicalExtensionCodec {}; + let expr = parse_physical_expr(expr.as_ref(), registry, input_schema, &codec)?; let options = SortOptions { descending: !proto.asc, nulls_first: proto.nulls_first, @@ -110,7 +114,9 @@ pub fn parse_physical_sort_exprs( .iter() .map(|sort_expr| { if let Some(expr) = &sort_expr.expr { - let expr = parse_physical_expr(expr.as_ref(), registry, input_schema)?; + let codec = DefaultPhysicalExtensionCodec {}; + let expr = + parse_physical_expr(expr.as_ref(), registry, input_schema, &codec)?; let options = SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -137,16 +143,17 @@ pub fn parse_physical_window_expr( registry: &dyn FunctionRegistry, input_schema: &Schema, ) -> Result> { + let codec = DefaultPhysicalExtensionCodec {}; let window_node_expr = proto .args .iter() - .map(|e| parse_physical_expr(e, registry, input_schema)) + .map(|e| parse_physical_expr(e, registry, input_schema, &codec)) .collect::>>()?; let partition_by = proto .partition_by .iter() - .map(|p| parse_physical_expr(p, registry, input_schema)) + .map(|p| parse_physical_expr(p, registry, input_schema, &codec)) .collect::>>()?; let order_by = proto @@ -191,6 +198,7 @@ pub fn parse_physical_expr( proto: &protobuf::PhysicalExprNode, registry: &dyn FunctionRegistry, input_schema: &Schema, + codec: &dyn PhysicalExtensionCodec, ) -> Result> { let expr_type = proto .expr_type @@ -270,7 +278,7 @@ pub fn parse_physical_expr( )?, e.list .iter() - .map(|x| parse_physical_expr(x, registry, input_schema)) + .map(|x| parse_physical_expr(x, registry, input_schema, codec)) .collect::, _>>()?, &e.negated, input_schema, @@ -278,7 +286,7 @@ pub fn parse_physical_expr( ExprType::Case(e) => Arc::new(CaseExpr::try_new( e.expr .as_ref() - .map(|e| parse_physical_expr(e.as_ref(), registry, input_schema)) + .map(|e| parse_physical_expr(e.as_ref(), registry, input_schema, codec)) .transpose()?, e.when_then_expr .iter() @@ -301,7 +309,7 @@ pub fn parse_physical_expr( .collect::>>()?, e.else_expr .as_ref() - .map(|e| parse_physical_expr(e.as_ref(), registry, input_schema)) + .map(|e| parse_physical_expr(e.as_ref(), registry, input_schema, codec)) .transpose()?, )?), ExprType::Cast(e) => Arc::new(CastExpr::new( @@ -334,7 +342,7 @@ pub fn parse_physical_expr( let args = e .args .iter() - .map(|x| parse_physical_expr(x, registry, input_schema)) + .map(|x| parse_physical_expr(x, registry, input_schema, codec)) .collect::, _>>()?; // TODO Do not create new the ExecutionProps @@ -348,19 +356,22 @@ pub fn parse_physical_expr( )? } ExprType::ScalarUdf(e) => { - let udf = registry.udf(e.name.as_str())?; + let udf = match &e.fun_definition { + Some(buf) => codec.try_decode_udf(&e.name, buf)?, + None => registry.udf(e.name.as_str())?, + }; let signature = udf.signature(); - let scalar_fun = udf.fun().clone(); + let scalar_fun_def = ScalarFunctionDefinition::UDF(udf.clone()); let args = e .args .iter() - .map(|x| parse_physical_expr(x, registry, input_schema)) + .map(|x| parse_physical_expr(x, registry, input_schema, codec)) .collect::, _>>()?; Arc::new(ScalarFunctionExpr::new( e.name.as_str(), - scalar_fun, + scalar_fun_def, args, convert_required!(e.return_type)?, None, @@ -394,7 +405,8 @@ fn parse_required_physical_expr( field: &str, input_schema: &Schema, ) -> Result> { - expr.map(|e| parse_physical_expr(e, registry, input_schema)) + let codec = DefaultPhysicalExtensionCodec {}; + expr.map(|e| parse_physical_expr(e, registry, input_schema, &codec)) .transpose()? .ok_or_else(|| { DataFusionError::Internal(format!("Missing required field {field:?}")) @@ -439,10 +451,11 @@ pub fn parse_protobuf_hash_partitioning( ) -> Result> { match partitioning { Some(hash_part) => { + let codec = DefaultPhysicalExtensionCodec {}; let expr = hash_part .hash_expr .iter() - .map(|e| parse_physical_expr(e, registry, input_schema)) + .map(|e| parse_physical_expr(e, registry, input_schema, &codec)) .collect::>, _>>()?; Ok(Some(Partitioning::Hash( @@ -503,6 +516,7 @@ pub fn parse_protobuf_file_scan_config( let mut output_ordering = vec![]; for node_collection in &proto.output_ordering { + let codec = DefaultPhysicalExtensionCodec {}; let sort_expr = node_collection .physical_sort_expr_nodes .iter() @@ -510,7 +524,7 @@ pub fn parse_protobuf_file_scan_config( let expr = node .expr .as_ref() - .map(|e| parse_physical_expr(e.as_ref(), registry, &schema)) + .map(|e| parse_physical_expr(e.as_ref(), registry, &schema, &codec)) .unwrap()?; Ok(PhysicalSortExpr { expr, diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 004948da938f..da31c5e762bc 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -20,6 +20,7 @@ use std::fmt::Debug; use std::sync::Arc; use self::from_proto::parse_physical_window_expr; +use self::to_proto::serialize_physical_expr; use crate::common::{byte_to_string, proto_error, str_to_byte}; use crate::convert_required; @@ -138,7 +139,12 @@ impl AsExecutionPlan for PhysicalPlanNode { .zip(projection.expr_name.iter()) .map(|(expr, name)| { Ok(( - parse_physical_expr(expr, registry, input.schema().as_ref())?, + parse_physical_expr( + expr, + registry, + input.schema().as_ref(), + extension_codec, + )?, name.to_string(), )) }) @@ -156,7 +162,12 @@ impl AsExecutionPlan for PhysicalPlanNode { .expr .as_ref() .map(|expr| { - parse_physical_expr(expr, registry, input.schema().as_ref()) + parse_physical_expr( + expr, + registry, + input.schema().as_ref(), + extension_codec, + ) }) .transpose()? .ok_or_else(|| { @@ -208,6 +219,7 @@ impl AsExecutionPlan for PhysicalPlanNode { expr, registry, base_config.file_schema.as_ref(), + extension_codec, ) }) .transpose()?; @@ -254,7 +266,12 @@ impl AsExecutionPlan for PhysicalPlanNode { .hash_expr .iter() .map(|e| { - parse_physical_expr(e, registry, input.schema().as_ref()) + parse_physical_expr( + e, + registry, + input.schema().as_ref(), + extension_codec, + ) }) .collect::>, _>>()?; @@ -329,7 +346,12 @@ impl AsExecutionPlan for PhysicalPlanNode { .partition_keys .iter() .map(|expr| { - parse_physical_expr(expr, registry, input.schema().as_ref()) + parse_physical_expr( + expr, + registry, + input.schema().as_ref(), + extension_codec, + ) }) .collect::>>>()?; @@ -396,8 +418,13 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .zip(hash_agg.group_expr_name.iter()) .map(|(expr, name)| { - parse_physical_expr(expr, registry, input.schema().as_ref()) - .map(|expr| (expr, name.to_string())) + parse_physical_expr( + expr, + registry, + input.schema().as_ref(), + extension_codec, + ) + .map(|expr| (expr, name.to_string())) }) .collect::, _>>()?; @@ -406,8 +433,13 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .zip(hash_agg.group_expr_name.iter()) .map(|(expr, name)| { - parse_physical_expr(expr, registry, input.schema().as_ref()) - .map(|expr| (expr, name.to_string())) + parse_physical_expr( + expr, + registry, + input.schema().as_ref(), + extension_codec, + ) + .map(|expr| (expr, name.to_string())) }) .collect::, _>>()?; @@ -434,7 +466,14 @@ impl AsExecutionPlan for PhysicalPlanNode { .map(|expr| { expr.expr .as_ref() - .map(|e| parse_physical_expr(e, registry, &physical_schema)) + .map(|e| { + parse_physical_expr( + e, + registry, + &physical_schema, + extension_codec, + ) + }) .transpose() }) .collect::, _>>()?; @@ -451,7 +490,7 @@ impl AsExecutionPlan for PhysicalPlanNode { match expr_type { ExprType::AggregateExpr(agg_node) => { let input_phy_expr: Vec> = agg_node.expr.iter() - .map(|e| parse_physical_expr(e, registry, &physical_schema).unwrap()).collect(); + .map(|e| parse_physical_expr(e, registry, &physical_schema, extension_codec).unwrap()).collect(); let ordering_req: Vec = agg_node.ordering_req.iter() .map(|e| parse_physical_sort_expr(e, registry, &physical_schema).unwrap()).collect(); agg_node.aggregate_function.as_ref().map(|func| { @@ -524,11 +563,13 @@ impl AsExecutionPlan for PhysicalPlanNode { &col.left.clone().unwrap(), registry, left_schema.as_ref(), + extension_codec, )?; let right = parse_physical_expr( &col.right.clone().unwrap(), registry, right_schema.as_ref(), + extension_codec, )?; Ok((left, right)) }) @@ -555,6 +596,7 @@ impl AsExecutionPlan for PhysicalPlanNode { proto_error("Unexpected empty filter expression") })?, registry, &schema, + extension_codec, )?; let column_indices = f.column_indices .iter() @@ -635,11 +677,13 @@ impl AsExecutionPlan for PhysicalPlanNode { &col.left.clone().unwrap(), registry, left_schema.as_ref(), + extension_codec, )?; let right = parse_physical_expr( &col.right.clone().unwrap(), registry, right_schema.as_ref(), + extension_codec, )?; Ok((left, right)) }) @@ -666,6 +710,7 @@ impl AsExecutionPlan for PhysicalPlanNode { proto_error("Unexpected empty filter expression") })?, registry, &schema, + extension_codec, )?; let column_indices = f.column_indices .iter() @@ -805,7 +850,7 @@ impl AsExecutionPlan for PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr, registry, input.schema().as_ref())?, + expr: parse_physical_expr(expr, registry, input.schema().as_ref(), extension_codec)?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -852,7 +897,7 @@ impl AsExecutionPlan for PhysicalPlanNode { })? .as_ref(); Ok(PhysicalSortExpr { - expr: parse_physical_expr(expr, registry, input.schema().as_ref())?, + expr: parse_physical_expr(expr, registry, input.schema().as_ref(), extension_codec)?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -916,6 +961,7 @@ impl AsExecutionPlan for PhysicalPlanNode { proto_error("Unexpected empty filter expression") })?, registry, &schema, + extension_codec, )?; let column_indices = f.column_indices .iter() @@ -1088,7 +1134,7 @@ impl AsExecutionPlan for PhysicalPlanNode { let expr = exec .expr() .iter() - .map(|expr| expr.0.clone().try_into()) + .map(|expr| serialize_physical_expr(expr.0.clone(), extension_codec)) .collect::>>()?; let expr_name = exec.expr().iter().map(|expr| expr.1.clone()).collect(); return Ok(protobuf::PhysicalPlanNode { @@ -1128,7 +1174,10 @@ impl AsExecutionPlan for PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::Filter(Box::new( protobuf::FilterExecNode { input: Some(Box::new(input)), - expr: Some(exec.predicate().clone().try_into()?), + expr: Some(serialize_physical_expr( + exec.predicate().clone(), + extension_codec, + )?), default_filter_selectivity: exec.default_selectivity() as u32, }, ))), @@ -1183,8 +1232,8 @@ impl AsExecutionPlan for PhysicalPlanNode { .on() .iter() .map(|tuple| { - let l = tuple.0.to_owned().try_into()?; - let r = tuple.1.to_owned().try_into()?; + let l = serialize_physical_expr(tuple.0.to_owned(), extension_codec)?; + let r = serialize_physical_expr(tuple.1.to_owned(), extension_codec)?; Ok::<_, DataFusionError>(protobuf::JoinOn { left: Some(l), right: Some(r), @@ -1196,7 +1245,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .filter() .as_ref() .map(|f| { - let expression = f.expression().to_owned().try_into()?; + let expression = serialize_physical_expr( + f.expression().to_owned(), + extension_codec, + )?; let column_indices = f .column_indices() .iter() @@ -1254,8 +1306,8 @@ impl AsExecutionPlan for PhysicalPlanNode { .on() .iter() .map(|tuple| { - let l = tuple.0.to_owned().try_into()?; - let r = tuple.1.to_owned().try_into()?; + let l = serialize_physical_expr(tuple.0.to_owned(), extension_codec)?; + let r = serialize_physical_expr(tuple.1.to_owned(), extension_codec)?; Ok::<_, DataFusionError>(protobuf::JoinOn { left: Some(l), right: Some(r), @@ -1267,7 +1319,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .filter() .as_ref() .map(|f| { - let expression = f.expression().to_owned().try_into()?; + let expression = serialize_physical_expr( + f.expression().to_owned(), + extension_codec, + )?; let column_indices = f .column_indices() .iter() @@ -1304,7 +1359,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .map(|expr| { Ok(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + expr: Some(Box::new(serialize_physical_expr( + expr.expr.to_owned(), + extension_codec, + )?)), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }) @@ -1321,7 +1379,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .map(|expr| { Ok(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + expr: Some(Box::new(serialize_physical_expr( + expr.expr.to_owned(), + extension_codec, + )?)), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }) @@ -1423,14 +1484,14 @@ impl AsExecutionPlan for PhysicalPlanNode { .group_expr() .null_expr() .iter() - .map(|expr| expr.0.to_owned().try_into()) + .map(|expr| serialize_physical_expr(expr.0.to_owned(), extension_codec)) .collect::>>()?; let group_expr = exec .group_expr() .expr() .iter() - .map(|expr| expr.0.to_owned().try_into()) + .map(|expr| serialize_physical_expr(expr.0.to_owned(), extension_codec)) .collect::>>()?; return Ok(protobuf::PhysicalPlanNode { @@ -1512,7 +1573,7 @@ impl AsExecutionPlan for PhysicalPlanNode { if let Some(exec) = plan.downcast_ref::() { let predicate = exec .predicate() - .map(|pred| pred.clone().try_into()) + .map(|pred| serialize_physical_expr(pred.clone(), extension_codec)) .transpose()?; return Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::ParquetScan( @@ -1559,7 +1620,9 @@ impl AsExecutionPlan for PhysicalPlanNode { PartitionMethod::Hash(protobuf::PhysicalHashRepartition { hash_expr: exprs .iter() - .map(|expr| expr.clone().try_into()) + .map(|expr| { + serialize_physical_expr(expr.clone(), extension_codec) + }) .collect::>>()?, partition_count: *partition_count as u64, }) @@ -1592,7 +1655,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .map(|expr| { let sort_expr = Box::new(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + expr: Some(Box::new(serialize_physical_expr( + expr.expr.to_owned(), + extension_codec, + )?)), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }); @@ -1658,7 +1724,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .iter() .map(|expr| { let sort_expr = Box::new(protobuf::PhysicalSortExprNode { - expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + expr: Some(Box::new(serialize_physical_expr( + expr.expr.to_owned(), + extension_codec, + )?)), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }); @@ -1695,7 +1764,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .filter() .as_ref() .map(|f| { - let expression = f.expression().to_owned().try_into()?; + let expression = serialize_physical_expr( + f.expression().to_owned(), + extension_codec, + )?; let column_indices = f .column_indices() .iter() @@ -1743,7 +1815,7 @@ impl AsExecutionPlan for PhysicalPlanNode { let partition_keys = exec .partition_keys .iter() - .map(|e| e.clone().try_into()) + .map(|e| serialize_physical_expr(e.clone(), extension_codec)) .collect::>>()?; return Ok(protobuf::PhysicalPlanNode { @@ -1773,7 +1845,7 @@ impl AsExecutionPlan for PhysicalPlanNode { let partition_keys = exec .partition_keys .iter() - .map(|e| e.clone().try_into()) + .map(|e| serialize_physical_expr(e.clone(), extension_codec)) .collect::>>()?; let input_order_mode = match &exec.input_order_mode { @@ -1816,7 +1888,10 @@ impl AsExecutionPlan for PhysicalPlanNode { .map(|requirement| { let expr: PhysicalSortExpr = requirement.to_owned().into(); let sort_expr = protobuf::PhysicalSortExprNode { - expr: Some(Box::new(expr.expr.to_owned().try_into()?)), + expr: Some(Box::new(serialize_physical_expr( + expr.expr.to_owned(), + extension_codec, + )?)), asc: !expr.options.descending, nulls_first: expr.options.nulls_first, }; diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index ba77b30b7f8d..b66709d0c5bd 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -22,7 +22,6 @@ use std::{ sync::Arc, }; -use crate::logical_plan::csv_writer_options_to_proto; use crate::protobuf::{ self, copy_to_node, physical_aggregate_expr_node, physical_window_expr_node, scalar_value::Value, ArrowOptions, AvroOptions, PhysicalSortExprNode, @@ -31,13 +30,10 @@ use crate::protobuf::{ #[cfg(feature = "parquet")] use datafusion::datasource::file_format::parquet::ParquetSink; -use datafusion::datasource::{ - file_format::csv::CsvSink, - file_format::json::JsonSink, - listing::{FileRange, PartitionedFile}, - physical_plan::FileScanConfig, - physical_plan::FileSinkConfig, -}; + +use datafusion_expr::ScalarFunctionDefinition; + +use crate::logical_plan::csv_writer_options_to_proto; use datafusion::logical_expr::BuiltinScalarFunction; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; @@ -46,16 +42,24 @@ use datafusion::physical_plan::expressions::{ ArrayAgg, Avg, BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, CastExpr, Column, Correlation, Count, Covariance, CovariancePop, CumeDist, DistinctArrayAgg, DistinctBitXor, DistinctCount, DistinctSum, FirstValue, Grouping, - InListExpr, IsNotNullExpr, IsNullExpr, LastValue, LikeExpr, Literal, Max, Median, - Min, NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, - Rank, RankType, Regr, RegrType, RowNumber, Stddev, StddevPop, StringAgg, Sum, - TryCastExpr, Variance, VariancePop, WindowShift, + InListExpr, IsNotNullExpr, IsNullExpr, LastValue, Literal, Max, Median, Min, + NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, + RankType, Regr, RegrType, RowNumber, Stddev, StddevPop, StringAgg, Sum, TryCastExpr, + Variance, VariancePop, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; use datafusion::physical_plan::{ AggregateExpr, ColumnStatistics, PhysicalExpr, Statistics, WindowExpr, }; +use datafusion::{ + datasource::{ + file_format::{csv::CsvSink, json::JsonSink}, + listing::{FileRange, PartitionedFile}, + physical_plan::{FileScanConfig, FileSinkConfig}, + }, + physical_plan::expressions::LikeExpr, +}; use datafusion_common::config::{ ColumnOptions, CsvOptions, FormatOptions, JsonOptions, ParquetOptions, TableParquetOptions, @@ -68,14 +72,17 @@ use datafusion_common::{ DataFusionError, JoinSide, Result, }; +use super::{DefaultPhysicalExtensionCodec, PhysicalExtensionCodec}; + impl TryFrom> for protobuf::PhysicalExprNode { type Error = DataFusionError; fn try_from(a: Arc) -> Result { + let codec = DefaultPhysicalExtensionCodec {}; let expressions: Vec = a .expressions() .iter() - .map(|e| e.clone().try_into()) + .map(|e| serialize_physical_expr(e.clone(), &codec)) .collect::>>()?; let ordering_req: Vec = a @@ -237,16 +244,16 @@ impl TryFrom> for protobuf::PhysicalWindowExprNode { } else { return not_impl_err!("WindowExpr not supported: {window_expr:?}"); }; - + let codec = DefaultPhysicalExtensionCodec {}; let args = args .into_iter() - .map(|e| e.try_into()) + .map(|e| serialize_physical_expr(e, &codec)) .collect::>>()?; let partition_by = window_expr .partition_by() .iter() - .map(|p| p.clone().try_into()) + .map(|p| serialize_physical_expr(p.clone(), &codec)) .collect::>>()?; let order_by = window_expr @@ -374,195 +381,250 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { Ok(AggrFn { inner, distinct }) } -impl TryFrom> for protobuf::PhysicalExprNode { - type Error = DataFusionError; - - fn try_from(value: Arc) -> Result { - let expr = value.as_any(); - - if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Column( - protobuf::PhysicalColumn { - name: expr.name().to_string(), - index: expr.index() as u32, - }, - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - let binary_expr = Box::new(protobuf::PhysicalBinaryExprNode { - l: Some(Box::new(expr.left().to_owned().try_into()?)), - r: Some(Box::new(expr.right().to_owned().try_into()?)), - op: format!("{:?}", expr.op()), - }); +/// Serialize a `PhysicalExpr` to default protobuf representation. +/// +/// If required, a [`PhysicalExtensionCodec`] can be provided which can handle +/// serialization of udfs requiring specialized serialization (see [`PhysicalExtensionCodec::try_encode_udf`]) +pub fn serialize_physical_expr( + value: Arc, + codec: &dyn PhysicalExtensionCodec, +) -> Result { + let expr = value.as_any(); + + if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Column( + protobuf::PhysicalColumn { + name: expr.name().to_string(), + index: expr.index() as u32, + }, + )), + }) + } else if let Some(expr) = expr.downcast_ref::() { + let binary_expr = Box::new(protobuf::PhysicalBinaryExprNode { + l: Some(Box::new(serialize_physical_expr( + expr.left().clone(), + codec, + )?)), + r: Some(Box::new(serialize_physical_expr( + expr.right().clone(), + codec, + )?)), + op: format!("{:?}", expr.op()), + }); - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::BinaryExpr( - binary_expr, - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some( - protobuf::physical_expr_node::ExprType::Case( - Box::new( - protobuf::PhysicalCaseNode { - expr: expr - .expr() - .map(|exp| exp.clone().try_into().map(Box::new)) - .transpose()?, - when_then_expr: expr - .when_then_expr() - .iter() - .map(|(when_expr, then_expr)| { - try_parse_when_then_expr(when_expr, then_expr) - }) - .collect::, - Self::Error, - >>()?, - else_expr: expr - .else_expr() - .map(|a| a.clone().try_into().map(Box::new)) - .transpose()?, - }, - ), + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::BinaryExpr( + binary_expr, + )), + }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some( + protobuf::physical_expr_node::ExprType::Case( + Box::new( + protobuf::PhysicalCaseNode { + expr: expr + .expr() + .map(|exp| { + serialize_physical_expr(exp.clone(), codec) + .map(Box::new) + }) + .transpose()?, + when_then_expr: expr + .when_then_expr() + .iter() + .map(|(when_expr, then_expr)| { + try_parse_when_then_expr(when_expr, then_expr, codec) + }) + .collect::, + DataFusionError, + >>()?, + else_expr: expr + .else_expr() + .map(|a| { + serialize_physical_expr(a.clone(), codec) + .map(Box::new) + }) + .transpose()?, + }, ), ), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::NotExpr( - Box::new(protobuf::PhysicalNot { - expr: Some(Box::new(expr.arg().to_owned().try_into()?)), - }), - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::IsNullExpr( - Box::new(protobuf::PhysicalIsNull { - expr: Some(Box::new(expr.arg().to_owned().try_into()?)), - }), - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::IsNotNullExpr( - Box::new(protobuf::PhysicalIsNotNull { - expr: Some(Box::new(expr.arg().to_owned().try_into()?)), - }), - )), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some( - protobuf::physical_expr_node::ExprType::InList( - Box::new( - protobuf::PhysicalInListNode { - expr: Some(Box::new(expr.expr().to_owned().try_into()?)), - list: expr - .list() - .iter() - .map(|a| a.clone().try_into()) - .collect::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::NotExpr(Box::new( + protobuf::PhysicalNot { + expr: Some(Box::new(serialize_physical_expr( + expr.arg().to_owned(), + codec, + )?)), + }, + ))), + }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::IsNullExpr( + Box::new(protobuf::PhysicalIsNull { + expr: Some(Box::new(serialize_physical_expr( + expr.arg().to_owned(), + codec, + )?)), + }), + )), + }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::IsNotNullExpr( + Box::new(protobuf::PhysicalIsNotNull { + expr: Some(Box::new(serialize_physical_expr( + expr.arg().to_owned(), + codec, + )?)), + }), + )), + }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some( + protobuf::physical_expr_node::ExprType::InList( + Box::new( + protobuf::PhysicalInListNode { + expr: Some(Box::new(serialize_physical_expr( + expr.expr().to_owned(), + codec, + )?)), + list: expr + .list() + .iter() + .map(|a| serialize_physical_expr(a.clone(), codec)) + .collect::, - Self::Error, + DataFusionError, >>()?, - negated: expr.negated(), - }, - ), + negated: expr.negated(), + }, ), ), - }) - } else if let Some(expr) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Negative( - Box::new(protobuf::PhysicalNegativeNode { - expr: Some(Box::new(expr.arg().to_owned().try_into()?)), - }), - )), - }) - } else if let Some(lit) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Literal( - lit.value().try_into()?, - )), - }) - } else if let Some(cast) = expr.downcast_ref::() { + ), + }) + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Negative(Box::new( + protobuf::PhysicalNegativeNode { + expr: Some(Box::new(serialize_physical_expr( + expr.arg().to_owned(), + codec, + )?)), + }, + ))), + }) + } else if let Some(lit) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Literal( + lit.value().try_into()?, + )), + }) + } else if let Some(cast) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::Cast(Box::new( + protobuf::PhysicalCastNode { + expr: Some(Box::new(serialize_physical_expr( + cast.expr().to_owned(), + codec, + )?)), + arrow_type: Some(cast.cast_type().try_into()?), + }, + ))), + }) + } else if let Some(cast) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast(Box::new( + protobuf::PhysicalTryCastNode { + expr: Some(Box::new(serialize_physical_expr( + cast.expr().to_owned(), + codec, + )?)), + arrow_type: Some(cast.cast_type().try_into()?), + }, + ))), + }) + } else if let Some(expr) = expr.downcast_ref::() { + let args: Vec = expr + .args() + .iter() + .map(|e| serialize_physical_expr(e.to_owned(), codec)) + .collect::, _>>()?; + if let Ok(fun) = BuiltinScalarFunction::from_str(expr.name()) { + let fun: protobuf::ScalarFunction = (&fun).try_into()?; + Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::Cast(Box::new( - protobuf::PhysicalCastNode { - expr: Some(Box::new(cast.expr().clone().try_into()?)), - arrow_type: Some(cast.cast_type().try_into()?), + expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarFunction( + protobuf::PhysicalScalarFunctionNode { + name: expr.name().to_string(), + fun: fun.into(), + args, + return_type: Some(expr.return_type().try_into()?), }, - ))), - }) - } else if let Some(cast) = expr.downcast_ref::() { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::TryCast( - Box::new(protobuf::PhysicalTryCastNode { - expr: Some(Box::new(cast.expr().clone().try_into()?)), - arrow_type: Some(cast.cast_type().try_into()?), - }), )), }) - } else if let Some(expr) = expr.downcast_ref::() { - let args: Vec = expr - .args() - .iter() - .map(|e| e.to_owned().try_into()) - .collect::, _>>()?; - if let Ok(fun) = BuiltinScalarFunction::from_str(expr.name()) { - let fun: protobuf::ScalarFunction = (&fun).try_into()?; - - Ok(protobuf::PhysicalExprNode { - expr_type: Some( - protobuf::physical_expr_node::ExprType::ScalarFunction( - protobuf::PhysicalScalarFunctionNode { - name: expr.name().to_string(), - fun: fun.into(), - args, - return_type: Some(expr.return_type().try_into()?), - }, - ), - ), - }) - } else { - Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarUdf( - protobuf::PhysicalScalarUdfNode { - name: expr.name().to_string(), - args, - return_type: Some(expr.return_type().try_into()?), - }, - )), - }) + } else { + let mut buf = Vec::new(); + match expr.fun() { + ScalarFunctionDefinition::UDF(udf) => { + codec.try_encode_udf(udf, &mut buf)?; + } + _ => { + return not_impl_err!( + "Proto serialization error: Trying to serialize a unresolved function" + ); + } } - } else if let Some(expr) = expr.downcast_ref::() { + + let fun_definition = if buf.is_empty() { None } else { Some(buf) }; Ok(protobuf::PhysicalExprNode { - expr_type: Some(protobuf::physical_expr_node::ExprType::LikeExpr( - Box::new(protobuf::PhysicalLikeExprNode { - negated: expr.negated(), - case_insensitive: expr.case_insensitive(), - expr: Some(Box::new(expr.expr().to_owned().try_into()?)), - pattern: Some(Box::new(expr.pattern().to_owned().try_into()?)), - }), + expr_type: Some(protobuf::physical_expr_node::ExprType::ScalarUdf( + protobuf::PhysicalScalarUdfNode { + name: expr.name().to_string(), + args, + fun_definition, + return_type: Some(expr.return_type().try_into()?), + }, )), }) - } else { - internal_err!("physical_plan::to_proto() unsupported expression {value:?}") } + } else if let Some(expr) = expr.downcast_ref::() { + Ok(protobuf::PhysicalExprNode { + expr_type: Some(protobuf::physical_expr_node::ExprType::LikeExpr(Box::new( + protobuf::PhysicalLikeExprNode { + negated: expr.negated(), + case_insensitive: expr.case_insensitive(), + expr: Some(Box::new(serialize_physical_expr( + expr.expr().to_owned(), + codec, + )?)), + pattern: Some(Box::new(serialize_physical_expr( + expr.pattern().to_owned(), + codec, + )?)), + }, + ))), + }) + } else { + internal_err!("physical_plan::to_proto() unsupported expression {value:?}") } } fn try_parse_when_then_expr( when_expr: &Arc, then_expr: &Arc, + codec: &dyn PhysicalExtensionCodec, ) -> Result { Ok(protobuf::PhysicalWhenThen { - when_expr: Some(when_expr.clone().try_into()?), - then_expr: Some(then_expr.clone().try_into()?), + when_expr: Some(serialize_physical_expr(when_expr.clone(), codec)?), + then_expr: Some(serialize_physical_expr(then_expr.clone(), codec)?), }) } @@ -683,6 +745,7 @@ impl TryFrom<&FileScanConfig> for protobuf::FileScanExecConf { fn try_from( conf: &FileScanConfig, ) -> Result { + let codec = DefaultPhysicalExtensionCodec {}; let file_groups = conf .file_groups .iter() @@ -694,7 +757,7 @@ impl TryFrom<&FileScanConfig> for protobuf::FileScanExecConf { let expr_node_vec = order .iter() .map(|sort_expr| { - let expr = sort_expr.expr.clone().try_into()?; + let expr = serialize_physical_expr(sort_expr.expr.clone(), &codec)?; Ok(PhysicalSortExprNode { expr: Some(Box::new(expr)), asc: !sort_expr.options.descending, @@ -757,10 +820,11 @@ impl TryFrom>> for protobuf::MaybeFilter { type Error = DataFusionError; fn try_from(expr: Option>) -> Result { + let codec = DefaultPhysicalExtensionCodec {}; match expr { None => Ok(protobuf::MaybeFilter { expr: None }), Some(expr) => Ok(protobuf::MaybeFilter { - expr: Some(expr.try_into()?), + expr: Some(serialize_physical_expr(expr, &codec)?), }), } } @@ -786,8 +850,9 @@ impl TryFrom for protobuf::PhysicalSortExprNode { type Error = DataFusionError; fn try_from(sort_expr: PhysicalSortExpr) -> std::result::Result { + let codec = DefaultPhysicalExtensionCodec {}; Ok(PhysicalSortExprNode { - expr: Some(Box::new(sort_expr.expr.try_into()?)), + expr: Some(Box::new(serialize_physical_expr(sort_expr.expr, &codec)?)), asc: !sort_expr.options.descending, nulls_first: sort_expr.options.nulls_first, }) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 7f0c6286a19d..4924128ae190 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; use std::ops::Deref; use std::sync::Arc; use std::vec; @@ -32,7 +33,7 @@ use datafusion::datasource::physical_plan::{ wrap_partition_type_in_dict, wrap_partition_value_in_dict, FileScanConfig, FileSinkConfig, ParquetExec, }; -use datafusion::execution::context::ExecutionProps; +use datafusion::execution::FunctionRegistry; use datafusion::logical_expr::{ create_udf, BuiltinScalarFunction, JoinType, Operator, Volatility, }; @@ -49,7 +50,6 @@ use datafusion::physical_plan::expressions::{ NotExpr, NthValue, PhysicalSortExpr, StringAgg, Sum, }; use datafusion::physical_plan::filter::FilterExec; -use datafusion::physical_plan::functions; use datafusion::physical_plan::insert::FileSinkExec; use datafusion::physical_plan::joins::{ HashJoinExec, NestedLoopJoinExec, PartitionMode, StreamJoinPartitionMode, @@ -73,13 +73,19 @@ use datafusion_common::file_options::csv_writer::CsvWriterOptions; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::stats::Precision; -use datafusion_common::Result; +use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result}; use datafusion_expr::{ - Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, Signature, - SimpleAggregateUDF, WindowFrame, WindowFrameBound, + Accumulator, AccumulatorFactoryFunction, AggregateUDF, ColumnarValue, + ScalarFunctionDefinition, ScalarUDF, ScalarUDFImpl, Signature, SimpleAggregateUDF, + WindowFrame, WindowFrameBound, +}; +use datafusion_proto::physical_plan::from_proto::parse_physical_expr; +use datafusion_proto::physical_plan::to_proto::serialize_physical_expr; +use datafusion_proto::physical_plan::{ + AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec, }; -use datafusion_proto::physical_plan::{AsExecutionPlan, DefaultPhysicalExtensionCodec}; use datafusion_proto::protobuf; +use prost::Message; /// Perform a serde roundtrip and assert that the string representation of the before and after plans /// are identical. Note that this often isn't sufficient to guarantee that no information is @@ -603,14 +609,11 @@ fn roundtrip_builtin_scalar_function() -> Result<()> { let input = Arc::new(EmptyExec::new(schema.clone())); - let execution_props = ExecutionProps::new(); - - let fun_expr = - functions::create_physical_fun(&BuiltinScalarFunction::Sin, &execution_props)?; + let fun_def = ScalarFunctionDefinition::BuiltIn(BuiltinScalarFunction::Sin); let expr = ScalarFunctionExpr::new( "sin", - fun_expr, + fun_def, vec![col("a", &schema)?], DataType::Float64, None, @@ -646,9 +649,11 @@ fn roundtrip_scalar_udf() -> Result<()> { scalar_fn.clone(), ); + let fun_def = ScalarFunctionDefinition::UDF(Arc::new(udf.clone())); + let expr = ScalarFunctionExpr::new( "dummy", - scalar_fn, + fun_def, vec![col("a", &schema)?], DataType::Int64, None, @@ -665,6 +670,134 @@ fn roundtrip_scalar_udf() -> Result<()> { roundtrip_test_with_context(Arc::new(project), ctx) } +#[test] +fn roundtrip_scalar_udf_extension_codec() { + #[derive(Debug)] + struct MyRegexUdf { + signature: Signature, + // regex as original string + pattern: String, + } + + impl MyRegexUdf { + fn new(pattern: String) -> Self { + Self { + signature: Signature::uniform( + 1, + vec![DataType::Int32], + Volatility::Immutable, + ), + pattern, + } + } + } + + /// Implement the ScalarUDFImpl trait for MyRegexUdf + impl ScalarUDFImpl for MyRegexUdf { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "regex_udf" + } + fn signature(&self) -> &Signature { + &self.signature + } + fn return_type(&self, args: &[DataType]) -> Result { + if !matches!(args.first(), Some(&DataType::Utf8)) { + return plan_err!("regex_udf only accepts Utf8 arguments"); + } + Ok(DataType::Int32) + } + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + unimplemented!() + } + } + + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct MyRegexUdfNode { + #[prost(string, tag = "1")] + pub pattern: String, + } + + #[derive(Debug)] + pub struct ScalarUDFExtensionCodec {} + + impl PhysicalExtensionCodec for ScalarUDFExtensionCodec { + fn try_decode( + &self, + _buf: &[u8], + _inputs: &[Arc], + _registry: &dyn FunctionRegistry, + ) -> Result> { + not_impl_err!("No extension codec provided") + } + + fn try_encode( + &self, + _node: Arc, + _buf: &mut Vec, + ) -> Result<()> { + not_impl_err!("No extension codec provided") + } + + fn try_decode_udf(&self, name: &str, buf: &[u8]) -> Result> { + if name == "regex_udf" { + let proto = MyRegexUdfNode::decode(buf).map_err(|err| { + DataFusionError::Internal(format!( + "failed to decode regex_udf: {}", + err + )) + })?; + + Ok(Arc::new(ScalarUDF::new_from_impl(MyRegexUdf::new( + proto.pattern, + )))) + } else { + not_impl_err!("unrecognized scalar UDF implementation, cannot decode") + } + } + + fn try_encode_udf(&self, node: &ScalarUDF, buf: &mut Vec) -> Result<()> { + let binding = node.inner(); + if let Some(udf) = binding.as_any().downcast_ref::() { + let proto = MyRegexUdfNode { + pattern: udf.pattern.clone(), + }; + proto.encode(buf).map_err(|e| { + DataFusionError::Internal(format!("failed to encode udf: {e:?}")) + })?; + } + Ok(()) + } + } + + let pattern = ".*"; + let udf = ScalarUDF::from(MyRegexUdf::new(pattern.to_string())); + let test_expr = ScalarFunctionExpr::new( + udf.name(), + ScalarFunctionDefinition::UDF(Arc::new(udf.clone())), + vec![], + DataType::Int32, + None, + false, + ); + let fmt_expr = format!("{test_expr:?}"); + let ctx = SessionContext::new(); + + ctx.register_udf(udf.clone()); + let extension_codec = ScalarUDFExtensionCodec {}; + let proto: protobuf::PhysicalExprNode = + match serialize_physical_expr(Arc::new(test_expr), &extension_codec) { + Ok(proto) => proto, + Err(e) => panic!("failed to serialize expr: {e:?}"), + }; + let field_a = Field::new("a", DataType::Int32, false); + let schema = Arc::new(Schema::new(vec![field_a])); + let round_trip = + parse_physical_expr(&proto, &ctx, &schema, &extension_codec).unwrap(); + assert_eq!(fmt_expr, format!("{round_trip:?}")); +} #[test] fn roundtrip_distinct_count() -> Result<()> { let field_a = Field::new("a", DataType::Int64, false);