Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion-examples/examples/advanced_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use datafusion::prelude::*;
/// the power of the second argument `a^b`.
///
/// To do so, we must implement the `ScalarUDFImpl` trait.
#[derive(Debug, Clone)]
#[derive(Debug, PartialEq, Eq, Hash)]
struct PowUdf {
signature: Signature,
aliases: Vec<String>,
Expand Down
2 changes: 1 addition & 1 deletion datafusion-examples/examples/async_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ fn animal() -> Result<RecordBatch> {
///
/// Since this is a simplified example, it does not call an LLM service, but
/// could be extended to do so in a real-world scenario.
#[derive(Debug)]
#[derive(Debug, PartialEq, Eq, Hash)]
struct AskLLM {
signature: Signature,
}
Expand Down
36 changes: 2 additions & 34 deletions datafusion-examples/examples/function_factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use datafusion::logical_expr::{
ColumnarValue, CreateFunction, Expr, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
Signature, Volatility,
};
use std::hash::{DefaultHasher, Hash, Hasher};
use std::hash::Hash;
use std::result::Result as RResult;
use std::sync::Arc;

Expand Down Expand Up @@ -107,7 +107,7 @@ impl FunctionFactory for CustomFunctionFactory {
}

/// this function represents the newly created execution engine.
#[derive(Debug)]
#[derive(Debug, PartialEq, Eq, Hash)]
struct ScalarFunctionWrapper {
/// The text of the function body, `$1 + f1($2)` in our example
name: String,
Expand Down Expand Up @@ -154,38 +154,6 @@ impl ScalarUDFImpl for ScalarFunctionWrapper {
fn output_ordering(&self, _input: &[ExprProperties]) -> Result<SortProperties> {
Ok(SortProperties::Unordered)
}

fn equals(&self, other: &dyn ScalarUDFImpl) -> bool {
let Some(other) = other.as_any().downcast_ref::<Self>() else {
return false;
};
let Self {
name,
expr,
signature,
return_type,
} = self;
name == &other.name
&& expr == &other.expr
&& signature == &other.signature
&& return_type == &other.return_type
}

fn hash_value(&self) -> u64 {
let Self {
name,
expr,
signature,
return_type,
} = self;
let mut hasher = DefaultHasher::new();
std::any::type_name::<Self>().hash(&mut hasher);
name.hash(&mut hasher);
expr.hash(&mut hasher);
signature.hash(&mut hasher);
return_type.hash(&mut hasher);
hasher.finish()
}
}

impl ScalarFunctionWrapper {
Expand Down
10 changes: 2 additions & 8 deletions datafusion-examples/examples/json_shredding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,17 +282,15 @@ impl TableProvider for ExampleTableProvider {
}

/// Scalar UDF that uses serde_json to access json fields
#[derive(Debug)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct JsonGetStr {
signature: Signature,
aliases: [String; 1],
}

impl Default for JsonGetStr {
fn default() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
aliases: ["json_get_str".to_string()],
}
}
}
Expand All @@ -303,7 +301,7 @@ impl ScalarUDFImpl for JsonGetStr {
}

fn name(&self) -> &str {
self.aliases[0].as_str()
"json_get_str"
}

fn signature(&self) -> &Signature {
Expand Down Expand Up @@ -355,10 +353,6 @@ impl ScalarUDFImpl for JsonGetStr {
.collect::<StringArray>();
Ok(ColumnarValue::Array(Arc::new(values)))
}

fn aliases(&self) -> &[String] {
&self.aliases
}
}

/// Factory for creating ShreddedJsonRewriter instances
Expand Down
2 changes: 1 addition & 1 deletion datafusion-examples/examples/optimizer_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ fn is_lit_or_col(expr: &Expr) -> bool {
}

/// A simple user defined filter function
#[derive(Debug, Clone)]
#[derive(Debug, PartialEq, Eq, Hash, Clone)]
struct MyEq {
signature: Signature,
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/fuzz_cases/equivalence/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ fn get_sort_columns(
.collect::<Result<Vec<_>>>()
}

#[derive(Debug, Clone)]
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct TestScalarUDF {
pub(crate) signature: Signature,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ use insta::assert_snapshot;
use itertools::Itertools;

/// Mocked UDF
#[derive(Debug)]
#[derive(Debug, PartialEq, Eq, Hash)]
struct DummyUDF {
signature: Signature,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ use datafusion_common::{
use datafusion_expr::expr::FieldMetadata;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::{
lit_with_metadata, udf_equals_hash, Accumulator, ColumnarValue, CreateFunction,
CreateFunctionBody, LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs,
ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility,
lit_with_metadata, Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody,
LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs,
ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_functions_nested::range::range_udf;
use parking_lot::Mutex;
Expand Down Expand Up @@ -218,8 +218,6 @@ impl ScalarUDFImpl for Simple0ArgsScalarUDF {
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(100))))
}

udf_equals_hash!(ScalarUDFImpl);
}

#[tokio::test]
Expand Down Expand Up @@ -560,8 +558,6 @@ impl ScalarUDFImpl for AddIndexToStringVolatileScalarUDF {
};
Ok(ColumnarValue::Array(Arc::new(StringArray::from(answer))))
}

udf_equals_hash!(ScalarUDFImpl);
}

