Skip to content

Commit

Permalink
Merge pull request #8009 from JiayiFeng/dev_reader
Browse files Browse the repository at this point in the history
Fundamental Data Reading in C++
  • Loading branch information
JiayiFeng authored Feb 7, 2018
2 parents 83df277 + c1349d9 commit 812cf15
Show file tree
Hide file tree
Showing 17 changed files with 789 additions and 37 deletions.
2 changes: 2 additions & 0 deletions paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto)
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor paddle_memory)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor init)

cc_library(reader SRCS reader.cc DEPS lod_tensor ddim)

cc_test(variable_test SRCS variable_test.cc)

cc_library(threadpool SRCS threadpool.cc DEPS enforce)
Expand Down
7 changes: 5 additions & 2 deletions paddle/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/framework/lod_rank_table.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/reader.h"
#include "paddle/platform/place.h"
#include "paddle/platform/profiler.h"

Expand Down Expand Up @@ -52,11 +53,13 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
var->GetMutable<LoDTensorArray>();
} else if (var_type == proto::VarDesc::PLACE_LIST) {
var->GetMutable<platform::PlaceList>();
} else if (var_type == proto::VarDesc::READER) {
var->GetMutable<ReaderHolder>();
} else {
PADDLE_THROW(
"Variable type %d is not in "
"[LoDTensor, SelectedRows, FEED_MINIBATCH, FETCH_LIST, LOD_RANK_TABLE,"
" PLACE_LIST]",
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER]",
var_type);
}
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/framework/framework.proto
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ message LoDTensorArrayDesc {
optional int32 lod_level = 2 [ default = 0 ];
}

message Reader { repeated LoDTensorDesc lod_tensor = 1; }
message ReaderDesc { repeated LoDTensorDesc lod_tensor = 1; }

message VarDesc {
enum VarType {
Expand All @@ -136,7 +136,7 @@ message VarDesc {
optional LoDTensorDesc lod_tensor = 4;
optional TensorDesc selected_rows = 5;
optional LoDTensorArrayDesc tensor_array = 6;
optional Reader reader = 7;
optional ReaderDesc reader = 7;
}

message BlockDesc {
Expand Down
42 changes: 36 additions & 6 deletions paddle/framework/op_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ class CompileTimeInferShapeContext : public InferShapeContext {

void SetDim(const std::string &name, const DDim &dim) override;

std::vector<DDim> GetRepeatedDims(const std::string &name) const override;

void SetRepeatedDims(const std::string &name,
const std::vector<DDim> &dims) override;

const OpDesc &op_;
const BlockDesc &block_;
};
Expand Down Expand Up @@ -457,23 +462,48 @@ const std::vector<std::string> &CompileTimeInferShapeContext::Outputs(
DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const {
auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
DDim res;
try {
auto shape = var->GetShape();
if (shape.empty()) {
return framework::make_ddim({0UL});
} else {
return framework::make_ddim(var->GetShape());
}
res = shape.empty() ? make_ddim({0UL}) : make_ddim(shape);
} catch (...) {
VLOG(5) << "GetDim of variable " << name << " error";
std::rethrow_exception(std::current_exception());
}
return res;
}

std::vector<DDim> CompileTimeInferShapeContext::GetRepeatedDims(
const std::string &name) const {
auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
std::vector<DDim> res;
try {
auto shapes = var->GetShapes();
for (const auto &s : shapes) {
res.push_back(s.empty() ? make_ddim({0UL}) : make_ddim(s));
}
} catch (...) {
VLOG(5) << "GetRepeatedDim of variable " << name << " error.";
std::rethrow_exception(std::current_exception());
}
return res;
}

void CompileTimeInferShapeContext::SetDim(const std::string &name,
const DDim &dim) {
block_.FindVarRecursive(name)->SetShape(framework::vectorize(dim));
block_.FindVarRecursive(name)->SetShape(vectorize(dim));
}

void CompileTimeInferShapeContext::SetRepeatedDims(
const std::string &name, const std::vector<DDim> &dims) {
auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
std::vector<std::vector<int64_t>> dim_vec(dims.size());
std::transform(dims.begin(), dims.end(), dim_vec.begin(), vectorize);
var->SetShapes(dim_vec);
}

bool CompileTimeInferShapeContext::IsRuntime() const { return false; }

proto::VarDesc::VarType CompileTimeInferShapeContext::GetVarType(
Expand Down
39 changes: 33 additions & 6 deletions paddle/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
if (length == 0) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL, "Input %s should have more than one inputs",
name);
PADDLE_ENFORCE_EQ(length, 1UL,
"Input %s should not have more than one inputs", name);
auto ipt = ins[0];
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr;
Expand All @@ -333,8 +333,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
if (length == 0) {
return false;
}
PADDLE_ENFORCE_EQ(length, 1UL, "Output %s should have more than one inputs",
name);
PADDLE_ENFORCE_EQ(length, 1UL,
"Output %s should not have more than one inputs", name);
auto ipt = outs[0];
auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
return var != nullptr;
Expand Down Expand Up @@ -421,8 +421,22 @@ class RuntimeInferShapeContext : public InferShapeContext {
} else if (var->IsType<SelectedRows>()) {
return var->Get<SelectedRows>().GetCompleteDims();
} else {
PADDLE_THROW("Variable %s type_id %s, expect LoDTensor/SelectedRows.",
name, var->Type().name());
PADDLE_THROW(
"Only LoDTensor/SelectedRows support 'GetDim', but Variable %s's "
"type_id is %s.",
name, var->Type().name());
}
}

std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
Variable* var = scope_.FindVar(name);
if (var->IsType<ReaderHolder>()) {
return var->Get<ReaderHolder>().shapes();
} else {
PADDLE_THROW(
"Only ReaderHolder support 'GetRepeatedDims', but Variable %s's "
"type_id is %s.",
name, var->Type().name());
}
}

Expand All @@ -438,6 +452,19 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
}

void SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) override {
Variable* var = scope_.FindVar(name);
if (var->IsType<ReaderHolder>()) {
var->GetMutable<ReaderHolder>()->set_shapes(dims);
} else {
PADDLE_THROW(
"Only ReaderHolder support 'SetRepeatedDims', but Variable %s's "
"type_id is %s.",
name, var->Type().name());
}
}

