diff --git a/datafusion/spark/src/function/math/mod.rs b/datafusion/spark/src/function/math/mod.rs index 092335e4aa18d..4560a8e72ecc7 100644 --- a/datafusion/spark/src/function/math/mod.rs +++ b/datafusion/spark/src/function/math/mod.rs @@ -20,6 +20,7 @@ pub mod factorial; pub mod hex; pub mod modulus; pub mod rint; +pub mod trigonometry; pub mod width_bucket; use datafusion_expr::ScalarUDF; @@ -33,6 +34,7 @@ make_udf_function!(modulus::SparkMod, modulus); make_udf_function!(modulus::SparkPmod, pmod); make_udf_function!(rint::SparkRint, rint); make_udf_function!(width_bucket::SparkWidthBucket, width_bucket); +make_udf_function!(trigonometry::SparkCsc, csc); pub mod expr_fn { use datafusion_functions::export_functions; @@ -48,6 +50,7 @@ pub mod expr_fn { export_functions!((pmod, "Returns the positive remainder of division of the first argument by the second argument.", arg1 arg2)); export_functions!((rint, "Returns the double value that is closest in value to the argument and is equal to a mathematical integer.", arg1)); export_functions!((width_bucket, "Returns the bucket number into which the value of this expression would fall after being evaluated.", arg1 arg2 arg3 arg4)); + export_functions!((csc, "Returns the cosecant of expr.", arg1)); } pub fn functions() -> Vec> { @@ -59,5 +62,6 @@ pub fn functions() -> Vec> { pmod(), rint(), width_bucket(), + csc(), ] } diff --git a/datafusion/spark/src/function/math/trigonometry.rs b/datafusion/spark/src/function/math/trigonometry.rs new file mode 100644 index 0000000000000..a6bcf8044f5a9 --- /dev/null +++ b/datafusion/spark/src/function/math/trigonometry.rs @@ -0,0 +1,97 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::function::error_utils::unsupported_data_type_exec_err; +use arrow::array::{ArrayRef, AsArray}; +use arrow::datatypes::{DataType, Float64Type}; +use datafusion_common::utils::take_function_args; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, +}; +use std::any::Any; +use std::sync::Arc; + +static CSC_FUNCTION_NAME: &str = "csc"; + +/// +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct SparkCsc { + signature: Signature, +} + +impl Default for SparkCsc { + fn default() -> Self { + Self::new() + } +} + +impl SparkCsc { + pub fn new() -> Self { + Self { + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkCsc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + CSC_FUNCTION_NAME + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + let [arg] = take_function_args(self.name(), &args.args)?; + spark_csc(arg) + } +} + +fn spark_csc(arg: &ColumnarValue) -> Result { + match arg { + ColumnarValue::Scalar(ScalarValue::Float64(value)) => Ok(ColumnarValue::Scalar( + ScalarValue::Float64(value.map(|x| 1.0 / x.sin())), + )), + ColumnarValue::Array(array) => match array.data_type() { + DataType::Float64 => Ok(ColumnarValue::Array(Arc::new( + array + .as_primitive::() + .unary::<_, Float64Type>(|x| 1.0 / x.sin()), + ) as ArrayRef)), + other => Err(unsupported_data_type_exec_err( + CSC_FUNCTION_NAME, + format!("{}", DataType::Float64).as_str(), + other, + )), + }, + other => Err(unsupported_data_type_exec_err( + CSC_FUNCTION_NAME, + format!("{}", DataType::Float64).as_str(), + &other.data_type(), + )), + } +} diff --git a/datafusion/sqllogictest/test_files/spark/math/csc.slt b/datafusion/sqllogictest/test_files/spark/math/csc.slt index b11986c3e1b9f..5eb9f44472807 100644 --- a/datafusion/sqllogictest/test_files/spark/math/csc.slt +++ b/datafusion/sqllogictest/test_files/spark/math/csc.slt @@ -23,5 +23,24 @@ ## Original Query: SELECT csc(1); ## PySpark 3.5.5 Result: {'CSC(1)': 1.1883951057781212, 'typeof(CSC(1))': 'double', 'typeof(1)': 'int'} -#query -#SELECT csc(1::int); + +query R +SELECT csc(1::INT); +---- +1.188395105778121 + +query R +SELECT csc(a) FROM (VALUES (0::INT), (1::INT), (-1::INT), (null)) AS t(a); +---- +Infinity +1.188395105778121 +-1.188395105778121 +NULL + +query R +SELECT csc(a) FROM (VALUES (pi()), (-pi()), (pi()/2) , (arrow_cast('NAN','Float32'))) AS t(a); +---- +8165619676597685 +-8165619676597685 +1 +NaN \ No newline at end of file