#[tokio::test]
Expand Down Expand Up @@ -665,7 +661,7 @@ async fn volatile_scalar_udf_with_params() -> Result<()> {
Ok(())
}

#[derive(Debug)]
#[derive(Debug, PartialEq, Eq, Hash)]
struct CastToI64UDF {
signature: Signature,
}
Expand Down Expand Up @@ -787,7 +783,7 @@ async fn deregister_udf() -> Result<()> {
Ok(())
}

#[derive(Debug)]
#[derive(Debug, PartialEq, Eq, Hash)]
struct TakeUDF {
signature: Signature,
}
Expand Down Expand Up @@ -979,8 +975,6 @@ impl ScalarUDFImpl for ScalarFunctionWrapper {

Ok(ExprSimplifyResult::Simplified(replacement))
}

udf_equals_hash!(ScalarUDFImpl);
}

impl ScalarFunctionWrapper {
Expand Down Expand Up @@ -1282,8 +1276,6 @@ impl ScalarUDFImpl for MyRegexUdf {
_ => exec_err!("regex_udf only accepts a Utf8 arguments"),
}
}

udf_equals_hash!(ScalarUDFImpl);
}

#[tokio::test]
Expand Down Expand Up @@ -1471,8 +1463,6 @@ impl ScalarUDFImpl for MetadataBasedUdf {
}
}
}

udf_equals_hash!(ScalarUDFImpl);
}

#[tokio::test]
Expand Down Expand Up @@ -1611,7 +1601,7 @@ async fn test_metadata_based_udf_with_literal() -> Result<()> {
/// sides. For the input, we will handle the data differently if there is
/// the canonical extension type Bool8. For the output we will add a
/// user defined extension type.
#[derive(Debug)]
#[derive(Debug, PartialEq, Eq, Hash)]
struct ExtensionBasedUdf {
name: String,
signature: Signature,
Expand Down Expand Up @@ -1790,7 +1780,7 @@ async fn test_extension_based_udf() -> Result<()> {

#[tokio::test]
async fn test_config_options_work_for_scalar_func() -> Result<()> {
#[derive(Debug)]
#[derive(Debug, PartialEq, Eq, Hash)]
struct TestScalarUDF {
signature: Signature,
}
Expand Down
6 changes: 1 addition & 5 deletions datafusion/expr/src/async_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
// under the License.

use crate::ptr_eq::{arc_ptr_eq, arc_ptr_hash};
use crate::{
udf_equals_hash, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl,
};
use crate::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl};
use arrow::datatypes::{DataType, FieldRef};
use async_trait::async_trait;
use datafusion_common::error::Result;
Expand Down Expand Up @@ -127,8 +125,6 @@ impl ScalarUDFImpl for AsyncScalarUDF {
fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result<ColumnarValue> {
internal_err!("async functions should not be called directly")
}

udf_equals_hash!(ScalarUDFImpl);
}

impl Display for AsyncScalarUDF {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3648,7 +3648,7 @@ mod test {
#[test]
fn test_is_volatile_scalar_func() {
// UDF
#[derive(Debug)]
#[derive(Debug, PartialEq, Eq, Hash)]
struct TestScalarUDF {
signature: Signature,
}
Expand Down
6 changes: 2 additions & 4 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ use crate::ptr_eq::PtrEq;
use crate::select_expr::SelectExpr;
use crate::{
conditional_expressions::CaseBuilder, expr::Sort, logical_plan::Subquery,
udf_equals_hash, AggregateUDF, Expr, LogicalPlan, Operator, PartitionEvaluator,
ScalarFunctionArgs, ScalarFunctionImplementation, ScalarUDF, Signature, Volatility,
AggregateUDF, Expr, LogicalPlan, Operator, PartitionEvaluator, ScalarFunctionArgs,
ScalarFunctionImplementation, ScalarUDF, Signature, Volatility,
};
use crate::{
AggregateUDFImpl, ColumnarValue, ScalarUDFImpl, WindowFrame, WindowUDF, WindowUDFImpl,
Expand Down Expand Up @@ -477,8 +477,6 @@ impl ScalarUDFImpl for SimpleScalarUDF {
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
(self.fun)(&args.args)
}

udf_equals_hash!(ScalarUDFImpl);
}

/// Creates a new UDAF with a specific signature, state type and return type.
Expand Down
8 changes: 8 additions & 0 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1230,6 +1230,7 @@ mod test {
};
use std::any::Any;
use std::cmp::Ordering;
use std::hash::{DefaultHasher, Hash, Hasher};

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct AMeanUdf {
Expand Down Expand Up @@ -1319,6 +1320,7 @@ mod test {
let eq = a1 == a2;
assert!(eq);
assert_eq!(a1, a2);
assert_eq!(hash(a1), hash(a2));
}

#[test]
Expand All @@ -1333,4 +1335,10 @@ mod test {
assert!(a1 < b1);
assert!(!(a1 == b1));
}

fn hash<T: Hash>(value: T) -> u64 {
let hasher = &mut DefaultHasher::new();
value.hash(hasher);
hasher.finish()
}
}
Loading