Skip to content

Commit

Permalink
implement Hash for various types and replace PartialOrd (#1580)
Browse files Browse the repository at this point in the history
* implement Hash for various types

* change more partialOrd => hash

* update unit tests
  • Loading branch information
jimexist authored Jan 16, 2022
1 parent 438b417 commit 6f7b2d2
Show file tree
Hide file tree
Showing 8 changed files with 47 additions and 30 deletions.
26 changes: 22 additions & 4 deletions datafusion/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ use functions::{ReturnTypeFunction, ScalarFunctionImplementation, Signature};
use std::collections::{HashMap, HashSet};
use std::convert::Infallible;
use std::fmt;
use std::hash::{BuildHasher, Hash, Hasher};
use std::ops::Not;
use std::str::FromStr;
use std::sync::Arc;
Expand Down Expand Up @@ -221,7 +222,7 @@ impl fmt::Display for Column {
/// assert_eq!(op, Operator::Eq);
/// }
/// ```
#[derive(Clone, PartialEq, PartialOrd)]
#[derive(Clone, PartialEq, Hash)]
pub enum Expr {
/// An expression with a specific name.
Alias(Box<Expr>, String),
Expand Down Expand Up @@ -372,6 +373,23 @@ pub enum Expr {
Wildcard,
}

/// Fixed seed for the hashing so that Ords are consistent across runs
const SEED: ahash::RandomState = ahash::RandomState::with_seeds(0, 0, 0, 0);

impl PartialOrd for Expr {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
let mut hasher = SEED.build_hasher();
self.hash(&mut hasher);
let s = hasher.finish();

let mut hasher = SEED.build_hasher();
other.hash(&mut hasher);
let o = hasher.finish();

Some(s.cmp(&o))
}
}

impl Expr {
/// Returns the [arrow::datatypes::DataType] of the expression based on [arrow::datatypes::Schema].
///
Expand Down Expand Up @@ -2434,16 +2452,16 @@ mod tests {

#[test]
fn test_partial_ord() {
// Test validates that partial ord is defined for Expr, not
// Test validates that partial ord is defined for Expr using hashes, not
// intended to exhaustively test all possibilities
let exp1 = col("a") + lit(1);
let exp2 = col("a") + lit(2);
let exp3 = !(col("a") + lit(2));

assert!(exp1 < exp2);
assert!(exp2 > exp1);
assert!(exp2 < exp3);
assert!(exp3 > exp2);
assert!(exp2 > exp3);
assert!(exp3 < exp2);
}

#[test]
Expand Down
2 changes: 1 addition & 1 deletion datafusion/src/logical_plan/operators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::{fmt, ops};
use super::{binary_expr, Expr};

/// Operators applied to expressions
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub enum Operator {
/// Expressions are equal
Eq,
Expand Down
11 changes: 9 additions & 2 deletions datafusion/src/logical_plan/window_frames.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@ use sqlparser::ast;
use std::cmp::Ordering;
use std::convert::{From, TryFrom};
use std::fmt;
use std::hash::{Hash, Hasher};

/// The frame-spec determines which output rows are read by an aggregate window function.
///
/// The ending frame boundary can be omitted (if the BETWEEN and AND keywords that surround the
/// starting frame boundary are also omitted), in which case the ending frame boundary defaults to
/// CURRENT ROW.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)]
pub struct WindowFrame {
/// A frame type - either ROWS, RANGE or GROUPS
pub units: WindowFrameUnits,
Expand Down Expand Up @@ -190,6 +191,12 @@ impl Ord for WindowFrameBound {
}
}

impl Hash for WindowFrameBound {
fn hash<H: Hasher>(&self, state: &mut H) {
self.get_rank().hash(state)
}
}

impl WindowFrameBound {
/// get the rank of this window frame bound.
///
Expand All @@ -211,7 +218,7 @@ impl WindowFrameBound {

/// There are three frame types: ROWS, GROUPS, and RANGE. The frame type determines how the
/// starting and ending boundaries of the frame are measured.
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)]
pub enum WindowFrameUnits {
/// The ROWS frame type means that the starting and ending boundaries for the frame are
/// determined by counting individual rows relative to the current row.
Expand Down
2 changes: 1 addition & 1 deletion datafusion/src/physical_plan/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ pub type StateTypeFunction =
Arc<dyn Fn(&DataType) -> Result<Arc<Vec<DataType>>> + Send + Sync>;

/// Enum of all built-in aggregate functions
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd)]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
pub enum AggregateFunction {
/// count
Count,
Expand Down
8 changes: 4 additions & 4 deletions datafusion/src/physical_plan/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ use std::convert::From;
use std::{any::Any, fmt, str::FromStr, sync::Arc};

/// A function's type signature, which defines the function's supported argument types.
#[derive(Debug, Clone, PartialEq, PartialOrd)]
#[derive(Debug, Clone, PartialEq, Hash)]
pub enum TypeSignature {
/// arbitrary number of arguments of an common type out of a list of valid types
// A function such as `concat` is `Variadic(vec![DataType::Utf8, DataType::LargeUtf8])`
Expand All @@ -79,7 +79,7 @@ pub enum TypeSignature {
}

///The Signature of a function defines its supported input types as well as its volatility.
#[derive(Debug, Clone, PartialEq, PartialOrd)]
#[derive(Debug, Clone, PartialEq, Hash)]
pub struct Signature {
/// type_signature - The types that the function accepts. See [TypeSignature] for more information.
pub type_signature: TypeSignature,
Expand Down Expand Up @@ -144,7 +144,7 @@ impl Signature {
}

///A function's volatility, which defines the functions eligibility for certain optimizations
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
pub enum Volatility {
/// Immutable - An immutable function will always return the same output when given the same input. An example of this is [BuiltinScalarFunction::Cos].
Immutable,
Expand All @@ -170,7 +170,7 @@ pub type ReturnTypeFunction =
Arc<dyn Fn(&[DataType]) -> Result<Arc<DataType>> + Send + Sync>;

/// Enum of all built-in scalar functions
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum BuiltinScalarFunction {
// math functions
/// abs
Expand Down
12 changes: 4 additions & 8 deletions datafusion/src/physical_plan/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,10 @@ impl PartialEq for AggregateUDF {
}
}

impl PartialOrd for AggregateUDF {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
let c = self.name.partial_cmp(&other.name);
if matches!(c, Some(std::cmp::Ordering::Equal)) {
self.signature.partial_cmp(&other.signature)
} else {
c
}
impl std::hash::Hash for AggregateUDF {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name.hash(state);
self.signature.hash(state);
}
}

Expand Down
12 changes: 4 additions & 8 deletions datafusion/src/physical_plan/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,10 @@ impl PartialEq for ScalarUDF {
}
}

impl PartialOrd for ScalarUDF {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
let c = self.name.partial_cmp(&other.name);
if matches!(c, Some(std::cmp::Ordering::Equal)) {
self.signature.partial_cmp(&other.signature)
} else {
c
}
impl std::hash::Hash for ScalarUDF {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.name.hash(state);
self.signature.hash(state);
}
}

Expand Down
4 changes: 2 additions & 2 deletions datafusion/src/physical_plan/window_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ use std::sync::Arc;
use std::{fmt, str::FromStr};

/// WindowFunction
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum WindowFunction {
/// window function that leverages an aggregate function
AggregateFunction(AggregateFunction),
Expand Down Expand Up @@ -91,7 +91,7 @@ impl fmt::Display for WindowFunction {
}

/// An aggregate function that is part of a built-in window function
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum BuiltInWindowFunction {
/// number of the current row within its partition, counting from 1
RowNumber,
Expand Down

0 comments on commit 6f7b2d2

Please sign in to comment.