Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize attr default value #33357

Merged
merged 18 commits into from
Jun 23, 2021
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 66 additions & 30 deletions paddle/fluid/framework/attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,22 +208,35 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc);

class AttrReader {
public:
explicit AttrReader(const AttributeMap& attrs) : attrs_(attrs) {}
explicit AttrReader(const AttributeMap& attrs)
: attrs_(attrs), default_attrs_(nullptr) {}

AttrReader(const AttributeMap& attrs, const AttributeMap& default_attrs)
: attrs_(attrs), default_attrs_(&default_attrs) {}

template <typename T>
inline const T& Get(const std::string& name) const {
PADDLE_ENFORCE_NE(attrs_.count(name), 0,
auto it = attrs_.find(name);
bool found = it != attrs_.end();
if (!found) {
if (default_attrs_ != nullptr) {
it = default_attrs_->find(name);
found = it != default_attrs_->end();
}
}
PADDLE_ENFORCE_EQ(found, true,
platform::errors::NotFound(
"Attribute (%s) should be in AttributeMap.", name));

Attribute& attr = const_cast<Attribute&>(attrs_.at(name));
Attribute& attr = const_cast<Attribute&>(it->second);
ExtractAttribute<T> extract_attr(name);
T* attr_value = extract_attr(attr);
return *attr_value;
}

private:
const AttributeMap& attrs_;
const AttributeMap* default_attrs_;
};

// check whether a value(attribute) fit a certain limit
Expand All @@ -234,8 +247,8 @@ class GreaterThanChecker {
void operator()(const T& value) const {
PADDLE_ENFORCE_GT(
value, lower_bound_,
platform::errors::OutOfRange(
"Check for attribute value greater than a certain value failed."));
platform::errors::OutOfRange("Check for attribute value greater than "
"a certain value failed."));
}

private:
Expand Down Expand Up @@ -332,9 +345,9 @@ class TypedAttrChecker {
TypedAttrChecker& SetDefault(const T& default_value) {
PADDLE_ENFORCE_EQ(
default_value_setter_.empty(), true,
platform::errors::AlreadyExists(
"Attribute (%s) has a default value and cannot be set repeatedly.",
attr_name_));
platform::errors::AlreadyExists("Attribute (%s) has a default value "
"and cannot be set repeatedly.",
attr_name_));
default_value_setter_.push_back(DefaultValueSetter<T>(default_value));
return *this;
}
Expand All @@ -345,30 +358,41 @@ class TypedAttrChecker {
return *this;
}

void operator()(AttributeMap* attr_map,
bool get_default_value_only = false) const {
void operator()(AttributeMap* attr_map, bool get_default_value_only = false,
bool only_check_exist_value = false) const {
if (get_default_value_only) {
if (!default_value_setter_.empty()) {
attr_map->emplace(attr_name_, default_value_setter_[0]());
}
return;
}

auto it = attr_map->find(attr_name_);
if (it == attr_map->end()) {
// user do not set this attr
PADDLE_ENFORCE_EQ(
default_value_setter_.empty(), false,
platform::errors::InvalidArgument(
"Attribute (%s) is not set correctly.", attr_name_));
// default_value_setter_ has no more than one element
attr_map->emplace(attr_name_, default_value_setter_[0]());
}
it = attr_map->find(attr_name_);
ExtractAttribute<T> extract_attr(attr_name_);
T* attr_value = extract_attr(it->second);
for (const auto& checker : value_checkers_) {
checker(*attr_value);
if (only_check_exist_value) {
auto it = attr_map->find(attr_name_);
if (it != attr_map->end()) {
ExtractAttribute<T> extract_attr(attr_name_);
T* attr_value = extract_attr(it->second);
for (const auto& checker : value_checkers_) {
checker(*attr_value);
}
}
} else {
auto it = attr_map->find(attr_name_);
if (it == attr_map->end()) {
// user do not set this attr
PADDLE_ENFORCE_EQ(
default_value_setter_.empty(), false,
platform::errors::InvalidArgument(
"Attribute (%s) is not set correctly.", attr_name_));
// default_value_setter_ has no more than one element
auto tmp = attr_map->emplace(attr_name_, default_value_setter_[0]());
it = tmp.first;
}
ExtractAttribute<T> extract_attr(attr_name_);
T* attr_value = extract_attr(it->second);
for (const auto& checker : value_checkers_) {
checker(*attr_value);
}
}
}

Expand All @@ -380,7 +404,7 @@ class TypedAttrChecker {

// check whether op's all attributes fit their own limits
class OpAttrChecker {
typedef std::function<void(AttributeMap*, bool)> AttrChecker;
typedef std::function<void(AttributeMap*, bool, bool)> AttrChecker;

public:
template <typename T>
Expand All @@ -390,18 +414,19 @@ class OpAttrChecker {
return *(checker.target<TypedAttrChecker<T>>());
}

void Check(AttributeMap* attr_map, bool explicit_only = false) const {
void Check(AttributeMap* attr_map, bool explicit_only = false,
bool only_check_exist_value = false) const {
auto checker_num = attr_checkers_.size();
if (explicit_only) checker_num = explicit_checker_num_;
for (size_t i = 0; i < checker_num; ++i) {
attr_checkers_[i](attr_map, false);
attr_checkers_[i](attr_map, false, only_check_exist_value);
}
}

AttributeMap GetAttrsDefaultValuesMap() const {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetDefaultAttrsValuesMap()?

AttributeMap default_values_map;
for (const auto& checker : attr_checkers_) {
checker(&default_values_map, true);
checker(&default_values_map, true, false);
}
return default_values_map;
}
Expand All @@ -410,15 +435,26 @@ class OpAttrChecker {
explicit_checker_num_ = attr_checkers_.size();
}

void InitDefaultAttributeMap() {
for (const auto& checker : attr_checkers_) {
checker(&default_attrs_, true, false);
}
}

const AttributeMap& GetDefaultAttrMap() const { return default_attrs_; }

private:
std::vector<AttrChecker> attr_checkers_;

AttributeMap default_attrs_;

// in order to improve the efficiency of dynamic graph mode,
// we divede the attribute into explicit type and implicit type.
// for explicit attribute, we mean the attribute added in the customized
// op makers, usually it's defined in the overloaded Make method.
// for implicit attribute, we mean the attribute added outside of the Make
// method like "op_role", "op_role_var", and they are useless in dynamic graph
// method like "op_role", "op_role_var", and they are useless in dynamic
// graph
// mode
size_t explicit_checker_num_;
};
Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/custom_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -781,10 +781,12 @@ void RegisterOperatorWithMetaInfo(
const imperative::NameVarBaseMap& var_base_map_in,
const imperative::NameVarBaseMap& var_base_map_out,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const std::map<std::string, std::string>& inplace_map) {
CustomGradOpMaker<paddle::imperative::OpBase> maker(
type, var_base_map_in, var_base_map_out, attrs, inplace_map,
grad_op_name, grad_op_inputs, grad_op_outputs);
maker.SetDygraphDefaultAttrsMap(default_attrs);
return maker();
};

Expand Down
2 changes: 2 additions & 0 deletions paddle/fluid/framework/details/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,10 @@ struct OpInfoFiller<T, kGradOpBaseMaker> {
const imperative::NameVarBaseMap& var_base_map_in,
const imperative::NameVarBaseMap& var_base_map_out,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const std::map<std::string, std::string>& inplace_map) {
T maker(type, var_base_map_in, var_base_map_out, attrs, inplace_map);
maker.SetDygraphDefaultAttrsMap(default_attrs);
return maker();
};
}
Expand Down
14 changes: 14 additions & 0 deletions paddle/fluid/framework/grad_op_desc_maker.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,19 @@ class SingleGradOpMaker<imperative::OpBase>
public:
using GradOpBaseMakerBase::GradOpBaseMakerBase;

virtual const framework::Attribute& GetAttr(const std::string& name) const {
auto it = Attrs().find(name);
if (it == Attrs().end()) {
it = this->DefaultAttrsMap().find(name);
PADDLE_ENFORCE_EQ(it != this->DefaultAttrsMap().end(), true,
platform::errors::NotFound(
"Cannot find attribute [%s] in operator [%s]", name,
this->ForwardOpType()));
}

return it->second;
}

std::shared_ptr<imperative::GradOpNode> operator()() const final {
auto node = this->NewGradNode();
auto& inplace_map = this->GetInplaceMap();
Expand All @@ -228,6 +241,7 @@ class SingleGradOpMaker<imperative::OpBase>
{
imperative::TracedGradOp traced_grad_op(node);
try {
traced_grad_op.SetAttrDefaultMap(this->DefaultAttrsMap());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SetDefaultAttrMap?

this->Apply(&traced_grad_op);
} catch (platform::EnforceNotMet& exception) {
framework::AppendErrorOpHint(traced_grad_op.Type(), &exception);
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/op_proto_maker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
op_checker_ = attr_checker;
Make();
op_checker_->RecordExplicitCheckerNum();
op_checker_->InitDefaultAttributeMap();

AddAttr<int>(OpRoleAttrName(), "The role of this operator")
.InEnum(
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/type_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ using DygraphGradOpMakerFN =
const imperative::NameVarBaseMap& /*var_base_map_in*/,
const imperative::NameVarBaseMap& /*var_base_map_out*/,
const framework::AttributeMap& /*attributes*/,
const framework::AttributeMap& /*attributes default*/,
const std::map<std::string, std::string>& /*inplace_map*/)>;

using InferVarTypeFN =
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/imperative/basic_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -474,10 +474,11 @@ void BasicEngine::Execute() {
try {
if (tmp_ins_ptr == nullptr) {
OpBase::Run(cur_op.InnerOp(), bwd_ins, tmp_outs, cur_op.Attrs(),
cur_op.place());
cur_op.DefaultAttrsMap(), cur_op.place());
} else {
OpBase::Run(cur_op.InnerOp(), *tmp_ins_ptr, tmp_outs,
cur_op.Attrs(), cur_op.place());
cur_op.Attrs(), cur_op.DefaultAttrsMap(),
cur_op.place());
}
} catch (platform::EnforceNotMet& exception) {
Clear();
Expand Down
16 changes: 15 additions & 1 deletion paddle/fluid/imperative/dygraph_grad_maker.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,18 @@ class GradOpBaseMakerBase {
return vec_temp;
}

// Only for dygraph
void SetDygraphDefaultAttrsMap(const framework::AttributeMap& default_attrs) {
default_attrs_ = &default_attrs;
}

const framework::AttributeMap& DefaultAttrsMap() const {
return *default_attrs_;
}

const framework::AttributeMap& Attrs() const { return attrs_; }

const framework::Attribute& GetAttr(const std::string& name) const {
virtual const framework::Attribute& GetAttr(const std::string& name) const {
auto it = attrs_.find(name);
PADDLE_ENFORCE_EQ(
it != attrs_.end(), true,
Expand Down Expand Up @@ -199,6 +208,7 @@ class GradOpBaseMakerBase {
const NameVarBaseMap& var_base_map_in_;
const NameVarBaseMap& var_base_map_out_;
const framework::AttributeMap& attrs_;
const framework::AttributeMap* default_attrs_;
const std::map<std::string, std::string>& inplace_map_;
};

Expand Down Expand Up @@ -285,6 +295,10 @@ class TracedGradOp {
return op_->SetAttrMap(attrs);
}

void SetAttrDefaultMap(const framework::AttributeMap& attrs) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SetDefaultAttrMap?

return op_->SetAttrDefaultMap(attrs);
}

void SetAttr(const std::string& name, const framework::Attribute& v) {
op_->SetAttr(name, v);
}
Expand Down
20 changes: 14 additions & 6 deletions paddle/fluid/imperative/execution_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const framework::RuntimeContext& ctx,
const NameVarMap<VarType>& var_base_map_in,
const NameVarMap<VarType>& var_base_map_out,
const framework::AttributeMap& attrs)
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs)
: ExecutionContext(op, scope, device_context, ctx),
var_base_map_in_(var_base_map_in),
var_base_map_out_(var_base_map_out),
attrs_(attrs) {}
attrs_(attrs),
default_attrs_(default_attrs) {}

std::string InputName(const std::string& name) const override {
auto it = var_base_map_in_.find(name);
Expand Down Expand Up @@ -92,17 +94,22 @@ class DygraphExecutionContext : public framework::ExecutionContext {
}

bool HasAttr(const std::string& name) const override {
return attrs_.count(name) != 0;
return attrs_.count(name) != 0 || default_attrs_.count(name) != 0;
}

const framework::AttributeMap& Attrs() const override { return attrs_; }

const framework::Attribute& GetAttr(const std::string& name) const override {
auto it = attrs_.find(name);

PADDLE_ENFORCE_NE(
it, attrs_.end(),
platform::errors::NotFound("can not find [%s] in attrs", name));
if (it == attrs_.end()) {
it = default_attrs_.find(name);
if (it == default_attrs_.end()) {
PADDLE_THROW(platform::errors::NotFound(
"Can not find [%s] in attributes of op %s.", name,
this->GetOp().Type()));
}
}

return it->second;
}
Expand Down Expand Up @@ -192,6 +199,7 @@ class DygraphExecutionContext : public framework::ExecutionContext {
const NameVarMap<VarType>& var_base_map_in_;
const NameVarMap<VarType>& var_base_map_out_;
const framework::AttributeMap& attrs_;
const framework::AttributeMap& default_attrs_;
};

} // namespace imperative
Expand Down
5 changes: 4 additions & 1 deletion paddle/fluid/imperative/infer_shape_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
DygraphInferShapeContext(const NameVarMap<VarType>* in,
const NameVarMap<VarType>* out,
const framework::AttributeMap* attr,
const framework::AttributeMap* attr_default,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const std::string op_type)
: var_base_map_in_(in),
var_base_map_out_(out),
attrs_(attr),
default_attrs_(attr_default),
op_type_(op_type) {}

bool HasInput(const std::string& name) const override {
Expand Down Expand Up @@ -101,7 +103,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
}

framework::AttrReader Attrs() const override {
return framework::AttrReader(*attrs_);
return framework::AttrReader(*attrs_, *default_attrs_);
}

std::vector<std::string> Inputs(const std::string& name) const override {
Expand Down Expand Up @@ -395,6 +397,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
const NameVarMap<VarType>* var_base_map_in_;
const NameVarMap<VarType>* var_base_map_out_;
const framework::AttributeMap* attrs_;
const framework::AttributeMap* default_attrs_;
const std::string op_type_;
};

Expand Down
Loading