Skip to content
Merged
101 changes: 101 additions & 0 deletions datafusion/spark/src/function/conditional/if.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::datatypes::DataType;
use datafusion_common::{internal_err, plan_err, Result};
use datafusion_expr::{
binary::try_type_union_resolution, simplify::ExprSimplifyResult, when, ColumnarValue,
Expr, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
};

#[derive(Debug, PartialEq, Eq, Hash)]
pub struct SparkIf {
signature: Signature,
}

impl Default for SparkIf {
fn default() -> Self {
Self::new()
}
}

impl SparkIf {
pub fn new() -> Self {
Self {
signature: Signature::user_defined(Volatility::Immutable),
}
}
}

impl ScalarUDFImpl for SparkIf {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"if"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 3 {
return plan_err!(
"Function 'if' expects 3 arguments but received {}",
arg_types.len()
);
}

if arg_types[0] != DataType::Boolean && arg_types[0] != DataType::Null {
return plan_err!(
"For function 'if' {} is not a boolean or null",
arg_types[0]
);
}

let target_types = try_type_union_resolution(&arg_types[1..])?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to modify the second and third args? We can probably just do:

let result = vec![DataType::Boolean, arg_types[1], arg_types[2]];

Also, I can't recall if coerce_types is actually called when simplify is used. Did you happen to test it out by any chance?

Copy link
Contributor Author

@chenkovsky chenkovsky Aug 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

query failed: DataFusion error: Optimizer rule 'simplify_expressions' failed
caused by
Error during planning: CASE expression 'then' values had multiple data types: {Float64, Int64}

let mut result = vec![DataType::Boolean];
result.extend(target_types);
Ok(result)
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
Ok(arg_types[1].clone())
}

fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
internal_err!("if should have been simplified to case")
}

fn simplify(
&self,
args: Vec<Expr>,
_info: &dyn datafusion_expr::simplify::SimplifyInfo,
) -> Result<ExprSimplifyResult> {
let condition = args[0].clone();
let then_expr = args[1].clone();
let else_expr = args[2].clone();

// Convert IF(condition, then_expr, else_expr) to
// CASE WHEN condition THEN then_expr ELSE else_expr END
let case_expr = when(condition, then_expr).otherwise(else_expr)?;

Ok(ExprSimplifyResult::Simplified(case_expr))
}
}
13 changes: 11 additions & 2 deletions datafusion/spark/src/function/conditional/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,19 @@
// under the License.

use datafusion_expr::ScalarUDF;
use datafusion_functions::make_udf_function;
use std::sync::Arc;

pub mod expr_fn {}
mod r#if;

