diff --git a/cpp/src/arrow/compute/kernels/hash_aggregate.cc b/cpp/src/arrow/compute/kernels/hash_aggregate.cc index 6c0f1791f6ece..912301e036f8f 100644 --- a/cpp/src/arrow/compute/kernels/hash_aggregate.cc +++ b/cpp/src/arrow/compute/kernels/hash_aggregate.cc @@ -470,9 +470,10 @@ struct GrouperFastImpl : Grouper { impl->key_types_[i] = key; } - impl->group_map_.init(arrow::exec::util::CPUInstructionSet::avx2, ctx->memory_pool(), - static_cast(keys.size()), impl->is_fixedlen_, - impl->col_widths_.data()); + RETURN_NOT_OK(impl->group_map_.init(arrow::exec::util::CPUInstructionSet::avx2, + ctx->memory_pool(), + static_cast(keys.size()), + impl->is_fixedlen_, impl->col_widths_.data())); return std::move(impl); } @@ -500,10 +501,10 @@ struct GrouperFastImpl : Grouper { TypedBufferBuilder group_ids_batch(ctx_->memory_pool()); RETURN_NOT_OK(group_ids_batch.Resize(batch.length)); - group_map_.push_input(static_cast(num_rows), - non_null_buffers_maybe_null_.data(), fixedlen_buffers_.data(), - varlen_buffer_maybe_null_.data(), - reinterpret_cast(group_ids->mutable_data())); + RETURN_NOT_OK(group_map_.push_input( + static_cast(num_rows), non_null_buffers_maybe_null_.data(), + fixedlen_buffers_.data(), varlen_buffer_maybe_null_.data(), + reinterpret_cast(group_ids->mutable_data()))); return Datum(UInt32Array(batch.length, std::move(group_ids))); } @@ -515,9 +516,10 @@ struct GrouperFastImpl : Grouper { Result GetUniques() override { uint64_t num_groups; bool is_row_fixedlen; - uint32_t num_columns = static_cast(col_widths_.size()); - group_map_.pull_output_prepare(num_groups, is_row_fixedlen); + group_map_.pull_output_prepare(&num_groups, &is_row_fixedlen); ExecBatch out({}, num_groups); + + auto num_columns = static_cast(col_widths_.size()); out.values.resize(num_columns); std::vector> non_null_bufs; diff --git a/cpp/src/arrow/exec/groupby.cc b/cpp/src/arrow/exec/groupby.cc index 369197a41b51c..5877bb2c62e3b 100644 --- a/cpp/src/arrow/exec/groupby.cc +++ b/cpp/src/arrow/exec/groupby.cc @@ -82,12 +82,12 @@ Status GroupMap::append_callback(int num_keys, const uint16_t* selection) { minibatch_rows.data(), minibatch_nulls.data()); } -void GroupMap::init(util::CPUInstructionSet instruction_set_in, MemoryPool* pool, - uint32_t num_columns, const std::vector& is_fixed_len_in, - const uint32_t* col_widths_in) { +Status GroupMap::init(util::CPUInstructionSet instruction_set_in, MemoryPool* pool, + uint32_t num_columns, const std::vector& is_fixed_len_in, + const uint32_t* col_widths_in) { memory_pool = pool; minibatch_size = minibatch_size_min; - temp_buffers.init(memory_pool, minibatch_size_max); + RETURN_NOT_OK(temp_buffers.init(memory_pool, minibatch_size_max)); instruction_set = instruction_set_in; bool is_row_fixed_len = true; @@ -105,8 +105,8 @@ void GroupMap::init(util::CPUInstructionSet instruction_set_in, MemoryPool* pool } } - key_store.init(instruction_set, memory_pool, num_columns, is_row_fixed_len, - fixed_len_size); + RETURN_NOT_OK(key_store.init(instruction_set, memory_pool, num_columns, + is_row_fixed_len, fixed_len_size)); auto equal_func = [this](int num_keys_to_compare, const uint16_t* selection_may_be_null /* may be null */, const uint32_t* group_ids, uint8_t* match_bitvector) { @@ -117,8 +117,8 @@ void GroupMap::init(util::CPUInstructionSet instruction_set_in, MemoryPool* pool return this->append_callback(num_keys, selection); }; - group_id_map.init(instruction_set, memory_pool, &temp_buffers, log_minibatch_max, - equal_func, append_func); + RETURN_NOT_OK(group_id_map.init(instruction_set, memory_pool, &temp_buffers, + log_minibatch_max, equal_func, append_func)); col_non_nulls.resize(key_store.get_num_cols()); col_offsets.resize(key_store.get_num_cols()); @@ -148,12 +148,14 @@ void GroupMap::init(util::CPUInstructionSet instruction_set_in, MemoryPool* pool KeyStore::padding_for_SIMD); minibatch_offsets.resize((minibatch_size_max + 1) + KeyStore::padding_for_SIMD / sizeof(uint32_t)); + return Status::OK(); } -void GroupMap::push_input(uint32_t num_rows, const uint8_t** non_null_buffers_maybe_null, - const uint8_t** fixedlen_values_buffers, - const uint8_t** varlen_buffers_maybe_null, - uint32_t* group_ids) { +Status GroupMap::push_input(uint32_t num_rows, + const uint8_t** non_null_buffers_maybe_null, + const uint8_t** fixedlen_values_buffers, + const uint8_t** varlen_buffers_maybe_null, + uint32_t* group_ids) { uint32_t num_columns = key_store.get_num_cols(); bool fixed_len_row = true; @@ -280,18 +282,21 @@ void GroupMap::push_input(uint32_t num_rows, const uint8_t** non_null_buffers_ma } // Map - group_id_map.map(curr_minibatch_size, minibatch_hashes.data(), group_ids + irow0); + RETURN_NOT_OK(group_id_map.map(curr_minibatch_size, minibatch_hashes.data(), + group_ids + irow0)); irow0 += curr_minibatch_size; if (minibatch_size * 2 <= minibatch_size_max) { minibatch_size *= 2; } } + + return Status::OK(); } -void GroupMap::pull_output_prepare(uint64_t& out_num_rows, bool& out_is_row_fixedlen) { - out_num_rows = key_store.get_num_keys(); - out_is_row_fixedlen = key_store.is_row_fixedlen(); +void GroupMap::pull_output_prepare(uint64_t* out_num_rows, bool* out_is_row_fixedlen) { + *out_num_rows = key_store.get_num_keys(); + *out_is_row_fixedlen = key_store.is_row_fixedlen(); } void GroupMap::pull_output_fixedlen_and_nulls(uint8_t** non_null_buffers, @@ -332,11 +337,11 @@ void GroupMap::pull_output_fixedlen_and_nulls(uint8_t** non_null_buffers, arrow::exec::KeyTranspose::row2col( instruction_set, false, key_store.get_num_cols(), curr_minibatch_size, col_widths.data(), out_col_non_nulls.data(), out_col_offsets.data(), - out_col_values.data(), + out_col_values.data(), key_store.is_row_fixedlen() ? &row_length : row_offsets + irow0, - key_store.is_row_fixedlen() ? row_vals + row_length * irow0 : row_vals, - row_nulls + (irow0 << log_row_null_bits) / 8, - minibatch_size_max, minibatch_temp.data(), + key_store.is_row_fixedlen() ? row_vals + row_length * irow0 : row_vals, + row_nulls + (irow0 << log_row_null_bits) / 8, minibatch_size_max, + minibatch_temp.data(), minibatch_temp.data() + minibatch_size_max + KeyStore::padding_for_SIMD / sizeof(uint32_t)); @@ -350,7 +355,8 @@ void GroupMap::pull_output_fixedlen_and_nulls(uint8_t** non_null_buffers, if (is_col_fixed_len[icol]) { out_varlen_buffer_sizes[icol] = 0; } else { - out_varlen_buffer_sizes[icol] = reinterpret_cast(fixedlen_buffers[icol])[num_rows]; + out_varlen_buffer_sizes[icol] = + reinterpret_cast(fixedlen_buffers[icol])[num_rows]; } } } @@ -373,7 +379,8 @@ void GroupMap::pull_output_varlen(uint8_t** non_null_buffers, uint8_t** fixedlen for (uint32_t icol = 0; icol < key_store.get_num_cols(); ++icol) { out_col_non_nulls[icol] = non_null_buffers[icol] + irow0 / 8; out_col_offsets[icol] = reinterpret_cast( - is_col_fixed_len[icol] ? nullptr : fixedlen_buffers[icol] + sizeof(uint32_t) * irow0); + is_col_fixed_len[icol] ? nullptr + : fixedlen_buffers[icol] + sizeof(uint32_t) * irow0); out_col_values[icol] = is_col_fixed_len[icol] ? nullptr : varlen_buffers_maybe_null[icol]; } @@ -381,11 +388,11 @@ void GroupMap::pull_output_varlen(uint8_t** non_null_buffers, uint8_t** fixedlen arrow::exec::KeyTranspose::row2col( instruction_set, true, key_store.get_num_cols(), curr_minibatch_size, col_widths.data(), out_col_non_nulls.data(), out_col_offsets.data(), - out_col_values.data(), + out_col_values.data(), key_store.is_row_fixedlen() ? &row_length : row_offsets + irow0, - key_store.is_row_fixedlen() ? row_vals + row_length * irow0 : row_vals, - row_nulls + (irow0 << log_row_null_bits) / 8, - minibatch_size_max, minibatch_temp.data(), + key_store.is_row_fixedlen() ? row_vals + row_length * irow0 : row_vals, + row_nulls + (irow0 << log_row_null_bits) / 8, minibatch_size_max, + minibatch_temp.data(), minibatch_temp.data() + minibatch_size_max + KeyStore::padding_for_SIMD / sizeof(uint32_t)); diff --git a/cpp/src/arrow/exec/groupby.h b/cpp/src/arrow/exec/groupby.h index e1a93ab917623..8084ca0702c02 100644 --- a/cpp/src/arrow/exec/groupby.h +++ b/cpp/src/arrow/exec/groupby.h @@ -16,14 +16,14 @@ // under the License. #pragma once -#include "arrow/memory_pool.h" -#include "arrow/result.h" -#include "arrow/status.h" #include "arrow/exec/common.h" #include "arrow/exec/groupby_hash.h" #include "arrow/exec/groupby_map.h" #include "arrow/exec/groupby_storage.h" #include "arrow/exec/util.h" +#include "arrow/memory_pool.h" +#include "arrow/result.h" +#include "arrow/status.h" namespace arrow { namespace exec { @@ -32,15 +32,15 @@ class GroupMap { public: ~GroupMap(); - void init(util::CPUInstructionSet instruction_set_in, MemoryPool* pool, - uint32_t num_columns, const std::vector& is_fixed_len_in, - const uint32_t* col_widths_in); + Status init(util::CPUInstructionSet instruction_set_in, MemoryPool* pool, + uint32_t num_columns, const std::vector& is_fixed_len_in, + const uint32_t* col_widths_in); - void push_input(uint32_t num_rows, const uint8_t** non_null_buffers_maybe_null, - const uint8_t** fixedlen_buffers, - const uint8_t** varlen_buffers_maybe_null, uint32_t* group_ids); + Status push_input(uint32_t num_rows, const uint8_t** non_null_buffers_maybe_null, + const uint8_t** fixedlen_buffers, + const uint8_t** varlen_buffers_maybe_null, uint32_t* group_ids); - void pull_output_prepare(uint64_t& out_num_rows, bool& is_row_fixedlen); + void pull_output_prepare(uint64_t* out_num_rows, bool* is_row_fixedlen); void pull_output_fixedlen_and_nulls(uint8_t** non_null_buffers, uint8_t** fixedlen_buffers, diff --git a/cpp/src/arrow/exec/groupby_map.cc b/cpp/src/arrow/exec/groupby_map.cc index b6a5d6202ad1c..42f737b7ccdcf 100644 --- a/cpp/src/arrow/exec/groupby_map.cc +++ b/cpp/src/arrow/exec/groupby_map.cc @@ -24,60 +24,44 @@ #include #include "arrow/exec/common.h" +#include "arrow/util/bit_util.h" #include "arrow/util/bitmap_ops.h" namespace arrow { + +using BitUtil::CountLeadingZeros; + namespace exec { -// Scan bytes in block in reverse and stop as soon -// as a position of interest is found, which is either of: -// a) slot with a matching stamp is encountered, -// b) first empty slot is encountered, -// c) we reach the end of the block. -// Return an index corresponding to this position (8 represents end of block). -// Return also an integer flag (0 or 1) indicating if we reached case a) -// (if we found a stamp match). +constexpr uint64_t kHighBitOfEachByte = 0x8080808080808080ULL; + // template -inline void SwissTable::search_block( - uint64_t block, // 8B block of hash table - int stamp, // 7-bits of hash used as a stamp - int start_slot, // Index of the first slot in the block to start search from. - // We assume that this index always points to a non-empty slot - // (comes before any empty slots). - // Used only by one template variant. - int& out_slot, // Returned index of a slot - int& out_match_found) { // Returned integer flag indicating match found +inline void SwissTable::search_block(uint64_t block, int stamp, int start_slot, + int* out_slot, int* out_match_found) { // Filled slot bytes have the highest bit set to 0 and empty slots are equal to 0x80. // Replicate 7-bit stamp to all non-empty slots: - uint64_t block_high_bits = block & UINT64_C(0x8080808080808080); - uint64_t stamp_pattern = - stamp * ((block_high_bits ^ UINT64_C(0x8080808080808080)) >> 7); + uint64_t block_high_bits = block & kHighBitOfEachByte; + uint64_t stamp_pattern = stamp * ((block_high_bits ^ kHighBitOfEachByte) >> 7); // If we xor this pattern with block bytes we get: // a) 0x00, for filled slots matching the stamp, // b) 0x00 < x < 0x80, for filled slots not matching the stamp, // c) 0x80, for empty slots. // If we then add 0x7f to every byte, negate the result and leave only the highest bits // in each byte, we get 0x00 for non-match slot and 0x80 for match slot. - uint64_t matches = ~((block ^ stamp_pattern) + UINT64_C(0x7f7f7f7f7f7f7f7f)); + uint64_t matches = ~((block ^ stamp_pattern) + ~kHighBitOfEachByte); if (use_start_slot) { - matches &= UINT64_C(0x8080808080808080) >> (8 * start_slot); + matches &= kHighBitOfEachByte >> (8 * start_slot); } else { - matches &= UINT64_C(0x8080808080808080); + matches &= kHighBitOfEachByte; } // We get 0 if there are no matches - out_match_found = (matches == 0 ? 0 : 1); + *out_match_found = (matches == 0 ? 0 : 1); // Now if we or with the highest bits of the block and scan zero bits in reverse, // we get 8x slot index that we were looking for. - out_slot = static_cast(LZCNT64(matches | block_high_bits) >> 3); + *out_slot = static_cast(CountLeadingZeros(matches | block_high_bits) >> 3); } -// Extract group id for a given slot in a given block. -// Group ids follow in memory after 64-bit block data. -// Maximum number of groups inserted is equal to the number -// of all slots in all blocks, which is 8 * the number of blocks. -// Group ids are bit packed using that maximum to determine the necessary number of bits. -// inline uint64_t SwissTable::extract_group_id(const uint8_t* block_ptr, int slot, uint64_t group_id_mask) { int num_bits_group_id = log_blocks_ + 3; @@ -127,7 +111,7 @@ void SwissTable::lookup_1(const uint16_t* selection, const int num_keys, int match_found; int islot_in_block; - search_block(block, stamp, 0, islot_in_block, match_found); + search_block(block, stamp, 0, &islot_in_block, &match_found); uint64_t groupid = extract_group_id(blockbase, islot_in_block, groupid_mask); ARROW_DCHECK(groupid < num_inserted_ || num_inserted_ == 0); uint64_t islot = next_slot_to_visit(iblock, islot_in_block, match_found); @@ -224,9 +208,9 @@ Status SwissTable::lookup_2(const uint32_t* hashes, int& inout_num_selected, } else { int new_match_found; int new_slot; - search_block(block, static_cast(stamp), start_slot, new_slot, - new_match_found); - uint32_t new_groupid = + search_block(block, static_cast(stamp), start_slot, &new_slot, + &new_match_found); + auto new_groupid = static_cast(extract_group_id(blockbase, new_slot, groupid_mask)); ARROW_DCHECK(new_groupid < num_inserted_ + num_ids[category_inserted]); new_slot = @@ -393,12 +377,14 @@ Status SwissTable::grow_double() { uint8_t* block_base = blocks_ + i * block_size_before; uint8_t* double_block_base_new = blocks_new + 2 * i * block_size_after; uint64_t block = *reinterpret_cast(block_base); - int full_slots = static_cast(LZCNT64(block & 0x8080808080808080ULL) >> 3); + + auto full_slots = + static_cast(CountLeadingZeros(block & kHighBitOfEachByte) >> 3); int full_slots_new[2]; full_slots_new[0] = full_slots_new[1] = 0; - *reinterpret_cast(double_block_base_new) = 0x8080808080808080ULL; + *reinterpret_cast(double_block_base_new) = kHighBitOfEachByte; *reinterpret_cast(double_block_base_new + block_size_after) = - 0x8080808080808080ULL; + kHighBitOfEachByte; for (int j = 0; j < full_slots; ++j) { uint64_t slot_id = i * 8 + j; @@ -435,7 +421,7 @@ Status SwissTable::grow_double() { // How many full slots in this block uint8_t* block_base = blocks_ + i * block_size_before; uint64_t block = *reinterpret_cast(block_base); - int full_slots = static_cast(LZCNT64(block & 0x8080808080808080ULL) >> 3); + int full_slots = static_cast(CountLeadingZeros(block & kHighBitOfEachByte) >> 3); for (int j = 0; j < full_slots; ++j) { uint64_t slot_id = i * 8 + j; @@ -457,13 +443,13 @@ Status SwissTable::grow_double() { uint8_t* block_base_new = blocks_new + block_id_new * block_size_after; uint64_t block_new = *reinterpret_cast(block_base_new); int full_slots_new = - static_cast(LZCNT64(block_new & 0x8080808080808080ULL) >> 3); + static_cast(CountLeadingZeros(block_new & kHighBitOfEachByte) >> 3); while (full_slots_new == 8) { block_id_new = (block_id_new + 1) & ((1 << log_blocks_after) - 1); block_base_new = blocks_new + block_id_new * block_size_after; block_new = *reinterpret_cast(block_base_new); full_slots_new = - static_cast(LZCNT64(block_new & 0x8080808080808080ULL) >> 3); + static_cast(CountLeadingZeros(block_new & kHighBitOfEachByte) >> 3); } hashes_new[block_id_new * 8 + full_slots_new] = hash; @@ -505,7 +491,7 @@ Status SwissTable::init(util::CPUInstructionSet cpu_instruction_set, MemoryPool* for (uint64_t i = 0; i < (static_cast(1) << log_blocks_); ++i) { *reinterpret_cast(blocks_ + i * (8 + num_groupid_bits)) = - UINT64_C(0x8080808080808080); + kHighBitOfEachByte; } const uint64_t cbhashes = (sizeof(uint32_t) << num_groupid_bits) + padding_; @@ -568,7 +554,8 @@ void SwissTable::precomputed_make() { for (int i = 0; i < (1 << log_blocks_); ++i) { uint8_t* block_base = blocks_ + i * (8 + num_group_id_bits); uint64_t block = *reinterpret_cast(block_base); - int full_slots = static_cast(LZCNT64(block & 0x8080808080808080ULL) >> 3); + int full_slots = static_cast( + CountLeadingZeros(static_cast(block & kHighBitOfEachByte)) >> 3); for (int j = 0; j < full_slots; ++j) { uint64_t group_id_bit_offs = j * num_group_id_bits; uint64_t group_id = (*reinterpret_cast( diff --git a/cpp/src/arrow/exec/groupby_map.h b/cpp/src/arrow/exec/groupby_map.h index dde16777518e9..2d1b197b1fe1d 100644 --- a/cpp/src/arrow/exec/groupby_map.h +++ b/cpp/src/arrow/exec/groupby_map.h @@ -19,11 +19,11 @@ #include +#include "arrow/exec/common.h" +#include "arrow/exec/util.h" #include "arrow/memory_pool.h" #include "arrow/result.h" #include "arrow/status.h" -#include "arrow/exec/common.h" -#include "arrow/exec/util.h" class SwissTableSimple; @@ -45,13 +45,7 @@ class SwissTable { friend class ::SwissTableSimple; public: - SwissTable() - : log_blocks_(0), - num_inserted_(0), - blocks_(), - hashes_(), - pool_(), - temp_buffers_() {} + SwissTable() = default; ~SwissTable() { cleanup(); } using EqualImpl = @@ -68,12 +62,41 @@ class SwissTable { private: // Lookup helpers + + /// \brief Scan bytes in block in reverse and stop as soon + /// as a position of interest is found. + /// + /// Positions of interest: + /// a) slot with a matching stamp is encountered, + /// b) first empty slot is encountered, + /// c) we reach the end of the block. + /// + /// \param[in] block 8 byte block of hash table + /// \param[in] stamp 7 bits of hash used as a stamp + /// \param[in] start_slot Index of the first slot in the block to start search from. We + /// assume that this index always points to a non-empty slot, equivalently + /// that it comes before any empty slots. (Used only by one template + /// variant.) + /// \param[out] out_slot index corresponding to the discovered position of interest (8 + /// represents end of block). + /// \param[out] out_match_found an integer flag (0 or 1) indicating if we found a + /// matching stamp. template - inline void search_block(uint64_t block, int stamp, int start_slot, int& out_slot, - int& out_match_found); + inline void search_block(uint64_t block, int stamp, int start_slot, int* out_slot, + int* out_match_found); + + /// \brief Extract group id for a given slot in a given block. + /// + /// Group ids follow in memory after 64-bit block data. + /// Maximum number of groups inserted is equal to the number + /// of all slots in all blocks, which is 8 * the number of blocks. + /// Group ids are bit packed using that maximum to determine the necessary number of + /// bits. inline uint64_t extract_group_id(const uint8_t* block_ptr, int slot, uint64_t group_id_mask); + inline uint64_t next_slot_to_visit(uint64_t block_index, int slot, int match_found); + inline void insert(uint8_t* block_base, uint64_t slot_id, uint32_t hash, uint8_t stamp, uint32_t group_id); @@ -120,9 +143,9 @@ class SwissTable { int log_minibatch_; // Base 2 log of the number of blocks - int log_blocks_; + int log_blocks_ = 0; // Number of keys inserted into hash table - uint32_t num_inserted_; + uint32_t num_inserted_ = 0; // Data for blocks. // Each block has 8x of onse byte stamp slots, followed by 8x of bit packed group ids. diff --git a/cpp/src/arrow/exec/util.cc b/cpp/src/arrow/exec/util.cc index 2ffa0377f0048..72cef693c6085 100644 --- a/cpp/src/arrow/exec/util.cc +++ b/cpp/src/arrow/exec/util.cc @@ -21,13 +21,16 @@ #include "arrow/util/bitmap_ops.h" namespace arrow { + +using BitUtil::CountTrailingZeros; + namespace exec { namespace util { inline void BitUtil::bits_to_indexes_helper(uint64_t word, uint16_t base_index, int& num_indexes, uint16_t* indexes) { while (word) { - indexes[num_indexes++] = base_index + static_cast(TZCNT64(word)); + indexes[num_indexes++] = base_index + static_cast(CountTrailingZeros(word)); word &= word - 1; } } @@ -36,7 +39,7 @@ inline void BitUtil::bits_filter_indexes_helper(uint64_t word, const uint16_t* input_indexes, int& num_indexes, uint16_t* indexes) { while (word) { - indexes[num_indexes++] = input_indexes[TZCNT64(word)]; + indexes[num_indexes++] = input_indexes[CountTrailingZeros(word)]; word &= word - 1; } } diff --git a/cpp/src/arrow/exec/util.h b/cpp/src/arrow/exec/util.h index 4f69f0f6a0a1d..c26bfe8dfb7ff 100644 --- a/cpp/src/arrow/exec/util.h +++ b/cpp/src/arrow/exec/util.h @@ -27,14 +27,10 @@ #include "arrow/util/logging.h" #if defined(__clang__) || defined(__GNUC__) -#define LZCNT64(x) __builtin_clzll(x) -#define TZCNT64(x) __builtin_ctzll(x) #define BYTESWAP(x) __builtin_bswap64(x) #define ROTL(x, n) (((x) << (n)) | ((x) >> (32 - (n)))) #elif defined(_MSC_VER) #include -#define LZCNT64(x) __lzcnt64(x) -#define TZCNT64(x) _tzcnt_u64(x) #define BYTESWAP(x) _byteswap_uint64(x) #define ROTL(x, n) _rotl((x), (n)) #endif