Skip to content
This repository has been archived by the owner on Feb 18, 2024. It is now read-only.

Commit

Permalink
Simpler code (#901)
Browse files Browse the repository at this point in the history
  • Loading branch information
jorgecarleitao authored Mar 9, 2022
1 parent f59a68d commit 31c8ec6
Showing 1 changed file with 23 additions and 115 deletions.
138 changes: 23 additions & 115 deletions src/scalar/equal.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::sync::Arc;

use super::*;
use crate::types::days_ms;
use crate::datatypes::PhysicalType;

impl PartialEq for dyn Scalar + '_ {
fn eq(&self, that: &dyn Scalar) -> bool {
Expand All @@ -23,14 +23,8 @@ impl PartialEq<dyn Scalar> for Box<dyn Scalar + '_> {

macro_rules! dyn_eq {
($ty:ty, $lhs:expr, $rhs:expr) => {{
let lhs = $lhs
.as_any()
.downcast_ref::<PrimitiveScalar<$ty>>()
.unwrap();
let rhs = $rhs
.as_any()
.downcast_ref::<PrimitiveScalar<$ty>>()
.unwrap();
let lhs = $lhs.as_any().downcast_ref::<$ty>().unwrap();
let rhs = $rhs.as_any().downcast_ref::<$ty>().unwrap();
lhs == rhs
}};
}
Expand All @@ -40,112 +34,26 @@ fn equal(lhs: &dyn Scalar, rhs: &dyn Scalar) -> bool {
return false;
}

match lhs.data_type() {
DataType::Null => {
let lhs = lhs.as_any().downcast_ref::<NullScalar>().unwrap();
let rhs = rhs.as_any().downcast_ref::<NullScalar>().unwrap();
lhs == rhs
}
DataType::Boolean => {
let lhs = lhs.as_any().downcast_ref::<BooleanScalar>().unwrap();
let rhs = rhs.as_any().downcast_ref::<BooleanScalar>().unwrap();
lhs == rhs
}
DataType::UInt8 => {
dyn_eq!(u8, lhs, rhs)
}
DataType::UInt16 => {
dyn_eq!(u16, lhs, rhs)
}
DataType::UInt32 => {
dyn_eq!(u32, lhs, rhs)
}
DataType::UInt64 => {
dyn_eq!(u64, lhs, rhs)
}
DataType::Int8 => {
dyn_eq!(i8, lhs, rhs)
}
DataType::Int16 => {
dyn_eq!(i16, lhs, rhs)
}
DataType::Int32
| DataType::Date32
| DataType::Time32(_)
| DataType::Interval(IntervalUnit::YearMonth) => {
dyn_eq!(i32, lhs, rhs)
}
DataType::Int64
| DataType::Date64
| DataType::Time64(_)
| DataType::Timestamp(_, _)
| DataType::Duration(_) => {
dyn_eq!(i64, lhs, rhs)
}
DataType::Decimal(_, _) => {
dyn_eq!(i128, lhs, rhs)
}
DataType::Interval(IntervalUnit::DayTime) => {
dyn_eq!(days_ms, lhs, rhs)
}
DataType::Float16 => unreachable!(),
DataType::Float32 => {
dyn_eq!(f32, lhs, rhs)
}
DataType::Float64 => {
dyn_eq!(f64, lhs, rhs)
}
DataType::Utf8 => {
let lhs = lhs.as_any().downcast_ref::<Utf8Scalar<i32>>().unwrap();
let rhs = rhs.as_any().downcast_ref::<Utf8Scalar<i32>>().unwrap();
lhs == rhs
}
DataType::LargeUtf8 => {
let lhs = lhs.as_any().downcast_ref::<Utf8Scalar<i64>>().unwrap();
let rhs = rhs.as_any().downcast_ref::<Utf8Scalar<i64>>().unwrap();
lhs == rhs
}
DataType::Binary => {
let lhs = lhs.as_any().downcast_ref::<BinaryScalar<i32>>().unwrap();
let rhs = rhs.as_any().downcast_ref::<BinaryScalar<i32>>().unwrap();
lhs == rhs
}
DataType::LargeBinary => {
let lhs = lhs.as_any().downcast_ref::<BinaryScalar<i64>>().unwrap();
let rhs = rhs.as_any().downcast_ref::<BinaryScalar<i64>>().unwrap();
lhs == rhs
}
DataType::List(_) => {
let lhs = lhs.as_any().downcast_ref::<ListScalar<i32>>().unwrap();
let rhs = rhs.as_any().downcast_ref::<ListScalar<i32>>().unwrap();
lhs == rhs
}
DataType::LargeList(_) => {
let lhs = lhs.as_any().downcast_ref::<ListScalar<i64>>().unwrap();
let rhs = rhs.as_any().downcast_ref::<ListScalar<i64>>().unwrap();
lhs == rhs
}
DataType::Dictionary(key_type, _, _) => match_integer_type!(key_type, |$T| {
let lhs = lhs.as_any().downcast_ref::<DictionaryScalar<$T>>().unwrap();
let rhs = rhs.as_any().downcast_ref::<DictionaryScalar<$T>>().unwrap();
lhs == rhs
use PhysicalType::*;
match lhs.data_type().to_physical_type() {
Null => dyn_eq!(NullScalar, lhs, rhs),
Boolean => dyn_eq!(BooleanScalar, lhs, rhs),
Primitive(primitive) => with_match_primitive_type!(primitive, |$T| {
dyn_eq!(PrimitiveScalar<$T>, lhs, rhs)
}),
DataType::Struct(_) => {
let lhs = lhs.as_any().downcast_ref::<StructScalar>().unwrap();
let rhs = rhs.as_any().downcast_ref::<StructScalar>().unwrap();
lhs == rhs
}
DataType::FixedSizeBinary(_) => {
let lhs = lhs
.as_any()
.downcast_ref::<FixedSizeBinaryScalar>()
.unwrap();
let rhs = rhs
.as_any()
.downcast_ref::<FixedSizeBinaryScalar>()
.unwrap();
lhs == rhs
}
other => unimplemented!("{:?}", other),
Utf8 => dyn_eq!(Utf8Scalar<i32>, lhs, rhs),
LargeUtf8 => dyn_eq!(Utf8Scalar<i64>, lhs, rhs),
Binary => dyn_eq!(BinaryScalar<i32>, lhs, rhs),
LargeBinary => dyn_eq!(BinaryScalar<i64>, lhs, rhs),
List => dyn_eq!(ListScalar<i32>, lhs, rhs),
LargeList => dyn_eq!(ListScalar<i64>, lhs, rhs),
Dictionary(key_type) => match_integer_type!(key_type, |$T| {
dyn_eq!(DictionaryScalar<$T>, lhs, rhs)
}),
Struct => dyn_eq!(StructScalar, lhs, rhs),
FixedSizeBinary => dyn_eq!(FixedSizeBinaryScalar, lhs, rhs),
FixedSizeList => dyn_eq!(FixedSizeListScalar, lhs, rhs),
Union => unimplemented!("{:?}", Union),
Map => unimplemented!("{:?}", Map),
}
}

0 comments on commit 31c8ec6

Please sign in to comment.