Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix: BITCOUNT/BITPOS negative handling fixing #2069

Merged
merged 9 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 22 additions & 17 deletions src/types/redis_bitmap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ rocksdb::Status Bitmap::SetBit(const Slice &user_key, uint32_t offset, bool new_
uint32_t byte_index = (offset / 8) % kBitmapSegmentBytes;
uint64_t used_size = index + byte_index + 1;
uint64_t bitmap_size = std::max(used_size, metadata.size);
// NOTE: value.size() might be greater than metadata.size.
ExpandBitmapSegment(&value, byte_index + 1);
uint32_t bit_offset = offset % 8;
*old_bit = (value[byte_index] & (1 << bit_offset)) != 0;
Expand Down Expand Up @@ -223,10 +224,10 @@ rocksdb::Status Bitmap::BitCount(const Slice &user_key, int64_t start, int64_t s
return bitmap_string_db.BitCount(raw_value, start, stop, cnt);
}

if (start < 0) start += static_cast<int64_t>(metadata.size) + 1;
if (stop < 0) stop += static_cast<int64_t>(metadata.size) + 1;
if (stop > static_cast<int64_t>(metadata.size)) stop = static_cast<int64_t>(metadata.size);
if (start < 0 || stop <= 0 || start >= stop) return rocksdb::Status::OK();
// Counting bits in byte [start, stop].
std::tie(start, stop) = BitmapString::NormalizeRange(start, stop, static_cast<int64_t>(metadata.size));
// Always return 0 if start is greater than stop after normalization.
if (start > stop) return rocksdb::Status::OK();

auto u_start = static_cast<uint32_t>(start);
auto u_stop = static_cast<uint32_t>(stop);
Expand All @@ -244,12 +245,17 @@ rocksdb::Status Bitmap::BitCount(const Slice &user_key, int64_t start, int64_t s
.Encode();
s = storage_->Get(read_options, sub_key, &pin_value);
if (!s.ok() && !s.IsNotFound()) return s;
// NotFound means all bits in this segment are 0.
if (s.IsNotFound()) continue;
size_t j = 0;
if (i == start_index) j = u_start % kBitmapSegmentBytes;
auto k = static_cast<int64_t>(pin_value.size());
if (i == stop_index) k = u_stop % kBitmapSegmentBytes + 1;
*cnt += BitmapString::RawPopcount(reinterpret_cast<const uint8_t *>(pin_value.data()) + j, k);
// Counting bits in [start_in_segment, start_in_segment + length_in_segment)
size_t start_in_segment = 0;
if (i == start_index) start_in_segment = u_start % kBitmapSegmentBytes;
// Though `ExpandBitmapSegment` might generate a segment with logical size less than pin_value.size(),
// the `RawPopcount` will always return 0 on these padding bytes, so we don't need to worry about it.
auto length_in_segment = static_cast<int64_t>(pin_value.size());
if (i == stop_index) length_in_segment = u_stop % kBitmapSegmentBytes + 1;
*cnt += BitmapString::RawPopcount(reinterpret_cast<const uint8_t *>(pin_value.data()) + start_in_segment,
length_in_segment);
}
return rocksdb::Status::OK();
}
Expand All @@ -271,13 +277,7 @@ rocksdb::Status Bitmap::BitPos(const Slice &user_key, bool bit, int64_t start, i
redis::BitmapString bitmap_string_db(storage_, namespace_);
return bitmap_string_db.BitPos(raw_value, bit, start, stop, stop_given, pos);
}

if (start < 0) start += static_cast<int64_t>(metadata.size) + 1;
if (stop < 0) stop += static_cast<int64_t>(metadata.size) + 1;
if (start < 0 || stop < 0 || start > stop) {
*pos = -1;
return rocksdb::Status::OK();
}
std::tie(start, stop) = BitmapString::NormalizeRange(start, stop, static_cast<int64_t>(metadata.size));
auto u_start = static_cast<uint32_t>(start);
auto u_stop = static_cast<uint32_t>(stop);

