Skip to content

Commit

Permalink
Increase MAX_BATCH_SIZE
Browse files Browse the repository at this point in the history
Increasing the MAX_BATCH_SIZE to improve parallelism.
  • Loading branch information
dpetrov4 committed Dec 5, 2023
1 parent 93e54e7 commit 79e24b3
Showing 1 changed file with 14 additions and 22 deletions.
36 changes: 14 additions & 22 deletions table/multiget_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#pragma once
#include <algorithm>
#include <array>
#include <bitset>
#include <string>

#include "db/dbformat.h"
Expand Down Expand Up @@ -97,12 +98,11 @@ class MultiGetContext {
// there is negligible benefit for batches exceeding this. Keeping this < 32
// simplifies iteration, as well as reduces the amount of stack allocations
// that need to be performed
static const int MAX_BATCH_SIZE = 32;
static const int MAX_BATCH_SIZE = 1024;

// A bitmask of at least MAX_BATCH_SIZE - 1 bits, so that
// Mask{1} << MAX_BATCH_SIZE is well defined
using Mask = uint64_t;
static_assert(MAX_BATCH_SIZE < sizeof(Mask) * 8);
using Mask = std::bitset<MAX_BATCH_SIZE>;

MultiGetContext(autovector<KeyContext*, MAX_BATCH_SIZE>* sorted_keys,
size_t begin, size_t num_keys, SequenceNumber snapshot,
Expand Down Expand Up @@ -198,9 +198,8 @@ class MultiGetContext {
Iterator(const Range* range, size_t idx)
: range_(range), ctx_(range->ctx_), index_(idx) {
while (index_ < range_->end_ &&
(Mask{1} << index_) &
(range_->ctx_->value_mask_ | range_->skip_mask_ |
range_->invalid_mask_))
range_->invalid_mask_).test(index_))
index_++;
}

Expand All @@ -214,9 +213,8 @@ class MultiGetContext {

Iterator& operator++() {
while (++index_ < range_->end_ &&
(Mask{1} << index_) &
(range_->ctx_->value_mask_ | range_->skip_mask_ |
range_->invalid_mask_))
range_->invalid_mask_).test(index_))
;
return *this;
}
Expand Down Expand Up @@ -264,8 +262,6 @@ class MultiGetContext {
}
skip_mask_ = mget_range.skip_mask_;
invalid_mask_ = mget_range.invalid_mask_;
assert(start_ < 64);
assert(end_ < 64);
}

Range() = default;
Expand All @@ -274,27 +270,27 @@ class MultiGetContext {

Iterator end() const { return Iterator(this, end_); }

bool empty() const { return RemainingMask() == 0; }
bool empty() const { return RemainingMask().none(); }

void SkipIndex(size_t index) { skip_mask_ |= Mask{1} << index; }

void SkipKey(const Iterator& iter) { SkipIndex(iter.index_); }

bool IsKeySkipped(const Iterator& iter) const {
return skip_mask_ & (Mask{1} << iter.index_);
return skip_mask_.test(iter.index_);
}

// Update the value_mask_ in MultiGetContext so its
// immediately reflected in all the Range Iterators
void MarkKeyDone(Iterator& iter) {
ctx_->value_mask_ |= (Mask{1} << iter.index_);
ctx_->value_mask_.set(iter.index_);
}

bool CheckKeyDone(Iterator& iter) const {
return ctx_->value_mask_ & (Mask{1} << iter.index_);
return ctx_->value_mask_.test(iter.index_);
}

uint64_t KeysLeft() const { return BitsSetToOne(RemainingMask()); }
uint64_t KeysLeft() const { return RemainingMask().count(); }

void AddSkipsFrom(const Range& other) {
assert(ctx_ == other.ctx_);
Expand Down Expand Up @@ -335,8 +331,6 @@ class MultiGetContext {
skip_mask_ |= rhs.skip_mask_ & RangeMask(rhs.start_, rhs.end_);
invalid_mask_ |= (rhs.invalid_mask_ | rhs.skip_mask_) &
RangeMask(rhs.start_, rhs.end_);
assert(start_ < 64);
assert(end_ < 64);
return *this;
}

Expand Down Expand Up @@ -373,22 +367,20 @@ class MultiGetContext {
end_(num_keys),
skip_mask_(0),
invalid_mask_(0) {
assert(num_keys < 64);
}

static Mask RangeMask(size_t start, size_t end) {
return (((Mask{1} << (end - start)) - 1) << start);
return (Mask(0).flip() <<= (end - start)).flip() <<= start;
}

Mask RemainingMask() const {
return (((Mask{1} << end_) - 1) & ~((Mask{1} << start_) - 1) &
~(ctx_->value_mask_ | skip_mask_));
return RangeMask(start_, end_) & ~(ctx_->value_mask_ | skip_mask_);
}

size_t FindLastRemaining() const {
Mask mask = RemainingMask();
size_t index = (mask >>= start_) ? start_ : 0;
while (mask >>= 1) {
size_t index = (mask >>= start_).any() ? start_ : 0;
while ((mask >>= 1).any()) {
index++;
}
return index;
Expand Down

0 comments on commit 79e24b3

Please sign in to comment.