Skip to content

Commit

Permalink
optimize attr default value (#33357)
Browse files Browse the repository at this point in the history
* optimize attr default value, test=develop

* refine, test=develop

* refine, test=develop

* refine, test=develop

* fix bug in AttrReader, test=develop

* fix bug, test=develop

* fix double_grad, test=develop

* refine, test=develop

* refine, test=develop

* fix checker null, test=develop

* for test, test=develop

* refine, test=develop

* refine, test=develop

* refine, test=develop

* refine, test=develop

* refine, test=develop

* refine, test=develop
  • Loading branch information
wanghuancoder authored Jun 23, 2021
1 parent 2133b45 commit 5d2eb67
Show file tree
Hide file tree
Showing 23 changed files with 239 additions and 102 deletions.
98 changes: 67 additions & 31 deletions paddle/fluid/framework/attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,22 +208,35 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc);

class AttrReader {
public:
explicit AttrReader(const AttributeMap& attrs) : attrs_(attrs) {}
explicit AttrReader(const AttributeMap& attrs)
: attrs_(attrs), default_attrs_(nullptr) {}

AttrReader(const AttributeMap& attrs, const AttributeMap& default_attrs)
: attrs_(attrs), default_attrs_(&default_attrs) {}

template <typename T>
inline const T& Get(const std::string& name) const {
PADDLE_ENFORCE_NE(attrs_.count(name), 0,
auto it = attrs_.find(name);
bool found = it != attrs_.end();
if (!found) {
if (default_attrs_ != nullptr) {
it = default_attrs_->find(name);
found = it != default_attrs_->end();
}
}
PADDLE_ENFORCE_EQ(found, true,
platform::errors::NotFound(
"Attribute (%s) should be in AttributeMap.", name));

Attribute& attr = const_cast<Attribute&>(attrs_.at(name));
Attribute& attr = const_cast<Attribute&>(it->second);
ExtractAttribute<T> extract_attr(name);
T* attr_value = extract_attr(attr);
return *attr_value;
}

private:
const AttributeMap& attrs_;
const AttributeMap* default_attrs_;
};

// check whether a value(attribute) fit a certain limit
Expand All @@ -234,8 +247,8 @@ class GreaterThanChecker {
void operator()(const T& value) const {
PADDLE_ENFORCE_GT(
value, lower_bound_,
platform::errors::OutOfRange(
"Check for attribute value greater than a certain value failed."));
platform::errors::OutOfRange("Check for attribute value greater than "
"a certain value failed."));
}

private:
Expand Down Expand Up @@ -332,9 +345,9 @@ class TypedAttrChecker {
TypedAttrChecker& SetDefault(const T& default_value) {
PADDLE_ENFORCE_EQ(
default_value_setter_.empty(), true,
platform::errors::AlreadyExists(
"Attribute (%s) has a default value and cannot be set repeatedly.",
attr_name_));
platform::errors::AlreadyExists("Attribute (%s) has a default value "
"and cannot be set repeatedly.",
attr_name_));
default_value_setter_.push_back(DefaultValueSetter<T>(default_value));
return *this;
}
Expand All @@ -345,30 +358,41 @@ class TypedAttrChecker {
return *this;
}

void operator()(AttributeMap* attr_map,
bool get_default_value_only = false) const {
void operator()(AttributeMap* attr_map, bool get_default_value_only = false,
bool only_check_exist_value = false) const {
if (get_default_value_only) {
if (!default_value_setter_.empty()) {
attr_map->emplace(attr_name_, default_value_setter_[0]());
}
return;
}

auto it = attr_map->find(attr_name_);
if (it == attr_map->end()) {
// user do not set this attr
PADDLE_ENFORCE_EQ(
default_value_setter_.empty(), false,
platform::errors::InvalidArgument(
"Attribute (%s) is not set correctly.", attr_name_));
// default_value_setter_ has no more than one element
attr_map->emplace(attr_name_, default_value_setter_[0]());
}
it = attr_map->find(attr_name_);
ExtractAttribute<T> extract_attr(attr_name_);
T* attr_value = extract_attr(it->second);
for (const auto& checker : value_checkers_) {
checker(*attr_value);
if (only_check_exist_value) {
auto it = attr_map->find(attr_name_);
if (it != attr_map->end()) {
ExtractAttribute<T> extract_attr(attr_name_);
T* attr_value = extract_attr(it->second);
for (const auto& checker : value_checkers_) {
checker(*attr_value);
}
}
} else {
auto it = attr_map->find(attr_name_);
if (it == attr_map->end()) {
// user do not set this attr
PADDLE_ENFORCE_EQ(
default_value_setter_.empty(), false,
platform::errors::InvalidArgument(
"Attribute (%s) is not set correctly.", attr_name_));
// default_value_setter_ has no more than one element
auto tmp = attr_map->emplace(attr_name_, default_value_setter_[0]());
it = tmp.first;
}
ExtractAttribute<T> extract_attr(attr_name_);
T* attr_value = extract_attr(it->second);
for (const auto& checker : value_checkers_) {
checker(*attr_value);
}
}
}

Expand All @@ -380,7 +404,7 @@ class TypedAttrChecker {

// check whether op's all attributes fit their own limits
class OpAttrChecker {
typedef std::function<void(AttributeMap*, bool)> AttrChecker;
typedef std::function<void(AttributeMap*, bool, bool)> AttrChecker;

public:
template <typename T>
Expand All @@ -390,18 +414,19 @@ class OpAttrChecker {
return *(checker.target<TypedAttrChecker<T>>());
}

void Check(AttributeMap* attr_map, bool explicit_only = false) const {
void Check(AttributeMap* attr_map, bool explicit_only = false,
bool only_check_exist_value = false) const {
auto checker_num = attr_checkers_.size();
if (explicit_only) checker_num = explicit_checker_num_;
for (size_t i = 0; i < checker_num; ++i) {
attr_checkers_[i](attr_map, false);
attr_checkers_[i](attr_map, false, only_check_exist_value);
}
}

AttributeMap GetAttrsDefaultValuesMap() const {
AttributeMap GetDefaultAttrsMap() const {
AttributeMap default_values_map;
for (const auto& checker : attr_checkers_) {
checker(&default_values_map, true);
checker(&default_values_map, true, false);
}
return default_values_map;
}
Expand All @@ -410,15 +435,26 @@ class OpAttrChecker {
explicit_checker_num_ = attr_checkers_.size();
}

void InitDefaultAttributeMap() {
for (const auto& checker : attr_checkers_) {
checker(&default_attrs_, true, false);
}
}

const AttributeMap& GetDefaultAttrMap() const { return default_attrs_; }

private:
std::vector<AttrChecker> attr_checkers_;

AttributeMap default_attrs_;

// in order to improve the efficiency of dynamic graph mode,
// we divede the attribute into explicit type and implicit type.
// for explicit attribute, we mean the attribute added in the customized
// op makers, usually it's defined in the overloaded Make method.
// for implicit attribute, we mean the attribute added outside of the Make
// method like "op_role", "op_role_var", and they are useless in dynamic graph
// method like "op_role", "op_role_var", and they are useless in dynamic
// graph
// mode
size_t explicit_checker_num_;
};
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/custom_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -781,10 +781,12 @@ void RegisterOperatorWithMetaInfo(
const imperative::NameVarBaseMap& var_base_map_in,
const imperative::NameVarBaseMap& var_base_map_out,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const std::map<std::string, std::string>& inplace_map) {
CustomGradOpMaker<paddle::imperative::OpBase> maker(
type, var_base_map_in, var_base_map_out, attrs, inplace_map,
grad_op_name, grad_op_inputs, grad_op_outputs);
maker.SetDygraphDefaultAttrsMap(default_attrs);
return maker();
};

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/details/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,10 @@ struct OpInfoFiller<T, kGradOpBaseMaker> {
const imperative::NameVarBaseMap& var_base_map_in,
const imperative::NameVarBaseMap& var_base_map_out,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const std::map<std::string, std::string>& inplace_map) {
T maker(type, var_base_map_in, var_base_map_out, attrs, inplace_map);
maker.SetDygraphDefaultAttrsMap(default_attrs);
return maker();
};
}
Expand Down
14 changes: 14 additions & 0 deletions paddle/fluid/framework/grad_op_desc_maker.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,19 @@ class SingleGradOpMaker<imperative::OpBase>
public:
using GradOpBaseMakerBase::GradOpBaseMakerBase;

virtual const framework::Attribute& GetAttr(const std::string& name) const {
auto it = Attrs().find(name);
if (it == Attrs().end()) {
it = this->DefaultAttrsMap().find(name);
PADDLE_ENFORCE_EQ(it != this->DefaultAttrsMap().end(), true,
platform::errors::NotFound(
"Cannot find attribute [%s] in operator [%s]", name,
this->ForwardOpType()));
}

return it->second;
}

std::shared_ptr<imperative::GradOpNode> operator()() const final {
auto node = this->NewGradNode();
auto& inplace_map = this->GetInplaceMap();
Expand All @@ -228,6 +241,7 @@ class SingleGradOpMaker<imperative::OpBase>
{
imperative::TracedGradOp traced_grad_op(node);
try {
traced_grad_op.SetDefaultAttrsMap(this->DefaultAttrsMap());
this->Apply(&traced_grad_op);
} catch (platform::EnforceNotMet& exception) {
framework::AppendErrorOpHint(traced_grad_op.Type(), &exception);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/ir/op_compat_sensible_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ AttrCompat& AttrCompat::IsLeftDefault() {
return *this;
}
const OpInfo& op_info = OpInfoMap::Instance().Get(op_name);
const AttributeMap attrs = op_info.Checker()->GetAttrsDefaultValuesMap();
const AttributeMap attrs = op_info.Checker()->GetDefaultAttrsMap();
if (attrs.find(attr_name_) == attrs.end()) {
LOG(WARNING) << "Op (" << op_name << ") has no default attr:" << attr_name_;
conditions_.emplace_back([](const Attribute& attr) { return false; });
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/op_proto_maker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
op_checker_ = attr_checker;
Make();
op_checker_->RecordExplicitCheckerNum();
op_checker_->InitDefaultAttributeMap();

AddAttr<int>(OpRoleAttrName(), "The role of this operator")
.InEnum(
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/type_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ using DygraphGradOpMakerFN =
const imperative::NameVarBaseMap& /*var_base_map_in*/,
const imperative::NameVarBaseMap& /*var_base_map_out*/,
const framework::AttributeMap& /*attributes*/,
const framework::AttributeMap& /*default attributes*/,
const std::map<std::string, std::string>& /*inplace_map*/)>;

using InferVarTypeFN =
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/imperative/basic_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -474,10 +474,11 @@ void BasicEngine::Execute() {
try {
if (tmp_ins_ptr == nullptr) {
OpBase::Run(cur_op.InnerOp(), bwd_ins, tmp_outs, cur_op.Attrs(),
cur_op.place());
cur_op.DefaultAttrsMap(), cur_op.place());
} else {
OpBase::Run(cur_op.InnerOp(), *tmp_ins_ptr, tmp_outs,
cur_op.Attrs(), cur_op.place());
cur_op.Attrs(), cur_op.DefaultAttrsMap(),
cur_op.place());
}
} catch (platform::EnforceNotMet& exception) {
Clear();
Expand Down
16 changes: 15 additions & 1 deletion paddle/fluid/imperative/dygraph_grad_maker.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,18 @@ class GradOpBaseMakerBase {
return vec_temp;
}

// Only for dygraph
void SetDygraphDefaultAttrsMap(const framework::AttributeMap& default_attrs) {
default_attrs_ = &default_attrs;
}

const framework::AttributeMap& DefaultAttrsMap() const {
return *default_attrs_;
}

const framework::AttributeMap& Attrs() const { return attrs_; }

const framework::Attribute& GetAttr(const std::string& name) const {
virtual const framework::Attribute& GetAttr(const std::string& name) const {
auto it = attrs_.find(name);
PADDLE_ENFORCE_EQ(
it != attrs_.end(), true,
Expand Down Expand Up @@ -199,6 +208,7 @@ class GradOpBaseMakerBase {
const NameVarBaseMap& var_base_map_in_;
const NameVarBaseMap& var_base_map_out_;
const framework::AttributeMap& attrs_;
const framework::AttributeMap* default_attrs_;
const std::map<std::string, std::string>& inplace_map_;
};

Expand Down Expand Up @@ -285,6 +295,10 @@ class TracedGradOp {
return op_->SetAttrMap(attrs);
}

void SetDefaultAttrsMap(const framework::AttributeMap& attrs) {
return op_->SetDefaultAttrsMap(attrs);
}

void SetAttr(const std::string& name, const framework::Attribute& v) {
op_->SetAttr(name, v);
}
Expand Down
20 changes: 14 additions & 6 deletions paddle/fluid/imperative/execution_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const framework::RuntimeContext& ctx,
const NameVarMap<VarType>& var_base_map_in,
const NameVarMap<VarType>& var_base_map_out,
const framework::AttributeMap& attrs)
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs)
: ExecutionContext(op, scope, device_context, ctx),
var_base_map_in_(var_base_map_in),
var_base_map_out_(var_base_map_out),
attrs_(attrs) {}
attrs_(attrs),
default_attrs_(default_attrs) {}

std::string InputName(const std::string& name) const override {
auto it = var_base_map_in_.find(name);
Expand Down Expand Up @@ -92,17 +94,22 @@ class DygraphExecutionContext : public framework::ExecutionContext {
}

bool HasAttr(const std::string& name) const override {
return attrs_.count(name) != 0;
return attrs_.count(name) != 0 || default_attrs_.count(name) != 0;
}

const framework::AttributeMap& Attrs() const override { return attrs_; }

const framework::Attribute& GetAttr(const std::string& name) const override {
auto it = attrs_.find(name);

PADDLE_ENFORCE_NE(
it, attrs_.end(),
platform::errors::NotFound("can not find [%s] in attrs", name));
if (it == attrs_.end()) {
it = default_attrs_.find(name);
if (it == default_attrs_.end()) {
PADDLE_THROW(platform::errors::NotFound(
"Can not find [%s] in attributes of op %s.", name,
this->GetOp().Type()));
}
}

return it->second;
}
Expand Down Expand Up @@ -192,6 +199,7 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const NameVarMap<VarType>& var_base_map_in_;
const NameVarMap<VarType>& var_base_map_out_;
const framework::AttributeMap& attrs_;
const framework::AttributeMap& default_attrs_;
};

} // namespace imperative
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/imperative/infer_shape_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
DygraphInferShapeContext(const NameVarMap<VarType>* in,
const NameVarMap<VarType>* out,
const framework::AttributeMap* attr,
const framework::AttributeMap* default_attr,
const std::string op_type)
: var_base_map_in_(in),
var_base_map_out_(out),
attrs_(attr),
default_attrs_(default_attr),
op_type_(op_type) {}

bool HasInput(const std::string& name) const override {
Expand Down Expand Up @@ -101,7 +103,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
}

framework::AttrReader Attrs() const override {
return framework::AttrReader(*attrs_);
return framework::AttrReader(*attrs_, *default_attrs_);
}

std::vector<std::string> Inputs(const std::string& name) const override {
Expand Down Expand Up @@ -395,6 +397,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
const NameVarMap<VarType>* var_base_map_in_;
const NameVarMap<VarType>* var_base_map_out_;
const framework::AttributeMap* attrs_;
const framework::AttributeMap* default_attrs_;
const std::string op_type_;
};

Expand Down
Loading

0 comments on commit 5d2eb67

Please sign in to comment.