Skip to content

Commit

Permalink
fix compile errors
Browse files Browse the repository at this point in the history
  • Loading branch information
JiayiFeng committed Feb 7, 2018
1 parent b00cae6 commit c1349d9
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 17 deletions.
2 changes: 2 additions & 0 deletions paddle/framework/reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ void ShuffleReader::ReadNext(std::vector<LoDTensor>* out) {
break;
}
}
// TODO(fengjiayi): 'std::random_shuffle' can be very slow. It needs to be
// optimize.
std::random_shuffle(buffer_.begin(), buffer_.end());
iteration_pos_ = 0;
}
Expand Down
11 changes: 9 additions & 2 deletions paddle/framework/reader.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ class ReaderBase {
virtual void ReadNext(std::vector<LoDTensor>* out) = 0;
virtual bool HasNext() const = 0;

virtual void ReInit() = 0;

DDim shape(size_t idx) const;
std::vector<DDim> shapes() const { return shapes_; }
void set_shapes(const std::vector<DDim>& shapes) { shapes_ = shapes; }
Expand All @@ -52,16 +54,18 @@ class DecoratedReader : public ReaderBase {

bool HasNext() const override { return reader_->HasNext(); }

void ReInit() override { reader_->ReInit(); }

protected:
ReaderBase* reader_;
};

// file readers

template <typename T>
class RandomReader : public FileReader {
class RandomDataGenerator : public FileReader {
public:
RandomReader(const std::vector<DDim>& shapes, float min, float max)
RandomDataGenerator(const std::vector<DDim>& shapes, float min, float max)
: FileReader(shapes), min_(min), max_(max) {
PADDLE_ENFORCE_LE(
min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max);
Expand Down Expand Up @@ -91,6 +95,8 @@ class RandomReader : public FileReader {

bool HasNext() const override { return true; }

void ReInit() override { return; }

private:
float min_;
float max_;
Expand Down Expand Up @@ -139,6 +145,7 @@ class ReaderHolder {

void ReadNext(std::vector<LoDTensor>* out) { reader_->ReadNext(out); }
bool HasNext() const { return reader_->HasNext(); }
void ReInit() { reader_->ReInit(); }

DDim shape(size_t idx) const { return reader_->shape(idx); }
std::vector<DDim> shapes() const { return reader_->shapes(); }
Expand Down
2 changes: 1 addition & 1 deletion paddle/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS})
op_library(${src})
endforeach()
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\nUSE_NO_KERNEL_OP(create_random_reader);\n")
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\nUSE_NO_KERNEL_OP(create_random_data_generator);\n")

set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")

Expand Down
22 changes: 12 additions & 10 deletions paddle/operators/create_reader_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
namespace paddle {
namespace operators {

std::vector<framework::DDim> RestoreShapes(const std::vector<int>& shape_concat,
const std::vector<int>& ranks) {
static std::vector<framework::DDim> RestoreShapes(
const std::vector<int>& shape_concat, const std::vector<int>& ranks) {
std::vector<framework::DDim> res;
int offset = 0;
for (int len : ranks) {
Expand Down Expand Up @@ -69,7 +69,7 @@ class CreateReaderInferVarType : public framework::VarTypeInference {
};

template <typename T>
class CreateRandomReaderOp : public framework::OperatorBase {
class CreateRandomDataGeneratorOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
void Run(const framework::Scope& scope,
Expand All @@ -84,14 +84,15 @@ class CreateRandomReaderOp : public framework::OperatorBase {
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks);
auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new framework::RandomReader<T>(shapes, Attr<float>("min"),
Attr<float>("max")));
out->Reset(new framework::RandomDataGenerator<T>(shapes, Attr<float>("min"),
Attr<float>("max")));
}
};

class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker {
class CreateRandomDataGeneratorOpMaker
: public framework::OpProtoAndCheckerMaker {
public:
CreateRandomReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
CreateRandomDataGeneratorOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(op_proto, op_checker) {
AddOutput("Out", "(ReaderHolder) The created random reader.");
AddAttr<std::vector<int>>("shape_concat",
Expand All @@ -107,7 +108,7 @@ class CreateRandomReaderOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<float>("min", "The lower bound of reader's uniform distribution.");
AddAttr<float>("max", "The upper bound of reader's uniform distribution.");
AddComment(R"DOC(
CreateRandomReader Operator
CreateRandomDataGenerator Operator
This Op creates a random reader.
The reader generates random data instead of really reading from files.
Expand Down Expand Up @@ -186,9 +187,10 @@ class CreateBatchReaderOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(create_random_reader, ops::CreateRandomReaderOp<float>,
REGISTER_OPERATOR(create_random_data_generator,
ops::CreateRandomDataGeneratorOp<float>,
ops::CreateFileReaderInferShape,
ops::CreateRandomReaderOpMaker,
ops::CreateRandomDataGeneratorOpMaker,
paddle::framework::EmptyGradOpMaker,
ops::CreateReaderInferVarType);
REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp,
Expand Down
5 changes: 4 additions & 1 deletion paddle/operators/read_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ class ReadOp : public framework::OperatorBase {
framework::ReaderHolder* reader =
scope.FindVar(Input("Reader"))->GetMutable<framework::ReaderHolder>();
if (!reader->HasNext()) {
return;
reader->ReInit();
PADDLE_ENFORCE(
reader->HasNext(),
"Reader can not read the next data even it has been re-initialized.");
}
std::vector<std::string> out_arg_names = Outputs("Out");
std::vector<framework::LoDTensor> ins;
Expand Down
6 changes: 3 additions & 3 deletions python/paddle/v2/fluid/tests/test_cpp_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@
block = prog.current_block()

random_reader = block.create_var(
type=fluid.core.VarDesc.VarType.READER, name="RandomReader")
type=fluid.core.VarDesc.VarType.READER, name="RandomDataGenerator")
random_reader.desc.set_lod_levels([0, 0])

create_random_reader_op = block.append_op(
type="create_random_reader",
create_random_data_generator_op = block.append_op(
type="create_random_data_generator",
outputs={"Out": random_reader},
attrs={
"shape_concat": [1, 2, 1, 1],
Expand Down

0 comments on commit c1349d9

Please sign in to comment.