From 198430ea38de79ae37b81100c6210626607e40a3 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Mon, 28 Jul 2025 20:11:48 +0800 Subject: [PATCH 01/10] update --- .../spark/src/function/conditional/if.rs | 390 ++++++++++++++++++ .../spark/src/function/conditional/mod.rs | 13 +- 2 files changed, 401 insertions(+), 2 deletions(-) create mode 100644 datafusion/spark/src/function/conditional/if.rs diff --git a/datafusion/spark/src/function/conditional/if.rs b/datafusion/spark/src/function/conditional/if.rs new file mode 100644 index 000000000000..fb9d6d4848d6 --- /dev/null +++ b/datafusion/spark/src/function/conditional/if.rs @@ -0,0 +1,390 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{array::{ArrayRef, BooleanArray}, compute::kernels::zip::zip, datatypes::DataType}; +use datafusion_common::{plan_err, utils::take_function_args, Result}; +use datafusion_expr::{binary::comparison_coercion_numeric, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; + + +#[derive(Debug)] +pub struct SparkIf { + signature: Signature, +} + +impl Default for SparkIf { + fn default() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkIf { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "if" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + if arg_types.len() != 3 { + return plan_err!( + "Function 'if' expects 3 arguments but received {}", + arg_types.len() + ); + } + let Some(target_type) = comparison_coercion_numeric(&arg_types[1], &arg_types[2]) else { + return plan_err!("For function 'if' {} and {} is not comparable", arg_types[1], arg_types[2]); + }; + // Convert null to String type. + if target_type.is_null() { + Ok(vec![DataType::Boolean, DataType::Utf8View, DataType::Utf8View]) + } else { + Ok(vec![DataType::Boolean, target_type.clone(), target_type]) + } + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[1].clone()) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let args = ColumnarValue::values_to_arrays(&args.args)?; + let [expr1, expr2, expr3] = take_function_args::<3, ArrayRef>("if", args)?; + let expr1 = expr1.as_any().downcast_ref::().unwrap(); + let result = zip(&expr1, &expr2, &expr3)?; + Ok(ColumnarValue::Array(result)) + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{Array, BooleanArray, Float64Array, Int32Array, Int64Array, StringArray}; + use std::sync::Arc; + + use super::*; + + #[test] + fn test_if_basic() { + let if_udf = SparkIf::default(); + + // Test basic functionality + let condition = BooleanArray::from(vec![true, false, true, false]); + let true_value = Int32Array::from(vec![10, 20, 30, 40]); + let false_value = Int32Array::from(vec![100, 200, 300, 400]); + + let args = vec![ + ColumnarValue::Array(Arc::new(condition)), + ColumnarValue::Array(Arc::new(true_value)), + ColumnarValue::Array(Arc::new(false_value)), + ]; + + let arg_fields = args.iter() + .enumerate() + .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into()) + .collect::>(); + + let result = if_udf.invoke_with_args(ScalarFunctionArgs { + args, + arg_fields, + number_rows: 4, + return_field: arrow::datatypes::Field::new("result", DataType::Int32, true).into(), + }).unwrap(); + + let result_array = result.into_array(4).unwrap(); + let result_int32 = result_array.as_any().downcast_ref::().unwrap(); + + assert_eq!(result_int32.value(0), 10); // true -> 10 + assert_eq!(result_int32.value(1), 200); // false -> 200 + assert_eq!(result_int32.value(2), 30); // true -> 30 + assert_eq!(result_int32.value(3), 400); // false -> 400 + } + + #[test] + fn test_if_with_nulls() { + let if_udf = SparkIf::default(); + + // Test with NULL values in condition + let condition = BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]); + let true_value = Int32Array::from(vec![10, 20, 30, 40]); + let false_value = Int32Array::from(vec![100, 200, 300, 400]); + + let args = vec![ + ColumnarValue::Array(Arc::new(condition)), + ColumnarValue::Array(Arc::new(true_value)), + ColumnarValue::Array(Arc::new(false_value)), + ]; + + let arg_fields = args.iter() + .enumerate() + .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into()) + .collect::>(); + + let result = if_udf.invoke_with_args(ScalarFunctionArgs { + args, + arg_fields, + number_rows: 4, + return_field: arrow::datatypes::Field::new("result", DataType::Int32, true).into(), + }).unwrap(); + + let result_array = result.into_array(4).unwrap(); + let result_int32 = result_array.as_any().downcast_ref::().unwrap(); + + assert_eq!(result_int32.value(0), 10); // true -> 10 + assert!(result_int32.is_null(1)); // NULL -> NULL + assert_eq!(result_int32.value(2), 300); // false -> 300 + assert_eq!(result_int32.value(3), 40); // true -> 40 + } + + #[test] + fn test_if_string_types() { + let if_udf = SparkIf::default(); + + // Test with string types + let condition = BooleanArray::from(vec![true, false, true]); + let true_value = StringArray::from(vec!["yes", "maybe", "yes"]); + let false_value = StringArray::from(vec!["no", "maybe", "no"]); + + let args = vec![ + ColumnarValue::Array(Arc::new(condition)), + ColumnarValue::Array(Arc::new(true_value)), + ColumnarValue::Array(Arc::new(false_value)), + ]; + + let arg_fields = args.iter() + .enumerate() + .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into()) + .collect::>(); + + let result = if_udf.invoke_with_args(ScalarFunctionArgs { + args, + arg_fields, + number_rows: 3, + return_field: arrow::datatypes::Field::new("result", DataType::Utf8, true).into(), + }).unwrap(); + + let result_array = result.into_array(3).unwrap(); + let result_string = result_array.as_any().downcast_ref::().unwrap(); + + assert_eq!(result_string.value(0), "yes"); // true -> "yes" + assert_eq!(result_string.value(1), "maybe"); // false -> "maybe" + assert_eq!(result_string.value(2), "yes"); // true -> "yes" + } + + #[test] + fn test_if_float_types() { + let if_udf = SparkIf::default(); + + // Test with float types + let condition = BooleanArray::from(vec![true, false, true, false]); + let true_value = Float64Array::from(vec![1.5, 2.5, 3.5, 4.5]); + let false_value = Float64Array::from(vec![10.5, 20.5, 30.5, 40.5]); + + let args = vec![ + ColumnarValue::Array(Arc::new(condition)), + ColumnarValue::Array(Arc::new(true_value)), + ColumnarValue::Array(Arc::new(false_value)), + ]; + + let arg_fields = args.iter() + .enumerate() + .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into()) + .collect::>(); + + let result = if_udf.invoke_with_args(ScalarFunctionArgs { + args, + arg_fields, + number_rows: 4, + return_field: arrow::datatypes::Field::new("result", DataType::Float64, true).into(), + }).unwrap(); + + let result_array = result.into_array(4).unwrap(); + let result_float = result_array.as_any().downcast_ref::().unwrap(); + + assert_eq!(result_float.value(0), 1.5); // true -> 1.5 + assert_eq!(result_float.value(1), 20.5); // false -> 20.5 + assert_eq!(result_float.value(2), 3.5); // true -> 3.5 + assert_eq!(result_float.value(3), 40.5); // false -> 40.5 + } + + #[test] + fn test_if_type_coercion() { + let if_udf = SparkIf::default(); + + // Test type coercion between Int32 and Int64 + let condition = BooleanArray::from(vec![true, false]); + let true_value = Int32Array::from(vec![10, 20]); + let false_value = Int64Array::from(vec![100, 200]); + + let args = vec![ + ColumnarValue::Array(Arc::new(condition)), + ColumnarValue::Array(Arc::new(true_value)), + ColumnarValue::Array(Arc::new(false_value)), + ]; + + let arg_fields = args.iter() + .enumerate() + .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into()) + .collect::>(); + + let result = if_udf.invoke_with_args(ScalarFunctionArgs { + args, + arg_fields, + number_rows: 2, + return_field: arrow::datatypes::Field::new("result", DataType::Int64, true).into(), + }).unwrap(); + + let result_array = result.into_array(2).unwrap(); + let result_int64 = result_array.as_any().downcast_ref::().unwrap(); + + assert_eq!(result_int64.value(0), 10); // true -> 10 (coerced to Int64) + assert_eq!(result_int64.value(1), 200); // false -> 200 + } + + #[test] + fn test_if_all_true() { + let if_udf = SparkIf::default(); + + // Test when all conditions are true + let condition = BooleanArray::from(vec![true, true, true]); + let true_value = Int32Array::from(vec![1, 2, 3]); + let false_value = Int32Array::from(vec![100, 200, 300]); + + let args = vec![ + ColumnarValue::Array(Arc::new(condition)), + ColumnarValue::Array(Arc::new(true_value)), + ColumnarValue::Array(Arc::new(false_value)), + ]; + + let arg_fields = args.iter() + .enumerate() + .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into()) + .collect::>(); + + let result = if_udf.invoke_with_args(ScalarFunctionArgs { + args, + arg_fields, + number_rows: 3, + return_field: arrow::datatypes::Field::new("result", DataType::Int32, true).into(), + }).unwrap(); + + let result_array = result.into_array(3).unwrap(); + let result_int32 = result_array.as_any().downcast_ref::().unwrap(); + + assert_eq!(result_int32.value(0), 1); + assert_eq!(result_int32.value(1), 2); + assert_eq!(result_int32.value(2), 3); + } + + #[test] + fn test_if_all_false() { + let if_udf = SparkIf::default(); + + // Test when all conditions are false + let condition = BooleanArray::from(vec![false, false, false]); + let true_value = Int32Array::from(vec![1, 2, 3]); + let false_value = Int32Array::from(vec![100, 200, 300]); + + let args = vec![ + ColumnarValue::Array(Arc::new(condition)), + ColumnarValue::Array(Arc::new(true_value)), + ColumnarValue::Array(Arc::new(false_value)), + ]; + + let arg_fields = args.iter() + .enumerate() + .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into()) + .collect::>(); + + let result = if_udf.invoke_with_args(ScalarFunctionArgs { + args, + arg_fields, + number_rows: 3, + return_field: arrow::datatypes::Field::new("result", DataType::Int32, true).into(), + }).unwrap(); + + let result_array = result.into_array(3).unwrap(); + let result_int32 = result_array.as_any().downcast_ref::().unwrap(); + + assert_eq!(result_int32.value(0), 100); + assert_eq!(result_int32.value(1), 200); + assert_eq!(result_int32.value(2), 300); + } + + #[test] + fn test_if_empty_arrays() { + let if_udf = SparkIf::default(); + + // Test with empty arrays + let condition = BooleanArray::from(Vec::::new()); + let true_value = Int32Array::from(Vec::::new()); + let false_value = Int32Array::from(Vec::::new()); + + let args = vec![ + ColumnarValue::Array(Arc::new(condition)), + ColumnarValue::Array(Arc::new(true_value)), + ColumnarValue::Array(Arc::new(false_value)), + ]; + + let arg_fields = args.iter() + .enumerate() + .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into()) + .collect::>(); + + let result = if_udf.invoke_with_args(ScalarFunctionArgs { + args, + arg_fields, + number_rows: 0, + return_field: arrow::datatypes::Field::new("result", DataType::Int32, true).into(), + }).unwrap(); + + let result_array = result.into_array(0).unwrap(); + let result_int32 = result_array.as_any().downcast_ref::().unwrap(); + + assert_eq!(result_int32.len(), 0); + } + + #[test] + fn test_if_wrong_argument_count() { + let if_udf = SparkIf::default(); + + // Test with wrong number of arguments + let arg_types = vec![DataType::Boolean, DataType::Int32]; + let result = if_udf.coerce_types(&arg_types); + + assert!(result.is_err()); + assert!(result.unwrap_err().message().contains("expects 3 arguments")); + } + + #[test] + fn test_if_incompatible_types() { + let if_udf = SparkIf::default(); + + // Test with incompatible types (Boolean and String) + let arg_types = vec![DataType::Boolean, DataType::Boolean, DataType::Utf8]; + let result = if_udf.coerce_types(&arg_types); + + // This should work as Boolean and String can be coerced + assert!(result.is_ok()); + } +} diff --git a/datafusion/spark/src/function/conditional/mod.rs b/datafusion/spark/src/function/conditional/mod.rs index a87df9a2c87a..4301d7642b41 100644 --- a/datafusion/spark/src/function/conditional/mod.rs +++ b/datafusion/spark/src/function/conditional/mod.rs @@ -16,10 +16,19 @@ // under the License. use datafusion_expr::ScalarUDF; +use datafusion_functions::make_udf_function; use std::sync::Arc; -pub mod expr_fn {} +mod r#if; + +make_udf_function!(r#if::SparkIf, r#if); + +pub mod expr_fn { + use datafusion_functions::export_functions; + + export_functions!((r#if, "If arg1 evaluates to true, then returns arg2; otherwise returns arg3", arg1 arg2 arg3)); +} pub fn functions() -> Vec> { - vec![] + vec![r#if()] } From 839f20f215c570b2f38a290c0d59538aa808f3da Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Mon, 28 Jul 2025 21:44:44 +0800 Subject: [PATCH 02/10] feat: spark if udf --- .../spark/src/function/conditional/if.rs | 382 +++++++----------- .../test_files/spark/conditional/if.slt | 129 +++++- 2 files changed, 271 insertions(+), 240 deletions(-) diff --git a/datafusion/spark/src/function/conditional/if.rs b/datafusion/spark/src/function/conditional/if.rs index fb9d6d4848d6..93de5fa371ca 100644 --- a/datafusion/spark/src/function/conditional/if.rs +++ b/datafusion/spark/src/function/conditional/if.rs @@ -15,10 +15,16 @@ // specific language governing permissions and limitations // under the License. -use arrow::{array::{ArrayRef, BooleanArray}, compute::kernels::zip::zip, datatypes::DataType}; +use arrow::{ + array::{ArrayRef, BooleanArray}, + compute::kernels::zip::zip, + datatypes::DataType, +}; use datafusion_common::{plan_err, utils::take_function_args, Result}; -use datafusion_expr::{binary::comparison_coercion_numeric, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; - +use datafusion_expr::{ + binary::comparison_coercion_numeric, ColumnarValue, ScalarFunctionArgs, + ScalarUDFImpl, Signature, Volatility, +}; #[derive(Debug)] pub struct SparkIf { @@ -27,6 +33,12 @@ pub struct SparkIf { impl Default for SparkIf { fn default() -> Self { + Self::new() + } +} + +impl SparkIf { + pub fn new() -> Self { Self { signature: Signature::user_defined(Volatility::Immutable), } @@ -53,12 +65,21 @@ impl ScalarUDFImpl for SparkIf { arg_types.len() ); } - let Some(target_type) = comparison_coercion_numeric(&arg_types[1], &arg_types[2]) else { - return plan_err!("For function 'if' {} and {} is not comparable", arg_types[1], arg_types[2]); + let Some(target_type) = comparison_coercion_numeric(&arg_types[1], &arg_types[2]) + else { + return plan_err!( + "For function 'if' {} and {} is not comparable", + arg_types[1], + arg_types[2] + ); }; // Convert null to String type. if target_type.is_null() { - Ok(vec![DataType::Boolean, DataType::Utf8View, DataType::Utf8View]) + Ok(vec![ + DataType::Boolean, + DataType::Utf8View, + DataType::Utf8View, + ]) } else { Ok(vec![DataType::Boolean, target_type.clone(), target_type]) } @@ -72,14 +93,16 @@ impl ScalarUDFImpl for SparkIf { let args = ColumnarValue::values_to_arrays(&args.args)?; let [expr1, expr2, expr3] = take_function_args::<3, ArrayRef>("if", args)?; let expr1 = expr1.as_any().downcast_ref::().unwrap(); - let result = zip(&expr1, &expr2, &expr3)?; + let result = zip(expr1, &expr2, &expr3)?; Ok(ColumnarValue::Array(result)) } } #[cfg(test)] mod tests { - use arrow::array::{Array, BooleanArray, Float64Array, Int32Array, Int64Array, StringArray}; + use arrow::array::{ + Array, BooleanArray, Float64Array, Int32Array, StringArray, + }; use std::sync::Arc; use super::*; @@ -87,304 +110,191 @@ mod tests { #[test] fn test_if_basic() { let if_udf = SparkIf::default(); - + // Test basic functionality let condition = BooleanArray::from(vec![true, false, true, false]); let true_value = Int32Array::from(vec![10, 20, 30, 40]); let false_value = Int32Array::from(vec![100, 200, 300, 400]); - + let args = vec![ ColumnarValue::Array(Arc::new(condition)), ColumnarValue::Array(Arc::new(true_value)), ColumnarValue::Array(Arc::new(false_value)), ]; - - let arg_fields = args.iter() + + let arg_fields = args + .iter() .enumerate() - .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into()) + .map(|(idx, arg)| { + arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true) + .into() + }) .collect::>(); - - let result = if_udf.invoke_with_args(ScalarFunctionArgs { - args, - arg_fields, - number_rows: 4, - return_field: arrow::datatypes::Field::new("result", DataType::Int32, true).into(), - }).unwrap(); - + + let result = if_udf + .invoke_with_args(ScalarFunctionArgs { + args, + arg_fields, + number_rows: 4, + return_field: arrow::datatypes::Field::new( + "result", + DataType::Int32, + true, + ) + .into(), + }) + .unwrap(); + let result_array = result.into_array(4).unwrap(); let result_int32 = result_array.as_any().downcast_ref::().unwrap(); - - assert_eq!(result_int32.value(0), 10); // true -> 10 + + assert_eq!(result_int32.value(0), 10); // true -> 10 assert_eq!(result_int32.value(1), 200); // false -> 200 - assert_eq!(result_int32.value(2), 30); // true -> 30 + assert_eq!(result_int32.value(2), 30); // true -> 30 assert_eq!(result_int32.value(3), 400); // false -> 400 } #[test] fn test_if_with_nulls() { let if_udf = SparkIf::default(); - + // Test with NULL values in condition - let condition = BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]); + let condition = + BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]); let true_value = Int32Array::from(vec![10, 20, 30, 40]); let false_value = Int32Array::from(vec![100, 200, 300, 400]); - + let args = vec![ ColumnarValue::Array(Arc::new(condition)), ColumnarValue::Array(Arc::new(true_value)), ColumnarValue::Array(Arc::new(false_value)), ]; - - let arg_fields = args.iter() + + let arg_fields = args + .iter() .enumerate() - .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into()) + .map(|(idx, arg)| { + arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true) + .into() + }) .collect::>(); - - let result = if_udf.invoke_with_args(ScalarFunctionArgs { - args, - arg_fields, - number_rows: 4, - return_field: arrow::datatypes::Field::new("result", DataType::Int32, true).into(), - }).unwrap(); - + + let result = if_udf + .invoke_with_args(ScalarFunctionArgs { + args, + arg_fields, + number_rows: 4, + return_field: arrow::datatypes::Field::new( + "result", + DataType::Int32, + true, + ) + .into(), + }) + .unwrap(); + let result_array = result.into_array(4).unwrap(); let result_int32 = result_array.as_any().downcast_ref::().unwrap(); - - assert_eq!(result_int32.value(0), 10); // true -> 10 - assert!(result_int32.is_null(1)); // NULL -> NULL - assert_eq!(result_int32.value(2), 300); // false -> 300 - assert_eq!(result_int32.value(3), 40); // true -> 40 + + assert_eq!(result_int32.value(0), 10); // true -> 10 + assert_eq!(result_int32.value(1), 200); // NULL -> 200 + assert_eq!(result_int32.value(2), 300); // false -> 300 + assert_eq!(result_int32.value(3), 40); // true -> 40 } #[test] fn test_if_string_types() { let if_udf = SparkIf::default(); - + // Test with string types let condition = BooleanArray::from(vec![true, false, true]); - let true_value = StringArray::from(vec!["yes", "maybe", "yes"]); + let true_value = StringArray::from(vec!["yes", "yes", "yes"]); let false_value = StringArray::from(vec!["no", "maybe", "no"]); - + let args = vec![ ColumnarValue::Array(Arc::new(condition)), ColumnarValue::Array(Arc::new(true_value)), ColumnarValue::Array(Arc::new(false_value)), ]; - - let arg_fields = args.iter() + + let arg_fields = args + .iter() .enumerate() - .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into()) + .map(|(idx, arg)| { + arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true) + .into() + }) .collect::>(); - - let result = if_udf.invoke_with_args(ScalarFunctionArgs { - args, - arg_fields, - number_rows: 3, - return_field: arrow::datatypes::Field::new("result", DataType::Utf8, true).into(), - }).unwrap(); - + + let result = if_udf + .invoke_with_args(ScalarFunctionArgs { + args, + arg_fields, + number_rows: 3, + return_field: arrow::datatypes::Field::new( + "result", + DataType::Utf8, + true, + ) + .into(), + }) + .unwrap(); + let result_array = result.into_array(3).unwrap(); let result_string = result_array.as_any().downcast_ref::().unwrap(); - - assert_eq!(result_string.value(0), "yes"); // true -> "yes" + + assert_eq!(result_string.value(0), "yes"); // true -> "yes" assert_eq!(result_string.value(1), "maybe"); // false -> "maybe" - assert_eq!(result_string.value(2), "yes"); // true -> "yes" + assert_eq!(result_string.value(2), "yes"); // true -> "yes" } #[test] fn test_if_float_types() { let if_udf = SparkIf::default(); - + // Test with float types let condition = BooleanArray::from(vec![true, false, true, false]); let true_value = Float64Array::from(vec![1.5, 2.5, 3.5, 4.5]); let false_value = Float64Array::from(vec![10.5, 20.5, 30.5, 40.5]); - - let args = vec![ - ColumnarValue::Array(Arc::new(condition)), - ColumnarValue::Array(Arc::new(true_value)), - ColumnarValue::Array(Arc::new(false_value)), - ]; - - let arg_fields = args.iter() - .enumerate() - .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into()) - .collect::>(); - - let result = if_udf.invoke_with_args(ScalarFunctionArgs { - args, - arg_fields, - number_rows: 4, - return_field: arrow::datatypes::Field::new("result", DataType::Float64, true).into(), - }).unwrap(); - - let result_array = result.into_array(4).unwrap(); - let result_float = result_array.as_any().downcast_ref::().unwrap(); - - assert_eq!(result_float.value(0), 1.5); // true -> 1.5 - assert_eq!(result_float.value(1), 20.5); // false -> 20.5 - assert_eq!(result_float.value(2), 3.5); // true -> 3.5 - assert_eq!(result_float.value(3), 40.5); // false -> 40.5 - } - #[test] - fn test_if_type_coercion() { - let if_udf = SparkIf::default(); - - // Test type coercion between Int32 and Int64 - let condition = BooleanArray::from(vec![true, false]); - let true_value = Int32Array::from(vec![10, 20]); - let false_value = Int64Array::from(vec![100, 200]); - let args = vec![ ColumnarValue::Array(Arc::new(condition)), ColumnarValue::Array(Arc::new(true_value)), ColumnarValue::Array(Arc::new(false_value)), ]; - - let arg_fields = args.iter() - .enumerate() - .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into()) - .collect::>(); - - let result = if_udf.invoke_with_args(ScalarFunctionArgs { - args, - arg_fields, - number_rows: 2, - return_field: arrow::datatypes::Field::new("result", DataType::Int64, true).into(), - }).unwrap(); - - let result_array = result.into_array(2).unwrap(); - let result_int64 = result_array.as_any().downcast_ref::().unwrap(); - - assert_eq!(result_int64.value(0), 10); // true -> 10 (coerced to Int64) - assert_eq!(result_int64.value(1), 200); // false -> 200 - } - #[test] - fn test_if_all_true() { - let if_udf = SparkIf::default(); - - // Test when all conditions are true - let condition = BooleanArray::from(vec![true, true, true]); - let true_value = Int32Array::from(vec![1, 2, 3]); - let false_value = Int32Array::from(vec![100, 200, 300]); - - let args = vec![ - ColumnarValue::Array(Arc::new(condition)), - ColumnarValue::Array(Arc::new(true_value)), - ColumnarValue::Array(Arc::new(false_value)), - ]; - - let arg_fields = args.iter() + let arg_fields = args + .iter() .enumerate() - .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into()) + .map(|(idx, arg)| { + arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true) + .into() + }) .collect::>(); - - let result = if_udf.invoke_with_args(ScalarFunctionArgs { - args, - arg_fields, - number_rows: 3, - return_field: arrow::datatypes::Field::new("result", DataType::Int32, true).into(), - }).unwrap(); - - let result_array = result.into_array(3).unwrap(); - let result_int32 = result_array.as_any().downcast_ref::().unwrap(); - - assert_eq!(result_int32.value(0), 1); - assert_eq!(result_int32.value(1), 2); - assert_eq!(result_int32.value(2), 3); - } - #[test] - fn test_if_all_false() { - let if_udf = SparkIf::default(); - - // Test when all conditions are false - let condition = BooleanArray::from(vec![false, false, false]); - let true_value = Int32Array::from(vec![1, 2, 3]); - let false_value = Int32Array::from(vec![100, 200, 300]); - - let args = vec![ - ColumnarValue::Array(Arc::new(condition)), - ColumnarValue::Array(Arc::new(true_value)), - ColumnarValue::Array(Arc::new(false_value)), - ]; - - let arg_fields = args.iter() - .enumerate() - .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into()) - .collect::>(); - - let result = if_udf.invoke_with_args(ScalarFunctionArgs { - args, - arg_fields, - number_rows: 3, - return_field: arrow::datatypes::Field::new("result", DataType::Int32, true).into(), - }).unwrap(); - - let result_array = result.into_array(3).unwrap(); - let result_int32 = result_array.as_any().downcast_ref::().unwrap(); - - assert_eq!(result_int32.value(0), 100); - assert_eq!(result_int32.value(1), 200); - assert_eq!(result_int32.value(2), 300); - } + let result = if_udf + .invoke_with_args(ScalarFunctionArgs { + args, + arg_fields, + number_rows: 4, + return_field: arrow::datatypes::Field::new( + "result", + DataType::Float64, + true, + ) + .into(), + }) + .unwrap(); - #[test] - fn test_if_empty_arrays() { - let if_udf = SparkIf::default(); - - // Test with empty arrays - let condition = BooleanArray::from(Vec::::new()); - let true_value = Int32Array::from(Vec::::new()); - let false_value = Int32Array::from(Vec::::new()); - - let args = vec![ - ColumnarValue::Array(Arc::new(condition)), - ColumnarValue::Array(Arc::new(true_value)), - ColumnarValue::Array(Arc::new(false_value)), - ]; - - let arg_fields = args.iter() - .enumerate() - .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true).into()) - .collect::>(); - - let result = if_udf.invoke_with_args(ScalarFunctionArgs { - args, - arg_fields, - number_rows: 0, - return_field: arrow::datatypes::Field::new("result", DataType::Int32, true).into(), - }).unwrap(); - - let result_array = result.into_array(0).unwrap(); - let result_int32 = result_array.as_any().downcast_ref::().unwrap(); - - assert_eq!(result_int32.len(), 0); - } - - #[test] - fn test_if_wrong_argument_count() { - let if_udf = SparkIf::default(); - - // Test with wrong number of arguments - let arg_types = vec![DataType::Boolean, DataType::Int32]; - let result = if_udf.coerce_types(&arg_types); - - assert!(result.is_err()); - assert!(result.unwrap_err().message().contains("expects 3 arguments")); - } + let result_array = result.into_array(4).unwrap(); + let result_float = result_array + .as_any() + .downcast_ref::() + .unwrap(); - #[test] - fn test_if_incompatible_types() { - let if_udf = SparkIf::default(); - - // Test with incompatible types (Boolean and String) - let arg_types = vec![DataType::Boolean, DataType::Boolean, DataType::Utf8]; - let result = if_udf.coerce_types(&arg_types); - - // This should work as Boolean and String can be coerced - assert!(result.is_ok()); + assert_eq!(result_float.value(0), 1.5); // true -> 1.5 + assert_eq!(result_float.value(1), 20.5); // false -> 20.5 + assert_eq!(result_float.value(2), 3.5); // true -> 3.5 + assert_eq!(result_float.value(3), 40.5); // false -> 40.5 } } diff --git a/datafusion/sqllogictest/test_files/spark/conditional/if.slt b/datafusion/sqllogictest/test_files/spark/conditional/if.slt index 7baedad7456e..6b49e18163dc 100644 --- a/datafusion/sqllogictest/test_files/spark/conditional/if.slt +++ b/datafusion/sqllogictest/test_files/spark/conditional/if.slt @@ -21,7 +21,128 @@ # For more information, please see: # https://github.com/apache/datafusion/issues/15914 -## Original Query: SELECT if(1 < 2, 'a', 'b'); -## PySpark 3.5.5 Result: {'(IF((1 < 2), a, b))': 'a', 'typeof((IF((1 < 2), a, b)))': 'string', 'typeof((1 < 2))': 'boolean', 'typeof(a)': 'string', 'typeof(b)': 'string'} -#query -#SELECT if((1 < 2)::boolean, 'a'::string, 'b'::string); +## Basic IF function tests + +# Test basic true condition +query T +SELECT if(true, 'yes', 'no'); +---- +yes + +# Test basic false condition +query T +SELECT if(false, 'yes', 'no'); +---- +no + +# Test with comparison operators +query T +SELECT if(1 < 2, 'a', 'b'); +---- +a + +query T +SELECT if(1 > 2, 'a', 'b'); +---- +b + + +## Numeric type tests + +# Test with integers +query I +SELECT if(true, 10, 20); +---- +10 + +query I +SELECT if(false, 10, 20); +---- +20 + +# Test with different integer types +query I +SELECT if(true, 100, 200); +---- +100 + +## Float type tests + +# Test with floating point numbers +query R +SELECT if(true, 1.5, 2.5); +---- +1.5 + +query R +SELECT if(false, 1.5, 2.5); +---- +2.5 + +## String type tests + +# Test with different string values +query T +SELECT if(true, 'hello', 'world'); +---- +hello + +query T +SELECT if(false, 'hello', 'world'); +---- +world + +## NULL handling tests + +# Test with NULL condition +query T +SELECT if(NULL, 'yes', 'no'); +---- +no + +query T +SELECT if(NOT NULL, 'yes', 'no'); +---- +no + +# Test with NULL true value +query T +SELECT if(true, NULL, 'no'); +---- +NULL + +# Test with NULL false value +query T +SELECT if(false, 'yes', NULL); +---- +NULL + +# Test with all NULL +query T +SELECT if(true, NULL, NULL); +---- +NULL + +## Type coercion tests + +# Test integer to float coercion +query R +SELECT if(true, 10, 20.5); +---- +10 + +query R +SELECT if(false, 10, 20.5); +---- +20.5 + +# Test float to integer coercion +query R +SELECT if(true, 10.5, 20); +---- +10.5 + +query R +SELECT if(false, 10.5, 20); +---- +20 From 043fd35f22805c50d0412913bbc0ebde266c3d99 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Mon, 28 Jul 2025 23:00:58 +0800 Subject: [PATCH 03/10] fmt --- datafusion/spark/src/function/conditional/if.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/datafusion/spark/src/function/conditional/if.rs b/datafusion/spark/src/function/conditional/if.rs index 93de5fa371ca..4ab1edc9c1b5 100644 --- a/datafusion/spark/src/function/conditional/if.rs +++ b/datafusion/spark/src/function/conditional/if.rs @@ -100,9 +100,7 @@ impl ScalarUDFImpl for SparkIf { #[cfg(test)] mod tests { - use arrow::array::{ - Array, BooleanArray, Float64Array, Int32Array, StringArray, - }; + use arrow::array::{Array, BooleanArray, Float64Array, Int32Array, StringArray}; use std::sync::Arc; use super::*; From 8d2145738ce77759dc4267a598f1bae6ff10d56a Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Wed, 20 Aug 2025 07:28:02 +0800 Subject: [PATCH 04/10] update --- datafusion/spark/src/function/conditional/if.rs | 8 ++++++++ .../sqllogictest/test_files/spark/conditional/if.slt | 7 +++++++ 2 files changed, 15 insertions(+) diff --git a/datafusion/spark/src/function/conditional/if.rs b/datafusion/spark/src/function/conditional/if.rs index 4ab1edc9c1b5..3182f7b78712 100644 --- a/datafusion/spark/src/function/conditional/if.rs +++ b/datafusion/spark/src/function/conditional/if.rs @@ -65,6 +65,14 @@ impl ScalarUDFImpl for SparkIf { arg_types.len() ); } + + if arg_types[0] != DataType::Boolean && arg_types[0] != DataType::Null { + return plan_err!( + "For function 'if' {} is not a boolean or null", + arg_types[0] + ); + } + let Some(target_type) = comparison_coercion_numeric(&arg_types[1], &arg_types[2]) else { return plan_err!( diff --git a/datafusion/sqllogictest/test_files/spark/conditional/if.slt b/datafusion/sqllogictest/test_files/spark/conditional/if.slt index 6b49e18163dc..6360e81e34ec 100644 --- a/datafusion/sqllogictest/test_files/spark/conditional/if.slt +++ b/datafusion/sqllogictest/test_files/spark/conditional/if.slt @@ -146,3 +146,10 @@ query R SELECT if(false, 10.5, 20); ---- 20 + +statement error Int64 is not a boolean or null +SELECT if(1, 10.5, 20); + + +statement error Utf8 is not a boolean or null +SELECT if('x', 10.5, 20); From 84f7e21fc346abd8be6104ca9feafdbe20427bb7 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Wed, 20 Aug 2025 08:46:36 +0800 Subject: [PATCH 05/10] update --- .../spark/src/function/conditional/if.rs | 199 ------------------ 1 file changed, 199 deletions(-) diff --git a/datafusion/spark/src/function/conditional/if.rs b/datafusion/spark/src/function/conditional/if.rs index 3182f7b78712..ff7102e1ce95 100644 --- a/datafusion/spark/src/function/conditional/if.rs +++ b/datafusion/spark/src/function/conditional/if.rs @@ -105,202 +105,3 @@ impl ScalarUDFImpl for SparkIf { Ok(ColumnarValue::Array(result)) } } - -#[cfg(test)] -mod tests { - use arrow::array::{Array, BooleanArray, Float64Array, Int32Array, StringArray}; - use std::sync::Arc; - - use super::*; - - #[test] - fn test_if_basic() { - let if_udf = SparkIf::default(); - - // Test basic functionality - let condition = BooleanArray::from(vec![true, false, true, false]); - let true_value = Int32Array::from(vec![10, 20, 30, 40]); - let false_value = Int32Array::from(vec![100, 200, 300, 400]); - - let args = vec![ - ColumnarValue::Array(Arc::new(condition)), - ColumnarValue::Array(Arc::new(true_value)), - ColumnarValue::Array(Arc::new(false_value)), - ]; - - let arg_fields = args - .iter() - .enumerate() - .map(|(idx, arg)| { - arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true) - .into() - }) - .collect::>(); - - let result = if_udf - .invoke_with_args(ScalarFunctionArgs { - args, - arg_fields, - number_rows: 4, - return_field: arrow::datatypes::Field::new( - "result", - DataType::Int32, - true, - ) - .into(), - }) - .unwrap(); - - let result_array = result.into_array(4).unwrap(); - let result_int32 = result_array.as_any().downcast_ref::().unwrap(); - - assert_eq!(result_int32.value(0), 10); // true -> 10 - assert_eq!(result_int32.value(1), 200); // false -> 200 - assert_eq!(result_int32.value(2), 30); // true -> 30 - assert_eq!(result_int32.value(3), 400); // false -> 400 - } - - #[test] - fn test_if_with_nulls() { - let if_udf = SparkIf::default(); - - // Test with NULL values in condition - let condition = - BooleanArray::from(vec![Some(true), None, Some(false), Some(true)]); - let true_value = Int32Array::from(vec![10, 20, 30, 40]); - let false_value = Int32Array::from(vec![100, 200, 300, 400]); - - let args = vec![ - ColumnarValue::Array(Arc::new(condition)), - ColumnarValue::Array(Arc::new(true_value)), - ColumnarValue::Array(Arc::new(false_value)), - ]; - - let arg_fields = args - .iter() - .enumerate() - .map(|(idx, arg)| { - arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true) - .into() - }) - .collect::>(); - - let result = if_udf - .invoke_with_args(ScalarFunctionArgs { - args, - arg_fields, - number_rows: 4, - return_field: arrow::datatypes::Field::new( - "result", - DataType::Int32, - true, - ) - .into(), - }) - .unwrap(); - - let result_array = result.into_array(4).unwrap(); - let result_int32 = result_array.as_any().downcast_ref::().unwrap(); - - assert_eq!(result_int32.value(0), 10); // true -> 10 - assert_eq!(result_int32.value(1), 200); // NULL -> 200 - assert_eq!(result_int32.value(2), 300); // false -> 300 - assert_eq!(result_int32.value(3), 40); // true -> 40 - } - - #[test] - fn test_if_string_types() { - let if_udf = SparkIf::default(); - - // Test with string types - let condition = BooleanArray::from(vec![true, false, true]); - let true_value = StringArray::from(vec!["yes", "yes", "yes"]); - let false_value = StringArray::from(vec!["no", "maybe", "no"]); - - let args = vec![ - ColumnarValue::Array(Arc::new(condition)), - ColumnarValue::Array(Arc::new(true_value)), - ColumnarValue::Array(Arc::new(false_value)), - ]; - - let arg_fields = args - .iter() - .enumerate() - .map(|(idx, arg)| { - arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true) - .into() - }) - .collect::>(); - - let result = if_udf - .invoke_with_args(ScalarFunctionArgs { - args, - arg_fields, - number_rows: 3, - return_field: arrow::datatypes::Field::new( - "result", - DataType::Utf8, - true, - ) - .into(), - }) - .unwrap(); - - let result_array = result.into_array(3).unwrap(); - let result_string = result_array.as_any().downcast_ref::().unwrap(); - - assert_eq!(result_string.value(0), "yes"); // true -> "yes" - assert_eq!(result_string.value(1), "maybe"); // false -> "maybe" - assert_eq!(result_string.value(2), "yes"); // true -> "yes" - } - - #[test] - fn test_if_float_types() { - let if_udf = SparkIf::default(); - - // Test with float types - let condition = BooleanArray::from(vec![true, false, true, false]); - let true_value = Float64Array::from(vec![1.5, 2.5, 3.5, 4.5]); - let false_value = Float64Array::from(vec![10.5, 20.5, 30.5, 40.5]); - - let args = vec![ - ColumnarValue::Array(Arc::new(condition)), - ColumnarValue::Array(Arc::new(true_value)), - ColumnarValue::Array(Arc::new(false_value)), - ]; - - let arg_fields = args - .iter() - .enumerate() - .map(|(idx, arg)| { - arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true) - .into() - }) - .collect::>(); - - let result = if_udf - .invoke_with_args(ScalarFunctionArgs { - args, - arg_fields, - number_rows: 4, - return_field: arrow::datatypes::Field::new( - "result", - DataType::Float64, - true, - ) - .into(), - }) - .unwrap(); - - let result_array = result.into_array(4).unwrap(); - let result_float = result_array - .as_any() - .downcast_ref::() - .unwrap(); - - assert_eq!(result_float.value(0), 1.5); // true -> 1.5 - assert_eq!(result_float.value(1), 20.5); // false -> 20.5 - assert_eq!(result_float.value(2), 3.5); // true -> 3.5 - assert_eq!(result_float.value(3), 40.5); // false -> 40.5 - } -} From 2010cbeeeff5da7ee3acd1596544ecc3f2395061 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Wed, 20 Aug 2025 08:50:38 +0800 Subject: [PATCH 06/10] update --- datafusion/spark/src/function/conditional/if.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/spark/src/function/conditional/if.rs b/datafusion/spark/src/function/conditional/if.rs index ff7102e1ce95..631ef27c0676 100644 --- a/datafusion/spark/src/function/conditional/if.rs +++ b/datafusion/spark/src/function/conditional/if.rs @@ -26,7 +26,7 @@ use datafusion_expr::{ ScalarUDFImpl, Signature, Volatility, }; -#[derive(Debug)] +#[derive(Debug, Eq, Hash, PartialEq)] pub struct SparkIf { signature: Signature, } From 92d42cfe98e7ff8db618a75d4182f58ec84d6a6d Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Wed, 20 Aug 2025 08:54:53 +0800 Subject: [PATCH 07/10] update --- datafusion/spark/src/function/conditional/if.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/spark/src/function/conditional/if.rs b/datafusion/spark/src/function/conditional/if.rs index 631ef27c0676..bccbbee0ff4c 100644 --- a/datafusion/spark/src/function/conditional/if.rs +++ b/datafusion/spark/src/function/conditional/if.rs @@ -26,7 +26,7 @@ use datafusion_expr::{ ScalarUDFImpl, Signature, Volatility, }; -#[derive(Debug, Eq, Hash, PartialEq)] +#[derive(Debug, PartialEq, Eq, Hash)] pub struct SparkIf { signature: Signature, } From d1b4ded127a563a3af6bd3eb8749883c652ae183 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Wed, 20 Aug 2025 15:43:47 +0800 Subject: [PATCH 08/10] update --- .../spark/src/function/conditional/if.rs | 26 +++++-------------- .../test_files/spark/conditional/if.slt | 2 +- 2 files changed, 7 insertions(+), 21 deletions(-) diff --git a/datafusion/spark/src/function/conditional/if.rs b/datafusion/spark/src/function/conditional/if.rs index bccbbee0ff4c..df89528281ed 100644 --- a/datafusion/spark/src/function/conditional/if.rs +++ b/datafusion/spark/src/function/conditional/if.rs @@ -22,8 +22,8 @@ use arrow::{ }; use datafusion_common::{plan_err, utils::take_function_args, Result}; use datafusion_expr::{ - binary::comparison_coercion_numeric, ColumnarValue, ScalarFunctionArgs, - ScalarUDFImpl, Signature, Volatility, + binary::try_type_union_resolution, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, + Signature, Volatility, }; #[derive(Debug, PartialEq, Eq, Hash)] @@ -73,24 +73,10 @@ impl ScalarUDFImpl for SparkIf { ); } - let Some(target_type) = comparison_coercion_numeric(&arg_types[1], &arg_types[2]) - else { - return plan_err!( - "For function 'if' {} and {} is not comparable", - arg_types[1], - arg_types[2] - ); - }; - // Convert null to String type. - if target_type.is_null() { - Ok(vec![ - DataType::Boolean, - DataType::Utf8View, - DataType::Utf8View, - ]) - } else { - Ok(vec![DataType::Boolean, target_type.clone(), target_type]) - } + let target_types = try_type_union_resolution(&arg_types[1..])?; + let mut result = vec![DataType::Boolean]; + result.extend(target_types); + Ok(result) } fn return_type(&self, arg_types: &[DataType]) -> Result { diff --git a/datafusion/sqllogictest/test_files/spark/conditional/if.slt b/datafusion/sqllogictest/test_files/spark/conditional/if.slt index 6360e81e34ec..2af8d94f7e9a 100644 --- a/datafusion/sqllogictest/test_files/spark/conditional/if.slt +++ b/datafusion/sqllogictest/test_files/spark/conditional/if.slt @@ -118,7 +118,7 @@ SELECT if(false, 'yes', NULL); NULL # Test with all NULL -query T +query ? SELECT if(true, NULL, NULL); ---- NULL From 0a670dbc65ccc7c3561f7e120949c3071a750a6a Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Fri, 22 Aug 2025 09:49:05 +0800 Subject: [PATCH 09/10] update --- .../spark/src/function/conditional/if.rs | 36 +++++++++++-------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/datafusion/spark/src/function/conditional/if.rs b/datafusion/spark/src/function/conditional/if.rs index df89528281ed..aee43dd8d0a5 100644 --- a/datafusion/spark/src/function/conditional/if.rs +++ b/datafusion/spark/src/function/conditional/if.rs @@ -15,15 +15,11 @@ // specific language governing permissions and limitations // under the License. -use arrow::{ - array::{ArrayRef, BooleanArray}, - compute::kernels::zip::zip, - datatypes::DataType, -}; -use datafusion_common::{plan_err, utils::take_function_args, Result}; +use arrow::datatypes::DataType; +use datafusion_common::{internal_err, plan_err, Result}; use datafusion_expr::{ - binary::try_type_union_resolution, ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, - Signature, Volatility, + binary::try_type_union_resolution, simplify::ExprSimplifyResult, when, ColumnarValue, + Expr, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; #[derive(Debug, PartialEq, Eq, Hash)] @@ -83,11 +79,23 @@ impl ScalarUDFImpl for SparkIf { Ok(arg_types[1].clone()) } - fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let args = ColumnarValue::values_to_arrays(&args.args)?; - let [expr1, expr2, expr3] = take_function_args::<3, ArrayRef>("if", args)?; - let expr1 = expr1.as_any().downcast_ref::().unwrap(); - let result = zip(expr1, &expr2, &expr3)?; - Ok(ColumnarValue::Array(result)) + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + internal_err!("if should have been simplified to case") + } + + fn simplify( + &self, + args: Vec, + _info: &dyn datafusion_expr::simplify::SimplifyInfo, + ) -> Result { + let condition = args[0].clone(); + let then_expr = args[1].clone(); + let else_expr = args[2].clone(); + + // Convert IF(condition, then_expr, else_expr) to + // CASE WHEN condition THEN then_expr ELSE else_expr END + let case_expr = when(condition, then_expr).otherwise(else_expr)?; + + Ok(ExprSimplifyResult::Simplified(case_expr)) } } From 7c4833eb3def4f4d46b7afa9103f08ef045730f0 Mon Sep 17 00:00:00 2001 From: Chongchen Chen Date: Thu, 28 Aug 2025 08:29:53 +0800 Subject: [PATCH 10/10] update test --- .../sqllogictest/test_files/spark/conditional/if.slt | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/datafusion/sqllogictest/test_files/spark/conditional/if.slt b/datafusion/sqllogictest/test_files/spark/conditional/if.slt index 2af8d94f7e9a..b4380e065b98 100644 --- a/datafusion/sqllogictest/test_files/spark/conditional/if.slt +++ b/datafusion/sqllogictest/test_files/spark/conditional/if.slt @@ -153,3 +153,14 @@ SELECT if(1, 10.5, 20); statement error Utf8 is not a boolean or null SELECT if('x', 10.5, 20); + +query II +SELECT v, IF(v < 0, 10/0, 1) FROM (VALUES (1), (2)) t(v) +---- +1 1 +2 1 + +query I +SELECT IF(true, 1 / 1, 1 / 0); +---- +1