Skip to content

Commit

Permalink
Merge pull request #3 from bkietz/ARROW-12010-GroupIdentifier
Browse files Browse the repository at this point in the history
{trail,lead}ing zero count utils, moving doccomments, silencing warnings
  • Loading branch information
michalursa authored Mar 24, 2021
2 parents 4c0f7f3 + 4d8d98a commit 13ac291
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 107 deletions.
20 changes: 11 additions & 9 deletions cpp/src/arrow/compute/kernels/hash_aggregate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -470,9 +470,10 @@ struct GrouperFastImpl : Grouper {
impl->key_types_[i] = key;
}

impl->group_map_.init(arrow::exec::util::CPUInstructionSet::avx2, ctx->memory_pool(),
static_cast<uint32_t>(keys.size()), impl->is_fixedlen_,
impl->col_widths_.data());
RETURN_NOT_OK(impl->group_map_.init(arrow::exec::util::CPUInstructionSet::avx2,
ctx->memory_pool(),
static_cast<uint32_t>(keys.size()),
impl->is_fixedlen_, impl->col_widths_.data()));

return std::move(impl);
}
Expand Down Expand Up @@ -500,10 +501,10 @@ struct GrouperFastImpl : Grouper {
TypedBufferBuilder<uint32_t> group_ids_batch(ctx_->memory_pool());
RETURN_NOT_OK(group_ids_batch.Resize(batch.length));

group_map_.push_input(static_cast<uint32_t>(num_rows),
non_null_buffers_maybe_null_.data(), fixedlen_buffers_.data(),
varlen_buffer_maybe_null_.data(),
reinterpret_cast<uint32_t*>(group_ids->mutable_data()));
RETURN_NOT_OK(group_map_.push_input(
static_cast<uint32_t>(num_rows), non_null_buffers_maybe_null_.data(),
fixedlen_buffers_.data(), varlen_buffer_maybe_null_.data(),
reinterpret_cast<uint32_t*>(group_ids->mutable_data())));

return Datum(UInt32Array(batch.length, std::move(group_ids)));
}
Expand All @@ -515,9 +516,10 @@ struct GrouperFastImpl : Grouper {
Result<ExecBatch> GetUniques() override {
uint64_t num_groups;
bool is_row_fixedlen;
uint32_t num_columns = static_cast<uint32_t>(col_widths_.size());
group_map_.pull_output_prepare(num_groups, is_row_fixedlen);
group_map_.pull_output_prepare(&num_groups, &is_row_fixedlen);
ExecBatch out({}, num_groups);

auto num_columns = static_cast<uint32_t>(col_widths_.size());
out.values.resize(num_columns);

std::vector<std::shared_ptr<Buffer>> non_null_bufs;
Expand Down
59 changes: 33 additions & 26 deletions cpp/src/arrow/exec/groupby.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ Status GroupMap::append_callback(int num_keys, const uint16_t* selection) {
minibatch_rows.data(), minibatch_nulls.data());
}

void GroupMap::init(util::CPUInstructionSet instruction_set_in, MemoryPool* pool,
uint32_t num_columns, const std::vector<bool>& is_fixed_len_in,
const uint32_t* col_widths_in) {
Status GroupMap::init(util::CPUInstructionSet instruction_set_in, MemoryPool* pool,
uint32_t num_columns, const std::vector<bool>& is_fixed_len_in,
const uint32_t* col_widths_in) {
memory_pool = pool;
minibatch_size = minibatch_size_min;
temp_buffers.init(memory_pool, minibatch_size_max);
RETURN_NOT_OK(temp_buffers.init(memory_pool, minibatch_size_max));
instruction_set = instruction_set_in;

bool is_row_fixed_len = true;
Expand All @@ -105,8 +105,8 @@ void GroupMap::init(util::CPUInstructionSet instruction_set_in, MemoryPool* pool
}
}

key_store.init(instruction_set, memory_pool, num_columns, is_row_fixed_len,
fixed_len_size);
RETURN_NOT_OK(key_store.init(instruction_set, memory_pool, num_columns,
is_row_fixed_len, fixed_len_size));
auto equal_func = [this](int num_keys_to_compare,
const uint16_t* selection_may_be_null /* may be null */,
const uint32_t* group_ids, uint8_t* match_bitvector) {
Expand All @@ -117,8 +117,8 @@ void GroupMap::init(util::CPUInstructionSet instruction_set_in, MemoryPool* pool
return this->append_callback(num_keys, selection);
};

group_id_map.init(instruction_set, memory_pool, &temp_buffers, log_minibatch_max,
equal_func, append_func);
RETURN_NOT_OK(group_id_map.init(instruction_set, memory_pool, &temp_buffers,
log_minibatch_max, equal_func, append_func));

col_non_nulls.resize(key_store.get_num_cols());
col_offsets.resize(key_store.get_num_cols());
Expand Down Expand Up @@ -148,12 +148,14 @@ void GroupMap::init(util::CPUInstructionSet instruction_set_in, MemoryPool* pool
KeyStore::padding_for_SIMD);
minibatch_offsets.resize((minibatch_size_max + 1) +
KeyStore::padding_for_SIMD / sizeof(uint32_t));
return Status::OK();
}

void GroupMap::push_input(uint32_t num_rows, const uint8_t** non_null_buffers_maybe_null,
const uint8_t** fixedlen_values_buffers,
const uint8_t** varlen_buffers_maybe_null,
uint32_t* group_ids) {
Status GroupMap::push_input(uint32_t num_rows,
const uint8_t** non_null_buffers_maybe_null,
const uint8_t** fixedlen_values_buffers,
const uint8_t** varlen_buffers_maybe_null,
uint32_t* group_ids) {
uint32_t num_columns = key_store.get_num_cols();

bool fixed_len_row = true;
Expand Down Expand Up @@ -280,18 +282,21 @@ void GroupMap::push_input(uint32_t num_rows, const uint8_t** non_null_buffers_ma
}

// Map
group_id_map.map(curr_minibatch_size, minibatch_hashes.data(), group_ids + irow0);
RETURN_NOT_OK(group_id_map.map(curr_minibatch_size, minibatch_hashes.data(),
group_ids + irow0));

irow0 += curr_minibatch_size;
if (minibatch_size * 2 <= minibatch_size_max) {
minibatch_size *= 2;
}
}

return Status::OK();
}

void GroupMap::pull_output_prepare(uint64_t& out_num_rows, bool& out_is_row_fixedlen) {
out_num_rows = key_store.get_num_keys();
out_is_row_fixedlen = key_store.is_row_fixedlen();
void GroupMap::pull_output_prepare(uint64_t* out_num_rows, bool* out_is_row_fixedlen) {
*out_num_rows = key_store.get_num_keys();
*out_is_row_fixedlen = key_store.is_row_fixedlen();
}

void GroupMap::pull_output_fixedlen_and_nulls(uint8_t** non_null_buffers,
Expand Down Expand Up @@ -332,11 +337,11 @@ void GroupMap::pull_output_fixedlen_and_nulls(uint8_t** non_null_buffers,
arrow::exec::KeyTranspose::row2col(
instruction_set, false, key_store.get_num_cols(), curr_minibatch_size,
col_widths.data(), out_col_non_nulls.data(), out_col_offsets.data(),
out_col_values.data(),
out_col_values.data(),
key_store.is_row_fixedlen() ? &row_length : row_offsets + irow0,
key_store.is_row_fixedlen() ? row_vals + row_length * irow0 : row_vals,
row_nulls + (irow0 << log_row_null_bits) / 8,
minibatch_size_max, minibatch_temp.data(),
key_store.is_row_fixedlen() ? row_vals + row_length * irow0 : row_vals,
row_nulls + (irow0 << log_row_null_bits) / 8, minibatch_size_max,
minibatch_temp.data(),
minibatch_temp.data() + minibatch_size_max +
KeyStore::padding_for_SIMD / sizeof(uint32_t));

Expand All @@ -350,7 +355,8 @@ void GroupMap::pull_output_fixedlen_and_nulls(uint8_t** non_null_buffers,
if (is_col_fixed_len[icol]) {
out_varlen_buffer_sizes[icol] = 0;
} else {
out_varlen_buffer_sizes[icol] = reinterpret_cast<const uint32_t*>(fixedlen_buffers[icol])[num_rows];
out_varlen_buffer_sizes[icol] =
reinterpret_cast<const uint32_t*>(fixedlen_buffers[icol])[num_rows];
}
}
}
Expand All @@ -373,19 +379,20 @@ void GroupMap::pull_output_varlen(uint8_t** non_null_buffers, uint8_t** fixedlen
for (uint32_t icol = 0; icol < key_store.get_num_cols(); ++icol) {
out_col_non_nulls[icol] = non_null_buffers[icol] + irow0 / 8;
out_col_offsets[icol] = reinterpret_cast<uint32_t*>(
is_col_fixed_len[icol] ? nullptr : fixedlen_buffers[icol] + sizeof(uint32_t) * irow0);
is_col_fixed_len[icol] ? nullptr
: fixedlen_buffers[icol] + sizeof(uint32_t) * irow0);
out_col_values[icol] =
is_col_fixed_len[icol] ? nullptr : varlen_buffers_maybe_null[icol];
}

arrow::exec::KeyTranspose::row2col(
instruction_set, true, key_store.get_num_cols(), curr_minibatch_size,
col_widths.data(), out_col_non_nulls.data(), out_col_offsets.data(),
out_col_values.data(),
out_col_values.data(),
key_store.is_row_fixedlen() ? &row_length : row_offsets + irow0,
key_store.is_row_fixedlen() ? row_vals + row_length * irow0 : row_vals,
row_nulls + (irow0 << log_row_null_bits) / 8,
minibatch_size_max, minibatch_temp.data(),
key_store.is_row_fixedlen() ? row_vals + row_length * irow0 : row_vals,
row_nulls + (irow0 << log_row_null_bits) / 8, minibatch_size_max,
minibatch_temp.data(),
minibatch_temp.data() + minibatch_size_max +
KeyStore::padding_for_SIMD / sizeof(uint32_t));

Expand Down
20 changes: 10 additions & 10 deletions cpp/src/arrow/exec/groupby.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
// under the License.

#pragma once
#include "arrow/memory_pool.h"
#include "arrow/result.h"
#include "arrow/status.h"
#include "arrow/exec/common.h"
#include "arrow/exec/groupby_hash.h"
#include "arrow/exec/groupby_map.h"
#include "arrow/exec/groupby_storage.h"
#include "arrow/exec/util.h"
#include "arrow/memory_pool.h"
#include "arrow/result.h"
#include "arrow/status.h"

namespace arrow {
namespace exec {
Expand All @@ -32,15 +32,15 @@ class GroupMap {
public:
~GroupMap();

void init(util::CPUInstructionSet instruction_set_in, MemoryPool* pool,
uint32_t num_columns, const std::vector<bool>& is_fixed_len_in,
const uint32_t* col_widths_in);
Status init(util::CPUInstructionSet instruction_set_in, MemoryPool* pool,
uint32_t num_columns, const std::vector<bool>& is_fixed_len_in,
const uint32_t* col_widths_in);

void push_input(uint32_t num_rows, const uint8_t** non_null_buffers_maybe_null,
const uint8_t** fixedlen_buffers,
const uint8_t** varlen_buffers_maybe_null, uint32_t* group_ids);
Status push_input(uint32_t num_rows, const uint8_t** non_null_buffers_maybe_null,
const uint8_t** fixedlen_buffers,
const uint8_t** varlen_buffers_maybe_null, uint32_t* group_ids);

void pull_output_prepare(uint64_t& out_num_rows, bool& is_row_fixedlen);
void pull_output_prepare(uint64_t* out_num_rows, bool* is_row_fixedlen);

void pull_output_fixedlen_and_nulls(uint8_t** non_null_buffers,
uint8_t** fixedlen_buffers,
Expand Down
73 changes: 30 additions & 43 deletions cpp/src/arrow/exec/groupby_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,60 +24,44 @@
#include <cstdint>

#include "arrow/exec/common.h"
#include "arrow/util/bit_util.h"
#include "arrow/util/bitmap_ops.h"

namespace arrow {

using BitUtil::CountLeadingZeros;

namespace exec {

// Scan bytes in block in reverse and stop as soon
// as a position of interest is found, which is either of:
// a) slot with a matching stamp is encountered,
// b) first empty slot is encountered,
// c) we reach the end of the block.
// Return an index corresponding to this position (8 represents end of block).
// Return also an integer flag (0 or 1) indicating if we reached case a)
// (if we found a stamp match).
constexpr uint64_t kHighBitOfEachByte = 0x8080808080808080ULL;

//
template <bool use_start_slot>
inline void SwissTable::search_block(
uint64_t block, // 8B block of hash table
int stamp, // 7-bits of hash used as a stamp
int start_slot, // Index of the first slot in the block to start search from.
// We assume that this index always points to a non-empty slot
// (comes before any empty slots).
// Used only by one template variant.
int& out_slot, // Returned index of a slot
int& out_match_found) { // Returned integer flag indicating match found
inline void SwissTable::search_block(uint64_t block, int stamp, int start_slot,
int* out_slot, int* out_match_found) {
// Filled slot bytes have the highest bit set to 0 and empty slots are equal to 0x80.
// Replicate 7-bit stamp to all non-empty slots:
uint64_t block_high_bits = block & UINT64_C(0x8080808080808080);
uint64_t stamp_pattern =
stamp * ((block_high_bits ^ UINT64_C(0x8080808080808080)) >> 7);
uint64_t block_high_bits = block & kHighBitOfEachByte;
uint64_t stamp_pattern = stamp * ((block_high_bits ^ kHighBitOfEachByte) >> 7);
// If we xor this pattern with block bytes we get:
// a) 0x00, for filled slots matching the stamp,
// b) 0x00 < x < 0x80, for filled slots not matching the stamp,
// c) 0x80, for empty slots.
// If we then add 0x7f to every byte, negate the result and leave only the highest bits
// in each byte, we get 0x00 for non-match slot and 0x80 for match slot.
uint64_t matches = ~((block ^ stamp_pattern) + UINT64_C(0x7f7f7f7f7f7f7f7f));
uint64_t matches = ~((block ^ stamp_pattern) + ~kHighBitOfEachByte);
if (use_start_slot) {
matches &= UINT64_C(0x8080808080808080) >> (8 * start_slot);
matches &= kHighBitOfEachByte >> (8 * start_slot);
} else {
matches &= UINT64_C(0x8080808080808080);
matches &= kHighBitOfEachByte;
}
// We get 0 if there are no matches
out_match_found = (matches == 0 ? 0 : 1);
*out_match_found = (matches == 0 ? 0 : 1);
// Now if we or with the highest bits of the block and scan zero bits in reverse,
// we get 8x slot index that we were looking for.
out_slot = static_cast<int>(LZCNT64(matches | block_high_bits) >> 3);
*out_slot = static_cast<int>(CountLeadingZeros(matches | block_high_bits) >> 3);
}

// Extract group id for a given slot in a given block.
// Group ids follow in memory after 64-bit block data.
// Maximum number of groups inserted is equal to the number
// of all slots in all blocks, which is 8 * the number of blocks.
// Group ids are bit packed using that maximum to determine the necessary number of bits.
//
inline uint64_t SwissTable::extract_group_id(const uint8_t* block_ptr, int slot,
uint64_t group_id_mask) {
int num_bits_group_id = log_blocks_ + 3;
Expand Down Expand Up @@ -127,7 +111,7 @@ void SwissTable::lookup_1(const uint16_t* selection, const int num_keys,

int match_found;
int islot_in_block;
search_block<false>(block, stamp, 0, islot_in_block, match_found);
search_block<false>(block, stamp, 0, &islot_in_block, &match_found);
uint64_t groupid = extract_group_id(blockbase, islot_in_block, groupid_mask);
ARROW_DCHECK(groupid < num_inserted_ || num_inserted_ == 0);
uint64_t islot = next_slot_to_visit(iblock, islot_in_block, match_found);
Expand Down Expand Up @@ -224,9 +208,9 @@ Status SwissTable::lookup_2(const uint32_t* hashes, int& inout_num_selected,
} else {
int new_match_found;
int new_slot;
search_block<true>(block, static_cast<int>(stamp), start_slot, new_slot,
new_match_found);
uint32_t new_groupid =
search_block<true>(block, static_cast<int>(stamp), start_slot, &new_slot,
&new_match_found);
auto new_groupid =
static_cast<uint32_t>(extract_group_id(blockbase, new_slot, groupid_mask));
ARROW_DCHECK(new_groupid < num_inserted_ + num_ids[category_inserted]);
new_slot =
Expand Down Expand Up @@ -393,12 +377,14 @@ Status SwissTable::grow_double() {
uint8_t* block_base = blocks_ + i * block_size_before;
uint8_t* double_block_base_new = blocks_new + 2 * i * block_size_after;
uint64_t block = *reinterpret_cast<const uint64_t*>(block_base);
int full_slots = static_cast<int>(LZCNT64(block & 0x8080808080808080ULL) >> 3);

auto full_slots =
static_cast<int>(CountLeadingZeros(block & kHighBitOfEachByte) >> 3);
int full_slots_new[2];
full_slots_new[0] = full_slots_new[1] = 0;
*reinterpret_cast<uint64_t*>(double_block_base_new) = 0x8080808080808080ULL;
*reinterpret_cast<uint64_t*>(double_block_base_new) = kHighBitOfEachByte;
*reinterpret_cast<uint64_t*>(double_block_base_new + block_size_after) =
0x8080808080808080ULL;
kHighBitOfEachByte;

for (int j = 0; j < full_slots; ++j) {
uint64_t slot_id = i * 8 + j;
Expand Down Expand Up @@ -435,7 +421,7 @@ Status SwissTable::grow_double() {
// How many full slots in this block
uint8_t* block_base = blocks_ + i * block_size_before;
uint64_t block = *reinterpret_cast<const uint64_t*>(block_base);
int full_slots = static_cast<int>(LZCNT64(block & 0x8080808080808080ULL) >> 3);
int full_slots = static_cast<int>(CountLeadingZeros(block & kHighBitOfEachByte) >> 3);

for (int j = 0; j < full_slots; ++j) {
uint64_t slot_id = i * 8 + j;
Expand All @@ -457,13 +443,13 @@ Status SwissTable::grow_double() {
uint8_t* block_base_new = blocks_new + block_id_new * block_size_after;
uint64_t block_new = *reinterpret_cast<const uint64_t*>(block_base_new);
int full_slots_new =
static_cast<int>(LZCNT64(block_new & 0x8080808080808080ULL) >> 3);
static_cast<int>(CountLeadingZeros(block_new & kHighBitOfEachByte) >> 3);
while (full_slots_new == 8) {
block_id_new = (block_id_new + 1) & ((1 << log_blocks_after) - 1);
block_base_new = blocks_new + block_id_new * block_size_after;
block_new = *reinterpret_cast<const uint64_t*>(block_base_new);
full_slots_new =
static_cast<int>(LZCNT64(block_new & 0x8080808080808080ULL) >> 3);
static_cast<int>(CountLeadingZeros(block_new & kHighBitOfEachByte) >> 3);
}

hashes_new[block_id_new * 8 + full_slots_new] = hash;
Expand Down Expand Up @@ -505,7 +491,7 @@ Status SwissTable::init(util::CPUInstructionSet cpu_instruction_set, MemoryPool*

for (uint64_t i = 0; i < (static_cast<uint64_t>(1) << log_blocks_); ++i) {
*reinterpret_cast<uint64_t*>(blocks_ + i * (8 + num_groupid_bits)) =
UINT64_C(0x8080808080808080);
kHighBitOfEachByte;
}

const uint64_t cbhashes = (sizeof(uint32_t) << num_groupid_bits) + padding_;
Expand Down Expand Up @@ -568,7 +554,8 @@ void SwissTable::precomputed_make() {
for (int i = 0; i < (1 << log_blocks_); ++i) {
uint8_t* block_base = blocks_ + i * (8 + num_group_id_bits);
uint64_t block = *reinterpret_cast<const uint64_t*>(block_base);
int full_slots = static_cast<int>(LZCNT64(block & 0x8080808080808080ULL) >> 3);
int full_slots = static_cast<int>(
CountLeadingZeros(static_cast<uint64_t>(block & kHighBitOfEachByte)) >> 3);
for (int j = 0; j < full_slots; ++j) {
uint64_t group_id_bit_offs = j * num_group_id_bits;
uint64_t group_id = (*reinterpret_cast<const uint64_t*>(
Expand Down
Loading

0 comments on commit 13ac291

Please sign in to comment.