From 6fd34311986c22da8968f5ad76077a6911280735 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Sun, 10 Sep 2023 19:03:22 +0100 Subject: [PATCH] Simplify ScalarValue::distance (#7517) --- datafusion/common/src/scalar.rs | 44 +++++++++++---------------------- 1 file changed, 15 insertions(+), 29 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 88e0f3a301ef..68f89020ce2d 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -1134,31 +1134,22 @@ impl ScalarValue { /// /// Note: the datatype itself must support subtraction. pub fn distance(&self, other: &ScalarValue) -> Option { - // Having an explicit null check here is important because the - // subtraction for scalar values will return a real value even - // if one side is null. - if self.is_null() || other.is_null() { - return None; - } - - let distance = if self > other { - self.sub_checked(other).ok()? - } else { - other.sub_checked(self).ok()? - }; - - match distance { - ScalarValue::Int8(Some(v)) => usize::try_from(v).ok(), - ScalarValue::Int16(Some(v)) => usize::try_from(v).ok(), - ScalarValue::Int32(Some(v)) => usize::try_from(v).ok(), - ScalarValue::Int64(Some(v)) => usize::try_from(v).ok(), - ScalarValue::UInt8(Some(v)) => Some(v as usize), - ScalarValue::UInt16(Some(v)) => Some(v as usize), - ScalarValue::UInt32(Some(v)) => usize::try_from(v).ok(), - ScalarValue::UInt64(Some(v)) => usize::try_from(v).ok(), + match (self, other) { + (Self::Int8(Some(l)), Self::Int8(Some(r))) => Some(l.abs_diff(*r) as _), + (Self::Int16(Some(l)), Self::Int16(Some(r))) => Some(l.abs_diff(*r) as _), + (Self::Int32(Some(l)), Self::Int32(Some(r))) => Some(l.abs_diff(*r) as _), + (Self::Int64(Some(l)), Self::Int64(Some(r))) => Some(l.abs_diff(*r) as _), + (Self::UInt8(Some(l)), Self::UInt8(Some(r))) => Some(l.abs_diff(*r) as _), + (Self::UInt16(Some(l)), Self::UInt16(Some(r))) => Some(l.abs_diff(*r) as _), + (Self::UInt32(Some(l)), Self::UInt32(Some(r))) => Some(l.abs_diff(*r) as _), + (Self::UInt64(Some(l)), Self::UInt64(Some(r))) => Some(l.abs_diff(*r) as _), // TODO: we might want to look into supporting ceil/floor here for floats. - ScalarValue::Float32(Some(v)) => Some(v.round() as usize), - ScalarValue::Float64(Some(v)) => Some(v.round() as usize), + (Self::Float32(Some(l)), Self::Float32(Some(r))) => { + Some((l - r).abs().round() as _) + } + (Self::Float64(Some(l)), Self::Float64(Some(r))) => { + Some((l - r).abs().round() as _) + } _ => None, } } @@ -4726,11 +4717,6 @@ mod tests { ScalarValue::Decimal128(Some(123), 5, 5), ScalarValue::Decimal128(Some(120), 5, 5), ), - // Overflows - ( - ScalarValue::Int8(Some(i8::MAX)), - ScalarValue::Int8(Some(i8::MIN)), - ), ]; for (lhs, rhs) in cases { let distance = lhs.distance(&rhs);