Skip to content

Commit

Permalink
RDB: style enhancement for rdb load (#1839)
Browse files Browse the repository at this point in the history
  • Loading branch information
mapleFU authored Oct 20, 2023
1 parent 31d62fd commit a618574
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 28 deletions.
22 changes: 13 additions & 9 deletions src/common/rdb_stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
#include "fmt/format.h"
#include "vendor/crc64.h"

StatusOr<size_t> RdbStringStream::Read(char *buf, size_t n) {
Status RdbStringStream::Read(char *buf, size_t n) {
if (pos_ + n > input_.size()) {
return {Status::NotOK, "unexpected EOF"};
}
memcpy(buf, input_.data() + pos_, n);
pos_ += n;
return n;
return Status::OK();
}

StatusOr<uint64_t> RdbStringStream::GetCheckSum() const {
Expand All @@ -50,20 +50,24 @@ Status RdbFileStream::Open() {
return Status::OK();
}

StatusOr<size_t> RdbFileStream::Read(char *buf, size_t len) {
size_t n = 0;
Status RdbFileStream::Read(char *buf, size_t len) {
while (len) {
size_t read_bytes = std::min(max_read_chunk_size_, len);
ifs_.read(buf, static_cast<std::streamsize>(read_bytes));
if (!ifs_.good()) {
return Status(Status::NotOK, fmt::format("read failed: {}:", strerror(errno)));
if (!ifs_.eof()) {
return {Status::NotOK, fmt::format("read failed: {}:", strerror(errno))};
}
auto eof_read_bytes = static_cast<size_t>(ifs_.gcount());
if (read_bytes != eof_read_bytes) {
return {Status::NotOK, fmt::format("read failed: {}:", strerror(errno))};
}
}
check_sum_ = crc64(check_sum_, reinterpret_cast<const unsigned char *>(buf), read_bytes);
buf = (char *)buf + read_bytes;
buf = buf + read_bytes;
DCHECK(len >= read_bytes);
len -= read_bytes;
total_read_bytes_ += read_bytes;
n += read_bytes;
}

return n;
return Status::OK();
}
6 changes: 3 additions & 3 deletions src/common/rdb_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class RdbStream {
RdbStream() = default;
virtual ~RdbStream() = default;

virtual StatusOr<size_t> Read(char *buf, size_t len) = 0;
virtual Status Read(char *buf, size_t len) = 0;
virtual StatusOr<uint64_t> GetCheckSum() const = 0;
StatusOr<uint8_t> ReadByte() {
uint8_t value = 0;
Expand All @@ -51,7 +51,7 @@ class RdbStringStream : public RdbStream {
RdbStringStream &operator=(const RdbStringStream &) = delete;
~RdbStringStream() override = default;

StatusOr<size_t> Read(char *buf, size_t len) override;
Status Read(char *buf, size_t len) override;
StatusOr<uint64_t> GetCheckSum() const override;

private:
Expand All @@ -68,7 +68,7 @@ class RdbFileStream : public RdbStream {
~RdbFileStream() override = default;

Status Open();
StatusOr<size_t> Read(char *buf, size_t len) override;
Status Read(char *buf, size_t len) override;
StatusOr<uint64_t> GetCheckSum() const override {
uint64_t crc = check_sum_;
memrev64ifbe(&crc);
Expand Down
24 changes: 14 additions & 10 deletions src/storage/rdb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,13 @@ StatusOr<std::string> RDB::loadEncodedString() {
}

// Normal string
std::vector<char> vec(len);
GET_OR_RET(stream_->Read(vec.data(), len));
return std::string(vec.data(), len);
if (len == 0) {
return "";
}
std::string read_string;
read_string.resize(len);
GET_OR_RET(stream_->Read(read_string.data(), len));
return read_string;
}

StatusOr<std::vector<std::string>> RDB::LoadListWithQuickList(int type) {
Expand All @@ -196,19 +200,19 @@ StatusOr<std::vector<std::string>> RDB::LoadListWithQuickList(int type) {

if (container == QuickListNodeContainerPlain) {
auto element = GET_OR_RET(loadEncodedString());
list.push_back(element);
list.push_back(std::move(element));
continue;
}

auto encoded_string = GET_OR_RET(loadEncodedString());
if (type == RDBTypeListQuickList2) {
ListPack lp(encoded_string);
auto elements = GET_OR_RET(lp.Entries());
list.insert(list.end(), elements.begin(), elements.end());
list.insert(list.end(), std::make_move_iterator(elements.begin()), std::make_move_iterator(elements.end()));
} else {
ZipList zip_list(encoded_string);
auto elements = GET_OR_RET(zip_list.Entries());
list.insert(list.end(), elements.begin(), elements.end());
list.insert(list.end(), std::make_move_iterator(elements.begin()), std::make_move_iterator(elements.end()));
}
}
return list;
Expand All @@ -222,7 +226,7 @@ StatusOr<std::vector<std::string>> RDB::LoadListObject() {
}
for (size_t i = 0; i < len; i++) {
auto element = GET_OR_RET(loadEncodedString());
list.push_back(element);
list.push_back(std::move(element));
}
return list;
}
Expand All @@ -241,7 +245,7 @@ StatusOr<std::vector<std::string>> RDB::LoadSetObject() {
}
for (size_t i = 0; i < len; i++) {
auto element = GET_OR_RET(LoadStringObject());
set.push_back(element);
set.push_back(std::move(element));
}
return set;
}
Expand All @@ -268,7 +272,7 @@ StatusOr<std::map<std::string, std::string>> RDB::LoadHashObject() {
for (size_t i = 0; i < len; i++) {
auto field = GET_OR_RET(LoadStringObject());
auto value = GET_OR_RET(LoadStringObject());
hash[field] = value;
hash[field] = std::move(value);
}
return hash;
}
Expand Down Expand Up @@ -471,7 +475,7 @@ Status RDB::saveRdbObject(int type, const std::string &key, const RedisObjValue
const auto &member_scores = std::get<std::vector<MemberScore>>(obj);
redis::ZSet zset_db(storage_, ns_);
uint64_t count = 0;
db_status = zset_db.Add(key, ZAddFlags(0), (redis::ZSet::MemberScores *)&member_scores, &count);
db_status = zset_db.Add(key, ZAddFlags(0), const_cast<std::vector<MemberScore> *>(&member_scores), &count);
} else if (type == RDBTypeHash || type == RDBTypeHashListPack || type == RDBTypeHashZipList ||
type == RDBTypeHashZipMap) {
const auto &entries = std::get<std::map<std::string, std::string>>(obj);
Expand Down
11 changes: 5 additions & 6 deletions tests/cppunit/rdb_stream_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
TEST(RdbFileStreamOpenTest, FileNotExist) {
RdbFileStream reader("not_exist.rdb");
ASSERT_FALSE(reader.Open().IsOK());
;
}

TEST(RdbFileStreamOpenTest, FileExist) {
Expand All @@ -51,18 +50,18 @@ TEST(RdbFileStreamReadTest, ReadRdb) {
ASSERT_TRUE(reader.Open().IsOK());

char buf[16] = {0};
ASSERT_EQ(reader.Read(buf, 5).GetValue(), 5);
ASSERT_TRUE(reader.Read(buf, 5).IsOK());
ASSERT_EQ(strncmp(buf, "REDIS", 5), 0);
size -= 5;

auto len = static_cast<std::streamsize>(sizeof(buf) / sizeof(buf[0]));
while (size >= len) {
ASSERT_EQ(reader.Read(buf, len).GetValue(), len);
ASSERT_TRUE(reader.Read(buf, len).IsOK());
size -= len;
}

if (size > 0) {
ASSERT_EQ(reader.Read(buf, size).GetValue(), size);
ASSERT_TRUE(reader.Read(buf, size).IsOK());
}
}

Expand All @@ -80,11 +79,11 @@ TEST(RdbFileStreamReadTest, ReadRdbLittleChunk) {
auto len = static_cast<std::streamsize>(sizeof(buf) / sizeof(buf[0]));

while (size >= len) {
ASSERT_EQ(reader.Read(buf, len).GetValue(), len);
ASSERT_TRUE(reader.Read(buf, len).IsOK());
size -= len;
}

if (size > 0) {
ASSERT_EQ(reader.Read(buf, size).GetValue(), size);
ASSERT_TRUE(reader.Read(buf, size).IsOK());
}
}

0 comments on commit a618574

Please sign in to comment.