Skip to content

Commit

Permalink
Merge pull request #3322 from wangkuiyi/refactorize_framework_proto
Browse files Browse the repository at this point in the history
Refactorize framework/*.proto
  • Loading branch information
reyoung authored Aug 14, 2017
2 parents 8747d60 + 64a4dfe commit 81f5f86
Show file tree
Hide file tree
Showing 52 changed files with 821 additions and 1,346 deletions.
16 changes: 6 additions & 10 deletions paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,19 @@ cc_test(variable_test SRCS variable_test.cc)
cc_library(scope SRCS scope.cc)
cc_test(scope_test SRCS scope_test.cc DEPS scope)

proto_library(attribute_proto SRCS attribute.proto)
proto_library(op_proto SRCS op_proto.proto DEPS attribute_proto)
proto_library(op_desc SRCS op_desc.proto DEPS attribute_proto)
cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf)
cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf)
proto_library(framework_proto SRCS framework.proto)

cc_library(attribute SRCS attribute.cc DEPS op_desc op_proto)
cc_library(attribute SRCS attribute.cc DEPS framework_proto)

cc_library(operator SRCS operator.cc DEPS op_desc device_context tensor scope attribute)
cc_library(operator SRCS operator.cc DEPS framework_proto device_context tensor scope attribute)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)

cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS op_proto operator)
cc_library(op_registry SRCS op_registry.cc DEPS op_desc grad_op_builder)
cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator)
cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry add_op)

py_proto_compile(framework_py_proto SRCS attribute.proto op_proto.proto op_desc.proto)
py_proto_compile(framework_py_proto SRCS framework.proto)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch __init__.py)
add_dependencies(framework_py_proto framework_py_proto_init)
Expand Down
2 changes: 1 addition & 1 deletion paddle/framework/attribute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ AttrType AttrTypeID<std::vector<std::string>>() {
return STRINGS;
}

Attribute GetAttrValue(const AttrDesc& attr_desc) {
Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
switch (attr_desc.type()) {
case paddle::framework::AttrType::INT: {
return attr_desc.i();
Expand Down
5 changes: 2 additions & 3 deletions paddle/framework/attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ limitations under the License. */
#include <unordered_set>
#include <vector>

#include "paddle/framework/attribute.pb.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/variant.h"

Expand All @@ -37,7 +36,7 @@ typedef std::unordered_map<std::string, Attribute> AttributeMap;
template <typename T>
AttrType AttrTypeID();

Attribute GetAttrValue(const AttrDesc& attr_desc);
Attribute GetAttrValue(const OpDesc::Attr& attr_desc);

// check whether a value(attribute) fit a certain limit
template <typename T>
Expand Down
28 changes: 0 additions & 28 deletions paddle/framework/attribute.proto

This file was deleted.

65 changes: 41 additions & 24 deletions paddle/framework/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,24 @@
namespace paddle {
namespace framework {

static bool AllInSet(const std::vector<std::string>& names,
const std::string& suffix,
const std::unordered_set<std::string>& set) {
template <typename Map, typename T>
static void ForEachVarName(Map& names, T callback) {
for (auto& name : names) {
if (set.find(name + suffix) == set.end()) {
return false;
for (auto& n : name.second) {
if (callback(n)) return;
}
}
return true;
}

static bool AllInSet(
const std::map<std::string, std::vector<std::string>>& names,
const std::string& suffix, const std::unordered_set<std::string>& set) {
bool all_in_set = true;
ForEachVarName(names, [&all_in_set, &set, &suffix](const std::string& n) {
all_in_set = set.find(n + suffix) != set.end();
return !all_in_set;
});
return all_in_set;
}

static std::shared_ptr<OperatorBase> NOP() {
Expand Down Expand Up @@ -68,10 +77,11 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
// Then all input gradients cannot be computed at all, and we put them into
// `no_grad_names` set. Return an NOP.
if (AllInSet(forwardOp.outputs_, kGradVarSuffix, no_grad_names)) {
for (auto& name : forwardOp.inputs_) {
// Mark all input is not need
no_grad_names.insert(name + kGradVarSuffix);
}
ForEachVarName(forwardOp.inputs_,
[&no_grad_names](const std::string& name) -> bool {
no_grad_names.insert(GradVarName(name));
return false;
});
return NOP();
}

Expand All @@ -93,9 +103,11 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
auto fwd = *it;
auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id);
net->AddOp(bwd);
for (auto& out : bwd->outputs_) {
dup_output_ops[out].emplace_back(local_op_id);
}
ForEachVarName(bwd->outputs_,
[&dup_output_ops, local_op_id](const std::string& out) {
dup_output_ops[out].emplace_back(local_op_id);
return false;
});
}
// Get unique ID for this method.
auto uid = uniq_id++;
Expand All @@ -117,7 +129,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
insert_position.push_back(
{dup_op.back(),
OpRegistry::CreateOp(
"add", {dup_outputs}, {name},
"add", {{"X", {dup_outputs}}}, {{"Out", {name}}},
{{"input_format",
std::vector<int>{0, static_cast<int>(dup_outputs.size())}}})});
}
Expand All @@ -131,7 +143,9 @@ std::shared_ptr<OperatorBase> BackwardRecursive(

} else {
std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp);
for (std::string& grad_input : grad_op->inputs_) {

ForEachVarName(grad_op->inputs_, [&no_grad_names,
&net](std::string& grad_input) {
if (no_grad_names.count(grad_input)) {
// +1 for \0
std::string prefix = grad_input.substr(
Expand All @@ -140,16 +154,19 @@ std::shared_ptr<OperatorBase> BackwardRecursive(

// If part of input gradient of that operator is not calculated, fill
// zero variables to that input gradient.
net->AddOp(OpRegistry::CreateOp("fill_zeros_like", {prefix},
{grad_input}, {}));
net->AddOp(OpRegistry::CreateOp("fill_zeros_like", {{"Src", {prefix}}},
{{"Dst", {grad_input}}}, {}));
}
}

for (std::string& grad_output : grad_op->outputs_) {
if (no_grad_names.count(grad_output)) {
grad_output = kEmptyVarName;
}
}
return false;
});

ForEachVarName(grad_op->outputs_,
[&no_grad_names](std::string& grad_output) {
if (no_grad_names.count(grad_output)) {
grad_output = kEmptyVarName;
}
return false;
});

if (net->ops_.empty()) { // Current no aux op is added to network
return grad_op;
Expand Down
Loading

0 comments on commit 81f5f86

Please sign in to comment.