Skip to content

Commit

Permalink
Merge pull request #2936 from reyoung/feature/create_op_in_cpp_params
Browse files Browse the repository at this point in the history
Make CreateOp in Plain C++ params
  • Loading branch information
reyoung authored Jul 18, 2017
2 parents c85a323 + f6a51d9 commit fb48cb1
Showing 1 changed file with 37 additions and 27 deletions.
64 changes: 37 additions & 27 deletions paddle/framework/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ Add a mark to which output is temporary is helpful for future optimization.
class OpRegistry {
using OpCreator = std::function<OperatorBase*()>;
using VarIndexMap = std::unordered_map<std::string, int>;
using VarNameList = std::vector<std::string>;

public:
template <typename OpType, typename ProtoMakerType>
Expand Down Expand Up @@ -226,42 +227,51 @@ class OpRegistry {
}
}

static OperatorPtr CreateOp(const OpDesc& op_desc) {
//! Create a OpPtr by type.
std::string op_type = op_desc.type();
OperatorPtr op(creators().at(op_type)());
//! Fill op's data member. Not use constructor because it will be noising
//! for Op developer.
op->type_ = op_desc.type();
// set op's inputs_ from desc.
op->inputs_.reserve((size_t)op_desc.inputs_size());
std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
std::back_inserter(op->inputs_));
// set op's outputs_ from desc.
op->outputs_.reserve((size_t)op_desc.outputs_size());
std::copy(op_desc.outputs().begin(), op_desc.outputs().end(),
std::back_inserter(op->outputs_));
static OperatorPtr CreateOp(const std::string& type,
const VarNameList& inputs,
const VarNameList& outputs,
const AttributeMap& attrs) {
auto op_create_it = creators().find(type);
PADDLE_ENFORCE(op_create_it != creators().end(),
"Operator %s cannot be found", type);

//! Fill attrs, and validate attrs.
for (auto& attr : op_desc.attrs()) {
op->attrs_[attr.name()] = AttrTypeHelper::GetAttrValue(attr);
}
op_checkers().at(op_type).Check(op->attrs_);
auto op = op_create_it->second();
op->type_ = type;
op->inputs_ = inputs;
op->outputs_ = outputs;
op->attrs_ = attrs;
op_checkers().at(type).Check(op->attrs_);

//! Convert Temporary variable name to an unique variable name.
GenerateTempVariableName(op.get());
GenerateTempVariableName(op);

//! set argument offsets stored in op.
{
auto var_index_it = VarIndexMaps().find(op_type);
auto var_index_it = VarIndexMaps().find(type);
if (var_index_it != VarIndexMaps().end()) {
op->in_out_idxs_ = var_index_it->second;
}
}
//! Other op's custom Init for a complex Op. For simple Op, the Init
//! method do nothing.

op->Init();
return op;
return OperatorPtr(op);
}

static OperatorPtr CreateOp(const OpDesc& op_desc) {
std::vector<std::string> inputs;
inputs.reserve((size_t)op_desc.inputs_size());
std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
std::back_inserter(inputs));

std::vector<std::string> outputs;
outputs.reserve((size_t)op_desc.outputs_size());
std::copy(op_desc.outputs().begin(), op_desc.outputs().end(),
std::back_inserter(outputs));

AttributeMap attrs;
for (auto& attr : op_desc.attrs()) {
attrs[attr.name()] = AttrTypeHelper::GetAttrValue(attr);
}

return CreateOp(op_desc.type(), inputs, outputs, attrs);
}

static std::unordered_map<std::string, OpProto>& protos() {
Expand Down

0 comments on commit fb48cb1

Please sign in to comment.