-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fundamental Data Reading in C++ #8009
Changes from all commits
1acad21
f32ca63
d8cc21d
93cab64
1696cb0
3dfd1da
53e697c
da8a56e
6e6f5c7
1010e39
0bb9c80
542bdef
b00cae6
c1349d9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/framework/reader.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
DDim ReaderBase::shape(size_t idx) const { | ||
PADDLE_ENFORCE_LT( | ||
idx, shapes_.size(), | ||
"Cannot get the %d'th shape, 'shapes_' only has %d elements.", idx, | ||
shapes_.size()); | ||
return shapes_[idx]; | ||
} | ||
|
||
void ShuffleReader::ReadNext(std::vector<LoDTensor>* out) { | ||
if (iteration_pos_ >= buffer_.size()) { | ||
// Reload buffer with new data | ||
buffer_.clear(); | ||
buffer_.reserve(buffer_size_); | ||
for (int i = 0; i < buffer_size_; ++i) { | ||
if (reader_->HasNext()) { | ||
buffer_.push_back(std::vector<LoDTensor>()); | ||
reader_->ReadNext(&buffer_.back()); | ||
} else { | ||
break; | ||
} | ||
} | ||
// TODO(fengjiayi): 'std::random_shuffle' can be very slow. It needs to be | ||
// optimize. | ||
std::random_shuffle(buffer_.begin(), buffer_.end()); | ||
iteration_pos_ = 0; | ||
} | ||
out->clear(); | ||
if (!buffer_.empty()) { | ||
std::swap(*out, buffer_[iteration_pos_++]); | ||
} | ||
// if buffer_ is empty, the 'out' will return as an empty vector. | ||
} | ||
|
||
void BatchReader::ReadNext(std::vector<LoDTensor>* out) { | ||
buffer_.clear(); | ||
buffer_.reserve(batch_size_); | ||
for (int i = 0; i < batch_size_; ++i) { | ||
if (reader_->HasNext()) { | ||
buffer_.push_back(std::vector<LoDTensor>()); | ||
reader_->ReadNext(&buffer_.back()); | ||
} else { | ||
break; | ||
} | ||
} | ||
// Concat instances | ||
out->clear(); | ||
if (buffer_.empty()) { | ||
// if buffer_ is empty, the 'out' will return as an empty vector. | ||
return; | ||
} | ||
int out_num = buffer_[0].size(); | ||
out->reserve(out_num); | ||
for (int j = 0; j < out_num; ++j) { | ||
// Merge shape and check date type | ||
std::type_index batch_type = buffer_[0][j].type(); | ||
DDim batch_shape = buffer_[0][j].dims(); | ||
for (size_t i = 1; i < buffer_.size(); ++i) { | ||
std::type_index ins_type = buffer_[i][j].type(); | ||
DDim ins_shape = buffer_[i][j].dims(); | ||
PADDLE_ENFORCE_EQ(batch_type, ins_type); | ||
PADDLE_ENFORCE_EQ(slice_ddim(batch_shape, 1, batch_shape.size()), | ||
slice_ddim(ins_shape, 1, ins_shape.size())); | ||
PADDLE_ENFORCE_GT(ins_shape[0], 0); | ||
batch_shape[0] += ins_shape[0]; | ||
} | ||
|
||
LoDTensor out_tensor; | ||
out_tensor.Resize(batch_shape); | ||
out_tensor.mutable_data(platform::CPUPlace(), batch_type); | ||
int64_t dst_offset = 0; | ||
|
||
// Merge lod and data | ||
LoD batch_lod; | ||
std::vector<size_t> top_level_lod({0}); | ||
for (size_t i = 0; i < buffer_.size(); ++i) { | ||
DDim ins_shape = buffer_[i][j].dims(); | ||
LoD ins_lod = buffer_[i][j].lod(); | ||
if (i == 0) { | ||
batch_lod = ins_lod; | ||
} else { | ||
PADDLE_ENFORCE_EQ(batch_lod.size(), ins_lod.size()); | ||
for (size_t level_idx = 0; level_idx < batch_lod.size(); ++level_idx) { | ||
auto& lod_level = batch_lod[level_idx]; | ||
for (size_t k = 1; k < ins_lod[level_idx].size(); ++k) { | ||
lod_level.push_back(ins_lod[level_idx][k] + lod_level.back()); | ||
} | ||
} | ||
} | ||
top_level_lod.push_back( | ||
top_level_lod.back() + | ||
(ins_lod.empty() ? ins_shape[0] : (ins_lod[0].size() - 1))); | ||
|
||
Tensor dst = out_tensor.Slice(dst_offset, dst_offset + ins_shape[0]); | ||
Copy(buffer_[i][j], platform::CPUPlace(), &dst); | ||
dst_offset += ins_shape[0]; | ||
} | ||
batch_lod.insert(batch_lod.begin(), top_level_lod); | ||
out_tensor.set_lod(batch_lod); | ||
out->push_back(out_tensor); | ||
} | ||
} | ||
} // namespace framework | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include "paddle/framework/ddim.h" | ||
#include "paddle/framework/lod_tensor_array.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
class ReaderBase { | ||
public: | ||
explicit ReaderBase(const std::vector<DDim>& shapes) : shapes_(shapes) { | ||
PADDLE_ENFORCE(!shapes_.empty()); | ||
} | ||
virtual void ReadNext(std::vector<LoDTensor>* out) = 0; | ||
virtual bool HasNext() const = 0; | ||
|
||
virtual void ReInit() = 0; | ||
|
||
DDim shape(size_t idx) const; | ||
std::vector<DDim> shapes() const { return shapes_; } | ||
void set_shapes(const std::vector<DDim>& shapes) { shapes_ = shapes; } | ||
|
||
virtual ~ReaderBase() {} | ||
|
||
protected: | ||
std::vector<DDim> shapes_; | ||
}; | ||
|
||
class FileReader : public ReaderBase { | ||
public: | ||
explicit FileReader(const std::vector<DDim>& shapes) : ReaderBase(shapes) {} | ||
}; | ||
|
||
class DecoratedReader : public ReaderBase { | ||
public: | ||
explicit DecoratedReader(ReaderBase* reader) | ||
: ReaderBase(reader->shapes()), reader_(reader) { | ||
PADDLE_ENFORCE_NOT_NULL(reader_); | ||
} | ||
|
||
bool HasNext() const override { return reader_->HasNext(); } | ||
|
||
void ReInit() override { reader_->ReInit(); } | ||
|
||
protected: | ||
ReaderBase* reader_; | ||
}; | ||
|
||
// file readers | ||
|
||
template <typename T> | ||
class RandomDataGenerator : public FileReader { | ||
public: | ||
RandomDataGenerator(const std::vector<DDim>& shapes, float min, float max) | ||
: FileReader(shapes), min_(min), max_(max) { | ||
PADDLE_ENFORCE_LE( | ||
min, max, "'min' shouldn't be greater than 'max'.(%f vs %f)", min, max); | ||
unsigned int seed = std::random_device()(); | ||
engine_.seed(seed); | ||
dist_ = std::uniform_real_distribution<float>(min_, max_); | ||
} | ||
|
||
void ReadNext(std::vector<LoDTensor>* out) override { | ||
out->clear(); | ||
out->reserve(shapes_.size()); | ||
for (const DDim& shape : shapes_) { | ||
PADDLE_ENFORCE_GE( | ||
shape.size(), 2, | ||
"The rank of reader's output data should be 2 at least.(Now it's %d)", | ||
shape.size()); | ||
LoDTensor out_tensor; | ||
out_tensor.Resize(shape); | ||
T* data = out_tensor.mutable_data<T>(platform::CPUPlace()); | ||
int64_t numel = product(shape); | ||
for (int64_t i = 0; i < numel; ++i) { | ||
data[i] = dist_(engine_); | ||
} | ||
out->push_back(out_tensor); | ||
} | ||
} | ||
|
||
bool HasNext() const override { return true; } | ||
|
||
void ReInit() override { return; } | ||
|
||
private: | ||
float min_; | ||
float max_; | ||
std::minstd_rand engine_; | ||
std::uniform_real_distribution<float> dist_; | ||
}; | ||
|
||
// decorated readers | ||
|
||
class ShuffleReader : public DecoratedReader { | ||
public: | ||
ShuffleReader(ReaderBase* reader, int buffer_size) | ||
: DecoratedReader(reader), buffer_size_(buffer_size), iteration_pos_(0) { | ||
buffer_.reserve(buffer_size); | ||
} | ||
|
||
void ReadNext(std::vector<LoDTensor>* out) override; | ||
|
||
private: | ||
int buffer_size_; | ||
std::vector<std::vector<LoDTensor>> buffer_; | ||
size_t iteration_pos_; | ||
}; | ||
|
||
class BatchReader : public DecoratedReader { | ||
public: | ||
BatchReader(ReaderBase* reader, int batch_size) | ||
: DecoratedReader(reader), batch_size_(batch_size) { | ||
buffer_.reserve(batch_size_); | ||
} | ||
|
||
void ReadNext(std::vector<LoDTensor>* out) override; | ||
|
||
private: | ||
int batch_size_; | ||
std::vector<std::vector<LoDTensor>> buffer_; | ||
}; | ||
|
||
// The ReaderHolder is used as readers' unified wrapper, | ||
// making it easier to access different type readers in Variables. | ||
class ReaderHolder { | ||
public: | ||
void Reset(ReaderBase* reader) { reader_.reset(reader); } | ||
|
||
ReaderBase* Get() const { return reader_.get(); } | ||
|
||
void ReadNext(std::vector<LoDTensor>* out) { reader_->ReadNext(out); } | ||
bool HasNext() const { return reader_->HasNext(); } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For reader, maybe we should add a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. And the further discussion is needed about what shall we exactly do when a reader finishes one pass reading. And |
||
void ReInit() { reader_->ReInit(); } | ||
|
||
DDim shape(size_t idx) const { return reader_->shape(idx); } | ||
std::vector<DDim> shapes() const { return reader_->shapes(); } | ||
void set_shapes(const std::vector<DDim>& shapes) { | ||
reader_->set_shapes(shapes); | ||
} | ||
|
||
private: | ||
std::unique_ptr<ReaderBase> reader_; | ||
}; | ||
|
||
} // namespace framework | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/framework/op_registry.h" | ||
#include "paddle/framework/reader.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
static std::vector<framework::DDim> RestoreShapes( | ||
const std::vector<int>& shape_concat, const std::vector<int>& ranks) { | ||
std::vector<framework::DDim> res; | ||
int offset = 0; | ||
for (int len : ranks) { | ||
auto start_it = shape_concat.begin() + offset; | ||
auto end_it = start_it + len; | ||
res.push_back(framework::make_ddim(std::vector<int>(start_it, end_it))); | ||
offset += len; | ||
} | ||
return res; | ||
} | ||
|
||
// general infershape for file readers | ||
class CreateFileReaderInferShape : public framework::InferShapeBase { | ||
public: | ||
void operator()(framework::InferShapeContext* ctx) const override { | ||
PADDLE_ENFORCE(ctx->HasOutput("Out"), | ||
"The output file reader should not be null."); | ||
const auto shape_concat = | ||
ctx->Attrs().Get<std::vector<int>>("shape_concat"); | ||
const auto ranks = ctx->Attrs().Get<std::vector<int>>("ranks"); | ||
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks); | ||
ctx->SetReaderDims("Out", shapes); | ||
} | ||
}; | ||
|
||
// general infershape for decorated readers | ||
class CreateDecoratedReaderInferShape : public framework::InferShapeBase { | ||
public: | ||
void operator()(framework::InferShapeContext* ctx) const override { | ||
PADDLE_ENFORCE(ctx->HasInput("UnderlyingReader"), | ||
"Input(UnderlyingReader) should not be null."); | ||
PADDLE_ENFORCE(ctx->HasOutput("Out"), | ||
"The output decorated reader should not be null."); | ||
ctx->SetReaderDims("Out", ctx->GetReaderDims("UnderlyingReader")); | ||
} | ||
}; | ||
|
||
// general var type inference for all readers | ||
class CreateReaderInferVarType : public framework::VarTypeInference { | ||
public: | ||
void operator()(const framework::OpDesc& op_desc, | ||
framework::BlockDesc* block) const override { | ||
std::string reader_name = op_desc.Output("Out")[0]; | ||
framework::VarDesc* reader = block->FindVarRecursive(reader_name); | ||
reader->SetType(framework::proto::VarDesc::READER); | ||
} | ||
}; | ||
|
||
template <typename T> | ||
class CreateRandomDataGeneratorOp : public framework::OperatorBase { | ||
public: | ||
using framework::OperatorBase::OperatorBase; | ||
void Run(const framework::Scope& scope, | ||
const platform::Place& dev_place) const override { | ||
const auto& shape_concat = Attr<std::vector<int>>("shape_concat"); | ||
const auto& ranks = Attr<std::vector<int>>("ranks"); | ||
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty()); | ||
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0), | ||
int(shape_concat.size()), | ||
"The accumulate of all ranks should be equal to the " | ||
"shape concat's length."); | ||
std::vector<framework::DDim> shapes = RestoreShapes(shape_concat, ranks); | ||
auto* out = scope.FindVar(Output("Out")) | ||
->template GetMutable<framework::ReaderHolder>(); | ||
out->Reset(new framework::RandomDataGenerator<T>(shapes, Attr<float>("min"), | ||
Attr<float>("max"))); | ||
} | ||
}; | ||
|
||
class CreateRandomDataGeneratorOpMaker | ||
: public framework::OpProtoAndCheckerMaker { | ||
public: | ||
CreateRandomDataGeneratorOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) | ||
: OpProtoAndCheckerMaker(op_proto, op_checker) { | ||
AddOutput("Out", "(ReaderHolder) The created random reader."); | ||
AddAttr<std::vector<int>>("shape_concat", | ||
"The concat of all data's shapes."); | ||
AddAttr<std::vector<int>>( | ||
"ranks", | ||
"The ranks of each data." | ||
"e.g." | ||
"shape_concat = [2,3,4,5,6]" | ||
"ranks = [3,2]" | ||
"It means the reader will generate two data each time," | ||
"whose shapes are [2,3,4] and [5,6] respectively."); | ||
AddAttr<float>("min", "The lower bound of reader's uniform distribution."); | ||
AddAttr<float>("max", "The upper bound of reader's uniform distribution."); | ||
AddComment(R"DOC( | ||
CreateRandomDataGenerator Operator | ||
This Op creates a random reader. | ||
The reader generates random data instead of really reading from files. | ||
Generated data follow an uniform distribution between 'min' and 'max'. | ||
)DOC"); | ||
} | ||
}; | ||
|
||
class CreateShuffleReaderOp : public framework::OperatorBase { | ||
public: | ||
using framework::OperatorBase::OperatorBase; | ||
void Run(const framework::Scope& scope, | ||
const platform::Place& dev_place) const override { | ||
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) | ||
->Get<framework::ReaderHolder>(); | ||
auto* out = scope.FindVar(Output("Out")) | ||
->template GetMutable<framework::ReaderHolder>(); | ||
out->Reset(new framework::ShuffleReader(underlying_reader.Get(), | ||
Attr<int>("buffer_size"))); | ||
} | ||
}; | ||
|
||
class CreateShuffleReaderOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
CreateShuffleReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) | ||
: OpProtoAndCheckerMaker(op_proto, op_checker) { | ||
AddInput( | ||
"UnderlyingReader", | ||
"(ReaderHolder) The underlying reader for creating a shuffle reader."); | ||
AddOutput("Out", "(ReaderHolder) The created shuffle reader."); | ||
AddAttr<int>("buffer_size", "The shuffle buffer size.").GreaterThan(0); | ||
AddComment(R"DOC( | ||
CreateShuffleReader Operator | ||
A shuffle reader takes another reader as its 'underlying reader' | ||
and yields the underlying reader's outputs in a shuffled order. | ||
)DOC"); | ||
} | ||
}; | ||
|
||
class CreateBatchReaderOp : public framework::OperatorBase { | ||
public: | ||
using framework::OperatorBase::OperatorBase; | ||
void Run(const framework::Scope& scope, | ||
const platform::Place& dev_place) const override { | ||
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader")) | ||
->Get<framework::ReaderHolder>(); | ||
auto* out = scope.FindVar(Output("Out")) | ||
->template GetMutable<framework::ReaderHolder>(); | ||
out->Reset(new framework::BatchReader(underlying_reader.Get(), | ||
Attr<int>("batch_size"))); | ||
} | ||
}; | ||
|
||
class CreateBatchReaderOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
CreateBatchReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) | ||
: OpProtoAndCheckerMaker(op_proto, op_checker) { | ||
AddInput( | ||
"UnderlyingReader", | ||
"(ReaderHolder) The underlying reader for creating a batch reader."); | ||
AddOutput("Out", "(ReaderHolder) The created batch reader."); | ||
AddAttr<int>("batch_size", | ||
"How many instances the batch reader yields each time.") | ||
.GreaterThan(0); | ||
AddComment(R"DOC( | ||
CreateBatchReader Operator | ||
A batch reader takes another reader as its 'underlying reader', | ||
gathers the underlying reader's outputs and then yields them in batches. | ||
)DOC"); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OPERATOR(create_random_data_generator, | ||
ops::CreateRandomDataGeneratorOp<float>, | ||
ops::CreateFileReaderInferShape, | ||
ops::CreateRandomDataGeneratorOpMaker, | ||
paddle::framework::EmptyGradOpMaker, | ||
ops::CreateReaderInferVarType); | ||
REGISTER_OPERATOR(create_shuffle_reader, ops::CreateShuffleReaderOp, | ||
ops::CreateDecoratedReaderInferShape, | ||
ops::CreateShuffleReaderOpMaker, | ||
paddle::framework::EmptyGradOpMaker, | ||
ops::CreateReaderInferVarType); | ||
REGISTER_OPERATOR(create_batch_reader, ops::CreateBatchReaderOp, | ||
ops::CreateDecoratedReaderInferShape, | ||
ops::CreateBatchReaderOpMaker, | ||
paddle::framework::EmptyGradOpMaker, | ||
ops::CreateReaderInferVarType); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/framework/op_registry.h" | ||
#include "paddle/framework/reader.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class ReadInferShape : public framework::InferShapeBase { | ||
public: | ||
void operator()(framework::InferShapeContext* ctx) const override { | ||
PADDLE_ENFORCE(ctx->HasInput("Reader"), | ||
"The ReadOp must take a reader as input."); | ||
PADDLE_ENFORCE(ctx->HasOutputs("Out"), | ||
"The ReadOp should be assigned with output."); | ||
std::vector<framework::DDim> reader_dims = ctx->GetReaderDims("Reader"); | ||
std::vector<std::string> out_names = ctx->Outputs("Out"); | ||
PADDLE_ENFORCE_EQ( | ||
reader_dims.size(), out_names.size(), | ||
"The reader's dim number doesn't match the output number."); | ||
ctx->SetOutputsDim("Out", reader_dims); | ||
} | ||
}; | ||
|
||
class ReadInferVarType : public framework::VarTypeInference { | ||
public: | ||
void operator()(const framework::OpDesc& op_desc, | ||
framework::BlockDesc* block) const override { | ||
std::string reader_name = op_desc.Input("Reader")[0]; | ||
std::vector<std::string> out_names = op_desc.Output("Out"); | ||
framework::VarDesc* reader = block->FindVarRecursive(reader_name); | ||
auto dtypes = reader->GetDataTypes(); | ||
PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size()); | ||
for (size_t i = 0; i < dtypes.size(); ++i) { | ||
framework::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]); | ||
out.SetType(framework::proto::VarDesc::LOD_TENSOR); | ||
out.SetDataType(dtypes[i]); | ||
} | ||
} | ||
}; | ||
|
||
class ReadOp : public framework::OperatorBase { | ||
public: | ||
using framework::OperatorBase::OperatorBase; | ||
void Run(const framework::Scope& scope, | ||
const platform::Place& dev_place) const override { | ||
framework::ReaderHolder* reader = | ||
scope.FindVar(Input("Reader"))->GetMutable<framework::ReaderHolder>(); | ||
if (!reader->HasNext()) { | ||
reader->ReInit(); | ||
PADDLE_ENFORCE( | ||
reader->HasNext(), | ||
"Reader can not read the next data even it has been re-initialized."); | ||
} | ||
std::vector<std::string> out_arg_names = Outputs("Out"); | ||
std::vector<framework::LoDTensor> ins; | ||
reader->ReadNext(&ins); | ||
PADDLE_ENFORCE_EQ(ins.size(), out_arg_names.size()); | ||
for (size_t i = 0; i < ins.size(); ++i) { | ||
auto* out = | ||
scope.FindVar(out_arg_names[i])->GetMutable<framework::LoDTensor>(); | ||
out->ShareDataWith(ins[i]); | ||
out->set_lod(ins[i].lod()); | ||
} | ||
} | ||
}; | ||
|
||
class ReadOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
ReadOpMaker(OpProto* op_proto, OpAttrChecker* op_checker) | ||
: OpProtoAndCheckerMaker(op_proto, op_checker) { | ||
AddInput("Reader", "(ReaderHolder) The executed reader."); | ||
AddOutput("Out", "(LoDTensor) The output data.").AsDuplicable(); | ||
AddComment(R"DOC( | ||
Read Operator | ||
Execute a given reader once and output data. | ||
)DOC"); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OPERATOR(read, ops::ReadOp, ops::ReadInferShape, ops::ReadOpMaker, | ||
paddle::framework::EmptyGradOpMaker, ops::ReadInferVarType); |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle.v2 as paddle | ||
import paddle.v2.fluid as fluid | ||
import numpy as np | ||
|
||
prog = fluid.framework.Program() | ||
block = prog.current_block() | ||
|
||
random_reader = block.create_var( | ||
type=fluid.core.VarDesc.VarType.READER, name="RandomDataGenerator") | ||
random_reader.desc.set_lod_levels([0, 0]) | ||
|
||
create_random_data_generator_op = block.append_op( | ||
type="create_random_data_generator", | ||
outputs={"Out": random_reader}, | ||
attrs={ | ||
"shape_concat": [1, 2, 1, 1], | ||
"ranks": [2, 2], | ||
"min": 0.0, | ||
"max": 1.0 | ||
}) | ||
|
||
out1 = block.create_var( | ||
type=fluid.core.VarDesc.VarType.LOD_TENSOR, | ||
name="Out1", | ||
shape=[10, 2], | ||
dtype="float32", | ||
lod_level=1) | ||
out2 = block.create_var( | ||
type=fluid.core.VarDesc.VarType.LOD_TENSOR, | ||
name="Out2", | ||
shape=[10, 1], | ||
dtype="float32", | ||
lod_level=1) | ||
|
||
read_op = block.append_op( | ||
type="read", | ||
inputs={"Reader": random_reader}, | ||
outputs={"Out": [out1, out2]}) | ||
|
||
place = fluid.CPUPlace() | ||
exe = fluid.Executor(place) | ||
|
||
[res1, res2] = exe.run(prog, fetch_list=[out1, out2]) | ||
|
||
if len(res1) == 0 or len(res2) == 0: | ||
exit(1) | ||
|
||
exit(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can invoke
MergeTensorOp
here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
MergeTensorOp
can only merge twoLoDTensor
(true branch out and false branch out). However, in BatchReader we need to merge far more than twoLoDTensor
.