Skip to content

Commit

Permalink
Refactor Rigistry::CreateGradOp()
Browse files Browse the repository at this point in the history
We put forward Op's inputs, outputs and output gradients into Grad
Op's inputs, and put forward Op's input gradients into Grad Op's output.
So Grad Op's `in_out_idx`, `input_format` and 'output format' need to be
rebuilt during Op creating.
  • Loading branch information
JiayiFeng committed Jul 19, 2017
1 parent e786746 commit bf4da3d
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 29 deletions.
150 changes: 121 additions & 29 deletions paddle/framework/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,11 @@ class OpRegistry {
}
}

template <typename OpType>
static void RegisterGradOp(const std::string& op_type) {
grad_creators()[op_type] = [] { return new OpType; };
}

static OperatorPtr CreateOp(const std::string& type,
const VarNameList& inputs,
const VarNameList& outputs,
Expand All @@ -240,6 +245,7 @@ class OpRegistry {
op->type_ = type;
op->inputs_ = inputs;
op->outputs_ = outputs;

op->attrs_ = attrs;
op_checkers().at(type).Check(op->attrs_);

Expand All @@ -256,11 +262,6 @@ class OpRegistry {
return OperatorPtr(op);
}

template <typename OpType>
static void RegisterGradOp(const std::string& op_type) {
grad_creators()[op_type] = [] { return new OpType; };
}

static OperatorPtr CreateOp(const OpDesc& op_desc) {
std::vector<std::string> inputs;
inputs.reserve((size_t)op_desc.inputs_size());
Expand All @@ -280,19 +281,16 @@ class OpRegistry {
return CreateOp(op_desc.type(), inputs, outputs, attrs);
}

static OperatorPtr CreateGradOp(std::shared_ptr<OperatorBase> op) {
OperatorPtr op_grad(grad_creators().at(op->type_)());
op_grad->type_ = op->type_;
op_grad->inputs_.reserve(op->inputs_.size());
for (auto& input : op->inputs_) {
op_grad->inputs_.emplace_back(input);
op_grad->outputs_.emplace_back(input + "@grad");
}
for (auto& output : op->outputs_) {
op_grad->inputs_.emplace_back(output);
op_grad->inputs_.emplace_back(output + "@grad");
}
return op_grad;
static OperatorPtr CreateGradOp(OperatorPtr op) {
OperatorPtr grad_op(grad_creators().at(op->type_)());
grad_op->type_ = op->type_;

AssembleGradInOut(op, grad_op);
GenerateGradArgOffset(op, grad_op);
GenerateGradAttr(op, grad_op);

grad_op->Init();
return grad_op;
}

static std::unordered_map<std::string, OpProto>& protos() {
Expand All @@ -307,6 +305,21 @@ class OpRegistry {
return maps_;
}

static std::unordered_map<std::string, OpCreator>& creators() {
static std::unordered_map<std::string, OpCreator> creators_;
return creators_;
}

static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_;
};

static std::unordered_map<std::string, OpCreator>& grad_creators() {
static std::unordered_map<std::string, OpCreator> grad_creators_;
return grad_creators_;
}

static void GenerateTempVariableName(OperatorBase* op) {
static std::atomic<size_t> gUniqId(0UL);
for (auto& outname : op->outputs_) {
Expand All @@ -318,19 +331,98 @@ class OpRegistry {
}
}

static std::unordered_map<std::string, OpCreator>& creators() {
static std::unordered_map<std::string, OpCreator> creators_;
return creators_;
static void AssembleGradInOut(OperatorPtr op, OperatorPtr grad_op) {
size_t in_sz = op->inputs_.size() + op->outputs_.size() * 2;
grad_op->inputs_.reserve(in_sz);
size_t out_sz = op->inputs_.size();
grad_op->outputs_.reserve(out_sz);
// copy op->inputs_ to grad_op->inputs_
std::copy(op->inputs_.begin(), op->inputs_.end(),
std::back_inserter(grad_op->inputs_));
// copy op->outputs_ to grad_op->inputs_
std::copy(op->outputs_.begin(), op->outputs_.end(),
std::back_inserter(grad_op->inputs_));
// add gradients of op->outputs_ to grad_op->inputs_
for (const std::string& name : op->outputs_) {
grad_op->inputs_.emplace_back(name + OperatorBase::GRAD_VAR_SUFFIX());
}
// add gradients of op->inputs_ to grad_op->outputs_
for (const std::string& name : op->inputs_) {
grad_op->outputs_.emplace_back(name + OperatorBase::GRAD_VAR_SUFFIX());
}
}

static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
return op_checkers_;
};
static void GenerateGradArgOffset(OperatorPtr op, OperatorPtr grad_op) {
VarIndexMap* grad_varmap = new VarIndexMap();
const OpProto& op_proto = protos()[op->type_];
int idx = 0;
// offset of op's inputs
for (const auto& var : op_proto.inputs()) {
(*grad_varmap)[var.name()] = idx++;
}
// offset of op's outputs
for (const auto& var : op_proto.outputs()) {
(*grad_varmap)[var.name()] = idx++;
}
// offset of gradients of op's output
for (const auto& var : op_proto.outputs()) {
(*grad_varmap)[var.name() + OperatorBase::GRAD_VAR_SUFFIX()] = idx++;
}
idx = 0;
// offset of gradients of op's input
for (const auto& var : op_proto.inputs()) {
(*grad_varmap)[var.name() + OperatorBase::GRAD_VAR_SUFFIX()] = idx++;
}
grad_op->in_out_idxs_.reset(grad_varmap);
}

static std::unordered_map<std::string, OpCreator>& grad_creators() {
static std::unordered_map<std::string, OpCreator> grad_creators_;
return grad_creators_;
static void GenerateGradAttr(OperatorPtr op, OperatorPtr grad_op) {
const OpProto& op_proto = protos()[op->type_];
grad_op->attrs_ = op->attrs_;
grad_op->attrs_.erase("input_format");
grad_op->attrs_.erase("output_format");
bool has_in_format = op->attrs_.count("input_format");
bool has_out_format = op->attrs_.count("output_format");
// grad_op's inputs_ contains op's inputs_, outputs_ and gradients of
// outpus_. So grad_op's input_format is necessary when op has
// either input_format or output_format.
if (has_in_format || has_out_format) {
std::vector<int> old_in_format;
std::vector<int> old_out_format;
has_in_format
? old_in_format = op->GetAttr<std::vector<int>>("input_format")
: old_in_format = std::vector<int>(op_proto.inputs_size()),
std::iota(old_in_format.begin(), old_in_format.end(), 0);
has_out_format
? old_out_format = op->GetAttr<std::vector<int>>("output_format")
: old_out_format = std::vector<int>(op_proto.outputs_size()),
std::iota(old_out_format.begin(), old_out_format.end(), 0);

std::vector<int> in_format;
in_format.reserve(old_in_format.size() + old_out_format.size() * 2);
int base = 0;
for (const int& idx : old_in_format) {
in_format.emplace_back(idx + base);
}
base += op->inputs_.size();
for (const int& idx : old_out_format) {
in_format.emplace_back(idx + base);
}
base += op->outputs_.size();
for (const int& idx : old_in_format) {
in_format.emplace_back(idx + base);
}
grad_op->attrs_["input_format"] = in_format;
// grad_op's outputs_ contains gradients of op's inputs_. So grad_op's
// output_format is necessary only when op has input_format.
if (has_in_format) {
std::vector<int> out_format;
out_format.reserve(op_proto.inputs_size());
std::copy(old_in_format.begin(), old_in_format.end(),
std::back_inserter(out_format));
grad_op->attrs_["output_format"] = out_format;
}
}
}
};

Expand Down Expand Up @@ -370,7 +462,7 @@ class GradOpRegisterHelper {
int __op_register_##__op_type##_handle__() { return 0; }

/**
* Macro to Register Operator.
* Macro to Register Gradient Operator.
*/
#define REGISTER_GRADIENT_OP(__op_type, __op_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
Expand Down
5 changes: 5 additions & 0 deletions paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ class OperatorBase {
/// but it will be convert to a unique name in scope after OpCreator.
static std::string TMP_VAR_NAME() { return "@TEMP@"; }

/// If a variable's name has a certain suffix, it means that the
/// variable is the gradient of another varibale.
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
static std::string GRAD_VAR_SUFFIX() { return "@GRAD"; }

virtual ~OperatorBase() {}

template <typename T>
Expand Down

0 comments on commit bf4da3d

Please sign in to comment.