Skip to content

Commit

Permalink
Simplify ScalarValue::distance (apache#7517)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Sep 10, 2023
1 parent 7b9bb08 commit 6fd3431
Showing 1 changed file with 15 additions and 29 deletions.
44 changes: 15 additions & 29 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1134,31 +1134,22 @@ impl ScalarValue {
///
/// Note: the datatype itself must support subtraction.
pub fn distance(&self, other: &ScalarValue) -> Option<usize> {
// Having an explicit null check here is important because the
// subtraction for scalar values will return a real value even
// if one side is null.
if self.is_null() || other.is_null() {
return None;
}

let distance = if self > other {
self.sub_checked(other).ok()?
} else {
other.sub_checked(self).ok()?
};

match distance {
ScalarValue::Int8(Some(v)) => usize::try_from(v).ok(),
ScalarValue::Int16(Some(v)) => usize::try_from(v).ok(),
ScalarValue::Int32(Some(v)) => usize::try_from(v).ok(),
ScalarValue::Int64(Some(v)) => usize::try_from(v).ok(),
ScalarValue::UInt8(Some(v)) => Some(v as usize),
ScalarValue::UInt16(Some(v)) => Some(v as usize),
ScalarValue::UInt32(Some(v)) => usize::try_from(v).ok(),
ScalarValue::UInt64(Some(v)) => usize::try_from(v).ok(),
match (self, other) {
(Self::Int8(Some(l)), Self::Int8(Some(r))) => Some(l.abs_diff(*r) as _),
(Self::Int16(Some(l)), Self::Int16(Some(r))) => Some(l.abs_diff(*r) as _),
(Self::Int32(Some(l)), Self::Int32(Some(r))) => Some(l.abs_diff(*r) as _),
(Self::Int64(Some(l)), Self::Int64(Some(r))) => Some(l.abs_diff(*r) as _),
(Self::UInt8(Some(l)), Self::UInt8(Some(r))) => Some(l.abs_diff(*r) as _),
(Self::UInt16(Some(l)), Self::UInt16(Some(r))) => Some(l.abs_diff(*r) as _),
(Self::UInt32(Some(l)), Self::UInt32(Some(r))) => Some(l.abs_diff(*r) as _),
(Self::UInt64(Some(l)), Self::UInt64(Some(r))) => Some(l.abs_diff(*r) as _),
// TODO: we might want to look into supporting ceil/floor here for floats.
ScalarValue::Float32(Some(v)) => Some(v.round() as usize),
ScalarValue::Float64(Some(v)) => Some(v.round() as usize),
(Self::Float32(Some(l)), Self::Float32(Some(r))) => {
Some((l - r).abs().round() as _)
}
(Self::Float64(Some(l)), Self::Float64(Some(r))) => {
Some((l - r).abs().round() as _)
}
_ => None,
}
}
Expand Down Expand Up @@ -4726,11 +4717,6 @@ mod tests {
ScalarValue::Decimal128(Some(123), 5, 5),
ScalarValue::Decimal128(Some(120), 5, 5),
),
// Overflows
(
ScalarValue::Int8(Some(i8::MAX)),
ScalarValue::Int8(Some(i8::MIN)),
),
];
for (lhs, rhs) in cases {
let distance = lhs.distance(&rhs);
Expand Down

0 comments on commit 6fd3431

Please sign in to comment.