diff --git a/Cargo.lock b/Cargo.lock index cc237e3f7..012663127 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -349,7 +349,7 @@ dependencies = [ "arrow", "auron-jni-bridge", "auron-memmgr", - "auron-serde", + "auron-planner", "chrono", "datafusion", "datafusion-ext-commons", @@ -399,7 +399,7 @@ dependencies = [ ] [[package]] -name = "auron-serde" +name = "auron-planner" version = "0.1.0" dependencies = [ "arrow", diff --git a/Cargo.toml b/Cargo.toml index 34d89c3c8..47470c13a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ members = [ "native-engine/datafusion-ext-plans", "native-engine/auron", "native-engine/auron-jni-bridge", - "native-engine/auron-serde", + "native-engine/auron-planner", "native-engine/auron-memmgr", ] @@ -48,7 +48,7 @@ overflow-checks = false [workspace.dependencies] auron = { path = "./native-engine/auron" } auron-jni-bridge = { path = "./native-engine/auron-jni-bridge" } -auron-serde = { path = "./native-engine/auron-serde" } +auron-planner = { path = "./native-engine/auron-planner" } auron-memmgr = { path = "./native-engine/auron-memmgr" } datafusion-ext-commons = { path = "./native-engine/datafusion-ext-commons" } datafusion-ext-exprs = { path = "./native-engine/datafusion-ext-exprs" } diff --git a/dev/mvn-build-helper/proto/pom.xml b/dev/mvn-build-helper/proto/pom.xml index d6a15c1e0..94b158cea 100644 --- a/dev/mvn-build-helper/proto/pom.xml +++ b/dev/mvn-build-helper/proto/pom.xml @@ -48,7 +48,7 @@ 0.6.1 true - ../../../native-engine/auron-serde/proto + ../../../native-engine/auron-planner/proto com.google.protobuf:protoc:${protobufVersion}:exe:${os.detected.classifier} true true diff --git a/native-engine/auron-serde/Cargo.toml b/native-engine/auron-planner/Cargo.toml similarity index 98% rename from native-engine/auron-serde/Cargo.toml rename to native-engine/auron-planner/Cargo.toml index adb63dbff..72d30be33 100644 --- a/native-engine/auron-serde/Cargo.toml +++ b/native-engine/auron-planner/Cargo.toml @@ -16,7 +16,7 @@ # [package] -name = "auron-serde" +name = "auron-planner" version = "0.1.0" edition = "2024" diff --git a/native-engine/auron-serde/build.rs b/native-engine/auron-planner/build.rs similarity index 100% rename from native-engine/auron-serde/build.rs rename to native-engine/auron-planner/build.rs diff --git a/native-engine/auron-serde/proto/auron.proto b/native-engine/auron-planner/proto/auron.proto similarity index 100% rename from native-engine/auron-serde/proto/auron.proto rename to native-engine/auron-planner/proto/auron.proto diff --git a/native-engine/auron-serde/src/error.rs b/native-engine/auron-planner/src/error.rs similarity index 100% rename from native-engine/auron-serde/src/error.rs rename to native-engine/auron-planner/src/error.rs diff --git a/native-engine/auron-serde/src/lib.rs b/native-engine/auron-planner/src/lib.rs similarity index 99% rename from native-engine/auron-serde/src/lib.rs rename to native-engine/auron-planner/src/lib.rs index 98cf7177c..d546fea28 100644 --- a/native-engine/auron-serde/src/lib.rs +++ b/native-engine/auron-planner/src/lib.rs @@ -28,7 +28,7 @@ pub mod protobuf { } pub mod error; -pub mod from_proto; +pub mod planner; pub(crate) fn proto_error>(message: S) -> PlanSerDeError { PlanSerDeError::General(message.into()) @@ -58,9 +58,9 @@ macro_rules! into_required { #[macro_export] macro_rules! convert_box_required { - ($PB:expr) => {{ + ($self:expr, $PB:expr) => {{ if let Some(field) = $PB.as_ref() { - field.as_ref().try_into() + $self.create_plan(field.as_ref()) } else { Err(proto_error("Missing required field in protobuf")) } diff --git a/native-engine/auron-serde/src/from_proto.rs b/native-engine/auron-planner/src/planner.rs similarity index 75% rename from native-engine/auron-serde/src/from_proto.rs rename to native-engine/auron-planner/src/planner.rs index 9237a6c25..a95bba5ed 100644 --- a/native-engine/auron-serde/src/from_proto.rs +++ b/native-engine/auron-planner/src/planner.rs @@ -13,8 +13,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -//! Serde code to convert from protocol buffers to Rust data structures. - use std::{ any::Any, convert::{TryFrom, TryInto}, @@ -29,7 +27,7 @@ use arrow::{ }; use base64::{Engine, prelude::BASE64_URL_SAFE_NO_PAD}; use datafusion::{ - common::{Result, ScalarValue, stats::Precision}, + common::{ExprSchema, Result, ScalarValue, stats::Precision}, datasource::{ file_format::file_compression_type::FileCompressionType, listing::{FileRange, PartitionedFile}, @@ -104,32 +102,43 @@ use crate::{ }, }; -impl TryInto> for &protobuf::PhysicalPlanNode { - type Error = PlanSerDeError; +type PlanError = PlanSerDeError; - fn try_into(self) -> Result, Self::Error> { - let plan = self.physical_plan_type.as_ref().ok_or_else(|| { +pub struct PhysicalPlanner { + partition_id: usize, +} + +impl PhysicalPlanner { + pub fn new(partition_id: usize) -> Self { + Self { partition_id } + } + + pub fn create_plan<'a>( + &'a self, + spark_plan: &'a protobuf::PhysicalPlanNode, + ) -> Result, PlanError> { + let plan = spark_plan.physical_plan_type.as_ref().ok_or_else(|| { proto_error(format!( "physical_plan::from_proto() Unsupported physical plan '{:?}'", - self + spark_plan )) })?; match plan { PhysicalPlanType::Projection(projection) => { - let input: Arc = convert_box_required!(projection.input)?; + let input: Arc = convert_box_required!(self, projection.input)?; let input_schema = input.schema(); let data_types: Vec = projection .data_type .iter() .map(|data_type| data_type.try_into()) - .collect::, Self::Error>>()?; + .collect::, PlanError>>()?; let exprs = projection .expr .iter() .zip(projection.expr_name.iter()) .zip(data_types) .map(|((expr, name), data_type)| { - let physical_expr = try_parse_physical_expr(expr, &input_schema)?; + let physical_expr = self.try_parse_physical_expr(expr, &input_schema)?; let casted_expr = if physical_expr.data_type(&input_schema)? == data_type { physical_expr } else { @@ -137,17 +146,17 @@ impl TryInto> for &protobuf::PhysicalPlanNode { }; Ok((casted_expr, name.to_string())) }) - .collect::, Self::Error>>()?; + .collect::, PlanError>>()?; Ok(Arc::new(ProjectExec::try_new(exprs, input)?)) } PhysicalPlanType::Filter(filter) => { - let input: Arc = convert_box_required!(filter.input)?; + let input: Arc = convert_box_required!(self, filter.input)?; let predicates = filter .expr .iter() - .map(|expr| try_parse_physical_expr(expr, &input.schema())) - .collect::>()?; + .map(|expr| self.try_parse_physical_expr(expr, &input.schema())) + .collect::>()?; Ok(Arc::new(FilterExec::try_new(predicates, input)?)) } PhysicalPlanType::ParquetScan(scan) => { @@ -156,7 +165,8 @@ impl TryInto> for &protobuf::PhysicalPlanNode { .pruning_predicates .iter() .filter_map(|predicate| { - try_parse_physical_expr(predicate, &conf.file_schema).ok() + self.try_parse_physical_expr(predicate, &conf.file_schema) + .ok() }) .fold(phys_expr::lit(true), |a, b| { Arc::new(BinaryExpr::new(a, Operator::And, b)) @@ -173,7 +183,8 @@ impl TryInto> for &protobuf::PhysicalPlanNode { .pruning_predicates .iter() .filter_map(|predicate| { - try_parse_physical_expr(predicate, &conf.file_schema).ok() + self.try_parse_physical_expr(predicate, &conf.file_schema) + .ok() }) .fold(phys_expr::lit(true), |a, b| { Arc::new(BinaryExpr::new(a, Operator::And, b)) @@ -186,19 +197,21 @@ impl TryInto> for &protobuf::PhysicalPlanNode { } PhysicalPlanType::HashJoin(hash_join) => { let schema = Arc::new(convert_required!(hash_join.schema)?); - let left: Arc = convert_box_required!(hash_join.left)?; - let right: Arc = convert_box_required!(hash_join.right)?; + let left: Arc = convert_box_required!(self, hash_join.left)?; + let right: Arc = convert_box_required!(self, hash_join.right)?; let on: Vec<(PhysicalExprRef, PhysicalExprRef)> = hash_join .on .iter() .map(|col| { - let left_key = - try_parse_physical_expr(&col.left.as_ref().unwrap(), &left.schema())?; - let right_key = - try_parse_physical_expr(&col.right.as_ref().unwrap(), &right.schema())?; + let left_key = self + .try_parse_physical_expr(&col.left.as_ref().unwrap(), &left.schema())?; + let right_key = self.try_parse_physical_expr( + &col.right.as_ref().unwrap(), + &right.schema(), + )?; Ok((left_key, right_key)) }) - .collect::>()?; + .collect::>()?; let join_type = protobuf::JoinType::try_from(hash_join.join_type).expect("invalid JoinType"); @@ -224,19 +237,23 @@ impl TryInto> for &protobuf::PhysicalPlanNode { } PhysicalPlanType::SortMergeJoin(sort_merge_join) => { let schema = Arc::new(convert_required!(sort_merge_join.schema)?); - let left: Arc = convert_box_required!(sort_merge_join.left)?; - let right: Arc = convert_box_required!(sort_merge_join.right)?; + let left: Arc = + convert_box_required!(self, sort_merge_join.left)?; + let right: Arc = + convert_box_required!(self, sort_merge_join.right)?; let on: Vec<(PhysicalExprRef, PhysicalExprRef)> = sort_merge_join .on .iter() .map(|col| { - let left_key = - try_parse_physical_expr(&col.left.as_ref().unwrap(), &left.schema())?; - let right_key = - try_parse_physical_expr(&col.right.as_ref().unwrap(), &right.schema())?; + let left_key = self + .try_parse_physical_expr(&col.left.as_ref().unwrap(), &left.schema())?; + let right_key = self.try_parse_physical_expr( + &col.right.as_ref().unwrap(), + &right.schema(), + )?; Ok((left_key, right_key)) }) - .collect::>()?; + .collect::>()?; let sort_options = sort_merge_join .sort_options @@ -262,9 +279,10 @@ impl TryInto> for &protobuf::PhysicalPlanNode { )?)) } PhysicalPlanType::ShuffleWriter(shuffle_writer) => { - let input: Arc = convert_box_required!(shuffle_writer.input)?; + let input: Arc = + convert_box_required!(self, shuffle_writer.input)?; - let output_partitioning = parse_protobuf_partitioning( + let output_partitioning = self.parse_protobuf_partitioning( input.clone(), shuffle_writer.output_partitioning.as_ref(), )?; @@ -278,9 +296,9 @@ impl TryInto> for &protobuf::PhysicalPlanNode { } PhysicalPlanType::RssShuffleWriter(rss_shuffle_writer) => { let input: Arc = - convert_box_required!(rss_shuffle_writer.input)?; + convert_box_required!(self, rss_shuffle_writer.input)?; - let output_partitioning = parse_protobuf_partitioning( + let output_partitioning = self.parse_protobuf_partitioning( input.clone(), rss_shuffle_writer.output_partitioning.as_ref(), )?; @@ -291,7 +309,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { )?)) } PhysicalPlanType::IpcWriter(ipc_writer) => { - let input: Arc = convert_box_required!(ipc_writer.input)?; + let input: Arc = convert_box_required!(self, ipc_writer.input)?; Ok(Arc::new(IpcWriterExec::new( input, @@ -307,14 +325,16 @@ impl TryInto> for &protobuf::PhysicalPlanNode { ))) } PhysicalPlanType::Debug(debug) => { - let input: Arc = convert_box_required!(debug.input)?; + let input: Arc = convert_box_required!(self, debug.input)?; Ok(Arc::new(DebugExec::new(input, debug.debug_id.clone()))) } PhysicalPlanType::Sort(sort) => { - let input: Arc = convert_box_required!(sort.input)?; - let exprs = try_parse_physical_sort_expr(&input, sort).unwrap_or_else(|e| { - panic!("Failed to parse physical sort expressions: {}", e); - }); + let input: Arc = convert_box_required!(self, sort.input)?; + let exprs = self + .try_parse_physical_sort_expr(&input, sort) + .unwrap_or_else(|e| { + panic!("Failed to parse physical sort expressions: {}", e); + }); // always preserve partitioning Ok(Arc::new(SortExec::new( @@ -324,29 +344,33 @@ impl TryInto> for &protobuf::PhysicalPlanNode { ))) } PhysicalPlanType::BroadcastJoinBuildHashMap(bhm) => { - let input: Arc = convert_box_required!(bhm.input)?; + let input: Arc = convert_box_required!(self, bhm.input)?; let keys = bhm .keys .iter() - .map(|expr| try_parse_physical_expr(expr, &input.schema())) - .collect::, Self::Error>>()?; + .map(|expr| self.try_parse_physical_expr(expr, &input.schema())) + .collect::, PlanError>>()?; Ok(Arc::new(BroadcastJoinBuildHashMapExec::new(input, keys))) } PhysicalPlanType::BroadcastJoin(broadcast_join) => { let schema = Arc::new(convert_required!(broadcast_join.schema)?); - let left: Arc = convert_box_required!(broadcast_join.left)?; - let right: Arc = convert_box_required!(broadcast_join.right)?; + let left: Arc = + convert_box_required!(self, broadcast_join.left)?; + let right: Arc = + convert_box_required!(self, broadcast_join.right)?; let on: Vec<(PhysicalExprRef, PhysicalExprRef)> = broadcast_join .on .iter() .map(|col| { - let left_key = - try_parse_physical_expr(&col.left.as_ref().unwrap(), &left.schema())?; - let right_key = - try_parse_physical_expr(&col.right.as_ref().unwrap(), &right.schema())?; + let left_key = self + .try_parse_physical_expr(&col.left.as_ref().unwrap(), &left.schema())?; + let right_key = self.try_parse_physical_expr( + &col.right.as_ref().unwrap(), + &right.schema(), + )?; Ok((left_key, right_key)) }) - .collect::>()?; + .collect::>()?; let join_type = protobuf::JoinType::try_from(broadcast_join.join_type) .expect("invalid JoinType"); @@ -381,10 +405,13 @@ impl TryInto> for &protobuf::PhysicalPlanNode { .input .iter() .map(|input| { - let input_exec = convert_required!(input.input)?; + let input_exec = + self.create_plan(input.input.as_ref().ok_or_else(|| { + proto_error("Missing required field in protobuf") + })?)?; Ok(UnionInput(input_exec, input.partition as usize)) }) - .collect::, Self::Error>>()?; + .collect::, PlanError>>()?; Ok(Arc::new(UnionExec::new( inputs, @@ -401,14 +428,15 @@ impl TryInto> for &protobuf::PhysicalPlanNode { ))) } PhysicalPlanType::RenameColumns(rename_columns) => { - let input: Arc = convert_box_required!(rename_columns.input)?; + let input: Arc = + convert_box_required!(self, rename_columns.input)?; Ok(Arc::new(RenameColumnsExec::try_new( input, rename_columns.renamed_column_names.clone(), )?)) } PhysicalPlanType::Agg(agg) => { - let input: Arc = convert_box_required!(agg.input)?; + let input: Arc = convert_box_required!(self, agg.input)?; let input_schema = input.schema(); let exec_mode = match protobuf::AggExecMode::try_from(agg.exec_mode) @@ -435,12 +463,13 @@ impl TryInto> for &protobuf::PhysicalPlanNode { .iter() .zip(agg.grouping_expr_name.iter()) .map(|(expr, name)| { - try_parse_physical_expr(expr, &input_schema).and_then(|expr| { - Ok(GroupingExpr { - expr, - field_name: name.to_owned(), + self.try_parse_physical_expr(expr, &input_schema) + .and_then(|expr| { + Ok(GroupingExpr { + expr, + field_name: name.to_owned(), + }) }) - }) }) .collect::, _>>()?; @@ -468,7 +497,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { let agg_children_exprs = agg_node .children .iter() - .map(|expr| try_parse_physical_expr(expr, &input_schema)) + .map(|expr| self.try_parse_physical_expr(expr, &input_schema)) .collect::, _>>()?; let return_type = convert_required!(agg_node.return_type)?; @@ -503,7 +532,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { )?)) } PhysicalPlanType::Limit(limit) => { - let input: Arc = convert_box_required!(limit.input)?; + let input: Arc = convert_box_required!(self, limit.input)?; Ok(Arc::new(LimitExec::new(input, limit.limit))) } PhysicalPlanType::FfiReader(ffi_reader) => { @@ -515,12 +544,13 @@ impl TryInto> for &protobuf::PhysicalPlanNode { ))) } PhysicalPlanType::CoalesceBatches(coalesce_batches) => { - let input: Arc = convert_box_required!(coalesce_batches.input)?; + let input: Arc = + convert_box_required!(self, coalesce_batches.input)?; Ok(Arc::new(LimitExec::new(input, coalesce_batches.batch_size))) } PhysicalPlanType::Expand(expand) => { let schema = Arc::new(convert_required!(expand.schema)?); - let input: Arc = convert_box_required!(expand.input)?; + let input: Arc = convert_box_required!(self, expand.input)?; let projections = expand .projections .iter() @@ -528,15 +558,15 @@ impl TryInto> for &protobuf::PhysicalPlanNode { projection .expr .iter() - .map(|expr| try_parse_physical_expr(expr, &input.schema())) - .collect::, Self::Error>>() + .map(|expr| self.try_parse_physical_expr(expr, &input.schema())) + .collect::, PlanError>>() }) .collect::, _>>()?; Ok(Arc::new(ExpandExec::try_new(schema, projections, input)?)) } PhysicalPlanType::Window(window) => { - let input: Arc = convert_box_required!(window.input)?; + let input: Arc = convert_box_required!(self, window.input)?; let window_exprs = window .window_expr .iter() @@ -547,7 +577,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { .ok_or_else(|| { proto_error(format!( "physical_plan::from_proto() Unexpected sort expr {:?}", - self + spark_plan )) })? .try_into()?, @@ -556,8 +586,8 @@ impl TryInto> for &protobuf::PhysicalPlanNode { let children = w .children .iter() - .map(|expr| try_parse_physical_expr(expr, &input.schema())) - .collect::, Self::Error>>()?; + .map(|expr| self.try_parse_physical_expr(expr, &input.schema())) + .collect::, PlanError>>()?; let return_type = convert_required!(w.return_type)?; let window_func = match w.func_type() { @@ -606,7 +636,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { } }, }; - Ok::<_, Self::Error>(WindowExpr::new( + Ok::<_, PlanError>(WindowExpr::new( window_func, children, field, @@ -618,8 +648,8 @@ impl TryInto> for &protobuf::PhysicalPlanNode { let partition_specs = window .partition_spec .iter() - .map(|expr| try_parse_physical_expr(expr, &input.schema())) - .collect::, Self::Error>>()?; + .map(|expr| self.try_parse_physical_expr(expr, &input.schema())) + .collect::, PlanError>>()?; let order_specs = window .order_spec @@ -627,7 +657,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { .map(|expr| { let expr = expr.expr_type.as_ref().ok_or_else(|| { proto_error(format!( - "physical_plan::from_proto() Unexpected expr {self:?}", + "physical_plan::from_proto() Unexpected expr {spark_plan:?}", )) })?; if let protobuf::physical_expr_node::ExprType::Sort(sort_expr) = expr { @@ -637,12 +667,12 @@ impl TryInto> for &protobuf::PhysicalPlanNode { .ok_or_else(|| { proto_error(format!( "physical_plan::from_proto() Unexpected sort expr {:?}", - self + spark_plan )) })? .as_ref(); Ok(PhysicalSortExpr { - expr: try_parse_physical_expr(expr, &input.schema())?, + expr: self.try_parse_physical_expr(expr, &input.schema())?, options: SortOptions { descending: !sort_expr.asc, nulls_first: sort_expr.nulls_first, @@ -651,7 +681,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { } else { Err(PlanSerDeError::General(format!( "physical_plan::from_proto() {:?}", - self + spark_plan ))) } }) @@ -670,7 +700,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { )?)) } PhysicalPlanType::Generate(generate) => { - let input: Arc = convert_box_required!(generate.input)?; + let input: Arc = convert_box_required!(self, generate.input)?; let input_schema = input.schema(); let pb_generator = generate.generator.as_ref().expect("missing generator"); let pb_generator_children = &pb_generator.child; @@ -679,7 +709,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { let children = pb_generator_children .iter() - .map(|expr| try_parse_physical_expr(expr, &input_schema)) + .map(|expr| self.try_parse_physical_expr(expr, &input_schema)) .collect::, _>>()?; let generator = match pb_generate_func { @@ -733,7 +763,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { props.push((prop.key.clone(), prop.value.clone())); } Ok(Arc::new(ParquetSinkExec::new( - convert_box_required!(parquet_sink.input)?, + convert_box_required!(self, parquet_sink.input)?, parquet_sink.fs_resource_id.clone(), parquet_sink.num_dyn_parts as usize, props, @@ -741,125 +771,18 @@ impl TryInto> for &protobuf::PhysicalPlanNode { } } } -} - -impl From<&protobuf::PhysicalColumn> for Column { - fn from(c: &protobuf::PhysicalColumn) -> Column { - Column::new(&c.name, c.index as usize) - } -} - -impl From<&protobuf::BoundReference> for Column { - fn from(c: &protobuf::BoundReference) -> Column { - Column::new("__bound_reference__", c.index as usize) - } -} - -impl From for Arc { - fn from(f: protobuf::ScalarFunction) -> Self { - use datafusion::functions as f; - use datafusion_spark::function as spark_fun; - use protobuf::ScalarFunction; - - match f { - ScalarFunction::Sqrt => f::math::sqrt(), - ScalarFunction::Sin => f::math::sin(), - ScalarFunction::Cos => f::math::cos(), - ScalarFunction::Tan => f::math::tan(), - ScalarFunction::Asin => f::math::asin(), - ScalarFunction::Acos => f::math::acos(), - ScalarFunction::Atan => f::math::atan(), - ScalarFunction::Exp => f::math::exp(), - ScalarFunction::Log => f::math::log(), - ScalarFunction::Ln => f::math::ln(), - ScalarFunction::Log10 => f::math::log10(), - ScalarFunction::Floor => f::math::floor(), - ScalarFunction::Ceil => f::math::ceil(), - ScalarFunction::Round => f::math::round(), - ScalarFunction::Trunc => f::math::trunc(), - ScalarFunction::Abs => f::math::abs(), - ScalarFunction::OctetLength => f::string::octet_length(), - ScalarFunction::Concat => f::string::concat(), - ScalarFunction::Lower => f::string::lower(), - ScalarFunction::Upper => f::string::upper(), - ScalarFunction::Trim => f::string::btrim(), - ScalarFunction::Ltrim => f::string::ltrim(), - ScalarFunction::Rtrim => f::string::rtrim(), - ScalarFunction::ToTimestamp => f::datetime::to_timestamp(), - ScalarFunction::NullIf => f::core::nullif(), - ScalarFunction::Nvl2 => f::core::nvl2(), - ScalarFunction::Nvl => f::core::nvl(), - ScalarFunction::DatePart => f::datetime::date_part(), - ScalarFunction::DateTrunc => f::datetime::date_trunc(), - // ScalarFunction::Md5 => f::crypto::md5(), - // ScalarFunction::Sha224 => f::crypto::sha224(), - // ScalarFunction::Sha256 => f::crypto::sha256(), - // ScalarFunction::Sha384 => f::crypto::sha384(), - // ScalarFunction::Sha512 => f::crypto::sha512(), - ScalarFunction::Digest => f::crypto::digest(), - ScalarFunction::ToTimestampMillis => f::datetime::to_timestamp_millis(), - ScalarFunction::Log2 => f::math::log2(), - ScalarFunction::Signum => f::math::signum(), - ScalarFunction::Ascii => f::string::ascii(), - ScalarFunction::BitLength => f::string::bit_length(), - ScalarFunction::Btrim => f::string::btrim(), - ScalarFunction::CharacterLength => f::unicode::character_length(), - ScalarFunction::Chr => f::string::chr(), - ScalarFunction::ConcatWithSeparator => f::string::concat_ws(), - // ScalarFunction::InitCap => f::unicode::initcap(), - ScalarFunction::Left => f::unicode::left(), - ScalarFunction::Lpad => f::unicode::lpad(), - ScalarFunction::Random => f::math::random(), - ScalarFunction::RegexpReplace => f::regex::regexp_replace(), - ScalarFunction::Repeat => f::string::repeat(), - ScalarFunction::Replace => f::string::replace(), - ScalarFunction::Reverse => f::unicode::reverse(), - ScalarFunction::Right => f::unicode::right(), - ScalarFunction::Rpad => f::unicode::rpad(), - ScalarFunction::SplitPart => f::string::split_part(), - ScalarFunction::StartsWith => f::string::starts_with(), - ScalarFunction::Levenshtein => f::string::levenshtein(), - - ScalarFunction::FindInSet => f::unicode::find_in_set(), - ScalarFunction::Strpos => f::unicode::strpos(), - ScalarFunction::Substr => f::unicode::substr(), - // ScalarFunction::ToHex => f::string::to_hex(), - ScalarFunction::ToTimestampMicros => f::datetime::to_timestamp_micros(), - ScalarFunction::ToTimestampSeconds => f::datetime::to_timestamp_seconds(), - ScalarFunction::Now => f::datetime::now(), - ScalarFunction::Translate => f::unicode::translate(), - ScalarFunction::RegexpMatch => f::regex::regexp_match(), - ScalarFunction::Greatest => f::core::greatest(), - ScalarFunction::Coalesce => f::core::coalesce(), - ScalarFunction::Least => f::core::least(), - - // -- datafusion-spark functions - // math functions - ScalarFunction::Expm1 => spark_fun::math::expm1(), - ScalarFunction::Factorial => spark_fun::math::factorial(), - ScalarFunction::Hex => spark_fun::math::hex(), - - ScalarFunction::Power => f::math::power(), - ScalarFunction::IsNaN => f::math::isnan(), - - ScalarFunction::AuronExtFunctions => { - unreachable!() - } - } - } -} -fn try_parse_physical_expr( - expr: &protobuf::PhysicalExprNode, - input_schema: &SchemaRef, -) -> Result { - let expr_type = expr - .expr_type - .as_ref() - .ok_or_else(|| proto_error("Unexpected empty physical expression"))?; - - let pexpr: PhysicalExprRef = - match expr_type { + fn try_parse_physical_expr( + &self, + expr: &protobuf::PhysicalExprNode, + input_schema: &SchemaRef, + ) -> Result { + let expr_type = expr + .expr_type + .as_ref() + .ok_or_else(|| proto_error("Unexpected empty physical expression"))?; + + let pexpr: PhysicalExprRef = match expr_type { ExprType::Column(c) => Arc::new(Column::new(&c.name, input_schema.index_of(&c.name)?)), ExprType::Literal(scalar) => Arc::new(Literal::new(scalar.try_into()?)), ExprType::BoundReference(bound_reference) => { @@ -867,9 +790,9 @@ fn try_parse_physical_expr( Arc::new(pcol) } ExprType::BinaryExpr(binary_expr) => Arc::new(BinaryExpr::new( - try_parse_physical_expr_box_required(&binary_expr.l.clone(), input_schema)?, + self.try_parse_physical_expr_box_required(&binary_expr.l.clone(), input_schema)?, from_proto_binary_op(&binary_expr.op)?, - try_parse_physical_expr_box_required(&binary_expr.r.clone(), input_schema)?, + self.try_parse_physical_expr_box_required(&binary_expr.r.clone(), input_schema)?, )), ExprType::AggExpr(_) => { return Err(PlanSerDeError::General( @@ -882,26 +805,25 @@ fn try_parse_physical_expr( )); } ExprType::IsNullExpr(e) => Arc::new(IsNullExpr::new( - try_parse_physical_expr_box_required(&e.expr, input_schema)?, + self.try_parse_physical_expr_box_required(&e.expr, input_schema)?, )), ExprType::IsNotNullExpr(e) => Arc::new(IsNotNullExpr::new( - try_parse_physical_expr_box_required(&e.expr, input_schema)?, + self.try_parse_physical_expr_box_required(&e.expr, input_schema)?, + )), + ExprType::NotExpr(e) => Arc::new(NotExpr::new( + self.try_parse_physical_expr_box_required(&e.expr, input_schema)?, )), - ExprType::NotExpr(e) => Arc::new(NotExpr::new(try_parse_physical_expr_box_required( - &e.expr, - input_schema, - )?)), ExprType::Negative(e) => Arc::new(NegativeExpr::new( - try_parse_physical_expr_box_required(&e.expr, input_schema)?, + self.try_parse_physical_expr_box_required(&e.expr, input_schema)?, )), ExprType::InList(e) => { - let expr = try_parse_physical_expr_box_required(&e.expr, input_schema)?; + let expr = self.try_parse_physical_expr_box_required(&e.expr, input_schema)?; let dt = expr.data_type(input_schema)?; let list_exprs = e .list .iter() .map(|x| -> Result { - let e = try_parse_physical_expr(x, input_schema)?; + let e = self.try_parse_physical_expr(x, input_schema)?; if e.data_type(input_schema)? != dt { return Ok(Arc::new(TryCastExpr::new(e, dt.clone()))); } @@ -913,29 +835,29 @@ fn try_parse_physical_expr( ExprType::Case(e) => Arc::new(CaseExpr::try_new( e.expr .as_ref() - .map(|e| try_parse_physical_expr(e.as_ref(), input_schema)) + .map(|e| self.try_parse_physical_expr(e.as_ref(), input_schema)) .transpose()?, e.when_then_expr .iter() .map(|e| { Ok(( - try_parse_physical_expr_required(&e.when_expr, input_schema)?, - try_parse_physical_expr_required(&e.then_expr, input_schema)?, + self.try_parse_physical_expr_required(&e.when_expr, input_schema)?, + self.try_parse_physical_expr_required(&e.then_expr, input_schema)?, )) }) .collect::, PlanSerDeError>>()?, e.else_expr .as_ref() - .map(|e| try_parse_physical_expr(e.as_ref(), input_schema)) + .map(|e| self.try_parse_physical_expr(e.as_ref(), input_schema)) .transpose()?, )?), ExprType::Cast(e) => Arc::new(CastExpr::new( - try_parse_physical_expr_box_required(&e.expr, input_schema)?, + self.try_parse_physical_expr_box_required(&e.expr, input_schema)?, convert_required!(e.arrow_type)?, None, )), ExprType::TryCast(e) => { - let expr = try_parse_physical_expr_box_required(&e.expr, input_schema)?; + let expr = self.try_parse_physical_expr_box_required(&e.expr, input_schema)?; let cast_type = convert_required!(e.arrow_type)?; Arc::new(TryCastExpr::new(expr, cast_type)) } @@ -945,12 +867,15 @@ fn try_parse_physical_expr( let args = e .args .iter() - .map(|x| try_parse_physical_expr(x, input_schema)) + .map(|x| self.try_parse_physical_expr(x, input_schema)) .collect::, _>>()?; let scalar_udf = if scalar_function == protobuf::ScalarFunction::AuronExtFunctions { let fun_name = &e.name; - let fun = datafusion_ext_functions::create_auron_ext_function(fun_name)?; + let fun = datafusion_ext_functions::create_auron_ext_function( + fun_name, + self.partition_id, + )?; Arc::new(create_udf( &format!("spark_ext_function_{}", fun_name), args.iter() @@ -981,7 +906,7 @@ fn try_parse_physical_expr( e.return_nullable, e.params .iter() - .map(|x| try_parse_physical_expr(x, input_schema)) + .map(|x| self.try_parse_physical_expr(x, input_schema)) .collect::, _>>()?, e.expr_string.clone(), )?), @@ -993,48 +918,48 @@ fn try_parse_physical_expr( )?) } ExprType::GetIndexedFieldExpr(e) => { - let expr = try_parse_physical_expr_box_required(&e.expr, input_schema)?; + let expr = self.try_parse_physical_expr_box_required(&e.expr, input_schema)?; let key = convert_required!(e.key)?; Arc::new(GetIndexedFieldExpr::new(expr, key)) } ExprType::GetMapValueExpr(e) => { - let expr = try_parse_physical_expr_box_required(&e.expr, input_schema)?; + let expr = self.try_parse_physical_expr_box_required(&e.expr, input_schema)?; let key = convert_required!(e.key)?; Arc::new(GetMapValueExpr::new(expr, key)) } ExprType::StringStartsWithExpr(e) => { - let expr = try_parse_physical_expr_box_required(&e.expr, input_schema)?; + let expr = self.try_parse_physical_expr_box_required(&e.expr, input_schema)?; Arc::new(StringStartsWithExpr::new(expr, e.prefix.clone())) } ExprType::StringEndsWithExpr(e) => { - let expr = try_parse_physical_expr_box_required(&e.expr, input_schema)?; + let expr = self.try_parse_physical_expr_box_required(&e.expr, input_schema)?; Arc::new(StringEndsWithExpr::new(expr, e.suffix.clone())) } ExprType::StringContainsExpr(e) => { - let expr = try_parse_physical_expr_box_required(&e.expr, input_schema)?; + let expr = self.try_parse_physical_expr_box_required(&e.expr, input_schema)?; Arc::new(StringContainsExpr::new(expr, e.infix.clone())) } ExprType::RowNumExpr(_) => Arc::new(RowNumExpr::default()), ExprType::BloomFilterMightContainExpr(e) => Arc::new(BloomFilterMightContainExpr::new( e.uuid.clone(), - try_parse_physical_expr_box_required(&e.bloom_filter_expr, input_schema)?, - try_parse_physical_expr_box_required(&e.value_expr, input_schema)?, + self.try_parse_physical_expr_box_required(&e.bloom_filter_expr, input_schema)?, + self.try_parse_physical_expr_box_required(&e.value_expr, input_schema)?, )), ExprType::ScAndExpr(e) => { - let l = try_parse_physical_expr_box_required(&e.left, input_schema)?; - let r = try_parse_physical_expr_box_required(&e.right, input_schema)?; + let l = self.try_parse_physical_expr_box_required(&e.left, input_schema)?; + let r = self.try_parse_physical_expr_box_required(&e.right, input_schema)?; Arc::new(SCAndExpr::new(l, r)) } ExprType::ScOrExpr(e) => { - let l = try_parse_physical_expr_box_required(&e.left, input_schema)?; - let r = try_parse_physical_expr_box_required(&e.right, input_schema)?; + let l = self.try_parse_physical_expr_box_required(&e.left, input_schema)?; + let r = self.try_parse_physical_expr_box_required(&e.right, input_schema)?; Arc::new(SCOrExpr::new(l, r)) } ExprType::LikeExpr(e) => Arc::new(LikeExpr::new( e.negated, e.case_insensitive, - try_parse_physical_expr_box_required(&e.expr, input_schema)?, - try_parse_physical_expr_box_required(&e.pattern, input_schema)?, + self.try_parse_physical_expr_box_required(&e.expr, input_schema)?, + self.try_parse_physical_expr_box_required(&e.pattern, input_schema)?, )), ExprType::NamedStruct(e) => { @@ -1042,167 +967,282 @@ fn try_parse_physical_expr( Arc::new(NamedStructExpr::try_new( e.values .iter() - .map(|x| try_parse_physical_expr(x, input_schema)) + .map(|x| self.try_parse_physical_expr(x, input_schema)) .collect::, _>>()?, data_type, )?) } }; - Ok(pexpr) -} + Ok(pexpr) + } -fn try_parse_physical_expr_required( - proto: &Option, - input_schema: &SchemaRef, -) -> Result { - if let Some(field) = proto.as_ref() { - try_parse_physical_expr(field, input_schema) - } else { - Err(proto_error("Missing required field in protobuf")) + fn try_parse_physical_expr_required( + &self, + proto: &Option, + input_schema: &SchemaRef, + ) -> Result { + if let Some(field) = proto.as_ref() { + self.try_parse_physical_expr(field, input_schema) + } else { + Err(proto_error("Missing required field in protobuf")) + } } -} -fn try_parse_physical_expr_box_required( - proto: &Option>, - input_schema: &SchemaRef, -) -> Result { - if let Some(field) = proto.as_ref() { - try_parse_physical_expr(field, input_schema) - } else { - Err(proto_error("Missing required field in protobuf")) + fn try_parse_physical_expr_box_required( + &self, + proto: &Option>, + input_schema: &SchemaRef, + ) -> Result { + if let Some(field) = proto.as_ref() { + self.try_parse_physical_expr(field, input_schema) + } else { + Err(proto_error("Missing required field in protobuf")) + } } -} -fn try_parse_physical_sort_expr( - input: &Arc, - sort: &Box, -) -> Result, PlanSerDeError> { - let physical_sort_expr = sort - .expr - .iter() - .map(|expr| { - let expr = expr.expr_type.as_ref().ok_or_else(|| { + fn try_parse_physical_sort_expr( + &self, + input: &Arc, + sort: &Box, + ) -> Result, PlanSerDeError> { + let pyhsical_sort_expr = sort + .expr + .iter() + .map(|expr| { + let expr = expr.expr_type.as_ref().ok_or_else(|| { + proto_error(format!( + "physical_plan::from_proto() Unexpected expr {:?}", + input + )) + })?; + if let ExprType::Sort(sort_expr) = expr { + let expr = sort_expr + .expr + .as_ref() + .ok_or_else(|| { + proto_error(format!( + "physical_plan::from_proto() Unexpected sort expr {:?}", + input + )) + })? + .as_ref(); + Ok(PhysicalSortExpr { + expr: self.try_parse_physical_expr(expr, &input.schema())?, + options: SortOptions { + descending: !sort_expr.asc, + nulls_first: sort_expr.nulls_first, + }, + }) + } else { + Err(PlanSerDeError::General(format!( + "physical_plan::from_proto() {:?}", + input + ))) + } + }) + .collect::, _>>()?; + Ok(pyhsical_sort_expr) + } + + pub fn parse_protobuf_partitioning( + &self, + input: Arc, + partitioning: Option<&Box>, + ) -> Result, PlanSerDeError> { + partitioning.map_or(Ok(None), |p| { + let plan = p.repartition_type.as_ref().ok_or_else(|| { proto_error(format!( - "physical_plan::from_proto() Unexpected expr {:?}", - input + "partition::from_proto() Unsupported partition '{:?}'", + p )) })?; - if let ExprType::Sort(sort_expr) = expr { - let expr = sort_expr - .expr - .as_ref() - .ok_or_else(|| { - proto_error(format!( - "physical_plan::from_proto() Unexpected sort expr {:?}", - input - )) - })? - .as_ref(); - Ok(PhysicalSortExpr { - expr: try_parse_physical_expr(expr, &input.schema())?, - options: SortOptions { - descending: !sort_expr.asc, - nulls_first: sort_expr.nulls_first, - }, - }) - } else { - Err(PlanSerDeError::General(format!( - "physical_plan::from_proto() {:?}", - input - ))) + match plan { + RepartitionType::SingleRepartition(..) => { + Ok(Some(Partitioning::SinglePartitioning())) + } + RepartitionType::HashRepartition(hash_part) => { + // let hash_part = p.hash_repartition; + let expr = hash_part + .hash_expr + .iter() + .map(|e| self.try_parse_physical_expr(e, &input.schema())) + .collect::, _>>()?; + Ok(Some(Partitioning::HashPartitioning( + expr, + hash_part.partition_count.try_into().unwrap(), + ))) + } + + RepartitionType::RoundRobinRepartition(round_robin_part) => { + Ok(Some(Partitioning::RoundRobinPartitioning( + round_robin_part.partition_count.try_into().unwrap(), + ))) + } + + RepartitionType::RangeRepartition(range_part) => { + if range_part.partition_count == 1 { + Ok(Some(Partitioning::SinglePartitioning())) + } else { + let sort = range_part.sort_expr.clone().unwrap(); + let exprs = self + .try_parse_physical_sort_expr(&input, &sort) + .unwrap_or_else(|e| { + panic!("Failed to parse physical sort expressions: {}", e); + }); + + let value_list: Vec = range_part + .list_value + .iter() + .map(|v| v.try_into()) + .collect::, _>>()?; + + let sort_row_converter = Arc::new(SyncMutex::new(RowConverter::new( + exprs + .iter() + .map(|expr: &PhysicalSortExpr| { + Ok(SortField::new_with_options( + expr.expr.data_type(&input.schema())?, + expr.options, + )) + }) + .collect::>>()?, + )?)); + + let bound_cols: Vec = value_list + .iter() + .map(|x| { + if let ScalarValue::List(single) = x { + return single.value(0); + } else { + unreachable!("expect list scalar value"); + } + }) + .collect::>(); + + let bound_rows = sort_row_converter.lock().convert_columns(&bound_cols)?; + Ok(Some(Partitioning::RangePartitioning( + exprs, + range_part.partition_count.try_into().unwrap(), + Arc::new(bound_rows), + ))) + } + } } }) - .collect::, _>>()?; - Ok(physical_sort_expr) + } } -pub fn parse_protobuf_partitioning( - input: Arc, - partitioning: Option<&Box>, -) -> Result, PlanSerDeError> { - partitioning.map_or(Ok(None), |p| { - let plan = p.repartition_type.as_ref().ok_or_else(|| { - proto_error(format!( - "partition::from_proto() Unsupported partition '{:?}'", - p - )) - })?; - match plan { - RepartitionType::SingleRepartition(..) => Ok(Some(Partitioning::SinglePartitioning())), - RepartitionType::HashRepartition(hash_part) => { - // let hash_part = p.hash_repartition; - let expr = hash_part - .hash_expr - .iter() - .map(|e| try_parse_physical_expr(e, &input.schema())) - .collect::, _>>()?; - Ok(Some(Partitioning::HashPartitioning( - expr, - hash_part.partition_count.try_into().unwrap(), - ))) - } +impl From<&protobuf::PhysicalColumn> for Column { + fn from(c: &protobuf::PhysicalColumn) -> Column { + Column::new(&c.name, c.index as usize) + } +} - RepartitionType::RoundRobinRepartition(round_robin_part) => { - Ok(Some(Partitioning::RoundRobinPartitioning( - round_robin_part.partition_count.try_into().unwrap(), - ))) - } +impl From<&protobuf::BoundReference> for Column { + fn from(c: &protobuf::BoundReference) -> Column { + Column::new("__bound_reference__", c.index as usize) + } +} - RepartitionType::RangeRepartition(range_part) => { - if range_part.partition_count == 1 { - Ok(Some(Partitioning::SinglePartitioning())) - } else { - let sort = range_part.sort_expr.clone().unwrap(); - let exprs = try_parse_physical_sort_expr(&input, &sort).unwrap_or_else(|e| { - panic!("Failed to parse physical sort expressions: {}", e); - }); +impl From for Arc { + fn from(f: protobuf::ScalarFunction) -> Self { + use datafusion::functions as f; + use datafusion_spark::function as spark_fun; + use protobuf::ScalarFunction; - let value_list: Vec = range_part - .list_value - .iter() - .map(|v| v.try_into()) - .collect::, _>>()?; + match f { + ScalarFunction::Sqrt => f::math::sqrt(), + ScalarFunction::Sin => f::math::sin(), + ScalarFunction::Cos => f::math::cos(), + ScalarFunction::Tan => f::math::tan(), + ScalarFunction::Asin => f::math::asin(), + ScalarFunction::Acos => f::math::acos(), + ScalarFunction::Atan => f::math::atan(), + ScalarFunction::Exp => f::math::exp(), + ScalarFunction::Log => f::math::log(), + ScalarFunction::Ln => f::math::ln(), + ScalarFunction::Log10 => f::math::log10(), + ScalarFunction::Floor => f::math::floor(), + ScalarFunction::Ceil => f::math::ceil(), + ScalarFunction::Round => f::math::round(), + ScalarFunction::Trunc => f::math::trunc(), + ScalarFunction::Abs => f::math::abs(), + ScalarFunction::OctetLength => f::string::octet_length(), + ScalarFunction::Concat => f::string::concat(), + ScalarFunction::Lower => f::string::lower(), + ScalarFunction::Upper => f::string::upper(), + ScalarFunction::Trim => f::string::btrim(), + ScalarFunction::Ltrim => f::string::ltrim(), + ScalarFunction::Rtrim => f::string::rtrim(), + ScalarFunction::ToTimestamp => f::datetime::to_timestamp(), + ScalarFunction::NullIf => f::core::nullif(), + ScalarFunction::Nvl2 => f::core::nvl2(), + ScalarFunction::Nvl => f::core::nvl(), + ScalarFunction::DatePart => f::datetime::date_part(), + ScalarFunction::DateTrunc => f::datetime::date_trunc(), + // ScalarFunction::Md5 => f::crypto::md5(), + // ScalarFunction::Sha224 => f::crypto::sha224(), + // ScalarFunction::Sha256 => f::crypto::sha256(), + // ScalarFunction::Sha384 => f::crypto::sha384(), + // ScalarFunction::Sha512 => f::crypto::sha512(), + ScalarFunction::Digest => f::crypto::digest(), + ScalarFunction::ToTimestampMillis => f::datetime::to_timestamp_millis(), + ScalarFunction::Log2 => f::math::log2(), + ScalarFunction::Signum => f::math::signum(), + ScalarFunction::Ascii => f::string::ascii(), + ScalarFunction::BitLength => f::string::bit_length(), + ScalarFunction::Btrim => f::string::btrim(), + ScalarFunction::CharacterLength => f::unicode::character_length(), + ScalarFunction::Chr => f::string::chr(), + ScalarFunction::ConcatWithSeparator => f::string::concat_ws(), + // ScalarFunction::InitCap => f::unicode::initcap(), + ScalarFunction::Left => f::unicode::left(), + ScalarFunction::Lpad => f::unicode::lpad(), + ScalarFunction::Random => f::math::random(), + ScalarFunction::RegexpReplace => f::regex::regexp_replace(), + ScalarFunction::Repeat => f::string::repeat(), + ScalarFunction::Replace => f::string::replace(), + ScalarFunction::Reverse => f::unicode::reverse(), + ScalarFunction::Right => f::unicode::right(), + ScalarFunction::Rpad => f::unicode::rpad(), + ScalarFunction::SplitPart => f::string::split_part(), + ScalarFunction::StartsWith => f::string::starts_with(), + ScalarFunction::Levenshtein => f::string::levenshtein(), - let sort_row_converter = Arc::new(SyncMutex::new(RowConverter::new( - exprs - .iter() - .map(|expr: &PhysicalSortExpr| { - Ok(SortField::new_with_options( - expr.expr.data_type(&input.schema())?, - expr.options, - )) - }) - .collect::>>()?, - )?)); + ScalarFunction::FindInSet => f::unicode::find_in_set(), + ScalarFunction::Strpos => f::unicode::strpos(), + ScalarFunction::Substr => f::unicode::substr(), + // ScalarFunction::ToHex => f::string::to_hex(), + ScalarFunction::ToTimestampMicros => f::datetime::to_timestamp_micros(), + ScalarFunction::ToTimestampSeconds => f::datetime::to_timestamp_seconds(), + ScalarFunction::Now => f::datetime::now(), + ScalarFunction::Translate => f::unicode::translate(), + ScalarFunction::RegexpMatch => f::regex::regexp_match(), + ScalarFunction::Greatest => f::core::greatest(), + ScalarFunction::Coalesce => f::core::coalesce(), + ScalarFunction::Least => f::core::least(), - let bound_cols: Vec = value_list - .iter() - .map(|x| { - if let ScalarValue::List(single) = x { - return single.value(0); - } else { - unreachable!("expect list scalar value"); - } - }) - .collect::>(); + // -- datafusion-spark functions + // math functions + ScalarFunction::Expm1 => spark_fun::math::expm1(), + ScalarFunction::Factorial => spark_fun::math::factorial(), + ScalarFunction::Hex => spark_fun::math::hex(), - let bound_rows = sort_row_converter.lock().convert_columns(&bound_cols)?; - Ok(Some(Partitioning::RangePartitioning( - exprs, - range_part.partition_count.try_into().unwrap(), - Arc::new(bound_rows), - ))) - } + ScalarFunction::Power => f::math::power(), + ScalarFunction::IsNaN => f::math::isnan(), + + ScalarFunction::AuronExtFunctions => { + unreachable!() } } - }) + } } impl TryFrom<&protobuf::PartitionedFile> for PartitionedFile { type Error = PlanSerDeError; - fn try_from(val: &protobuf::PartitionedFile) -> Result { + fn try_from(val: &protobuf::PartitionedFile) -> Result { Ok(PartitionedFile { object_meta: ObjectMeta { location: Path::from(format!("/{}", BASE64_URL_SAFE_NO_PAD.encode(&val.path))), @@ -1227,7 +1267,7 @@ impl TryFrom<&protobuf::PartitionedFile> for PartitionedFile { impl TryFrom<&protobuf::FileRange> for FileRange { type Error = PlanSerDeError; - fn try_from(value: &protobuf::FileRange) -> Result { + fn try_from(value: &protobuf::FileRange) -> Result { Ok(FileRange { start: value.start, end: value.end, @@ -1238,7 +1278,7 @@ impl TryFrom<&protobuf::FileRange> for FileRange { impl TryFrom<&protobuf::FileGroup> for Vec { type Error = PlanSerDeError; - fn try_from(val: &protobuf::FileGroup) -> Result { + fn try_from(val: &protobuf::FileGroup) -> Result { val.files .iter() .map(|f| f.try_into()) @@ -1269,7 +1309,7 @@ impl From<&protobuf::ColumnStats> for ColumnStatistics { impl TryInto for &protobuf::Statistics { type Error = PlanSerDeError; - fn try_into(self) -> Result { + fn try_into(self) -> Result { let column_statistics = self .column_stats .iter() @@ -1287,7 +1327,7 @@ impl TryInto for &protobuf::Statistics { impl TryInto for &protobuf::FileScanExecConf { type Error = PlanSerDeError; - fn try_into(self) -> Result { + fn try_into(self) -> Result { let schema: SchemaRef = Arc::new(convert_required!(self.schema)?); let projection = self .projection diff --git a/native-engine/auron/Cargo.toml b/native-engine/auron/Cargo.toml index 9e3f1353c..81bfa9864 100644 --- a/native-engine/auron/Cargo.toml +++ b/native-engine/auron/Cargo.toml @@ -34,7 +34,7 @@ http-service = [] arrow = { workspace = true } auron-jni-bridge = { workspace = true } auron-memmgr = { workspace = true } -auron-serde = { workspace = true } +auron-planner = { workspace = true } datafusion = { workspace = true } datafusion-ext-commons = { workspace = true } datafusion-ext-plans = { workspace = true } diff --git a/native-engine/auron/src/rt.rs b/native-engine/auron/src/rt.rs index 7389e79a9..a8913b40d 100644 --- a/native-engine/auron/src/rt.rs +++ b/native-engine/auron/src/rt.rs @@ -29,7 +29,7 @@ use auron_jni_bridge::{ is_task_running, jni_call, jni_call_static, jni_convert_byte_array, jni_exception_check, jni_exception_occurred, jni_new_global_ref, jni_new_object, jni_new_string, }; -use auron_serde::protobuf::TaskDefinition; +use auron_planner::{planner::PhysicalPlanner, protobuf::TaskDefinition}; use datafusion::{ common::Result, error::DataFusionError, @@ -83,9 +83,11 @@ impl NativeExecutionRuntime { let plan = &task_definition.plan.expect("plan is empty"); drop(raw_task_definition); + let planner = PhysicalPlanner::new(partition_id); + // get execution plan - let execution_plan: Arc = plan - .try_into() + let execution_plan: Arc = planner + .create_plan(plan) .or_else(|err| df_execution_err!("cannot create execution plan: {err:?}"))?; let exec_ctx = ExecutionContext::new( diff --git a/native-engine/datafusion-ext-functions/src/lib.rs b/native-engine/datafusion-ext-functions/src/lib.rs index e1989de42..0117359e8 100644 --- a/native-engine/datafusion-ext-functions/src/lib.rs +++ b/native-engine/datafusion-ext-functions/src/lib.rs @@ -34,7 +34,10 @@ mod spark_round; mod spark_strings; mod spark_unscaled_value; -pub fn create_auron_ext_function(name: &str) -> Result { +pub fn create_auron_ext_function( + name: &str, + spark_partition_id: usize, +) -> Result { // auron ext functions, if used for spark should be start with 'Spark_', // if used for flink should be start with 'Flink_', // same to other engines.