Skip to content

Commit

Permalink
ARROW-12335: [Rust] [Ballista] Use latest DataFusion
Browse files Browse the repository at this point in the history
Updates Ballista to use the most recent DataFusion version.

Changes made:

- Ballista overrides physical optimizer rules to remove `Repartition`
- Added serde support for new `TryCast` expression
- Updated DataFrame API usage to use `Vec<_>` instead of `&[_]`
- Renamed some timestamp scalar variants
- HashJoinExec updated to take new `CollectLeft` argument
- Removed hard-coded batch size from serde code for `CsvScanExec`

Closes apache#9991 from andygrove/ballista-bump-df-version

Authored-by: Andy Grove <andygrove73@gmail.com>
Signed-off-by: Krisztián Szűcs <szucs.krisztian@gmail.com>
  • Loading branch information
andygrove authored and michalursa committed Jun 10, 2021
1 parent 73f92ce commit 83c8efe
Show file tree
Hide file tree
Showing 21 changed files with 189 additions and 92 deletions.
18 changes: 18 additions & 0 deletions rust/ballista/.dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

rust/**/target
6 changes: 3 additions & 3 deletions rust/ballista/rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ members = [
"scheduler",
]

[profile.release]
lto = true
codegen-units = 1
#[profile.release]
#lto = true
#codegen-units = 1
10 changes: 7 additions & 3 deletions rust/ballista/rust/benchmarks/tpch/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,13 @@ edition = "2018"
[dependencies]
ballista = { path="../../client" }

arrow = { git = "https://github.com/apache/arrow", rev="46161d2" }
datafusion = { git = "https://github.com/apache/arrow", rev="46161d2" }
parquet = { git = "https://github.com/apache/arrow", rev="46161d2" }
#arrow = { path = "../../../../arrow" }
#datafusion = { path = "../../../../datafusion" }
#parquet = { path = "../../../../parquet" }

arrow = { git = "https://github.com/apache/arrow", rev="fe83dca" }
datafusion = { git = "https://github.com/apache/arrow", rev="fe83dca" }
parquet = { git = "https://github.com/apache/arrow", rev="fe83dca" }


env_logger = "0.8"
Expand Down
8 changes: 6 additions & 2 deletions rust/ballista/rust/client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,9 @@ ballista-core = { path = "../core" }
futures = "0.3"
log = "0.4"
tokio = "1.0"
arrow = { git = "https://github.com/apache/arrow", rev="46161d2" }
datafusion = { git = "https://github.com/apache/arrow", rev="46161d2" }

#arrow = { path = "../../../arrow" }
#datafusion = { path = "../../../datafusion" }

arrow = { git = "https://github.com/apache/arrow", rev="fe83dca" }
datafusion = { git = "https://github.com/apache/arrow", rev="fe83dca" }
14 changes: 9 additions & 5 deletions rust/ballista/rust/client/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use ballista_core::{
};

use arrow::datatypes::Schema;
use datafusion::catalog::TableReference;
use datafusion::execution::context::ExecutionContext;
use datafusion::logical_plan::{DFSchema, Expr, LogicalPlan, Partitioning};
use datafusion::physical_plan::csv::CsvReadOptions;
Expand Down Expand Up @@ -148,7 +149,10 @@ impl BallistaContext {
for (name, plan) in &state.tables {
let plan = ctx.optimize(plan)?;
let execution_plan = ctx.create_physical_plan(&plan)?;
ctx.register_table(name, Arc::new(DFTableAdapter::new(plan, execution_plan)));
ctx.register_table(
TableReference::Bare { table: name },
Arc::new(DFTableAdapter::new(plan, execution_plan)),
)?;
}
let df = ctx.sql(sql)?;
Ok(BallistaDataFrame::from(self.state.clone(), df))
Expand Down Expand Up @@ -267,7 +271,7 @@ impl BallistaDataFrame {
))
}

