Skip to content

Commit

Permalink
move more physical expr (#1892)
Browse files Browse the repository at this point in the history
  • Loading branch information
jimexist authored Feb 28, 2022
1 parent f6bbb62 commit eb5a18a
Show file tree
Hide file tree
Showing 28 changed files with 1,288 additions and 1,230 deletions.
14 changes: 14 additions & 0 deletions datafusion-physical-expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ name = "datafusion_physical_expr"
path = "src/lib.rs"

[features]
default = ["crypto_expressions", "regex_expressions", "unicode_expressions"]
crypto_expressions = ["md-5", "sha2", "blake2", "blake3"]
regex_expressions = ["regex"]
unicode_expressions = ["unicode-segmentation"]

[dependencies]
datafusion-common = { path = "../datafusion-common", version = "7.0.0" }
Expand All @@ -41,3 +45,13 @@ arrow = { version = "9.0.0", features = ["prettyprint"] }
paste = "^1.0"
ahash = { version = "0.7", default-features = false }
ordered-float = "2.10"
lazy_static = { version = "^1.4.0" }
md-5 = { version = "^0.10.0", optional = true }
sha2 = { version = "^0.10.1", optional = true }
blake2 = { version = "^0.10.2", optional = true }
blake3 = { version = "1.0", optional = true }
rand = "0.8"
hashbrown = { version = "0.12", features = ["raw"] }
chrono = { version = "0.4", default-features = false }
regex = { version = "^1.4.3", optional = true }
unicode-segmentation = { version = "^1.7.1", optional = true }
4 changes: 1 addition & 3 deletions datafusion-physical-expr/src/aggregate_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@
// under the License.

use crate::PhysicalExpr;

use arrow::datatypes::Field;
use datafusion_common::Result;
use datafusion_expr::Accumulator;
use std::fmt::Debug;

use std::any::Any;
use std::fmt::Debug;
use std::sync::Arc;

/// An aggregate expression that:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@

//! Array expressions

use crate::error::{DataFusionError, Result};
use arrow::array::*;
use arrow::datatypes::DataType;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::ColumnarValue;
use std::sync::Arc;

use super::ColumnarValue;

macro_rules! downcast_vec {
($ARGS:expr, $ARRAY_TYPE:ident) => {{
$ARGS
Expand Down
262 changes: 262 additions & 0 deletions datafusion-physical-expr/src/coercion_rule/aggregate_rule.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Support the coercion rule for aggregate function.

use crate::expressions::{
is_approx_percentile_cont_supported_arg_type, is_avg_support_arg_type,
is_correlation_support_arg_type, is_covariance_support_arg_type,
is_stddev_support_arg_type, is_sum_support_arg_type, is_variance_support_arg_type,
try_cast,
};
use crate::PhysicalExpr;
use arrow::datatypes::DataType;
use arrow::datatypes::Schema;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::AggregateFunction;
use datafusion_expr::{Signature, TypeSignature};
use std::ops::Deref;
use std::sync::Arc;

/// Returns the coerced data type for each `input_types`.
/// Different aggregate function with different input data type will get corresponding coerced data type.
pub fn coerce_types(
agg_fun: &AggregateFunction,
input_types: &[DataType],
signature: &Signature,
) -> Result<Vec<DataType>> {
// Validate input_types matches (at least one of) the func signature.
check_arg_count(agg_fun, input_types, &signature.type_signature)?;

match agg_fun {
AggregateFunction::Count | AggregateFunction::ApproxDistinct => {
Ok(input_types.to_vec())
}
AggregateFunction::ArrayAgg => Ok(input_types.to_vec()),
AggregateFunction::Min | AggregateFunction::Max => {
// min and max support the dictionary data type
// unpack the dictionary to get the value
get_min_max_result_type(input_types)
}
AggregateFunction::Sum => {
// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
// smallint, int, bigint, real, double precision, decimal, or interval.
if !is_sum_support_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
Ok(input_types.to_vec())
}
AggregateFunction::Avg => {
// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
// smallint, int, bigint, real, double precision, decimal, or interval
if !is_avg_support_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
Ok(input_types.to_vec())
}
AggregateFunction::Variance => {
if !is_variance_support_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
Ok(input_types.to_vec())
}
AggregateFunction::VariancePop => {
if !is_variance_support_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
Ok(input_types.to_vec())
}
AggregateFunction::Covariance => {
if !is_covariance_support_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
Ok(input_types.to_vec())
}
AggregateFunction::CovariancePop => {
if !is_covariance_support_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
Ok(input_types.to_vec())
}
AggregateFunction::Stddev => {
if !is_stddev_support_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
Ok(input_types.to_vec())
}
AggregateFunction::StddevPop => {
if !is_stddev_support_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
Ok(input_types.to_vec())
}
AggregateFunction::Correlation => {
if !is_correlation_support_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
Ok(input_types.to_vec())
}
AggregateFunction::ApproxPercentileCont => {
if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
if !matches!(input_types[1], DataType::Float64) {
return Err(DataFusionError::Plan(format!(
"The percentile argument for {:?} must be Float64, not {:?}.",
agg_fun, input_types[1]
)));
}
Ok(input_types.to_vec())
}
AggregateFunction::ApproxMedian => {
if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
Ok(input_types.to_vec())
}
}
}

/// Validate the length of `input_types` matches the `signature` for `agg_fun`.
///
/// This method DOES NOT validate the argument types - only that (at least one,
/// in the case of [`TypeSignature::OneOf`]) signature matches the desired
/// number of input types.
fn check_arg_count(
agg_fun: &AggregateFunction,
input_types: &[DataType],
signature: &TypeSignature,
) -> Result<()> {
match signature {
TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => {
if input_types.len() != *agg_count {
return Err(DataFusionError::Plan(format!(
"The function {:?} expects {:?} arguments, but {:?} were provided",
agg_fun,
agg_count,
input_types.len()
)));
}
}
TypeSignature::Exact(types) => {
if types.len() != input_types.len() {
return Err(DataFusionError::Plan(format!(
"The function {:?} expects {:?} arguments, but {:?} were provided",
agg_fun,
types.len(),
input_types.len()
)));
}
}
TypeSignature::OneOf(variants) => {
let ok = variants
.iter()
.any(|v| check_arg_count(agg_fun, input_types, v).is_ok());
if !ok {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not accept {:?} function arguments.",
agg_fun,
input_types.len()
)));
}
}
_ => {
return Err(DataFusionError::Internal(format!(
"Aggregate functions do not support this {:?}",
signature
)));
}
}
Ok(())
}

fn get_min_max_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
// make sure that the input types only has one element.
assert_eq!(input_types.len(), 1);
// min and max support the dictionary data type
// unpack the dictionary to get the value
match &input_types[0] {
DataType::Dictionary(_, dict_value_type) => {
// TODO add checker, if the value type is complex data type
Ok(vec![dict_value_type.deref().clone()])
}
// TODO add checker for datatype which min and max supported
// For example, the `Struct` and `Map` type are not supported in the MIN and MAX function
_ => Ok(input_types.to_vec()),
}
}

