-
Notifications
You must be signed in to change notification settings - Fork 1.7k
feat(spark): implement Spark conditional function if #16946
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
198430e
839f20f
043fd35
8d21457
2a09faa
84f7e21
2010cbe
6754098
92d42cf
d1b4ded
0a670db
b470389
dfc02f1
7c4833e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
// Licensed to the Apache Software Foundation (ASF) under one | ||
// or more contributor license agreements. See the NOTICE file | ||
// distributed with this work for additional information | ||
// regarding copyright ownership. The ASF licenses this file | ||
// to you under the Apache License, Version 2.0 (the | ||
// "License"); you may not use this file except in compliance | ||
// with the License. You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, | ||
// software distributed under the License is distributed on an | ||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
// KIND, either express or implied. See the License for the | ||
// specific language governing permissions and limitations | ||
// under the License. | ||
|
||
use arrow::datatypes::DataType; | ||
use datafusion_common::{internal_err, plan_err, Result}; | ||
use datafusion_expr::{ | ||
binary::try_type_union_resolution, simplify::ExprSimplifyResult, when, ColumnarValue, | ||
Expr, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, | ||
}; | ||
|
||
#[derive(Debug, PartialEq, Eq, Hash)] | ||
pub struct SparkIf { | ||
signature: Signature, | ||
} | ||
|
||
impl Default for SparkIf { | ||
fn default() -> Self { | ||
Self::new() | ||
} | ||
} | ||
|
||
impl SparkIf { | ||
pub fn new() -> Self { | ||
Self { | ||
signature: Signature::user_defined(Volatility::Immutable), | ||
} | ||
} | ||
} | ||
|
||
impl ScalarUDFImpl for SparkIf { | ||
fn as_any(&self) -> &dyn std::any::Any { | ||
self | ||
} | ||
|
||
fn name(&self) -> &str { | ||
"if" | ||
} | ||
|
||
fn signature(&self) -> &Signature { | ||
&self.signature | ||
} | ||
|
||
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> { | ||
if arg_types.len() != 3 { | ||
return plan_err!( | ||
"Function 'if' expects 3 arguments but received {}", | ||
arg_types.len() | ||
); | ||
} | ||
|
||
if arg_types[0] != DataType::Boolean && arg_types[0] != DataType::Null { | ||
return plan_err!( | ||
"For function 'if' {} is not a boolean or null", | ||
arg_types[0] | ||
); | ||
} | ||
|
||
let target_types = try_type_union_resolution(&arg_types[1..])?; | ||
let mut result = vec![DataType::Boolean]; | ||
result.extend(target_types); | ||
Ok(result) | ||
} | ||
|
||
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> { | ||
Ok(arg_types[1].clone()) | ||
} | ||
|
||
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> { | ||
internal_err!("if should have been simplified to case") | ||
} | ||
|
||
fn simplify( | ||
&self, | ||
args: Vec<Expr>, | ||
_info: &dyn datafusion_expr::simplify::SimplifyInfo, | ||
) -> Result<ExprSimplifyResult> { | ||
let condition = args[0].clone(); | ||
let then_expr = args[1].clone(); | ||
let else_expr = args[2].clone(); | ||
|
||
// Convert IF(condition, then_expr, else_expr) to | ||
// CASE WHEN condition THEN then_expr ELSE else_expr END | ||
let case_expr = when(condition, then_expr).otherwise(else_expr)?; | ||
|
||
Ok(ExprSimplifyResult::Simplified(case_expr)) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,146 @@ | |
# For more information, please see: | ||
# https://github.com/apache/datafusion/issues/15914 | ||
|
||
## Original Query: SELECT if(1 < 2, 'a', 'b'); | ||
## PySpark 3.5.5 Result: {'(IF((1 < 2), a, b))': 'a', 'typeof((IF((1 < 2), a, b)))': 'string', 'typeof((1 < 2))': 'boolean', 'typeof(a)': 'string', 'typeof(b)': 'string'} | ||
#query | ||
#SELECT if((1 < 2)::boolean, 'a'::string, 'b'::string); | ||
## Basic IF function tests | ||
|
||
# Test basic true condition | ||
query T | ||
SELECT if(true, 'yes', 'no'); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add tests that use more complex expressions, such as referencing columns in an input file? Also, could you add tests that demonstrate that the 2nd expression is only evaluated when the 1st argument evaluates to true? There should be some tests in By the way, in Comet, we just translate an impl IfExpr {
/// Create a new IF expression
pub fn new(
if_expr: Arc<dyn PhysicalExpr>,
true_expr: Arc<dyn PhysicalExpr>,
false_expr: Arc<dyn PhysicalExpr>,
) -> Self {
Self {
if_expr: Arc::clone(&if_expr),
true_expr: Arc::clone(&true_expr),
false_expr: Arc::clone(&false_expr),
case_expr: Arc::new(
CaseExpr::try_new(None, vec![(if_expr, true_expr)], Some(false_expr)).unwrap(),
),
}
}
} There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm sorry, I wasn't clear on my comment above. It could be that the 2nd or 3rd argument expressions could fail if evaluated on certain rows, and we would expect
if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here is a test that demonstrates the issue:
This fails with The version of this test in
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please also check if when(...).otherwise(
and got this:
but this:
works OK, maybe the problem with example above is only about eager constant evaluation in else_expr, not every expression Not sure if eager evaluation of else expression is OK in datafusion, but in spark it's definitely NOT OK, and there is a doctest checking this: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I created another PR #17311 |
||
---- | ||
yes | ||
|
||
# Test basic false condition | ||
query T | ||
SELECT if(false, 'yes', 'no'); | ||
---- | ||
no | ||
|
||
# Test with comparison operators | ||
query T | ||
SELECT if(1 < 2, 'a', 'b'); | ||
---- | ||
a | ||
|
||
query T | ||
SELECT if(1 > 2, 'a', 'b'); | ||
---- | ||
b | ||
|
||
|
||
## Numeric type tests | ||
|
||
# Test with integers | ||
query I | ||
SELECT if(true, 10, 20); | ||
---- | ||
10 | ||
|
||
query I | ||
SELECT if(false, 10, 20); | ||
---- | ||
20 | ||
|
||
# Test with different integer types | ||
query I | ||
SELECT if(true, 100, 200); | ||
---- | ||
100 | ||
|
||
## Float type tests | ||
|
||
# Test with floating point numbers | ||
query R | ||
SELECT if(true, 1.5, 2.5); | ||
---- | ||
1.5 | ||
|
||
query R | ||
SELECT if(false, 1.5, 2.5); | ||
---- | ||
2.5 | ||
|
||
## String type tests | ||
|
||
# Test with different string values | ||
query T | ||
SELECT if(true, 'hello', 'world'); | ||
---- | ||
hello | ||
|
||
query T | ||
SELECT if(false, 'hello', 'world'); | ||
---- | ||
world | ||
|
||
## NULL handling tests | ||
|
||
# Test with NULL condition | ||
query T | ||
SELECT if(NULL, 'yes', 'no'); | ||
---- | ||
no | ||
|
||
query T | ||
SELECT if(NOT NULL, 'yes', 'no'); | ||
---- | ||
no | ||
|
||
# Test with NULL true value | ||
query T | ||
SELECT if(true, NULL, 'no'); | ||
---- | ||
NULL | ||
|
||
# Test with NULL false value | ||
query T | ||
SELECT if(false, 'yes', NULL); | ||
---- | ||
NULL | ||
|
||
# Test with all NULL | ||
query ? | ||
SELECT if(true, NULL, NULL); | ||
---- | ||
NULL | ||
|
||
## Type coercion tests | ||
|
||
# Test integer to float coercion | ||
query R | ||
SELECT if(true, 10, 20.5); | ||
---- | ||
10 | ||
|
||
query R | ||
SELECT if(false, 10, 20.5); | ||
---- | ||
20.5 | ||
|
||
# Test float to integer coercion | ||
query R | ||
SELECT if(true, 10.5, 20); | ||
---- | ||
10.5 | ||
|
||
query R | ||
SELECT if(false, 10.5, 20); | ||
---- | ||
20 | ||
|
||
statement error Int64 is not a boolean or null | ||
SELECT if(1, 10.5, 20); | ||
|
||
|
||
statement error Utf8 is not a boolean or null | ||
SELECT if('x', 10.5, 20); | ||
|
||
query II | ||
SELECT v, IF(v < 0, 10/0, 1) FROM (VALUES (1), (2)) t(v) | ||
---- | ||
1 1 | ||
2 1 | ||
|
||
query I | ||
SELECT IF(true, 1 / 1, 1 / 0); | ||
---- | ||
1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to modify the second and third args? We can probably just do:
Also, I can't recall if
coerce_types
is actually called whensimplify
is used. Did you happen to test it out by any chance?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
query failed: DataFusion error: Optimizer rule 'simplify_expressions' failed
caused by
Error during planning: CASE expression 'then' values had multiple data types: {Float64, Int64}