Skip to content

Commit

Permalink
Fix boundary check in Bitmap::BitOp (#1727)
Browse files Browse the repository at this point in the history
  • Loading branch information
mapleFU authored Sep 2, 2023
1 parent b1afbe3 commit e917552
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 6 deletions.
1 change: 1 addition & 0 deletions src/commands/cmd_bit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ class CommandBitOp : public Commander {

Status Execute(Server *svr, Connection *conn, std::string *output) override {
std::vector<Slice> op_keys;
op_keys.reserve(args_.size() - 2);
for (uint64_t i = 3; i < args_.size(); i++) {
op_keys.emplace_back(args_[i]);
}
Expand Down
16 changes: 12 additions & 4 deletions src/types/redis_bitmap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ rocksdb::Status Bitmap::BitOp(BitOpFlags op_flag, const std::string &op_name, co
return rocksdb::Status::InvalidArgument(kErrMsgWrongType);
}
if (metadata.size > max_size) max_size = metadata.size;
meta_pairs.emplace_back(ns_op_key, metadata);
meta_pairs.emplace_back(std::move(ns_op_key), metadata);
}

auto batch = storage_->GetWriteBatchBase();
Expand All @@ -376,7 +376,8 @@ rocksdb::Status Bitmap::BitOp(BitOpFlags op_flag, const std::string &op_name, co

BitmapMetadata res_metadata;
if (num_keys == op_keys.size() || op_flag != kBitOpAnd) {
uint64_t frag_numkeys = num_keys, stop_index = (max_size - 1) / kBitmapSegmentBytes;
uint64_t frag_numkeys = num_keys;
uint64_t stop_index = (max_size - 1) / kBitmapSegmentBytes;
std::unique_ptr<unsigned char[]> frag_res(new unsigned char[kBitmapSegmentBytes]);
uint16_t frag_maxlen = 0, frag_minlen = 0;
std::string fragment;
Expand Down Expand Up @@ -404,7 +405,7 @@ rocksdb::Status Bitmap::BitOp(BitOpFlags op_flag, const std::string &op_name, co
} else {
if (frag_maxlen < fragment.size()) frag_maxlen = fragment.size();
if (fragment.size() < frag_minlen || frag_minlen == 0) frag_minlen = fragment.size();
fragments.emplace_back(fragment);
fragments.emplace_back(std::move(fragment));
}
}

Expand Down Expand Up @@ -502,7 +503,14 @@ rocksdb::Status Bitmap::BitOp(BitOpFlags op_flag, const std::string &op_name, co

if (op_flag == kBitOpNot) {
if (frag_index == stop_index) {
frag_maxlen = max_size % kBitmapSegmentBytes;
if (max_size == (frag_index + 1) * kBitmapSegmentBytes) {
// If the last fragment is full, `max_size % kBitmapSegmentBytes`
// would be 0. In this case, we should set `frag_maxlen` to
// `kBitmapSegmentBytes` to avoid writing an empty fragment.
frag_maxlen = kBitmapSegmentBytes;
} else {
frag_maxlen = max_size % kBitmapSegmentBytes;
}
} else {
frag_maxlen = kBitmapSegmentBytes;
}
Expand Down
40 changes: 40 additions & 0 deletions tests/cppunit/disk_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,46 @@ TEST_F(RedisDiskTest, BitmapDisk) {
bitmap->Del(key_);
}

TEST_F(RedisDiskTest, BitmapDisk2) {
const int64_t kGroupSize = 8192;
for (size_t num_bits : {8192, 16384}) {
for (bool set_op : {false, true}) {
std::unique_ptr<redis::Bitmap> bitmap = std::make_unique<redis::Bitmap>(storage_, "disk_ns_bitmap2");
std::unique_ptr<redis::Disk> disk = std::make_unique<redis::Disk>(storage_, "disk_ns_bitmap2");
key_ = "bitmapdisk_key2";
bitmap->Del(key_);
bool bit = false;

for (size_t i = 0; i < num_bits; i += kGroupSize) {
// Set all first bit of group to `!set_op`
EXPECT_TRUE(bitmap->SetBit(key_, i, !set_op, &bit).ok());
// Set all last bit of group to `set_op`.
EXPECT_TRUE(bitmap->SetBit(key_, i + kGroupSize - 1, set_op, &bit).ok());
}

auto bit_not_dest_key = "bit_op_not_dest_key";

int64_t len = 0;
EXPECT_TRUE(bitmap->BitOp(BitOpFlags::kBitOpNot, "NOT", bit_not_dest_key, {key_}, &len).ok());

for (size_t i = 0; i < num_bits; i += kGroupSize) {
bool result = false;
// Check all first bit of group is `set_op`
EXPECT_TRUE(bitmap->GetBit(bit_not_dest_key, i, &result).ok());
EXPECT_EQ(set_op, result);
// Check all last bit of group is `!set_op`
EXPECT_TRUE(bitmap->GetBit(bit_not_dest_key, i + kGroupSize - 1, &result).ok());
EXPECT_EQ(!set_op, result);
// Check bit in group between (first, last) is "1".
for (size_t j = i + 1; j < i + kGroupSize - 1; ++j) {
EXPECT_TRUE(bitmap->GetBit(bit_not_dest_key, j, &result).ok());
EXPECT_TRUE(result) << j << " is not true";
}
}
}
}
}

TEST_F(RedisDiskTest, SortedintDisk) {
std::unique_ptr<redis::Sortedint> sortedint = std::make_unique<redis::Sortedint>(storage_, "disk_ns_sortedint");
std::unique_ptr<redis::Disk> disk = std::make_unique<redis::Disk>(storage_, "disk_ns_sortedint");
Expand Down
8 changes: 8 additions & 0 deletions tests/gocase/unit/type/bitmap/bitmap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,14 @@ func TestBitmap(t *testing.T) {
}
})

t.Run("BITOP Boundary Check", func(t *testing.T) {
require.NoError(t, rdb.Del(ctx, "str").Err())
str := util.RandStringWithSeed(0, 1000, util.Binary, 2701)
Set2SetBit(t, rdb, ctx, "str", []byte(str))
require.NoError(t, rdb.BitOpNot(ctx, "target", "str").Err())
require.EqualValues(t, SimulateBitOp(NOT, []byte(str)), rdb.Get(ctx, "target").Val())
})

t.Run("BITOP with non string source key", func(t *testing.T) {
require.NoError(t, rdb.Del(ctx, "c").Err())
Set2SetBit(t, rdb, ctx, "a", []byte("\xaa\x00\xff\x55"))
Expand Down
8 changes: 6 additions & 2 deletions tests/gocase/util/random.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,11 @@ const (
)

func RandString(min, max int, typ RandStringType) string {
r := rand.New(rand.NewSource(time.Now().UnixNano()))
return RandStringWithSeed(min, max, typ, time.Now().UnixNano())
}

func RandStringWithSeed(min, max int, typ RandStringType, seed int64) string {
r := rand.New(rand.NewSource(seed))
length := min + r.Intn(max-min+1)

var minVal, maxVal int
Expand All @@ -71,7 +75,7 @@ func RandString(min, max int, typ RandStringType) string {

var sb strings.Builder
for ; length > 0; length-- {
s := fmt.Sprintf("%c", minVal+rand.Intn(maxVal-minVal+1))
s := fmt.Sprintf("%c", minVal+int(r.Int31n(int32(maxVal-minVal+1))))
sb.WriteString(s)
}
return sb.String()
Expand Down

0 comments on commit e917552

Please sign in to comment.