diff --git a/src/scalar/equal.rs b/src/scalar/equal.rs index 90f345268ca..e2cf20ee4f7 100644 --- a/src/scalar/equal.rs +++ b/src/scalar/equal.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use super::*; -use crate::types::days_ms; +use crate::datatypes::PhysicalType; impl PartialEq for dyn Scalar + '_ { fn eq(&self, that: &dyn Scalar) -> bool { @@ -23,14 +23,8 @@ impl PartialEq for Box { macro_rules! dyn_eq { ($ty:ty, $lhs:expr, $rhs:expr) => {{ - let lhs = $lhs - .as_any() - .downcast_ref::>() - .unwrap(); - let rhs = $rhs - .as_any() - .downcast_ref::>() - .unwrap(); + let lhs = $lhs.as_any().downcast_ref::<$ty>().unwrap(); + let rhs = $rhs.as_any().downcast_ref::<$ty>().unwrap(); lhs == rhs }}; } @@ -40,112 +34,26 @@ fn equal(lhs: &dyn Scalar, rhs: &dyn Scalar) -> bool { return false; } - match lhs.data_type() { - DataType::Null => { - let lhs = lhs.as_any().downcast_ref::().unwrap(); - let rhs = rhs.as_any().downcast_ref::().unwrap(); - lhs == rhs - } - DataType::Boolean => { - let lhs = lhs.as_any().downcast_ref::().unwrap(); - let rhs = rhs.as_any().downcast_ref::().unwrap(); - lhs == rhs - } - DataType::UInt8 => { - dyn_eq!(u8, lhs, rhs) - } - DataType::UInt16 => { - dyn_eq!(u16, lhs, rhs) - } - DataType::UInt32 => { - dyn_eq!(u32, lhs, rhs) - } - DataType::UInt64 => { - dyn_eq!(u64, lhs, rhs) - } - DataType::Int8 => { - dyn_eq!(i8, lhs, rhs) - } - DataType::Int16 => { - dyn_eq!(i16, lhs, rhs) - } - DataType::Int32 - | DataType::Date32 - | DataType::Time32(_) - | DataType::Interval(IntervalUnit::YearMonth) => { - dyn_eq!(i32, lhs, rhs) - } - DataType::Int64 - | DataType::Date64 - | DataType::Time64(_) - | DataType::Timestamp(_, _) - | DataType::Duration(_) => { - dyn_eq!(i64, lhs, rhs) - } - DataType::Decimal(_, _) => { - dyn_eq!(i128, lhs, rhs) - } - DataType::Interval(IntervalUnit::DayTime) => { - dyn_eq!(days_ms, lhs, rhs) - } - DataType::Float16 => unreachable!(), - DataType::Float32 => { - dyn_eq!(f32, lhs, rhs) - } - DataType::Float64 => { - dyn_eq!(f64, lhs, rhs) - } - DataType::Utf8 => { - let lhs = lhs.as_any().downcast_ref::>().unwrap(); - let rhs = rhs.as_any().downcast_ref::>().unwrap(); - lhs == rhs - } - DataType::LargeUtf8 => { - let lhs = lhs.as_any().downcast_ref::>().unwrap(); - let rhs = rhs.as_any().downcast_ref::>().unwrap(); - lhs == rhs - } - DataType::Binary => { - let lhs = lhs.as_any().downcast_ref::>().unwrap(); - let rhs = rhs.as_any().downcast_ref::>().unwrap(); - lhs == rhs - } - DataType::LargeBinary => { - let lhs = lhs.as_any().downcast_ref::>().unwrap(); - let rhs = rhs.as_any().downcast_ref::>().unwrap(); - lhs == rhs - } - DataType::List(_) => { - let lhs = lhs.as_any().downcast_ref::>().unwrap(); - let rhs = rhs.as_any().downcast_ref::>().unwrap(); - lhs == rhs - } - DataType::LargeList(_) => { - let lhs = lhs.as_any().downcast_ref::>().unwrap(); - let rhs = rhs.as_any().downcast_ref::>().unwrap(); - lhs == rhs - } - DataType::Dictionary(key_type, _, _) => match_integer_type!(key_type, |$T| { - let lhs = lhs.as_any().downcast_ref::>().unwrap(); - let rhs = rhs.as_any().downcast_ref::>().unwrap(); - lhs == rhs + use PhysicalType::*; + match lhs.data_type().to_physical_type() { + Null => dyn_eq!(NullScalar, lhs, rhs), + Boolean => dyn_eq!(BooleanScalar, lhs, rhs), + Primitive(primitive) => with_match_primitive_type!(primitive, |$T| { + dyn_eq!(PrimitiveScalar<$T>, lhs, rhs) }), - DataType::Struct(_) => { - let lhs = lhs.as_any().downcast_ref::().unwrap(); - let rhs = rhs.as_any().downcast_ref::().unwrap(); - lhs == rhs - } - DataType::FixedSizeBinary(_) => { - let lhs = lhs - .as_any() - .downcast_ref::() - .unwrap(); - let rhs = rhs - .as_any() - .downcast_ref::() - .unwrap(); - lhs == rhs - } - other => unimplemented!("{:?}", other), + Utf8 => dyn_eq!(Utf8Scalar, lhs, rhs), + LargeUtf8 => dyn_eq!(Utf8Scalar, lhs, rhs), + Binary => dyn_eq!(BinaryScalar, lhs, rhs), + LargeBinary => dyn_eq!(BinaryScalar, lhs, rhs), + List => dyn_eq!(ListScalar, lhs, rhs), + LargeList => dyn_eq!(ListScalar, lhs, rhs), + Dictionary(key_type) => match_integer_type!(key_type, |$T| { + dyn_eq!(DictionaryScalar<$T>, lhs, rhs) + }), + Struct => dyn_eq!(StructScalar, lhs, rhs), + FixedSizeBinary => dyn_eq!(FixedSizeBinaryScalar, lhs, rhs), + FixedSizeList => dyn_eq!(FixedSizeListScalar, lhs, rhs), + Union => unimplemented!("{:?}", Union), + Map => unimplemented!("{:?}", Map), } }