proto::VarDesc::VarType GetVarType(const std::string& name) const override {
auto* var = scope_.FindVar(name);
return ToVarType(var->Type());
Expand Down
122 changes: 122 additions & 0 deletions paddle/framework/reader.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/framework/reader.h"

namespace paddle {
namespace framework {

DDim ReaderBase::shape(size_t idx) const {
PADDLE_ENFORCE_LT(
idx, shapes_.size(),
"Cannot get the %d'th shape, 'shapes_' only has %d elements.", idx,
shapes_.size());
return shapes_[idx];
}

void ShuffleReader::ReadNext(std::vector<LoDTensor>* out) {
if (iteration_pos_ >= buffer_.size()) {
// Reload buffer with new data
buffer_.clear();
buffer_.reserve(buffer_size_);
for (int i = 0; i < buffer_size_; ++i) {
if (reader_->HasNext()) {
buffer_.push_back(std::vector<LoDTensor>());
reader_->ReadNext(&buffer_.back());
} else {
break;
}
}
// TODO(fengjiayi): 'std::random_shuffle' can be very slow. It needs to be
// optimize.
std::random_shuffle(buffer_.begin(), buffer_.end());
iteration_pos_ = 0;
}
out->clear();
if (!buffer_.empty()) {
std::swap(*out, buffer_[iteration_pos_++]);
}
// if buffer_ is empty, the 'out' will return as an empty vector.
}

void BatchReader::ReadNext(std::vector<LoDTensor>* out) {
buffer_.clear();
buffer_.reserve(batch_size_);
for (int i = 0; i < batch_size_; ++i) {
if (reader_->HasNext()) {
buffer_.push_back(std::vector<LoDTensor>());
reader_->ReadNext(&buffer_.back());
} else {
break;
}
}
// Concat instances
out->clear();
if (buffer_.empty()) {
// if buffer_ is empty, the 'out' will return as an empty vector.
return;
}
int out_num = buffer_[0].size();
out->reserve(out_num);
for (int j = 0; j < out_num; ++j) {
// Merge shape and check date type
std::type_index batch_type = buffer_[0][j].type();
DDim batch_shape = buffer_[0][j].dims();
for (size_t i = 1; i < buffer_.size(); ++i) {
std::type_index ins_type = buffer_[i][j].type();
DDim ins_shape = buffer_[i][j].dims();
PADDLE_ENFORCE_EQ(batch_type, ins_type);
PADDLE_ENFORCE_EQ(slice_ddim(batch_shape, 1, batch_shape.size()),
slice_ddim(ins_shape, 1, ins_shape.size()));
PADDLE_ENFORCE_GT(ins_shape[0], 0);
batch_shape[0] += ins_shape[0];
}

LoDTensor out_tensor;
out_tensor.Resize(batch_shape);
out_tensor.mutable_data(platform::CPUPlace(), batch_type);
int64_t dst_offset = 0;

// Merge lod and data
LoD batch_lod;
std::vector<size_t> top_level_lod({0});
for (size_t i = 0; i < buffer_.size(); ++i) {
DDim ins_shape = buffer_[i][j].dims();
LoD ins_lod = buffer_[i][j].lod();
if (i == 0) {
batch_lod = ins_lod;
} else {
PADDLE_ENFORCE_EQ(batch_lod.size(), ins_lod.size());
for (size_t level_idx = 0; level_idx < batch_lod.size(); ++level_idx) {
auto& lod_level = batch_lod[level_idx];
for (size_t k = 1; k < ins_lod[level_idx].size(); ++k) {
lod_level.push_back(ins_lod[level_idx][k] + lod_level.back());
}
}
}
top_level_lod.push_back(
top_level_lod.back() +
(ins_lod.empty() ? ins_shape[0] : (ins_lod[0].size() - 1)));

Tensor dst = out_tensor.Slice(dst_offset, dst_offset + ins_shape[0]);
Copy(buffer_[i][j], platform::CPUPlace(), &dst);
dst_offset += ins_shape[0];
}
batch_lod.insert(batch_lod.begin(), top_level_lod);
out_tensor.set_lod(batch_lod);
out->push_back(out_tensor);
}
}
} // namespace framework
} // namespace paddle
Loading

0 comments on commit 812cf15

Please sign in to comment.