-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
optimize attr default value #33357
optimize attr default value #33357
Changes from 17 commits
dd3329d
c8b557b
8b1cdab
0cf6374
2facfca
911300c
3598f3c
2ff8504
e22ea5f
d44aeb6
0132959
32e2f60
eb9b97e
e681acd
ea470d7
f54c359
2d4c9dc
806564c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -219,6 +219,19 @@ class SingleGradOpMaker<imperative::OpBase> | |
public: | ||
using GradOpBaseMakerBase::GradOpBaseMakerBase; | ||
|
||
virtual const framework::Attribute& GetAttr(const std::string& name) const { | ||
auto it = Attrs().find(name); | ||
if (it == Attrs().end()) { | ||
it = this->DefaultAttrsMap().find(name); | ||
PADDLE_ENFORCE_EQ(it != this->DefaultAttrsMap().end(), true, | ||
platform::errors::NotFound( | ||
"Cannot find attribute [%s] in operator [%s]", name, | ||
this->ForwardOpType())); | ||
} | ||
|
||
return it->second; | ||
} | ||
|
||
std::shared_ptr<imperative::GradOpNode> operator()() const final { | ||
auto node = this->NewGradNode(); | ||
auto& inplace_map = this->GetInplaceMap(); | ||
|
@@ -228,6 +241,7 @@ class SingleGradOpMaker<imperative::OpBase> | |
{ | ||
imperative::TracedGradOp traced_grad_op(node); | ||
try { | ||
traced_grad_op.SetAttrDefaultMap(this->DefaultAttrsMap()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SetDefaultAttrMap? |
||
this->Apply(&traced_grad_op); | ||
} catch (platform::EnforceNotMet& exception) { | ||
framework::AppendErrorOpHint(traced_grad_op.Type(), &exception); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -113,9 +113,18 @@ class GradOpBaseMakerBase { | |
return vec_temp; | ||
} | ||
|
||
// Only for dygraph | ||
void SetDygraphDefaultAttrsMap(const framework::AttributeMap& default_attrs) { | ||
default_attrs_ = &default_attrs; | ||
} | ||
|
||
const framework::AttributeMap& DefaultAttrsMap() const { | ||
return *default_attrs_; | ||
} | ||
|
||
const framework::AttributeMap& Attrs() const { return attrs_; } | ||
|
||
const framework::Attribute& GetAttr(const std::string& name) const { | ||
virtual const framework::Attribute& GetAttr(const std::string& name) const { | ||
auto it = attrs_.find(name); | ||
PADDLE_ENFORCE_EQ( | ||
it != attrs_.end(), true, | ||
|
@@ -199,6 +208,7 @@ class GradOpBaseMakerBase { | |
const NameVarBaseMap& var_base_map_in_; | ||
const NameVarBaseMap& var_base_map_out_; | ||
const framework::AttributeMap& attrs_; | ||
const framework::AttributeMap* default_attrs_; | ||
const std::map<std::string, std::string>& inplace_map_; | ||
}; | ||
|
||
|
@@ -285,6 +295,10 @@ class TracedGradOp { | |
return op_->SetAttrMap(attrs); | ||
} | ||
|
||
void SetAttrDefaultMap(const framework::AttributeMap& attrs) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SetDefaultAttrMap? |
||
return op_->SetAttrDefaultMap(attrs); | ||
} | ||
|
||
void SetAttr(const std::string& name, const framework::Attribute& v) { | ||
op_->SetAttr(name, v); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,10 +35,12 @@ class DygraphInferShapeContext : public framework::InferShapeContext { | |
DygraphInferShapeContext(const NameVarMap<VarType>* in, | ||
const NameVarMap<VarType>* out, | ||
const framework::AttributeMap* attr, | ||
const framework::AttributeMap* attr_default, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ? |
||
const std::string op_type) | ||
: var_base_map_in_(in), | ||
var_base_map_out_(out), | ||
attrs_(attr), | ||
default_attrs_(attr_default), | ||
op_type_(op_type) {} | ||
|
||
bool HasInput(const std::string& name) const override { | ||
|
@@ -101,7 +103,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext { | |
} | ||
|
||
framework::AttrReader Attrs() const override { | ||
return framework::AttrReader(*attrs_); | ||
return framework::AttrReader(*attrs_, *default_attrs_); | ||
} | ||
|
||
std::vector<std::string> Inputs(const std::string& name) const override { | ||
|
@@ -395,6 +397,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext { | |
const NameVarMap<VarType>* var_base_map_in_; | ||
const NameVarMap<VarType>* var_base_map_out_; | ||
const framework::AttributeMap* attrs_; | ||
const framework::AttributeMap* default_attrs_; | ||
const std::string op_type_; | ||
}; | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GetDefaultAttrsValuesMap()?