diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 6d4c57c5c4c0..a354fb666f99 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -2503,9 +2503,22 @@ impl Window { // Manual implementation needed because of `schema` field. Comparison excludes this field. impl PartialOrd for Window { fn partial_cmp(&self, other: &Self) -> Option { - match self.input.partial_cmp(&other.input) { - Some(Ordering::Equal) => self.window_expr.partial_cmp(&other.window_expr), - cmp => cmp, + match self.input.partial_cmp(&other.input)? { + Ordering::Equal => {} // continue + not_equal => return Some(not_equal), + } + + match self.window_expr.partial_cmp(&other.window_expr)? { + Ordering::Equal => {} // continue + not_equal => return Some(not_equal), + } + + // Contract for PartialOrd and PartialEq consistency requires that + // a == b if and only if partial_cmp(a, b) == Some(Equal). + if self == other { + Some(Ordering::Equal) + } else { + None } } } @@ -4268,22 +4281,20 @@ fn get_unnested_list_datatype_recursive( #[cfg(test)] mod tests { - use super::*; use crate::builder::LogicalTableSource; use crate::logical_plan::table_scan; + use crate::test::function_stub::{count, count_udaf}; use crate::{ binary_expr, col, exists, in_subquery, lit, placeholder, scalar_subquery, GroupingSet, }; - use datafusion_common::tree_node::{ TransformedResult, TreeNodeRewriter, TreeNodeVisitor, }; use datafusion_common::{not_impl_err, Constraint, ScalarValue}; use insta::{assert_debug_snapshot, assert_snapshot}; - - use crate::test::function_stub::count; + use std::hash::DefaultHasher; fn employee_schema() -> Schema { Schema::new(vec![ @@ -4687,6 +4698,63 @@ mod tests { ); } + #[test] + fn test_partial_eq_hash_and_partial_ord() { + let empty_values = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: true, + schema: Arc::new(DFSchema::empty()), + })); + + let count_window_function = |schema| { + Window::try_new_with_schema( + vec![Expr::WindowFunction(Box::new(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(count_udaf()), + vec![], + )))], + Arc::clone(&empty_values), + Arc::new(schema), + ) + .unwrap() + }; + + let schema_without_metadata = || { + DFSchema::from_unqualified_fields( + vec![Field::new("count", DataType::Int64, false)].into(), + HashMap::new(), + ) + .unwrap() + }; + + let schema_with_metadata = || { + DFSchema::from_unqualified_fields( + vec![Field::new("count", DataType::Int64, false)].into(), + [("key".to_string(), "value".to_string())].into(), + ) + .unwrap() + }; + + // A Window + let f = count_window_function(schema_without_metadata()); + + // Same like `f`, different instance + let f2 = count_window_function(schema_without_metadata()); + assert_eq!(f, f2); + assert_eq!(hash(&f), hash(&f2)); + assert_eq!(f.partial_cmp(&f2), Some(Ordering::Equal)); + + // Same like `f`, except for schema metadata + let o = count_window_function(schema_with_metadata()); + assert_ne!(f, o); + assert_ne!(hash(&f), hash(&o)); // hash can collide for different values but does not collide in this test + assert_eq!(f.partial_cmp(&o), None); + } + + fn hash(value: &T) -> u64 { + let hasher = &mut DefaultHasher::new(); + value.hash(hasher); + hasher.finish() + } + #[test] fn projection_expr_schema_mismatch() -> Result<()> { let empty_schema = Arc::new(DFSchema::empty());