Expand Down Expand Up @@ -319,7 +319,12 @@ rocksdb::Status Bitmap::BitPos(const Slice &user_key, bool bit, int64_t start, i
}
}
if (!bit && pin_value.size() < kBitmapSegmentBytes) {
*pos = static_cast<int64_t>(i * kBitmapSegmentBits + pin_value.size() * 8);
// ExpandBitmapSegment might generate a segment with size less than kBitmapSegmentBytes,
// but larger than logical size, so we need to align the last segment with `metadata.size`
// rather than `pin_value.size()`.
auto last_segment_bytes = metadata.size % kBitmapSegmentBytes;
DCHECK_LE(last_segment_bytes, pin_value.size());
*pos = static_cast<int64_t>(i * kBitmapSegmentBits + last_segment_bytes * 8);
return rocksdb::Status::OK();
}
pin_value.Reset();
Expand Down
25 changes: 13 additions & 12 deletions src/types/redis_bitmap_string.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,13 @@ rocksdb::Status BitmapString::SetBit(const Slice &ns_key, std::string *raw_value

rocksdb::Status BitmapString::BitCount(const std::string &raw_value, int64_t start, int64_t stop, uint32_t *cnt) {
*cnt = 0;
auto string_value = raw_value.substr(Metadata::GetOffsetAfterExpire(raw_value[0]));
std::string_view string_value = std::string_view{raw_value}.substr(Metadata::GetOffsetAfterExpire(raw_value[0]));
/* Convert negative indexes */
if (start < 0 && stop < 0 && start > stop) {
return rocksdb::Status::OK();
}
auto strlen = static_cast<int64_t>(string_value.size());
if (start < 0) start = strlen + start;
if (stop < 0) stop = strlen + stop;
if (start < 0) start = 0;
if (stop < 0) stop = 0;
if (stop >= strlen) stop = strlen - 1;
std::tie(start, stop) = NormalizeRange(start, stop, strlen);

/* Precondition: end >= 0 && end < strlen, so the only condition where
* zero can be returned is: start > stop. */
Expand All @@ -95,12 +91,8 @@ rocksdb::Status BitmapString::BitPos(const std::string &raw_value, bool bit, int
bool stop_given, int64_t *pos) {
auto string_value = raw_value.substr(Metadata::GetOffsetAfterExpire(raw_value[0]));
auto strlen = static_cast<int64_t>(string_value.size());
/* Convert negative indexes */
if (start < 0) start = strlen + start;
if (stop < 0) stop = strlen + stop;
if (start < 0) start = 0;
if (stop < 0) stop = 0;
if (stop >= strlen) stop = strlen - 1;
/* Convert negative and out-of-bound indexes */
std::tie(start, stop) = NormalizeRange(start, stop, strlen);

if (start > stop) {
*pos = -1;
Expand Down Expand Up @@ -205,6 +197,15 @@ int64_t BitmapString::RawBitpos(const uint8_t *c, int64_t count, bool bit) {
return res;
}

std::pair<int64_t, int64_t> BitmapString::NormalizeRange(int64_t origin_start, int64_t origin_end, int64_t length) {
if (origin_start < 0) origin_start = length + origin_start;
if (origin_end < 0) origin_end = length + origin_end;
if (origin_start < 0) origin_start = 0;
if (origin_end < 0) origin_end = 0;
if (origin_end >= length) origin_end = length - 1;
return {origin_start, origin_end};
}

rocksdb::Status BitmapString::Bitfield(const Slice &ns_key, std::string *raw_value,
const std::vector<BitfieldOperation> &ops,
std::vector<std::optional<BitfieldValue>> *rets) {
Expand Down
10 changes: 10 additions & 0 deletions src/types/redis_bitmap_string.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ class BitmapString : public Database {

static size_t RawPopcount(const uint8_t *p, int64_t count);
static int64_t RawBitpos(const uint8_t *c, int64_t count, bool bit);

// NormalizeRange converts a range to a normalized range, which is a range with start and stop in [0, length).
//
// If start/end is negative, it will be converted to positive by adding length to it, and if the result is still
// negative, it will be converted to 0.
// If start/end is larger than length, it will be converted to length - 1.
//
// Return:
// The normalized [start, end] range.
static std::pair<int64_t, int64_t> NormalizeRange(int64_t origin_start, int64_t origin_end, int64_t length);
};

} // namespace redis
71 changes: 71 additions & 0 deletions tests/cppunit/types/bitmap_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,51 @@ TEST_F(RedisBitmapTest, BitCount) {
auto s = bitmap_->Del(key_);
}

TEST_F(RedisBitmapTest, BitCountNegative) {
{
bool bit = false;
bitmap_->SetBit(key_, 0, true, &bit);
EXPECT_FALSE(bit);
}
uint32_t cnt = 0;
bitmap_->BitCount(key_, 0, 4 * 1024, &cnt);
EXPECT_EQ(cnt, 1);
bitmap_->BitCount(key_, 0, 0, &cnt);
EXPECT_EQ(cnt, 1);
bitmap_->BitCount(key_, 0, -1, &cnt);
EXPECT_EQ(cnt, 1);
bitmap_->BitCount(key_, -1, -1, &cnt);
EXPECT_EQ(cnt, 1);
bitmap_->BitCount(key_, 1, 1, &cnt);
EXPECT_EQ(cnt, 0);
bitmap_->BitCount(key_, -10000, -10000, &cnt);
EXPECT_EQ(cnt, 1);

{
bool bit = false;
bitmap_->SetBit(key_, 5, true, &bit);
EXPECT_FALSE(bit);
}
bitmap_->BitCount(key_, -10000, -10000, &cnt);
EXPECT_EQ(cnt, 2);

{
bool bit = false;
bitmap_->SetBit(key_, 8 * 1024 - 1, true, &bit);
EXPECT_FALSE(bit);
bitmap_->SetBit(key_, 8 * 1024, true, &bit);
EXPECT_FALSE(bit);
}

bitmap_->BitCount(key_, 0, 1024, &cnt);
EXPECT_EQ(cnt, 4);

bitmap_->BitCount(key_, 0, 1023, &cnt);
EXPECT_EQ(cnt, 3);

auto s = bitmap_->Del(key_);
}

TEST_F(RedisBitmapTest, BitPosClearBit) {
int64_t pos = 0;
bool old_bit = false;
Expand Down Expand Up @@ -95,6 +140,32 @@ TEST_F(RedisBitmapTest, BitPosSetBit) {
auto s = bitmap_->Del(key_);
}

TEST_F(RedisBitmapTest, BitPosNegative) {
{
bool bit = false;
bitmap_->SetBit(key_, 8 * 1024 - 1, true, &bit);
EXPECT_FALSE(bit);
}
int64_t pos = 0;
// First bit is negative
bitmap_->BitPos(key_, false, 0, -1, true, &pos);
EXPECT_EQ(0, pos);
// 8 * 1024 - 1 bit is positive
bitmap_->BitPos(key_, true, 0, -1, true, &pos);
EXPECT_EQ(8 * 1024 - 1, pos);
// First bit in 1023 byte is negative
bitmap_->BitPos(key_, false, -1, -1, true, &pos);
EXPECT_EQ(8 * 1023, pos);
// Last Bit in 1023 byte is positive
bitmap_->BitPos(key_, true, -1, -1, true, &pos);
EXPECT_EQ(8 * 1024 - 1, pos);
// Large negative number will be normalized.
bitmap_->BitPos(key_, false, -10000, -10000, true, &pos);
EXPECT_EQ(0, pos);

auto s = bitmap_->Del(key_);
}

TEST_F(RedisBitmapTest, BitfieldGetSetTest) {
constexpr uint32_t magic = 0xdeadbeef;

Expand Down
6 changes: 6 additions & 0 deletions tests/gocase/unit/type/strings/strings_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,12 @@ func TestString(t *testing.T) {
require.NoError(t, rdb.SetBit(ctx, "mykey", maxOffset, 1).Err())
require.EqualValues(t, 1, rdb.GetBit(ctx, "mykey", maxOffset).Val())
require.EqualValues(t, 1, rdb.BitCount(ctx, "mykey", &redis.BitCount{Start: 0, End: maxOffset / 8}).Val())
// Last byte should contain 1 bit.
require.EqualValues(t, 1, rdb.BitCount(ctx, "mykey", &redis.BitCount{Start: -1, End: -1}).Val())
// 0 - Last byte should contain 1 bit.
require.EqualValues(t, 1, rdb.BitCount(ctx, "mykey", &redis.BitCount{Start: -100, End: -1}).Val())
// The first byte shouldn't contain any bits
require.EqualValues(t, 0, rdb.BitCount(ctx, "mykey", &redis.BitCount{Start: -100, End: -100}).Val())
require.EqualValues(t, maxOffset, rdb.BitPos(ctx, "mykey", 1).Val())
})

Expand Down
Loading