diff --git a/rust/ballista/.dockerignore b/rust/ballista/.dockerignore new file mode 100644 index 0000000000000..3cde49e0a0c4c --- /dev/null +++ b/rust/ballista/.dockerignore @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +rust/**/target diff --git a/rust/ballista/rust/Cargo.toml b/rust/ballista/rust/Cargo.toml index d1f588f3bd75c..5e344e004b838 100644 --- a/rust/ballista/rust/Cargo.toml +++ b/rust/ballista/rust/Cargo.toml @@ -25,6 +25,6 @@ members = [ "scheduler", ] -[profile.release] -lto = true -codegen-units = 1 +#[profile.release] +#lto = true +#codegen-units = 1 diff --git a/rust/ballista/rust/benchmarks/tpch/Cargo.toml b/rust/ballista/rust/benchmarks/tpch/Cargo.toml index 55a0fe1330cfd..8c37f8898fca4 100644 --- a/rust/ballista/rust/benchmarks/tpch/Cargo.toml +++ b/rust/ballista/rust/benchmarks/tpch/Cargo.toml @@ -27,9 +27,13 @@ edition = "2018" [dependencies] ballista = { path="../../client" } -arrow = { git = "https://github.com/apache/arrow", rev="46161d2" } -datafusion = { git = "https://github.com/apache/arrow", rev="46161d2" } -parquet = { git = "https://github.com/apache/arrow", rev="46161d2" } +#arrow = { path = "../../../../arrow" } +#datafusion = { path = "../../../../datafusion" } +#parquet = { path = "../../../../parquet" } + +arrow = { git = "https://github.com/apache/arrow", rev="fe83dca" } +datafusion = { git = "https://github.com/apache/arrow", rev="fe83dca" } +parquet = { git = "https://github.com/apache/arrow", rev="fe83dca" } env_logger = "0.8" diff --git a/rust/ballista/rust/client/Cargo.toml b/rust/ballista/rust/client/Cargo.toml index 966e2dcbb31f5..8ee5d427baef1 100644 --- a/rust/ballista/rust/client/Cargo.toml +++ b/rust/ballista/rust/client/Cargo.toml @@ -30,5 +30,9 @@ ballista-core = { path = "../core" } futures = "0.3" log = "0.4" tokio = "1.0" -arrow = { git = "https://github.com/apache/arrow", rev="46161d2" } -datafusion = { git = "https://github.com/apache/arrow", rev="46161d2" } + +#arrow = { path = "../../../arrow" } +#datafusion = { path = "../../../datafusion" } + +arrow = { git = "https://github.com/apache/arrow", rev="fe83dca" } +datafusion = { git = "https://github.com/apache/arrow", rev="fe83dca" } diff --git a/rust/ballista/rust/client/src/context.rs b/rust/ballista/rust/client/src/context.rs index 8b2431f56c27c..0556c2948daad 100644 --- a/rust/ballista/rust/client/src/context.rs +++ b/rust/ballista/rust/client/src/context.rs @@ -36,6 +36,7 @@ use ballista_core::{ }; use arrow::datatypes::Schema; +use datafusion::catalog::TableReference; use datafusion::execution::context::ExecutionContext; use datafusion::logical_plan::{DFSchema, Expr, LogicalPlan, Partitioning}; use datafusion::physical_plan::csv::CsvReadOptions; @@ -148,7 +149,10 @@ impl BallistaContext { for (name, plan) in &state.tables { let plan = ctx.optimize(plan)?; let execution_plan = ctx.create_physical_plan(&plan)?; - ctx.register_table(name, Arc::new(DFTableAdapter::new(plan, execution_plan))); + ctx.register_table( + TableReference::Bare { table: name }, + Arc::new(DFTableAdapter::new(plan, execution_plan)), + )?; } let df = ctx.sql(sql)?; Ok(BallistaDataFrame::from(self.state.clone(), df)) @@ -267,7 +271,7 @@ impl BallistaDataFrame { )) } - pub fn select(&self, expr: &[Expr]) -> Result { + pub fn select(&self, expr: Vec) -> Result { Ok(Self::from( self.state.clone(), self.df.select(expr).map_err(BallistaError::from)?, @@ -283,8 +287,8 @@ impl BallistaDataFrame { pub fn aggregate( &self, - group_expr: &[Expr], - aggr_expr: &[Expr], + group_expr: Vec, + aggr_expr: Vec, ) -> Result { Ok(Self::from( self.state.clone(), @@ -301,7 +305,7 @@ impl BallistaDataFrame { )) } - pub fn sort(&self, expr: &[Expr]) -> Result { + pub fn sort(&self, expr: Vec) -> Result { Ok(Self::from( self.state.clone(), self.df.sort(expr).map_err(BallistaError::from)?, diff --git a/rust/ballista/rust/core/Cargo.toml b/rust/ballista/rust/core/Cargo.toml index f5f6f8574b31e..60c38725bf787 100644 --- a/rust/ballista/rust/core/Cargo.toml +++ b/rust/ballista/rust/core/Cargo.toml @@ -39,10 +39,14 @@ sqlparser = "0.8" tokio = "1.0" tonic = "0.4" uuid = { version = "0.8", features = ["v4"] } -arrow = { git = "https://github.com/apache/arrow", rev="46161d2" } -arrow-flight = { git = "https://github.com/apache/arrow", rev="46161d2" } -datafusion = { git = "https://github.com/apache/arrow", rev="46161d2" } +#arrow = { path = "../../../arrow" } +#arrow-flight = { path = "../../../arrow-flight" } +#datafusion = { path = "../../../datafusion" } + +arrow = { git = "https://github.com/apache/arrow", rev="fe83dca" } +arrow-flight = { git = "https://github.com/apache/arrow", rev="fe83dca" } +datafusion = { git = "https://github.com/apache/arrow", rev="fe83dca" } [dev-dependencies] diff --git a/rust/ballista/rust/core/proto/ballista.proto b/rust/ballista/rust/core/proto/ballista.proto index ff0727b78875e..5733921bc92fb 100644 --- a/rust/ballista/rust/core/proto/ballista.proto +++ b/rust/ballista/rust/core/proto/ballista.proto @@ -59,6 +59,7 @@ message LogicalExprNode { InListNode in_list = 14; bool wildcard = 15; ScalarFunctionNode scalar_function = 16; + TryCastNode try_cast = 17; } } @@ -172,6 +173,11 @@ message CastNode { ArrowType arrow_type = 2; } +message TryCastNode { + LogicalExprNode expr = 1; + ArrowType arrow_type = 2; +} + message SortExprNode { LogicalExprNode expr = 1; bool asc = 2; diff --git a/rust/ballista/rust/core/src/datasource.rs b/rust/ballista/rust/core/src/datasource.rs index 531f63df40e4b..8ff0df44e4be4 100644 --- a/rust/ballista/rust/core/src/datasource.rs +++ b/rust/ballista/rust/core/src/datasource.rs @@ -57,6 +57,7 @@ impl TableProvider for DFTableAdapter { _projection: &Option>, _batch_size: usize, _filters: &[Expr], + _limit: Option, ) -> DFResult> { Ok(self.plan.clone()) } diff --git a/rust/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/rust/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 087ebdbf507c0..93084260662f8 100644 --- a/rust/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/rust/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -52,14 +52,13 @@ impl TryInto for &protobuf::LogicalPlanNode { match plan { LogicalPlanType::Projection(projection) => { let input: LogicalPlan = convert_box_required!(projection.input)?; + let x: Vec = projection + .expr + .iter() + .map(|expr| expr.try_into()) + .collect::, _>>()?; LogicalPlanBuilder::from(&input) - .project( - &projection - .expr - .iter() - .map(|expr| expr.try_into()) - .collect::, _>>()?, - )? + .project(x)? .build() .map_err(|e| e.into()) } @@ -89,7 +88,7 @@ impl TryInto for &protobuf::LogicalPlanNode { .map(|expr| expr.try_into()) .collect::, _>>()?; LogicalPlanBuilder::from(&input) - .aggregate(&group_expr, &aggr_expr)? + .aggregate(group_expr, aggr_expr)? .build() .map_err(|e| e.into()) } @@ -148,7 +147,7 @@ impl TryInto for &protobuf::LogicalPlanNode { .map(|expr| expr.try_into()) .collect::, _>>()?; LogicalPlanBuilder::from(&input) - .sort(&sort_expr)? + .sort(sort_expr)? .build() .map_err(|e| e.into()) } @@ -511,10 +510,10 @@ fn typechecked_scalar_value_conversion( ScalarValue::Date32(Some(*v)) } (Value::TimeMicrosecondValue(v), PrimitiveScalarType::TimeMicrosecond) => { - ScalarValue::TimeMicrosecond(Some(*v)) + ScalarValue::TimestampMicrosecond(Some(*v)) } (Value::TimeNanosecondValue(v), PrimitiveScalarType::TimeMicrosecond) => { - ScalarValue::TimeNanosecond(Some(*v)) + ScalarValue::TimestampNanosecond(Some(*v)) } (Value::Utf8Value(v), PrimitiveScalarType::Utf8) => { ScalarValue::Utf8(Some(v.to_owned())) @@ -547,10 +546,10 @@ fn typechecked_scalar_value_conversion( PrimitiveScalarType::LargeUtf8 => ScalarValue::LargeUtf8(None), PrimitiveScalarType::Date32 => ScalarValue::Date32(None), PrimitiveScalarType::TimeMicrosecond => { - ScalarValue::TimeMicrosecond(None) + ScalarValue::TimestampMicrosecond(None) } PrimitiveScalarType::TimeNanosecond => { - ScalarValue::TimeNanosecond(None) + ScalarValue::TimestampNanosecond(None) } PrimitiveScalarType::Null => { return Err(proto_error( @@ -610,10 +609,10 @@ impl TryInto for &protobuf::scalar_value::Value ScalarValue::Date32(Some(*v)) } protobuf::scalar_value::Value::TimeMicrosecondValue(v) => { - ScalarValue::TimeMicrosecond(Some(*v)) + ScalarValue::TimestampMicrosecond(Some(*v)) } protobuf::scalar_value::Value::TimeNanosecondValue(v) => { - ScalarValue::TimeNanosecond(Some(*v)) + ScalarValue::TimestampNanosecond(Some(*v)) } protobuf::scalar_value::Value::ListValue(v) => v.try_into()?, protobuf::scalar_value::Value::NullListValue(v) => { @@ -776,10 +775,10 @@ impl TryInto for protobuf::PrimitiveScalarType protobuf::PrimitiveScalarType::LargeUtf8 => ScalarValue::LargeUtf8(None), protobuf::PrimitiveScalarType::Date32 => ScalarValue::Date32(None), protobuf::PrimitiveScalarType::TimeMicrosecond => { - ScalarValue::TimeMicrosecond(None) + ScalarValue::TimestampMicrosecond(None) } protobuf::PrimitiveScalarType::TimeNanosecond => { - ScalarValue::TimeNanosecond(None) + ScalarValue::TimestampNanosecond(None) } }) } @@ -829,10 +828,10 @@ impl TryInto for &protobuf::ScalarValue { ScalarValue::Date32(Some(*v)) } protobuf::scalar_value::Value::TimeMicrosecondValue(v) => { - ScalarValue::TimeMicrosecond(Some(*v)) + ScalarValue::TimestampMicrosecond(Some(*v)) } protobuf::scalar_value::Value::TimeNanosecondValue(v) => { - ScalarValue::TimeNanosecond(Some(*v)) + ScalarValue::TimestampNanosecond(Some(*v)) } protobuf::scalar_value::Value::ListValue(scalar_list) => { let protobuf::ScalarListValue { @@ -962,6 +961,15 @@ impl TryInto for &protobuf::LogicalExprNode { let data_type = arrow_type.try_into()?; Ok(Expr::Cast { expr, data_type }) } + ExprType::TryCast(cast) => { + let expr = Box::new(parse_required_expr(&cast.expr)?); + let arrow_type: &protobuf::ArrowType = cast + .arrow_type + .as_ref() + .ok_or_else(|| proto_error("Protobuf deserialization error: CastNode message missing required field 'arrow_type'"))?; + let data_type = arrow_type.try_into()?; + Ok(Expr::TryCast { expr, data_type }) + } ExprType::Sort(sort) => Ok(Expr::Sort { expr: Box::new(parse_required_expr(&sort.expr)?), asc: sort.asc, diff --git a/rust/ballista/rust/core/src/serde/logical_plan/mod.rs b/rust/ballista/rust/core/src/serde/logical_plan/mod.rs index 50a529b6fa1f6..48dd96c4d3f31 100644 --- a/rust/ballista/rust/core/src/serde/logical_plan/mod.rs +++ b/rust/ballista/rust/core/src/serde/logical_plan/mod.rs @@ -82,7 +82,7 @@ mod roundtrip_tests { CsvReadOptions::new().schema(&schema).has_header(true), Some(vec![3, 4]), ) - .and_then(|plan| plan.sort(&[col("salary")])) + .and_then(|plan| plan.sort(vec![col("salary")])) .and_then(|plan| plan.build()) .map_err(BallistaError::DataFusionError)?, ); @@ -212,8 +212,8 @@ mod roundtrip_tests { ScalarValue::LargeUtf8(None), ScalarValue::List(None, DataType::Boolean), ScalarValue::Date32(None), - ScalarValue::TimeMicrosecond(None), - ScalarValue::TimeNanosecond(None), + ScalarValue::TimestampMicrosecond(None), + ScalarValue::TimestampNanosecond(None), ScalarValue::Boolean(Some(true)), ScalarValue::Boolean(Some(false)), ScalarValue::Float32(Some(1.0)), @@ -252,11 +252,11 @@ mod roundtrip_tests { ScalarValue::LargeUtf8(Some(String::from("Test Large utf8"))), ScalarValue::Date32(Some(0)), ScalarValue::Date32(Some(i32::MAX)), - ScalarValue::TimeNanosecond(Some(0)), - ScalarValue::TimeNanosecond(Some(i64::MAX)), - ScalarValue::TimeMicrosecond(Some(0)), - ScalarValue::TimeMicrosecond(Some(i64::MAX)), - ScalarValue::TimeMicrosecond(None), + ScalarValue::TimestampNanosecond(Some(0)), + ScalarValue::TimestampNanosecond(Some(i64::MAX)), + ScalarValue::TimestampMicrosecond(Some(0)), + ScalarValue::TimestampMicrosecond(Some(i64::MAX)), + ScalarValue::TimestampMicrosecond(None), ScalarValue::List( Some(vec![ ScalarValue::Float32(Some(-213.1)), @@ -610,8 +610,8 @@ mod roundtrip_tests { ScalarValue::Utf8(None), ScalarValue::LargeUtf8(None), ScalarValue::Date32(None), - ScalarValue::TimeMicrosecond(None), - ScalarValue::TimeNanosecond(None), + ScalarValue::TimestampMicrosecond(None), + ScalarValue::TimestampNanosecond(None), //ScalarValue::List(None, DataType::Boolean) ]; @@ -679,7 +679,7 @@ mod roundtrip_tests { CsvReadOptions::new().schema(&schema).has_header(true), Some(vec![3, 4]), ) - .and_then(|plan| plan.sort(&[col("salary")])) + .and_then(|plan| plan.sort(vec![col("salary")])) .and_then(|plan| plan.explain(true)) .and_then(|plan| plan.build()) .map_err(BallistaError::DataFusionError)?; @@ -689,7 +689,7 @@ mod roundtrip_tests { CsvReadOptions::new().schema(&schema).has_header(true), Some(vec![3, 4]), ) - .and_then(|plan| plan.sort(&[col("salary")])) + .and_then(|plan| plan.sort(vec![col("salary")])) .and_then(|plan| plan.explain(false)) .and_then(|plan| plan.build()) .map_err(BallistaError::DataFusionError)?; @@ -742,7 +742,7 @@ mod roundtrip_tests { CsvReadOptions::new().schema(&schema).has_header(true), Some(vec![3, 4]), ) - .and_then(|plan| plan.sort(&[col("salary")])) + .and_then(|plan| plan.sort(vec![col("salary")])) .and_then(|plan| plan.build()) .map_err(BallistaError::DataFusionError)?; roundtrip_test!(plan); @@ -784,7 +784,7 @@ mod roundtrip_tests { CsvReadOptions::new().schema(&schema).has_header(true), Some(vec![3, 4]), ) - .and_then(|plan| plan.aggregate(&[col("state")], &[max(col("salary"))])) + .and_then(|plan| plan.aggregate(vec![col("state")], vec![max(col("salary"))])) .and_then(|plan| plan.build()) .map_err(BallistaError::DataFusionError)?; diff --git a/rust/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/rust/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 69b53502fc9b9..a181f98b6eb6c 100644 --- a/rust/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/rust/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -641,12 +641,12 @@ impl TryFrom<&datafusion::scalar::ScalarValue> for protobuf::ScalarValue { datafusion::scalar::ScalarValue::Date32(val) => { create_proto_scalar(val, PrimitiveScalarType::Date32, |s| Value::Date32Value(*s)) } - datafusion::scalar::ScalarValue::TimeMicrosecond(val) => { + datafusion::scalar::ScalarValue::TimestampMicrosecond(val) => { create_proto_scalar(val, PrimitiveScalarType::TimeMicrosecond, |s| { Value::TimeMicrosecondValue(*s) }) } - datafusion::scalar::ScalarValue::TimeNanosecond(val) => { + datafusion::scalar::ScalarValue::TimestampNanosecond(val) => { create_proto_scalar(val, PrimitiveScalarType::TimeNanosecond, |s| { Value::TimeNanosecondValue(*s) }) @@ -939,10 +939,7 @@ impl TryInto for &LogicalPlan { }) } LogicalPlan::Extension { .. } => unimplemented!(), - // _ => Err(BallistaError::General(format!( - // "logical plan to_proto {:?}", - // self - // ))), + LogicalPlan::Union { .. } => unimplemented!(), } } } @@ -1161,10 +1158,7 @@ impl TryInto for &Expr { Expr::Wildcard => Ok(protobuf::LogicalExprNode { expr_type: Some(protobuf::logical_expr_node::ExprType::Wildcard(true)), }), - // _ => Err(BallistaError::General(format!( - // "logical expr to_proto {:?}", - // self - // ))), + Expr::TryCast { .. } => unimplemented!(), } } } diff --git a/rust/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/rust/ballista/rust/core/src/serde/physical_plan/from_proto.rs index cb04a3e819688..be0777dbb9a8f 100644 --- a/rust/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/rust/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -30,11 +30,15 @@ use crate::serde::{proto_error, protobuf}; use crate::{convert_box_required, convert_required}; use arrow::datatypes::{DataType, Schema, SchemaRef}; +use datafusion::catalog::catalog::{ + CatalogList, CatalogProvider, MemoryCatalogList, MemoryCatalogProvider, +}; use datafusion::execution::context::{ExecutionConfig, ExecutionContextState}; use datafusion::logical_plan::{DFSchema, Expr}; use datafusion::physical_plan::aggregates::{create_aggregate_expr, AggregateFunction}; use datafusion::physical_plan::expressions::col; use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; +use datafusion::physical_plan::hash_join::PartitionMode; use datafusion::physical_plan::merge::MergeExec; use datafusion::physical_plan::planner::DefaultPhysicalPlanner; use datafusion::physical_plan::{ @@ -102,15 +106,13 @@ impl TryInto> for &protobuf::PhysicalPlanNode { .file_extension(&scan.file_extension) .delimiter(scan.delimiter.as_bytes()[0]) .schema(&schema); - // TODO we don't care what the DataFusion batch size was because Ballista will - // have its own configs. Hard-code for now. - let batch_size = 32768; let projection = scan.projection.iter().map(|i| *i as usize).collect(); Ok(Arc::new(CsvExec::try_new( &scan.path, options, Some(projection), - batch_size, + scan.batch_size as usize, + None, )?)) } PhysicalPlanType::ParquetScan(scan) => { @@ -123,6 +125,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { None, scan.batch_size as usize, scan.num_partitions as usize, + None, )?)) } PhysicalPlanType::CoalesceBatches(coalesce_batches) => { @@ -215,8 +218,10 @@ impl TryInto> for &protobuf::PhysicalPlanNode { .collect::, _>>()?; let df_planner = DefaultPhysicalPlanner::default(); + let catalog_list = + Arc::new(MemoryCatalogList::new()) as Arc; let ctx_state = ExecutionContextState { - datasources: Default::default(), + catalog_list, scalar_functions: Default::default(), var_provider: Default::default(), aggregate_functions: Default::default(), @@ -294,7 +299,11 @@ impl TryInto> for &protobuf::PhysicalPlanNode { protobuf::JoinType::Right => JoinType::Right, }; Ok(Arc::new(HashJoinExec::try_new( - left, right, &on, &join_type, + left, + right, + &on, + &join_type, + PartitionMode::CollectLeft, )?)) } PhysicalPlanType::ShuffleReader(shuffle_reader) => { @@ -374,8 +383,9 @@ fn compile_expr( schema: &Schema, ) -> Result, BallistaError> { let df_planner = DefaultPhysicalPlanner::default(); + let catalog_list = Arc::new(MemoryCatalogList::new()) as Arc; let state = ExecutionContextState { - datasources: HashMap::new(), + catalog_list, scalar_functions: HashMap::new(), var_provider: HashMap::new(), aggregate_functions: HashMap::new(), diff --git a/rust/ballista/rust/core/src/serde/physical_plan/mod.rs b/rust/ballista/rust/core/src/serde/physical_plan/mod.rs index a6f146c73841b..e7985cc84a9a7 100644 --- a/rust/ballista/rust/core/src/serde/physical_plan/mod.rs +++ b/rust/ballista/rust/core/src/serde/physical_plan/mod.rs @@ -40,6 +40,7 @@ mod roundtrip_tests { use super::super::super::error::Result; use super::super::protobuf; + use datafusion::physical_plan::hash_join::PartitionMode; fn roundtrip_test(exec_plan: Arc) -> Result<()> { let proto: protobuf::PhysicalPlanNode = exec_plan.clone().try_into()?; @@ -84,6 +85,7 @@ mod roundtrip_tests { Arc::new(EmptyExec::new(false, Arc::new(schema_right))), &[("col".to_string(), "col".to_string())], &JoinType::Inner, + PartitionMode::CollectLeft, )?)) } diff --git a/rust/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/rust/ballista/rust/core/src/serde/physical_plan/to_proto.rs index 24c69c4692a2d..5352c1f777530 100644 --- a/rust/ballista/rust/core/src/serde/physical_plan/to_proto.rs +++ b/rust/ballista/rust/core/src/serde/physical_plan/to_proto.rs @@ -28,10 +28,10 @@ use std::{ use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::csv::CsvExec; -use datafusion::physical_plan::expressions::CastExpr; use datafusion::physical_plan::expressions::{ CaseExpr, InListExpr, IsNotNullExpr, IsNullExpr, NegativeExpr, NotExpr, }; +use datafusion::physical_plan::expressions::{CastExpr, TryCastExpr}; use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::hash_aggregate::AggregateMode; use datafusion::physical_plan::hash_join::HashJoinExec; @@ -236,7 +236,7 @@ impl TryInto for Arc { schema: Some(exec.file_schema().as_ref().into()), has_header: exec.has_header(), delimiter: delimiter.to_string(), - batch_size: 32768, + batch_size: exec.batch_size() as u32, }, )), }) @@ -510,6 +510,15 @@ impl TryFrom> for protobuf::LogicalExprNode { }, ))), }) + } else if let Some(cast) = expr.downcast_ref::() { + Ok(protobuf::LogicalExprNode { + expr_type: Some(protobuf::logical_expr_node::ExprType::TryCast( + Box::new(protobuf::TryCastNode { + expr: Some(Box::new(cast.expr().clone().try_into()?)), + arrow_type: Some(cast.cast_type().into()), + }), + )), + }) } else if let Some(expr) = expr.downcast_ref::() { let fun: BuiltinScalarFunction = BuiltinScalarFunction::from_str(expr.name())?; diff --git a/rust/ballista/rust/executor/Cargo.toml b/rust/ballista/rust/executor/Cargo.toml index 743b62cc1001d..beed860fd9470 100644 --- a/rust/ballista/rust/executor/Cargo.toml +++ b/rust/ballista/rust/executor/Cargo.toml @@ -45,9 +45,13 @@ tokio-stream = "0.1" tonic = "0.4" uuid = { version = "0.8", features = ["v4"] } -arrow = { git = "https://github.com/apache/arrow", rev="46161d2" } -arrow-flight = { git = "https://github.com/apache/arrow", rev="46161d2" } -datafusion = { git = "https://github.com/apache/arrow", rev="46161d2" } +#arrow = { path = "../../../arrow" } +#arrow-flight = { path = "../../../arrow-flight" } +#datafusion = { path = "../../../datafusion" } + +arrow = { git = "https://github.com/apache/arrow", rev="fe83dca" } +arrow-flight = { git = "https://github.com/apache/arrow", rev="fe83dca" } +datafusion = { git = "https://github.com/apache/arrow", rev="fe83dca" } [dev-dependencies] diff --git a/rust/ballista/rust/scheduler/Cargo.toml b/rust/ballista/rust/scheduler/Cargo.toml index b0213d37bda13..57342dd633ec7 100644 --- a/rust/ballista/rust/scheduler/Cargo.toml +++ b/rust/ballista/rust/scheduler/Cargo.toml @@ -52,8 +52,11 @@ tonic = "0.4" tower = { version = "0.4" } warp = "0.3" -arrow = { git = "https://github.com/apache/arrow", rev="46161d2" } -datafusion = { git = "https://github.com/apache/arrow", rev="46161d2" } +#arrow = { path = "../../../arrow" } +#datafusion = { path = "../../../datafusion" } + +arrow = { git = "https://github.com/apache/arrow", rev="fe83dca" } +datafusion = { git = "https://github.com/apache/arrow", rev="fe83dca" } [dev-dependencies] ballista-core = { path = "../core" } diff --git a/rust/ballista/rust/scheduler/src/api/mod.rs b/rust/ballista/rust/scheduler/src/api/mod.rs index 29c5cb1af6711..9e14378564acd 100644 --- a/rust/ballista/rust/scheduler/src/api/mod.rs +++ b/rust/ballista/rust/scheduler/src/api/mod.rs @@ -30,11 +30,11 @@ pub type Error = Box; pub type HttpBody = dyn http_body::Body + 'static; impl http_body::Body for EitherBody - where - A: http_body::Body + Send + Unpin, - B: http_body::Body + Send + Unpin, - A::Error: Into, - B::Error: Into, +where + A: http_body::Body + Send + Unpin, + B: http_body::Body + Send + Unpin, + A::Error: Into, + B::Error: Into, { type Data = A::Data; type Error = Error; @@ -67,7 +67,9 @@ impl http_body::Body for EitherBody } } -fn map_option_err>(err: Option>) -> Option> { +fn map_option_err>( + err: Option>, +) -> Option> { err.map(|e| e.map_err(Into::into)) } diff --git a/rust/ballista/rust/scheduler/src/lib.rs b/rust/ballista/rust/scheduler/src/lib.rs index 6df6c9ac57cdc..1bd4722e5cb5f 100644 --- a/rust/ballista/rust/scheduler/src/lib.rs +++ b/rust/ballista/rust/scheduler/src/lib.rs @@ -201,12 +201,13 @@ impl SchedulerGrpc for SchedulerServer { match file_type { FileType::Parquet => { - let parquet_exec = ParquetExec::try_from_path(&path, None, None, 1024, 1) - .map_err(|e| { - let msg = format!("Error opening parquet files: {}", e); - error!("{}", msg); - tonic::Status::internal(msg) - })?; + let parquet_exec = + ParquetExec::try_from_path(&path, None, None, 1024, 1, None) + .map_err(|e| { + let msg = format!("Error opening parquet files: {}", e); + error!("{}", msg); + tonic::Status::internal(msg) + })?; //TODO include statistics and any other info needed to reconstruct ParquetExec Ok(Response::new(GetFileMetadataResult { diff --git a/rust/ballista/rust/scheduler/src/main.rs b/rust/ballista/rust/scheduler/src/main.rs index c166fdc388d58..6f746292f659e 100644 --- a/rust/ballista/rust/scheduler/src/main.rs +++ b/rust/ballista/rust/scheduler/src/main.rs @@ -29,12 +29,12 @@ use ballista_core::BALLISTA_VERSION; use ballista_core::{ print_version, serde::protobuf::scheduler_grpc_server::SchedulerGrpcServer, }; +use ballista_scheduler::api::{get_routes, EitherBody, Error}; #[cfg(feature = "etcd")] use ballista_scheduler::state::EtcdClient; #[cfg(feature = "sled")] use ballista_scheduler::state::StandaloneClient; use ballista_scheduler::{state::ConfigBackendClient, ConfigBackend, SchedulerServer}; -use ballista_scheduler::api::{get_routes, EitherBody, Error}; use log::info; @@ -63,8 +63,10 @@ async fn start_server( ); Ok(Server::bind(&addr) .serve(make_service_fn(move |_| { - let scheduler_server = SchedulerServer::new(config_backend.clone(), namespace.clone()); - let scheduler_grpc_server = SchedulerGrpcServer::new(scheduler_server.clone()); + let scheduler_server = + SchedulerServer::new(config_backend.clone(), namespace.clone()); + let scheduler_grpc_server = + SchedulerGrpcServer::new(scheduler_server.clone()); let mut tonic = TonicServer::builder() .add_service(scheduler_grpc_server) diff --git a/rust/ballista/rust/scheduler/src/planner.rs b/rust/ballista/rust/scheduler/src/planner.rs index f06dcfdfcec85..e9f668a7d5f84 100644 --- a/rust/ballista/rust/scheduler/src/planner.rs +++ b/rust/ballista/rust/scheduler/src/planner.rs @@ -34,7 +34,10 @@ use ballista_core::{ execution_plans::{QueryStageExec, ShuffleReaderExec, UnresolvedShuffleExec}, serde::scheduler::PartitionLocation, }; -use datafusion::execution::context::ExecutionContext; +use datafusion::execution::context::{ExecutionConfig, ExecutionContext}; +use datafusion::physical_optimizer::coalesce_batches::CoalesceBatches; +use datafusion::physical_optimizer::merge_exec::AddMergeExec; +use datafusion::physical_optimizer::optimizer::PhysicalOptimizerRule; use datafusion::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; use datafusion::physical_plan::hash_join::HashJoinExec; use datafusion::physical_plan::merge::MergeExec; @@ -136,7 +139,13 @@ impl DistributedPlanner { } if let Some(adapter) = execution_plan.as_any().downcast_ref::() { - let ctx = ExecutionContext::new(); + // remove Repartition rule because that isn't supported yet + let rules: Vec> = vec![ + Arc::new(CoalesceBatches::new()), + Arc::new(AddMergeExec::new()), + ]; + let config = ExecutionConfig::new().with_physical_optimizer_rules(rules); + let ctx = ExecutionContext::with_config(config); Ok((ctx.create_physical_plan(&adapter.logical_plan)?, stages)) } else if let Some(merge) = execution_plan.as_any().downcast_ref::() { let query_stage = create_query_stage( diff --git a/rust/ballista/rust/scheduler/src/test_utils.rs b/rust/ballista/rust/scheduler/src/test_utils.rs index 9439740477777..330cc9a9332cb 100644 --- a/rust/ballista/rust/scheduler/src/test_utils.rs +++ b/rust/ballista/rust/scheduler/src/test_utils.rs @@ -15,10 +15,15 @@ // specific language governing permissions and limitations // under the License. +use std::sync::Arc; + use ballista_core::error::Result; use arrow::datatypes::{DataType, Field, Schema}; -use datafusion::execution::context::ExecutionContext; +use datafusion::execution::context::{ExecutionConfig, ExecutionContext}; +use datafusion::physical_optimizer::coalesce_batches::CoalesceBatches; +use datafusion::physical_optimizer::merge_exec::AddMergeExec; +use datafusion::physical_optimizer::optimizer::PhysicalOptimizerRule; use datafusion::physical_plan::csv::CsvReadOptions; pub const TPCH_TABLES: &[&str] = &[ @@ -26,7 +31,14 @@ pub const TPCH_TABLES: &[&str] = &[ ]; pub fn datafusion_test_context(path: &str) -> Result { - let mut ctx = ExecutionContext::new(); + // remove Repartition rule because that isn't supported yet + let rules: Vec> = vec![ + Arc::new(CoalesceBatches::new()), + Arc::new(AddMergeExec::new()), + ]; + let config = ExecutionConfig::new().with_physical_optimizer_rules(rules); + let mut ctx = ExecutionContext::with_config(config); + for table in TPCH_TABLES { let schema = get_tpch_schema(table); let options = CsvReadOptions::new()