@@ -257,37 +257,7 @@ class RangeEqualsVisitor {
257257 }
258258
259259 Status Visit (const DecimalArray& left) {
260- const auto & right = static_cast <const DecimalArray&>(right_);
261-
262- int32_t width = left.byte_width ();
263-
264- const uint8_t * left_data = nullptr ;
265- const uint8_t * right_data = nullptr ;
266-
267- if (left.values ()) {
268- left_data = left.raw_values ();
269- }
270-
271- if (right.values ()) {
272- right_data = right.raw_values ();
273- }
274-
275- for (int64_t i = left_start_idx_, o_i = right_start_idx_; i < left_end_idx_;
276- ++i, ++o_i) {
277- const bool is_null = left.IsNull (i);
278- if (is_null != right.IsNull (o_i)) {
279- result_ = false ;
280- return Status::OK ();
281- }
282- if (is_null) continue ;
283-
284- if (std::memcmp (left_data + width * i, right_data + width * o_i, width)) {
285- result_ = false ;
286- return Status::OK ();
287- }
288- }
289- result_ = true ;
290- return Status::OK ();
260+ return Visit (static_cast <const FixedSizeBinaryArray&>(left));
291261 }
292262
293263 Status Visit (const NullArray& left) {
@@ -341,7 +311,7 @@ class RangeEqualsVisitor {
341311
342312static bool IsEqualPrimitive (const PrimitiveArray& left, const PrimitiveArray& right) {
343313 const auto & size_meta = dynamic_cast <const FixedWidthType&>(*left.type ());
344- const int byte_width = size_meta.bit_width () / 8 ;
314+ const int byte_width = size_meta.bit_width () / CHAR_BIT ;
345315
346316 const uint8_t * left_data = nullptr ;
347317 const uint8_t * right_data = nullptr ;
@@ -372,6 +342,14 @@ static bool IsEqualPrimitive(const PrimitiveArray& left, const PrimitiveArray& r
372342template <typename T>
373343static inline bool CompareBuiltIn (const Array& left, const Array& right, const T* ldata,
374344 const T* rdata) {
345+ if (ldata == nullptr && rdata == nullptr ) {
346+ return true ;
347+ }
348+
349+ if (ldata == nullptr || rdata == nullptr ) {
350+ return false ;
351+ }
352+
375353 if (left.null_count () > 0 ) {
376354 for (int64_t i = 0 ; i < left.length (); ++i) {
377355 if (left.IsNull (i) != right.IsNull (i)) {
@@ -381,55 +359,9 @@ static inline bool CompareBuiltIn(const Array& left, const Array& right, const T
381359 }
382360 }
383361 return true ;
384- } else {
385- return memcmp (ldata, rdata, sizeof (T) * left.length ()) == 0 ;
386- }
387- }
388-
389- static bool IsEqualDecimal (const DecimalArray& left, const DecimalArray& right) {
390- const uint8_t * left_data = nullptr ;
391- const uint8_t * right_data = nullptr ;
392-
393- if (left.values () != nullptr ) {
394- left_data = left.raw_values ();
395- }
396-
397- if (right.values () != nullptr ) {
398- right_data = right.raw_values ();
399362 }
400363
401- const int32_t byte_width = left.byte_width ();
402- if (byte_width == 4 ) {
403- return CompareBuiltIn<int32_t >(left, right,
404- reinterpret_cast <const int32_t *>(left_data),
405- reinterpret_cast <const int32_t *>(right_data));
406- }
407-
408- if (byte_width == 8 ) {
409- return CompareBuiltIn<int64_t >(left, right,
410- reinterpret_cast <const int64_t *>(left_data),
411- reinterpret_cast <const int64_t *>(right_data));
412- }
413-
414- // 128-bit
415- for (int64_t i = 0 ; i < left.length (); ++i) {
416- const bool left_null = left.IsNull (i);
417- const bool right_null = right.IsNull (i);
418-
419- // one of left or right value is not null
420- if (left_null != right_null) {
421- return false ;
422- }
423-
424- // both are not null and their respective elements are byte for byte equal
425- if (!left_null && !right_null && memcmp (left_data, right_data, byte_width) != 0 ) {
426- return false ;
427- }
428-
429- left_data += byte_width;
430- right_data += byte_width;
431- }
432- return true ;
364+ return memcmp (ldata, rdata, sizeof (T) * left.length ()) == 0 ;
433365}
434366
435367class ArrayEqualsVisitor : public RangeEqualsVisitor {
@@ -474,11 +406,6 @@ class ArrayEqualsVisitor : public RangeEqualsVisitor {
474406 return Status::OK ();
475407 }
476408
477- Status Visit (const DecimalArray& left) {
478- result_ = IsEqualDecimal (left, static_cast <const DecimalArray&>(right_));
479- return Status::OK ();
480- }
481-
482409 template <typename ArrayType>
483410 bool ValueOffsetsEqual (const ArrayType& left) {
484411 const auto & right = static_cast <const ArrayType&>(right_);
@@ -580,6 +507,27 @@ class ArrayEqualsVisitor : public RangeEqualsVisitor {
580507 return Status::OK ();
581508 }
582509
510+ Status Visit (const DecimalArray& left) {
511+ const int byte_width = left.byte_width ();
512+ if (byte_width == 4 ) {
513+ result_ = CompareBuiltIn<int32_t >(
514+ left, right_, reinterpret_cast <const int32_t *>(left.raw_values ()),
515+ reinterpret_cast <const int32_t *>(
516+ static_cast <const DecimalArray&>(right_).raw_values ()));
517+ return Status::OK ();
518+ }
519+
520+ if (byte_width == 8 ) {
521+ result_ = CompareBuiltIn<int64_t >(
522+ left, right_, reinterpret_cast <const int64_t *>(left.raw_values ()),
523+ reinterpret_cast <const int64_t *>(
524+ static_cast <const DecimalArray&>(right_).raw_values ()));
525+ return Status::OK ();
526+ }
527+
528+ return RangeEqualsVisitor::Visit (left);
529+ }
530+
583531 template <typename T>
584532 typename std::enable_if<std::is_base_of<NestedType, typename T::TypeClass>::value,
585533 Status>::type
@@ -812,7 +760,7 @@ Status TensorEquals(const Tensor& left, const Tensor& right, bool* are_equal) {
812760 }
813761
814762 const auto & size_meta = dynamic_cast <const FixedWidthType&>(*left.type ());
815- const int byte_width = size_meta.bit_width () / 8 ;
763+ const int byte_width = size_meta.bit_width () / CHAR_BIT ;
816764 DCHECK_GT (byte_width, 0 );
817765
818766 const uint8_t * left_data = left.data ()->data ();
0 commit comments