Skip to content

Commit

Permalink
Merge pull request #11116 from reyoung/feature/faster_recordio
Browse files Browse the repository at this point in the history
Faster RecordIO Scanner
  • Loading branch information
reyoung authored Jun 5, 2018
2 parents c8d6c1d + c3632b8 commit 78afcbf
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 42 deletions.
58 changes: 37 additions & 21 deletions paddle/fluid/recordio/chunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,40 +119,56 @@ bool Chunk::Write(std::ostream& os, Compressor ct) const {
}

bool Chunk::Parse(std::istream& sin) {
Header hdr;
bool ok = hdr.Parse(sin);
ChunkParser parser(sin);
if (!parser.Init()) {
return false;
}
Clear();
while (parser.HasNext()) {
Add(parser.Next());
}
return true;
}

ChunkParser::ChunkParser(std::istream& sin) : in_(sin) {}
bool ChunkParser::Init() {
pos_ = 0;
bool ok = header_.Parse(in_);
if (!ok) {
return ok;
}
auto beg_pos = sin.tellg();
uint32_t crc = Crc32Stream(sin, hdr.CompressSize());
PADDLE_ENFORCE_EQ(hdr.Checksum(), crc);
Clear();
sin.seekg(beg_pos, sin.beg);
std::unique_ptr<std::istream> compressed_stream;
switch (hdr.CompressType()) {
auto beg_pos = in_.tellg();
uint32_t crc = Crc32Stream(in_, header_.CompressSize());
PADDLE_ENFORCE_EQ(header_.Checksum(), crc);
in_.seekg(beg_pos, in_.beg);

switch (header_.CompressType()) {
case Compressor::kNoCompress:
break;
case Compressor::kSnappy:
compressed_stream.reset(new snappy::iSnappyStream(sin));
compressed_stream_.reset(new snappy::iSnappyStream(in_));
break;
default:
PADDLE_THROW("Not implemented");
}
return true;
}

std::istream& stream = compressed_stream ? *compressed_stream : sin;
bool ChunkParser::HasNext() const { return pos_ < header_.NumRecords(); }

for (uint32_t i = 0; i < hdr.NumRecords(); ++i) {
uint32_t rec_len;
stream.read(reinterpret_cast<char*>(&rec_len), sizeof(uint32_t));
std::string buf;
buf.resize(rec_len);
stream.read(&buf[0], rec_len);
PADDLE_ENFORCE_EQ(rec_len, stream.gcount());
Add(buf);
std::string ChunkParser::Next() {
if (!HasNext()) {
return "";
}
return true;
++pos_;
std::istream& stream = compressed_stream_ ? *compressed_stream_ : in_;
uint32_t rec_len;
stream.read(reinterpret_cast<char*>(&rec_len), sizeof(uint32_t));
std::string buf;
buf.resize(rec_len);
stream.read(&buf[0], rec_len);
PADDLE_ENFORCE_EQ(rec_len, stream.gcount());
return buf;
}

} // namespace recordio
} // namespace paddle
16 changes: 14 additions & 2 deletions paddle/fluid/recordio/chunk.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#pragma once
#include <memory>
#include <string>
#include <vector>

Expand Down Expand Up @@ -53,9 +54,20 @@ class Chunk {
DISABLE_COPY_AND_ASSIGN(Chunk);
};

size_t CompressData(const char* in, size_t in_length, Compressor ct, char* out);
class ChunkParser {
public:
explicit ChunkParser(std::istream& sin);

bool Init();
std::string Next();
bool HasNext() const;

void DeflateData(const char* in, size_t in_length, Compressor ct, char* out);
private:
Header header_;
uint32_t pos_{0};
std::istream& in_;
std::unique_ptr<std::istream> compressed_stream_;
};

} // namespace recordio
} // namespace paddle
26 changes: 12 additions & 14 deletions paddle/fluid/recordio/scanner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,35 +22,33 @@ namespace paddle {
namespace recordio {

Scanner::Scanner(std::unique_ptr<std::istream> &&stream)
: stream_(std::move(stream)) {
: stream_(std::move(stream)), parser_(*stream_) {
Reset();
}

Scanner::Scanner(const std::string &filename) {
stream_.reset(new std::ifstream(filename));
Scanner::Scanner(const std::string &filename)
: stream_(new std::ifstream(filename)), parser_(*stream_) {
Reset();
}

void Scanner::Reset() {
stream_->clear();
stream_->seekg(0, std::ios::beg);
ParseNextChunk();
parser_.Init();
}

std::string Scanner::Next() {
PADDLE_ENFORCE(!eof_, "StopIteration");
auto rec = cur_chunk_.Record(offset_++);
if (offset_ == cur_chunk_.NumRecords()) {
ParseNextChunk();
if (stream_->eof()) {
return "";
}
return rec;
}

void Scanner::ParseNextChunk() {
eof_ = !cur_chunk_.Parse(*stream_);
offset_ = 0;
auto res = parser_.Next();
if (!parser_.HasNext() && HasNext()) {
parser_.Init();
}
return res;
}

bool Scanner::HasNext() const { return !eof_; }
bool Scanner::HasNext() const { return !stream_->eof(); }
} // namespace recordio
} // namespace paddle
6 changes: 1 addition & 5 deletions paddle/fluid/recordio/scanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,7 @@ class Scanner {

private:
std::unique_ptr<std::istream> stream_;
Chunk cur_chunk_;
size_t offset_;
bool eof_;

void ParseNextChunk();
ChunkParser parser_;
};
} // namespace recordio
} // namespace paddle

0 comments on commit 78afcbf

Please sign in to comment.