pub fn select(&self, expr: &[Expr]) -> Result<BallistaDataFrame> {
pub fn select(&self, expr: Vec<Expr>) -> Result<BallistaDataFrame> {
Ok(Self::from(
self.state.clone(),
self.df.select(expr).map_err(BallistaError::from)?,
Expand All @@ -283,8 +287,8 @@ impl BallistaDataFrame {

pub fn aggregate(
&self,
group_expr: &[Expr],
aggr_expr: &[Expr],
group_expr: Vec<Expr>,
aggr_expr: Vec<Expr>,
) -> Result<BallistaDataFrame> {
Ok(Self::from(
self.state.clone(),
Expand All @@ -301,7 +305,7 @@ impl BallistaDataFrame {
))
}

pub fn sort(&self, expr: &[Expr]) -> Result<BallistaDataFrame> {
pub fn sort(&self, expr: Vec<Expr>) -> Result<BallistaDataFrame> {
Ok(Self::from(
self.state.clone(),
self.df.sort(expr).map_err(BallistaError::from)?,
Expand Down
10 changes: 7 additions & 3 deletions rust/ballista/rust/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,14 @@ sqlparser = "0.8"
tokio = "1.0"
tonic = "0.4"
uuid = { version = "0.8", features = ["v4"] }
arrow = { git = "https://github.com/apache/arrow", rev="46161d2" }
arrow-flight = { git = "https://github.com/apache/arrow", rev="46161d2" }
datafusion = { git = "https://github.com/apache/arrow", rev="46161d2" }

#arrow = { path = "../../../arrow" }
#arrow-flight = { path = "../../../arrow-flight" }
#datafusion = { path = "../../../datafusion" }

arrow = { git = "https://github.com/apache/arrow", rev="fe83dca" }
arrow-flight = { git = "https://github.com/apache/arrow", rev="fe83dca" }
datafusion = { git = "https://github.com/apache/arrow", rev="fe83dca" }

[dev-dependencies]

Expand Down
6 changes: 6 additions & 0 deletions rust/ballista/rust/core/proto/ballista.proto
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ message LogicalExprNode {
InListNode in_list = 14;
bool wildcard = 15;
ScalarFunctionNode scalar_function = 16;
TryCastNode try_cast = 17;
}
}

Expand Down Expand Up @@ -172,6 +173,11 @@ message CastNode {
ArrowType arrow_type = 2;
}

message TryCastNode {
LogicalExprNode expr = 1;
ArrowType arrow_type = 2;
}

message SortExprNode {
LogicalExprNode expr = 1;
bool asc = 2;
Expand Down
1 change: 1 addition & 0 deletions rust/ballista/rust/core/src/datasource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ impl TableProvider for DFTableAdapter {
_projection: &Option<Vec<usize>>,
_batch_size: usize,
_filters: &[Expr],
_limit: Option<usize>,
) -> DFResult<Arc<dyn ExecutionPlan>> {
Ok(self.plan.clone())
}
Expand Down
46 changes: 27 additions & 19 deletions rust/ballista/rust/core/src/serde/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,13 @@ impl TryInto<LogicalPlan> for &protobuf::LogicalPlanNode {
match plan {
LogicalPlanType::Projection(projection) => {
let input: LogicalPlan = convert_box_required!(projection.input)?;
let x: Vec<Expr> = projection
.expr
.iter()
.map(|expr| expr.try_into())
.collect::<Result<Vec<_>, _>>()?;
LogicalPlanBuilder::from(&input)
.project(
&projection
.expr
.iter()
.map(|expr| expr.try_into())
.collect::<Result<Vec<_>, _>>()?,
)?
.project(x)?
.build()
.map_err(|e| e.into())
}
Expand Down Expand Up @@ -89,7 +88,7 @@ impl TryInto<LogicalPlan> for &protobuf::LogicalPlanNode {
.map(|expr| expr.try_into())
.collect::<Result<Vec<_>, _>>()?;
LogicalPlanBuilder::from(&input)
.aggregate(&group_expr, &aggr_expr)?
.aggregate(group_expr, aggr_expr)?
.build()
.map_err(|e| e.into())
}
Expand Down Expand Up @@ -148,7 +147,7 @@ impl TryInto<LogicalPlan> for &protobuf::LogicalPlanNode {
.map(|expr| expr.try_into())
.collect::<Result<Vec<Expr>, _>>()?;
LogicalPlanBuilder::from(&input)
.sort(&sort_expr)?
.sort(sort_expr)?
.build()
.map_err(|e| e.into())
}
Expand Down Expand Up @@ -511,10 +510,10 @@ fn typechecked_scalar_value_conversion(
ScalarValue::Date32(Some(*v))
}
(Value::TimeMicrosecondValue(v), PrimitiveScalarType::TimeMicrosecond) => {
ScalarValue::TimeMicrosecond(Some(*v))
ScalarValue::TimestampMicrosecond(Some(*v))
}
(Value::TimeNanosecondValue(v), PrimitiveScalarType::TimeMicrosecond) => {
ScalarValue::TimeNanosecond(Some(*v))
ScalarValue::TimestampNanosecond(Some(*v))
}
(Value::Utf8Value(v), PrimitiveScalarType::Utf8) => {
ScalarValue::Utf8(Some(v.to_owned()))
Expand Down Expand Up @@ -547,10 +546,10 @@ fn typechecked_scalar_value_conversion(
PrimitiveScalarType::LargeUtf8 => ScalarValue::LargeUtf8(None),
PrimitiveScalarType::Date32 => ScalarValue::Date32(None),
PrimitiveScalarType::TimeMicrosecond => {
ScalarValue::TimeMicrosecond(None)
ScalarValue::TimestampMicrosecond(None)
}
PrimitiveScalarType::TimeNanosecond => {
ScalarValue::TimeNanosecond(None)
ScalarValue::TimestampNanosecond(None)
}
PrimitiveScalarType::Null => {
return Err(proto_error(
Expand Down Expand Up @@ -610,10 +609,10 @@ impl TryInto<datafusion::scalar::ScalarValue> for &protobuf::scalar_value::Value
ScalarValue::Date32(Some(*v))
}
protobuf::scalar_value::Value::TimeMicrosecondValue(v) => {
ScalarValue::TimeMicrosecond(Some(*v))
ScalarValue::TimestampMicrosecond(Some(*v))
}
protobuf::scalar_value::Value::TimeNanosecondValue(v) => {
ScalarValue::TimeNanosecond(Some(*v))
ScalarValue::TimestampNanosecond(Some(*v))
}
protobuf::scalar_value::Value::ListValue(v) => v.try_into()?,
protobuf::scalar_value::Value::NullListValue(v) => {
Expand Down Expand Up @@ -776,10 +775,10 @@ impl TryInto<datafusion::scalar::ScalarValue> for protobuf::PrimitiveScalarType
protobuf::PrimitiveScalarType::LargeUtf8 => ScalarValue::LargeUtf8(None),
protobuf::PrimitiveScalarType::Date32 => ScalarValue::Date32(None),
protobuf::PrimitiveScalarType::TimeMicrosecond => {
ScalarValue::TimeMicrosecond(None)
ScalarValue::TimestampMicrosecond(None)
}
protobuf::PrimitiveScalarType::TimeNanosecond => {
ScalarValue::TimeNanosecond(None)
ScalarValue::TimestampNanosecond(None)
}
})
}
Expand Down Expand Up @@ -829,10 +828,10 @@ impl TryInto<datafusion::scalar::ScalarValue> for &protobuf::ScalarValue {
ScalarValue::Date32(Some(*v))
}
protobuf::scalar_value::Value::TimeMicrosecondValue(v) => {
ScalarValue::TimeMicrosecond(Some(*v))
ScalarValue::TimestampMicrosecond(Some(*v))
}
protobuf::scalar_value::Value::TimeNanosecondValue(v) => {
ScalarValue::TimeNanosecond(Some(*v))
ScalarValue::TimestampNanosecond(Some(*v))
}
protobuf::scalar_value::Value::ListValue(scalar_list) => {
let protobuf::ScalarListValue {
Expand Down Expand Up @@ -962,6 +961,15 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
let data_type = arrow_type.try_into()?;
Ok(Expr::Cast { expr, data_type })
}
ExprType::TryCast(cast) => {
let expr = Box::new(parse_required_expr(&cast.expr)?);
let arrow_type: &protobuf::ArrowType = cast
.arrow_type
.as_ref()
.ok_or_else(|| proto_error("Protobuf deserialization error: CastNode message missing required field 'arrow_type'"))?;
let data_type = arrow_type.try_into()?;
Ok(Expr::TryCast { expr, data_type })
}
ExprType::Sort(sort) => Ok(Expr::Sort {
expr: Box::new(parse_required_expr(&sort.expr)?),
asc: sort.asc,
Expand Down
28 changes: 14 additions & 14 deletions rust/ballista/rust/core/src/serde/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ mod roundtrip_tests {
CsvReadOptions::new().schema(&schema).has_header(true),
Some(vec![3, 4]),
)
.and_then(|plan| plan.sort(&[col("salary")]))
.and_then(|plan| plan.sort(vec![col("salary")]))
.and_then(|plan| plan.build())
.map_err(BallistaError::DataFusionError)?,
);
Expand Down Expand Up @@ -212,8 +212,8 @@ mod roundtrip_tests {
ScalarValue::LargeUtf8(None),
ScalarValue::List(None, DataType::Boolean),
ScalarValue::Date32(None),
ScalarValue::TimeMicrosecond(None),
ScalarValue::TimeNanosecond(None),
ScalarValue::TimestampMicrosecond(None),
ScalarValue::TimestampNanosecond(None),
ScalarValue::Boolean(Some(true)),
ScalarValue::Boolean(Some(false)),
ScalarValue::Float32(Some(1.0)),
Expand Down Expand Up @@ -252,11 +252,11 @@ mod roundtrip_tests {
ScalarValue::LargeUtf8(Some(String::from("Test Large utf8"))),
ScalarValue::Date32(Some(0)),
ScalarValue::Date32(Some(i32::MAX)),
ScalarValue::TimeNanosecond(Some(0)),
ScalarValue::TimeNanosecond(Some(i64::MAX)),
ScalarValue::TimeMicrosecond(Some(0)),
ScalarValue::TimeMicrosecond(Some(i64::MAX)),
ScalarValue::TimeMicrosecond(None),
ScalarValue::TimestampNanosecond(Some(0)),
ScalarValue::TimestampNanosecond(Some(i64::MAX)),
ScalarValue::TimestampMicrosecond(Some(0)),
ScalarValue::TimestampMicrosecond(Some(i64::MAX)),
ScalarValue::TimestampMicrosecond(None),
ScalarValue::List(
Some(vec![
ScalarValue::Float32(Some(-213.1)),
Expand Down Expand Up @@ -610,8 +610,8 @@ mod roundtrip_tests {
ScalarValue::Utf8(None),
ScalarValue::LargeUtf8(None),
ScalarValue::Date32(None),
ScalarValue::TimeMicrosecond(None),
ScalarValue::TimeNanosecond(None),
ScalarValue::TimestampMicrosecond(None),
ScalarValue::TimestampNanosecond(None),
//ScalarValue::List(None, DataType::Boolean)
];

Expand Down Expand Up @@ -679,7 +679,7 @@ mod roundtrip_tests {
CsvReadOptions::new().schema(&schema).has_header(true),
Some(vec![3, 4]),
)
.and_then(|plan| plan.sort(&[col("salary")]))
.and_then(|plan| plan.sort(vec![col("salary")]))
.and_then(|plan| plan.explain(true))
.and_then(|plan| plan.build())
.map_err(BallistaError::DataFusionError)?;
Expand All @@ -689,7 +689,7 @@ mod roundtrip_tests {
CsvReadOptions::new().schema(&schema).has_header(true),
Some(vec![3, 4]),
)
.and_then(|plan| plan.sort(&[col("salary")]))
.and_then(|plan| plan.sort(vec![col("salary")]))
.and_then(|plan| plan.explain(false))
.and_then(|plan| plan.build())
.map_err(BallistaError::DataFusionError)?;
Expand Down Expand Up @@ -742,7 +742,7 @@ mod roundtrip_tests {
CsvReadOptions::new().schema(&schema).has_header(true),
Some(vec![3, 4]),
)
.and_then(|plan| plan.sort(&[col("salary")]))
.and_then(|plan| plan.sort(vec![col("salary")]))
.and_then(|plan| plan.build())
.map_err(BallistaError::DataFusionError)?;
roundtrip_test!(plan);
Expand Down Expand Up @@ -784,7 +784,7 @@ mod roundtrip_tests {
CsvReadOptions::new().schema(&schema).has_header(true),
Some(vec![3, 4]),
)
.and_then(|plan| plan.aggregate(&[col("state")], &[max(col("salary"))]))
.and_then(|plan| plan.aggregate(vec![col("state")], vec![max(col("salary"))]))
.and_then(|plan| plan.build())
.map_err(BallistaError::DataFusionError)?;

Expand Down
14 changes: 4 additions & 10 deletions rust/ballista/rust/core/src/serde/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -641,12 +641,12 @@ impl TryFrom<&datafusion::scalar::ScalarValue> for protobuf::ScalarValue {
datafusion::scalar::ScalarValue::Date32(val) => {
create_proto_scalar(val, PrimitiveScalarType::Date32, |s| Value::Date32Value(*s))
}
datafusion::scalar::ScalarValue::TimeMicrosecond(val) => {
datafusion::scalar::ScalarValue::TimestampMicrosecond(val) => {
create_proto_scalar(val, PrimitiveScalarType::TimeMicrosecond, |s| {
Value::TimeMicrosecondValue(*s)
})
}
datafusion::scalar::ScalarValue::TimeNanosecond(val) => {
datafusion::scalar::ScalarValue::TimestampNanosecond(val) => {
create_proto_scalar(val, PrimitiveScalarType::TimeNanosecond, |s| {
Value::TimeNanosecondValue(*s)
})
Expand Down Expand Up @@ -939,10 +939,7 @@ impl TryInto<protobuf::LogicalPlanNode> for &LogicalPlan {
})
}
LogicalPlan::Extension { .. } => unimplemented!(),
// _ => Err(BallistaError::General(format!(
// "logical plan to_proto {:?}",
// self
// ))),
LogicalPlan::Union { .. } => unimplemented!(),
}
}
}
Expand Down Expand Up @@ -1161,10 +1158,7 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
Expr::Wildcard => Ok(protobuf::LogicalExprNode {
expr_type: Some(protobuf::logical_expr_node::ExprType::Wildcard(true)),
}),
// _ => Err(BallistaError::General(format!(
// "logical expr to_proto {:?}",
// self
// ))),
Expr::TryCast { .. } => unimplemented!(),
}
}
}
Expand Down
Loading

0 comments on commit 83c8efe

Please sign in to comment.