/// Returns the coerced exprs for each `input_exprs`.
/// Get the coerced data type from `aggregate_rule::coerce_types` and add `try_cast` if the
/// data type of `input_exprs` need to be coerced.
pub fn coerce_exprs(
agg_fun: &AggregateFunction,
input_exprs: &[Arc<dyn PhysicalExpr>],
schema: &Schema,
signature: &Signature,
) -> Result<Vec<Arc<dyn PhysicalExpr>>> {
if input_exprs.is_empty() {
return Ok(vec![]);
}
let input_types = input_exprs
.iter()
.map(|e| e.data_type(schema))
.collect::<Result<Vec<_>>>()?;

// get the coerced data types
let coerced_types = coerce_types(agg_fun, &input_types, signature)?;

// try cast if need
input_exprs
.iter()
.zip(coerced_types.into_iter())
.map(|(expr, coerced_type)| try_cast(expr.clone(), schema, coerced_type))
.collect::<Result<Vec<_>>>()
}
1 change: 1 addition & 0 deletions datafusion-physical-expr/src/coercion_rule/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@
//! Aggregate function rule
//! Binary operation rule

pub mod aggregate_rule;
pub mod binary_rule;
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@
// under the License.

//! Crypto expressions
use super::ColumnarValue;
use crate::{
error::{DataFusionError, Result},
scalar::ScalarValue,
};

use arrow::{
array::{
Array, ArrayRef, BinaryArray, GenericStringArray, StringArray,
Expand All @@ -30,6 +26,9 @@ use arrow::{
};
use blake2::{Blake2b512, Blake2s256, Digest};
use blake3::Hasher as Blake3;
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::ColumnarValue;
use md5::Md5;
use sha2::{Sha224, Sha256, Sha384, Sha512};
use std::any::type_name;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,7 @@
// under the License.

//! DateTime expressions
use std::sync::Arc;

use super::ColumnarValue;
use crate::{
error::{DataFusionError, Result},
scalar::{ScalarType, ScalarValue},
};
use arrow::{
array::{Array, ArrayRef, GenericStringArray, PrimitiveArray, StringOffsetSizeTrait},
compute::kernels::cast_utils::string_to_timestamp_nanos,
Expand All @@ -42,7 +36,11 @@ use arrow::{
};
use chrono::prelude::*;
use chrono::Duration;
use datafusion_common::{DataFusionError, Result};
use datafusion_common::{ScalarType, ScalarValue};
use datafusion_expr::ColumnarValue;
use std::borrow::Borrow;
use std::sync::Arc;

/// given a function `op` that maps a `&str` to a Result of an arrow native type,
/// returns a `PrimitiveArray` after the application
Expand Down
2 changes: 1 addition & 1 deletion datafusion-physical-expr/src/expressions/cume_dist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
//! Defines physical expression for `cume_dist` that can evaluated
//! at runtime during query execution

use crate::window::partition_evaluator::PartitionEvaluator;
use crate::window::BuiltInWindowFunctionExpr;
use crate::window::PartitionEvaluator;
use crate::PhysicalExpr;
use arrow::array::ArrayRef;
use arrow::array::Float64Array;
Expand Down
2 changes: 1 addition & 1 deletion datafusion-physical-expr/src/expressions/lead_lag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
//! Defines physical expression for `lead` and `lag` that can evaluated
//! at runtime during query execution

use crate::window::partition_evaluator::PartitionEvaluator;
use crate::window::BuiltInWindowFunctionExpr;
use crate::window::PartitionEvaluator;
use crate::PhysicalExpr;
use arrow::array::ArrayRef;
use arrow::compute::cast;
Expand Down
Loading

0 comments on commit eb5a18a

Please sign in to comment.