Skip to content

Commit

Permalink
Add functions of restoring ProgramDescBind from ProgramDesc (#5109)
Browse files Browse the repository at this point in the history
* compelete restoring program_bind from program_desc

* Fix bugs

* fix compile errors

* fix errors and add unit tests

* rename some vars

* Follow comments
  • Loading branch information
JiayiFeng authored Oct 26, 2017
1 parent b1cbdf0 commit aa379cc
Show file tree
Hide file tree
Showing 11 changed files with 173 additions and 22 deletions.
11 changes: 11 additions & 0 deletions paddle/framework/block_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,17 @@ BlockDesc *BlockDescBind::Proto() {
Flush();
return desc_;
}

BlockDescBind::BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
: prog_(prog), desc_(desc), need_update_(false) {
for (const VarDesc &var_desc : desc_->vars()) {
vars_[var_desc.name()].reset(new VarDescBind(var_desc));
}
for (const OpDesc &op_desc : desc_->ops()) {
ops_.emplace_back(new OpDescBind(op_desc, prog));
}
}

BlockDescBind::BlockDescBind(const BlockDescBind &other, BlockDesc *desc,
ProgramDescBind *prog)
: prog_(prog), desc_(desc) {
Expand Down
3 changes: 1 addition & 2 deletions paddle/framework/block_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ class ProgramDescBind;

class BlockDescBind {
public:
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
: prog_(prog), desc_(desc), need_update_(false) {}
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc);

BlockDescBind(const BlockDescBind &other, BlockDesc *desc,
ProgramDescBind *prog);
Expand Down
48 changes: 40 additions & 8 deletions paddle/framework/op_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,55 @@ limitations under the License. */
#include <unordered_map>
#include "paddle/framework/block_desc.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/program_desc.h"

namespace paddle {
namespace framework {

OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs,
const AttributeMap &attrs) {
op_desc_.set_type(type);
desc_.set_type(type);
inputs_ = inputs;
outputs_ = outputs;
attrs_ = attrs;
need_update_ = true;
}

OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
: desc_(desc), need_update_(false) {
// restore inputs_
int input_size = desc_.inputs_size();
for (int i = 0; i < input_size; ++i) {
const OpDesc::Var &var = desc_.inputs(i);
std::vector<std::string> &args = inputs_[var.parameter()];
int argu_size = var.arguments_size();
args.reserve(argu_size);
for (int j = 0; j < argu_size; ++j) {
args.push_back(var.arguments(j));
}
}
// restore outputs_
int output_size = desc_.outputs_size();
for (int i = 0; i < output_size; ++i) {
const OpDesc::Var &var = desc_.outputs(i);
std::vector<std::string> &args = outputs_[var.parameter()];
int argu_size = var.arguments_size();
args.reserve(argu_size);
for (int j = 0; j < argu_size; ++j) {
args.push_back(var.arguments(j));
}
}
// restore attrs_
for (const OpDesc::Attr &attr : desc_.attrs()) {
std::string attr_name = attr.name();
attrs_[attr_name] = GetAttrValue(attr, prog->Proto());
}
}

OpDesc *OpDescBind::Proto() {
Flush();
return &op_desc_;
return &desc_;
}

const std::vector<std::string> &OpDescBind::Input(
Expand Down Expand Up @@ -167,23 +199,23 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {

void OpDescBind::Flush() {
if (need_update_) {
this->op_desc_.mutable_inputs()->Clear();
this->desc_.mutable_inputs()->Clear();
for (auto &ipt : inputs_) {
auto *input = op_desc_.add_inputs();
auto *input = desc_.add_inputs();
input->set_parameter(ipt.first);
VectorToRepeated(ipt.second, input->mutable_arguments());
}

this->op_desc_.mutable_outputs()->Clear();
this->desc_.mutable_outputs()->Clear();
for (auto &opt : outputs_) {
auto *output = op_desc_.add_outputs();
auto *output = desc_.add_outputs();
output->set_parameter(opt.first);
VectorToRepeated(opt.second, output->mutable_arguments());
}

this->op_desc_.mutable_attrs()->Clear();
this->desc_.mutable_attrs()->Clear();
for (auto &attr : attrs_) {
auto *attr_desc = op_desc_.add_attrs();
auto *attr_desc = desc_.add_attrs();
attr_desc->set_name(attr.first);
attr_desc->set_type(
static_cast<framework::AttrType>(attr.second.which() - 1));
Expand Down
9 changes: 6 additions & 3 deletions paddle/framework/op_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace paddle {
namespace framework {

class BlockDescBind;
class ProgramDescBind;

class OpDescBind {
public:
Expand All @@ -32,11 +33,13 @@ class OpDescBind {
OpDescBind(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs);

OpDescBind(const OpDesc &desc, ProgramDescBind *prog);

OpDesc *Proto();

std::string Type() const { return op_desc_.type(); }
std::string Type() const { return desc_.type(); }

void SetType(const std::string &type) { op_desc_.set_type(type); }
void SetType(const std::string &type) { desc_.set_type(type); }

const std::vector<std::string> &Input(const std::string &name) const;

Expand Down Expand Up @@ -117,7 +120,7 @@ class OpDescBind {
return ret_val;
}

OpDesc op_desc_;
OpDesc desc_;
VariableNameMap inputs_;
VariableNameMap outputs_;
AttributeMap attrs_;
Expand Down
23 changes: 16 additions & 7 deletions paddle/framework/program_desc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ namespace paddle {
namespace framework {

BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) {
auto *b = prog_.add_blocks();
auto *b = desc_.add_blocks();
b->set_parent_idx(parent.ID());
b->set_idx(prog_.blocks_size() - 1);
b->set_idx(desc_.blocks_size() - 1);
blocks_.emplace_back(new BlockDescBind(this, b));
return blocks_.back().get();
}
Expand All @@ -30,23 +30,32 @@ ProgramDesc *ProgramDescBind::Proto() {
for (auto &block : blocks_) {
block->Flush();
}
return &prog_;
return &desc_;
}

ProgramDescBind::ProgramDescBind() {
auto *block = prog_.mutable_blocks()->Add();
auto *block = desc_.mutable_blocks()->Add();
block->set_idx(kRootBlockIndex);
block->set_parent_idx(kNoneBlockIndex);
blocks_.emplace_back(new BlockDescBind(this, block));
}

ProgramDescBind::ProgramDescBind(const ProgramDescBind &o) {
prog_ = o.prog_;
desc_ = o.desc_;

for (int i = 0; i < prog_.blocks_size(); ++i) {
auto *block = prog_.mutable_blocks(i);
for (int i = 0; i < desc_.blocks_size(); ++i) {
auto *block = desc_.mutable_blocks(i);
blocks_.emplace_back(new BlockDescBind(*o.blocks_[i], block, this));
}
}

ProgramDescBind::ProgramDescBind(const std::string &binary_str) {
PADDLE_ENFORCE(desc_.ParseFromString(binary_str),
"Fail to parse program_desc from binary string.");
for (auto &block_desc : *desc_.mutable_blocks()) {
blocks_.emplace_back(new BlockDescBind(this, &block_desc));
}
}

} // namespace framework
} // namespace paddle
4 changes: 3 additions & 1 deletion paddle/framework/program_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class ProgramDescBind {

ProgramDescBind(const ProgramDescBind &o);

explicit ProgramDescBind(const std::string &binary_str);

BlockDescBind *AppendBlock(const BlockDescBind &parent);

BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); }
Expand All @@ -40,7 +42,7 @@ class ProgramDescBind {
ProgramDesc *Proto();

private:
ProgramDesc prog_;
ProgramDesc desc_;

std::vector<std::unique_ptr<BlockDescBind>> blocks_;
};
Expand Down
64 changes: 63 additions & 1 deletion paddle/framework/program_desc_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ TEST(ProgramDesc, copy_ctor) {
};

ASSERT_EQ(global_block->LocalVarNames(), global_block_copy->LocalVarNames());
ASSERT_EQ(3, global_block_copy->LocalVarNames().size());
ASSERT_EQ(3UL, global_block_copy->LocalVarNames().size());
assert_same_var("X", x);
assert_same_var("Y", y);
assert_same_var("Out", out);
Expand All @@ -79,5 +79,67 @@ TEST(ProgramDesc, copy_ctor) {
// Not check block's protostr are same it because the order of vars could be
// different and it is correct.
}

TEST(ProgramDescBind, serialize_and_deserialize) {
ProgramDescBind program_origin;
auto* global_block = program_origin.Block(0);
auto* x = global_block->Var("X");
x->SetType(VarDesc_VarType_LOD_TENSOR);
x->SetLoDLevel(0);
x->SetDataType(FP32);
x->SetShape({1000, 784});

auto* y = global_block->Var("Y");
y->SetType(VarDesc_VarType_LOD_TENSOR);
y->SetLoDLevel(0);
y->SetDataType(FP32);
y->SetShape({784, 100});

auto* op = global_block->AppendOp();
op->SetType("mul");
op->SetInput("X", {x->Name()});
op->SetInput("Y", {y->Name()});

auto* out = global_block->Var("Out");
out->SetType(VarDesc_VarType_LOD_TENSOR);
op->SetOutput("Y", {out->Name()});

std::string binary_str;
program_origin.Proto()->SerializeToString(&binary_str);

ProgramDescBind program_restored(binary_str);
auto* global_block_restored = program_restored.Block(0);
ASSERT_NE(global_block, global_block_restored);

auto assert_same_var = [&](const std::string& name, VarDescBind* var_before) {
ASSERT_TRUE(global_block_restored->HasVar(name));
auto* restored = global_block_restored->Var(name);
ASSERT_NE(restored, var_before);
ASSERT_EQ(restored->Name(), var_before->Name());
ASSERT_EQ(restored->GetType(), var_before->GetType());
ASSERT_EQ(restored->Shape(), var_before->Shape());
ASSERT_EQ(restored->Proto()->SerializeAsString(),
var_before->Proto()->SerializeAsString());
};

ASSERT_EQ(global_block->LocalVarNames(),
global_block_restored->LocalVarNames());
ASSERT_EQ(3UL, global_block_restored->LocalVarNames().size());
assert_same_var("X", x);
assert_same_var("Y", y);
assert_same_var("Out", out);

for (size_t i = 0; i < global_block->OpSize(); ++i) {
auto op_origin = global_block->Op(i);
auto op_restored = global_block->Op(i);

ASSERT_EQ(op_origin->Type(), op_restored->Type());
ASSERT_EQ(op_origin->Inputs(), op_restored->Inputs());
ASSERT_EQ(op_origin->Outputs(), op_restored->Outputs());

ASSERT_EQ(op_restored->Proto()->SerializeAsString(),
op_origin->Proto()->SerializeAsString());
}
}
} // namespace framework
} // namespace paddle
2 changes: 2 additions & 0 deletions paddle/framework/var_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class VarDescBind {
desc_.set_type(VarDesc::LOD_TENSOR);
}

explicit VarDescBind(const VarDesc &desc) : desc_(desc) {}

VarDesc *Proto() { return &desc_; }

std::string Name() const { return desc_.name(); }
Expand Down
5 changes: 5 additions & 0 deletions paddle/pybind/protobuf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ void BindProgramDesc(py::module &m) {
[](ProgramDescBind &self, const ProgramDescBind &other) {
new (&self) ProgramDescBind(other);
})
.def("__init__",
[](ProgramDescBind &self, const py::bytes &binary_str) {
std::string str(binary_str);
new (&self) ProgramDescBind(str);
})
.def("append_block", &ProgramDescBind::AppendBlock,
py::return_value_policy::reference)
.def("append_backward",
Expand Down
7 changes: 7 additions & 0 deletions python/paddle/v2/framework/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,13 @@ def clone(self):
p.sync_with_cpp()
return p

@staticmethod
def parse_from_string(binary_str):
p = Program()
p.desc = core.ProgramDesc(binary_str)
p.sync_with_cpp()
return p

def __repr__(self):
return str(self)

Expand Down
19 changes: 19 additions & 0 deletions python/paddle/v2/framework/tests/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,25 @@ def test_program_clone(self):
print prog
print prog.clone()

def test_parse_program_from_string(self):
prog = Program()

x = prog.global_block().create_var(
name='X', shape=[1000, 784], dtype='float32')

y = prog.global_block().create_var(
name='Y', shape=[784, 100], dtype='float32')
out = prog.global_block().create_var(name='Out', dtype='float32')
prog.global_block().append_op(
type="mul", inputs={'X': [x],
'Y': [y]}, outputs={'Out': [out]})

binary_str = prog.desc.serialize_to_string()
prog_restored = Program.parse_from_string(binary_str)

print prog
print prog_restored

def test_append_backward(self):
prog = Program()
block = prog.global_block()
Expand Down

0 comments on commit aa379cc

Please sign in to comment.