diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index dcb942e65b91..f3cc62014486 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -67,14 +67,31 @@ impl PartialEq for ScalarUDF { } } -// TODO (https://github.com/apache/datafusion/issues/17064) PartialOrd is not consistent with PartialEq for `ScalarUDF` and it should be -// Manual implementation based on `ScalarUDFImpl::equals` impl PartialOrd for ScalarUDF { fn partial_cmp(&self, other: &Self) -> Option { - match self.name().partial_cmp(other.name()) { - Some(Ordering::Equal) => self.signature().partial_cmp(other.signature()), - cmp => cmp, + let mut cmp = self.name().cmp(other.name()); + if cmp == Ordering::Equal { + cmp = self.signature().partial_cmp(other.signature())?; } + if cmp == Ordering::Equal { + cmp = self.aliases().partial_cmp(other.aliases())?; + } + // Contract for PartialOrd and PartialEq consistency requires that + // a == b if and only if partial_cmp(a, b) == Some(Equal). + if cmp == Ordering::Equal && self != other { + // Functions may have other properties besides name and signature + // that differentiate two instances (e.g. type, or arbitrary parameters). + // We cannot return Some(Equal) in such case. + return None; + } + debug_assert!( + cmp == Ordering::Equal || self != other, + "Detected incorrect implementation of PartialEq when comparing functions: '{}' and '{}'. \ + The functions compare as equal, but they are not equal based on general properties that \ + the PartialOrd implementation observes,", + self.name(), other.name() + ); + Some(cmp) } } @@ -942,11 +959,14 @@ The following regular expression functions are supported:"#, #[cfg(test)] mod tests { use super::*; + use datafusion_expr_common::signature::Volatility; use std::hash::DefaultHasher; #[derive(Debug, PartialEq, Eq, Hash)] struct TestScalarUDFImpl { + name: &'static str, field: &'static str, + signature: Signature, } impl ScalarUDFImpl for TestScalarUDFImpl { fn as_any(&self) -> &dyn Any { @@ -954,11 +974,11 @@ mod tests { } fn name(&self) -> &str { - "TestScalarUDFImpl" + self.name } fn signature(&self) -> &Signature { - unimplemented!() + &self.signature } fn return_type(&self, _arg_types: &[DataType]) -> Result { @@ -970,17 +990,43 @@ mod tests { } } + // PartialEq and Hash must be consistent, and also PartialEq and PartialOrd + // must be consistent, so they are tested together. #[test] - fn test_partial_eq() { - let a1 = ScalarUDF::from(TestScalarUDFImpl { field: "a" }); - let a2 = ScalarUDF::from(TestScalarUDFImpl { field: "a" }); - let b = ScalarUDF::from(TestScalarUDFImpl { field: "b" }); - let eq = a1 == a2; - assert!(eq); - assert_eq!(a1, a2); - assert_eq!(hash(&a1), hash(&a2)); - assert_ne!(a1, b); - assert_ne!(a2, b); + fn test_partial_eq_hash_and_partial_ord() { + // A parameterized function + let f = test_func("foo", "a"); + + // Same like `f`, different instance + let f2 = test_func("foo", "a"); + assert_eq!(f, f2); + assert_eq!(hash(&f), hash(&f2)); + assert_eq!(f.partial_cmp(&f2), Some(Ordering::Equal)); + + // Different parameter + let b = test_func("foo", "b"); + assert_ne!(f, b); + assert_ne!(hash(&f), hash(&b)); // hash can collide for different values but does not collide in this test + assert_eq!(f.partial_cmp(&b), None); + + // Different name + let o = test_func("other", "a"); + assert_ne!(f, o); + assert_ne!(hash(&f), hash(&o)); // hash can collide for different values but does not collide in this test + assert_eq!(f.partial_cmp(&o), Some(Ordering::Less)); + + // Different name and parameter + assert_ne!(b, o); + assert_ne!(hash(&b), hash(&o)); // hash can collide for different values but does not collide in this test + assert_eq!(b.partial_cmp(&o), Some(Ordering::Less)); + } + + fn test_func(name: &'static str, parameter: &'static str) -> ScalarUDF { + ScalarUDF::from(TestScalarUDFImpl { + name, + field: parameter, + signature: Signature::any(1, Volatility::Immutable), + }) } fn hash(value: &T) -> u64 {