make_udf_function!(r#if::SparkIf, r#if);

pub mod expr_fn {
use datafusion_functions::export_functions;

export_functions!((r#if, "If arg1 evaluates to true, then returns arg2; otherwise returns arg3", arg1 arg2 arg3));
}

pub fn functions() -> Vec<Arc<ScalarUDF>> {
vec![]
vec![r#if()]
}
147 changes: 143 additions & 4 deletions datafusion/sqllogictest/test_files/spark/conditional/if.slt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,146 @@
# For more information, please see:
# https://github.com/apache/datafusion/issues/15914

## Original Query: SELECT if(1 < 2, 'a', 'b');
## PySpark 3.5.5 Result: {'(IF((1 < 2), a, b))': 'a', 'typeof((IF((1 < 2), a, b)))': 'string', 'typeof((1 < 2))': 'boolean', 'typeof(a)': 'string', 'typeof(b)': 'string'}
#query
#SELECT if((1 < 2)::boolean, 'a'::string, 'b'::string);
## Basic IF function tests

# Test basic true condition
query T
SELECT if(true, 'yes', 'no');
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add tests that use more complex expressions, such as referencing columns in an input file?

Also, could you add tests that demonstrate that the 2nd expression is only evaluated when the 1st argument evaluates to true? There should be some tests in case.slt that can be repurposed.

By the way, in Comet, we just translate an if condition to a case expression.

impl IfExpr {
    /// Create a new IF expression
    pub fn new(
        if_expr: Arc<dyn PhysicalExpr>,
        true_expr: Arc<dyn PhysicalExpr>,
        false_expr: Arc<dyn PhysicalExpr>,
    ) -> Self {
        Self {
            if_expr: Arc::clone(&if_expr),
            true_expr: Arc::clone(&true_expr),
            false_expr: Arc::clone(&false_expr),
            case_expr: Arc::new(
                CaseExpr::try_new(None, vec![(if_expr, true_expr)], Some(false_expr)).unwrap(),
            ),
        }
    }
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sorry, I wasn't clear on my comment above. It could be that the 2nd or 3rd argument expressions could fail if evaluated on certain rows, and we would expect if to provide conditional evaluation. For example:

select if(a==0, 0, b/a) from tbl

if a==0 then we want to avoid evaluating b/a because it would cause a divide by zero error.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is a test that demonstrates the issue:

query II
SELECT v, IF(v < 0, 10/0, 1) FROM (VALUES (1), (2)) t(v)
----
1 1
2 1

This fails with DataFusion error: Arrow error: Divide by zero error.

The version of this test in cast.slt works correctly:

query II
SELECT v, CASE WHEN v < 0 THEN 10/0 ELSE 1 END FROM (VALUES (1), (2)) t(v)
----
1 1
2 1

Copy link
Contributor

@SparkApplicationMaster SparkApplicationMaster Aug 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please also check if when(...).otherwise(else_expr) used here:
https://github.com/apache/datafusion/pull/16946/files#diff-8184a681e2d4b84411030426011f4e80cc4a79e2debd39f6290d0159d83a63a5R97
does not eagerly calculate else_expr.
because I faced it calculating else_expr when using when expression: lakehq/sail#648
and the fix was just removing else_expr from when expr and placing it to the last branch of when with true predicate:
https://github.com/lakehq/sail/pull/649/files
also checked this now by adding in datafusion if.slt:

query I
SELECT case when true then 1 / 1 else 1 / 0 end;
----
1

and got this:

1. query failed: DataFusion error: Arrow error: Divide by zero error
[SQL] SELECT case when true then 1 / 1 else 1 / 0 end;

but this:

query I
SELECT case when a then 1 / 1 else 1 / b end FROM (VALUES (false, 1), (true, 0)) t(a, b);
----
1
1

works OK, maybe the problem with example above is only about eager constant evaluation in else_expr, not every expression

Not sure if eager evaluation of else expression is OK in datafusion, but in spark it's definitely NOT OK, and there is a doctest checking this:
https://github.com/apache/spark/blob/326052ec8280d8bf8ee1904504be3b62a72d3d29/python/pyspark/sql/column.py#L1418-L1427
raise_error(literal(str)) is also constant expression, so it evaluates when not intented to be

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created another PR #17311

----
yes

# Test basic false condition
query T
SELECT if(false, 'yes', 'no');
----
no

# Test with comparison operators
query T
SELECT if(1 < 2, 'a', 'b');
----
a

query T
SELECT if(1 > 2, 'a', 'b');
----
b


## Numeric type tests

# Test with integers
query I
SELECT if(true, 10, 20);
----
10

query I
SELECT if(false, 10, 20);
----
20

# Test with different integer types
query I
SELECT if(true, 100, 200);
----
100

## Float type tests

# Test with floating point numbers
query R
SELECT if(true, 1.5, 2.5);
----
1.5

query R
SELECT if(false, 1.5, 2.5);
----
2.5

## String type tests

# Test with different string values
query T
SELECT if(true, 'hello', 'world');
----
hello

query T
SELECT if(false, 'hello', 'world');
----
world

## NULL handling tests

# Test with NULL condition
query T
SELECT if(NULL, 'yes', 'no');
----
no

query T
SELECT if(NOT NULL, 'yes', 'no');
----
no

# Test with NULL true value
query T
SELECT if(true, NULL, 'no');
----
NULL

# Test with NULL false value
query T
SELECT if(false, 'yes', NULL);
----
NULL

# Test with all NULL
query ?
SELECT if(true, NULL, NULL);
----
NULL

## Type coercion tests

# Test integer to float coercion
query R
SELECT if(true, 10, 20.5);
----
10

query R
SELECT if(false, 10, 20.5);
----
20.5

# Test float to integer coercion
query R
SELECT if(true, 10.5, 20);
----
10.5

query R
SELECT if(false, 10.5, 20);
----
20

statement error Int64 is not a boolean or null
SELECT if(1, 10.5, 20);


statement error Utf8 is not a boolean or null
SELECT if('x', 10.5, 20);

query II
SELECT v, IF(v < 0, 10/0, 1) FROM (VALUES (1), (2)) t(v)
----
1 1
2 1

query I
SELECT IF(true, 1 / 1, 1 / 0);
----
1