Skip to content

Commit b53d7cd

Browse files
committed
Share comparison code
1 parent 162eeeb commit b53d7cd

File tree

1 file changed

+33
-85
lines changed

1 file changed

+33
-85
lines changed

cpp/src/arrow/compare.cc

Lines changed: 33 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -257,37 +257,7 @@ class RangeEqualsVisitor {
257257
}
258258

259259
Status Visit(const DecimalArray& left) {
260-
const auto& right = static_cast<const DecimalArray&>(right_);
261-
262-
int32_t width = left.byte_width();
263-
264-
const uint8_t* left_data = nullptr;
265-
const uint8_t* right_data = nullptr;
266-
267-
if (left.values()) {
268-
left_data = left.raw_values();
269-
}
270-
271-
if (right.values()) {
272-
right_data = right.raw_values();
273-
}
274-
275-
for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
276-
++i, ++o_i) {
277-
const bool is_null = left.IsNull(i);
278-
if (is_null != right.IsNull(o_i)) {
279-
result_ = false;
280-
return Status::OK();
281-
}
282-
if (is_null) continue;
283-
284-
if (std::memcmp(left_data + width * i, right_data + width * o_i, width)) {
285-
result_ = false;
286-
return Status::OK();
287-
}
288-
}
289-
result_ = true;
290-
return Status::OK();
260+
return Visit(static_cast<const FixedSizeBinaryArray&>(left));
291261
}
292262

293263
Status Visit(const NullArray& left) {
@@ -341,7 +311,7 @@ class RangeEqualsVisitor {
341311

342312
static bool IsEqualPrimitive(const PrimitiveArray& left, const PrimitiveArray& right) {
343313
const auto& size_meta = dynamic_cast<const FixedWidthType&>(*left.type());
344-
const int byte_width = size_meta.bit_width() / 8;
314+
const int byte_width = size_meta.bit_width() / CHAR_BIT;
345315

346316
const uint8_t* left_data = nullptr;
347317
const uint8_t* right_data = nullptr;
@@ -372,6 +342,14 @@ static bool IsEqualPrimitive(const PrimitiveArray& left, const PrimitiveArray& r
372342
template <typename T>
373343
static inline bool CompareBuiltIn(const Array& left, const Array& right, const T* ldata,
374344
const T* rdata) {
345+
if (ldata == nullptr && rdata == nullptr) {
346+
return true;
347+
}
348+
349+
if (ldata == nullptr || rdata == nullptr) {
350+
return false;
351+
}
352+
375353
if (left.null_count() > 0) {
376354
for (int64_t i = 0; i < left.length(); ++i) {
377355
if (left.IsNull(i) != right.IsNull(i)) {
@@ -381,55 +359,9 @@ static inline bool CompareBuiltIn(const Array& left, const Array& right, const T
381359
}
382360
}
383361
return true;
384-
} else {
385-
return memcmp(ldata, rdata, sizeof(T) * left.length()) == 0;
386-
}
387-
}
388-
389-
static bool IsEqualDecimal(const DecimalArray& left, const DecimalArray& right) {
390-
const uint8_t* left_data = nullptr;
391-
const uint8_t* right_data = nullptr;
392-
393-
if (left.values() != nullptr) {
394-
left_data = left.raw_values();
395-
}
396-
397-
if (right.values() != nullptr) {
398-
right_data = right.raw_values();
399362
}
400363

401-
const int32_t byte_width = left.byte_width();
402-
if (byte_width == 4) {
403-
return CompareBuiltIn<int32_t>(left, right,
404-
reinterpret_cast<const int32_t*>(left_data),
405-
reinterpret_cast<const int32_t*>(right_data));
406-
}
407-
408-
if (byte_width == 8) {
409-
return CompareBuiltIn<int64_t>(left, right,
410-
reinterpret_cast<const int64_t*>(left_data),
411-
reinterpret_cast<const int64_t*>(right_data));
412-
}
413-
414-
// 128-bit
415-
for (int64_t i = 0; i < left.length(); ++i) {
416-
const bool left_null = left.IsNull(i);
417-
const bool right_null = right.IsNull(i);
418-
419-
// one of left or right value is not null
420-
if (left_null != right_null) {
421-
return false;
422-
}
423-
424-
// both are not null and their respective elements are byte for byte equal
425-
if (!left_null && !right_null && memcmp(left_data, right_data, byte_width) != 0) {
426-
return false;
427-
}
428-
429-
left_data += byte_width;
430-
right_data += byte_width;
431-
}
432-
return true;
364+
return memcmp(ldata, rdata, sizeof(T) * left.length()) == 0;
433365
}
434366

435367
class ArrayEqualsVisitor : public RangeEqualsVisitor {
@@ -474,11 +406,6 @@ class ArrayEqualsVisitor : public RangeEqualsVisitor {
474406
return Status::OK();
475407
}
476408

477-
Status Visit(const DecimalArray& left) {
478-
result_ = IsEqualDecimal(left, static_cast<const DecimalArray&>(right_));
479-
return Status::OK();
480-
}
481-
482409
template <typename ArrayType>
483410
bool ValueOffsetsEqual(const ArrayType& left) {
484411
const auto& right = static_cast<const ArrayType&>(right_);
@@ -580,6 +507,27 @@ class ArrayEqualsVisitor : public RangeEqualsVisitor {
580507
return Status::OK();
581508
}
582509

510+
Status Visit(const DecimalArray& left) {
511+
const int byte_width = left.byte_width();
512+
if (byte_width == 4) {
513+
result_ = CompareBuiltIn<int32_t>(
514+
left, right_, reinterpret_cast<const int32_t*>(left.raw_values()),
515+
reinterpret_cast<const int32_t*>(
516+
static_cast<const DecimalArray&>(right_).raw_values()));
517+
return Status::OK();
518+
}
519+
520+
if (byte_width == 8) {
521+
result_ = CompareBuiltIn<int64_t>(
522+
left, right_, reinterpret_cast<const int64_t*>(left.raw_values()),
523+
reinterpret_cast<const int64_t*>(
524+
static_cast<const DecimalArray&>(right_).raw_values()));
525+
return Status::OK();
526+
}
527+
528+
return RangeEqualsVisitor::Visit(left);
529+
}
530+
583531
template <typename T>
584532
typename std::enable_if<std::is_base_of<NestedType, typename T::TypeClass>::value,
585533
Status>::type
@@ -812,7 +760,7 @@ Status TensorEquals(const Tensor& left, const Tensor& right, bool* are_equal) {
812760
}
813761

814762
const auto& size_meta = dynamic_cast<const FixedWidthType&>(*left.type());
815-
const int byte_width = size_meta.bit_width() / 8;
763+
const int byte_width = size_meta.bit_width() / CHAR_BIT;
816764
DCHECK_GT(byte_width, 0);
817765

818766
const uint8_t* left_data = left.data()->data();

0 commit comments

Comments
 (0)