diff --git a/paddle/fluid/framework/attribute.h b/paddle/fluid/framework/attribute.h index 66b988ee1f1fb..e9e1875765633 100644 --- a/paddle/fluid/framework/attribute.h +++ b/paddle/fluid/framework/attribute.h @@ -208,15 +208,27 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc); class AttrReader { public: - explicit AttrReader(const AttributeMap& attrs) : attrs_(attrs) {} + explicit AttrReader(const AttributeMap& attrs) + : attrs_(attrs), default_attrs_(nullptr) {} + + AttrReader(const AttributeMap& attrs, const AttributeMap& default_attrs) + : attrs_(attrs), default_attrs_(&default_attrs) {} template inline const T& Get(const std::string& name) const { - PADDLE_ENFORCE_NE(attrs_.count(name), 0, + auto it = attrs_.find(name); + bool found = it != attrs_.end(); + if (!found) { + if (default_attrs_ != nullptr) { + it = default_attrs_->find(name); + found = it != default_attrs_->end(); + } + } + PADDLE_ENFORCE_EQ(found, true, platform::errors::NotFound( "Attribute (%s) should be in AttributeMap.", name)); - Attribute& attr = const_cast(attrs_.at(name)); + Attribute& attr = const_cast(it->second); ExtractAttribute extract_attr(name); T* attr_value = extract_attr(attr); return *attr_value; @@ -224,6 +236,7 @@ class AttrReader { private: const AttributeMap& attrs_; + const AttributeMap* default_attrs_; }; // check whether a value(attribute) fit a certain limit @@ -234,8 +247,8 @@ class GreaterThanChecker { void operator()(const T& value) const { PADDLE_ENFORCE_GT( value, lower_bound_, - platform::errors::OutOfRange( - "Check for attribute value greater than a certain value failed.")); + platform::errors::OutOfRange("Check for attribute value greater than " + "a certain value failed.")); } private: @@ -332,9 +345,9 @@ class TypedAttrChecker { TypedAttrChecker& SetDefault(const T& default_value) { PADDLE_ENFORCE_EQ( default_value_setter_.empty(), true, - platform::errors::AlreadyExists( - "Attribute (%s) has a default value and cannot be set repeatedly.", - attr_name_)); + platform::errors::AlreadyExists("Attribute (%s) has a default value " + "and cannot be set repeatedly.", + attr_name_)); default_value_setter_.push_back(DefaultValueSetter(default_value)); return *this; } @@ -345,8 +358,8 @@ class TypedAttrChecker { return *this; } - void operator()(AttributeMap* attr_map, - bool get_default_value_only = false) const { + void operator()(AttributeMap* attr_map, bool get_default_value_only = false, + bool only_check_exist_value = false) const { if (get_default_value_only) { if (!default_value_setter_.empty()) { attr_map->emplace(attr_name_, default_value_setter_[0]()); @@ -354,21 +367,32 @@ class TypedAttrChecker { return; } - auto it = attr_map->find(attr_name_); - if (it == attr_map->end()) { - // user do not set this attr - PADDLE_ENFORCE_EQ( - default_value_setter_.empty(), false, - platform::errors::InvalidArgument( - "Attribute (%s) is not set correctly.", attr_name_)); - // default_value_setter_ has no more than one element - attr_map->emplace(attr_name_, default_value_setter_[0]()); - } - it = attr_map->find(attr_name_); - ExtractAttribute extract_attr(attr_name_); - T* attr_value = extract_attr(it->second); - for (const auto& checker : value_checkers_) { - checker(*attr_value); + if (only_check_exist_value) { + auto it = attr_map->find(attr_name_); + if (it != attr_map->end()) { + ExtractAttribute extract_attr(attr_name_); + T* attr_value = extract_attr(it->second); + for (const auto& checker : value_checkers_) { + checker(*attr_value); + } + } + } else { + auto it = attr_map->find(attr_name_); + if (it == attr_map->end()) { + // user do not set this attr + PADDLE_ENFORCE_EQ( + default_value_setter_.empty(), false, + platform::errors::InvalidArgument( + "Attribute (%s) is not set correctly.", attr_name_)); + // default_value_setter_ has no more than one element + auto tmp = attr_map->emplace(attr_name_, default_value_setter_[0]()); + it = tmp.first; + } + ExtractAttribute extract_attr(attr_name_); + T* attr_value = extract_attr(it->second); + for (const auto& checker : value_checkers_) { + checker(*attr_value); + } } } @@ -380,7 +404,7 @@ class TypedAttrChecker { // check whether op's all attributes fit their own limits class OpAttrChecker { - typedef std::function AttrChecker; + typedef std::function AttrChecker; public: template @@ -390,18 +414,19 @@ class OpAttrChecker { return *(checker.target>()); } - void Check(AttributeMap* attr_map, bool explicit_only = false) const { + void Check(AttributeMap* attr_map, bool explicit_only = false, + bool only_check_exist_value = false) const { auto checker_num = attr_checkers_.size(); if (explicit_only) checker_num = explicit_checker_num_; for (size_t i = 0; i < checker_num; ++i) { - attr_checkers_[i](attr_map, false); + attr_checkers_[i](attr_map, false, only_check_exist_value); } } - AttributeMap GetAttrsDefaultValuesMap() const { + AttributeMap GetDefaultAttrsMap() const { AttributeMap default_values_map; for (const auto& checker : attr_checkers_) { - checker(&default_values_map, true); + checker(&default_values_map, true, false); } return default_values_map; } @@ -410,15 +435,26 @@ class OpAttrChecker { explicit_checker_num_ = attr_checkers_.size(); } + void InitDefaultAttributeMap() { + for (const auto& checker : attr_checkers_) { + checker(&default_attrs_, true, false); + } + } + + const AttributeMap& GetDefaultAttrMap() const { return default_attrs_; } + private: std::vector attr_checkers_; + AttributeMap default_attrs_; + // in order to improve the efficiency of dynamic graph mode, // we divede the attribute into explicit type and implicit type. // for explicit attribute, we mean the attribute added in the customized // op makers, usually it's defined in the overloaded Make method. // for implicit attribute, we mean the attribute added outside of the Make - // method like "op_role", "op_role_var", and they are useless in dynamic graph + // method like "op_role", "op_role_var", and they are useless in dynamic + // graph // mode size_t explicit_checker_num_; }; diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index c4b833ec94c29..b1c5ff86d1979 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -781,10 +781,12 @@ void RegisterOperatorWithMetaInfo( const imperative::NameVarBaseMap& var_base_map_in, const imperative::NameVarBaseMap& var_base_map_out, const framework::AttributeMap& attrs, + const framework::AttributeMap& default_attrs, const std::map& inplace_map) { CustomGradOpMaker maker( type, var_base_map_in, var_base_map_out, attrs, inplace_map, grad_op_name, grad_op_inputs, grad_op_outputs); + maker.SetDygraphDefaultAttrsMap(default_attrs); return maker(); }; diff --git a/paddle/fluid/framework/details/op_registry.h b/paddle/fluid/framework/details/op_registry.h index df5370e42ee9f..27f55e237f516 100644 --- a/paddle/fluid/framework/details/op_registry.h +++ b/paddle/fluid/framework/details/op_registry.h @@ -249,8 +249,10 @@ struct OpInfoFiller { const imperative::NameVarBaseMap& var_base_map_in, const imperative::NameVarBaseMap& var_base_map_out, const framework::AttributeMap& attrs, + const framework::AttributeMap& default_attrs, const std::map& inplace_map) { T maker(type, var_base_map_in, var_base_map_out, attrs, inplace_map); + maker.SetDygraphDefaultAttrsMap(default_attrs); return maker(); }; } diff --git a/paddle/fluid/framework/grad_op_desc_maker.h b/paddle/fluid/framework/grad_op_desc_maker.h index b0247fe795b3e..ebbfd446a03de 100644 --- a/paddle/fluid/framework/grad_op_desc_maker.h +++ b/paddle/fluid/framework/grad_op_desc_maker.h @@ -219,6 +219,19 @@ class SingleGradOpMaker public: using GradOpBaseMakerBase::GradOpBaseMakerBase; + virtual const framework::Attribute& GetAttr(const std::string& name) const { + auto it = Attrs().find(name); + if (it == Attrs().end()) { + it = this->DefaultAttrsMap().find(name); + PADDLE_ENFORCE_EQ(it != this->DefaultAttrsMap().end(), true, + platform::errors::NotFound( + "Cannot find attribute [%s] in operator [%s]", name, + this->ForwardOpType())); + } + + return it->second; + } + std::shared_ptr operator()() const final { auto node = this->NewGradNode(); auto& inplace_map = this->GetInplaceMap(); @@ -228,6 +241,7 @@ class SingleGradOpMaker { imperative::TracedGradOp traced_grad_op(node); try { + traced_grad_op.SetDefaultAttrsMap(this->DefaultAttrsMap()); this->Apply(&traced_grad_op); } catch (platform::EnforceNotMet& exception) { framework::AppendErrorOpHint(traced_grad_op.Type(), &exception); diff --git a/paddle/fluid/framework/ir/op_compat_sensible_pass.cc b/paddle/fluid/framework/ir/op_compat_sensible_pass.cc index cbb12839362f3..56637a6c7b2b3 100644 --- a/paddle/fluid/framework/ir/op_compat_sensible_pass.cc +++ b/paddle/fluid/framework/ir/op_compat_sensible_pass.cc @@ -61,7 +61,7 @@ AttrCompat& AttrCompat::IsLeftDefault() { return *this; } const OpInfo& op_info = OpInfoMap::Instance().Get(op_name); - const AttributeMap attrs = op_info.Checker()->GetAttrsDefaultValuesMap(); + const AttributeMap attrs = op_info.Checker()->GetDefaultAttrsMap(); if (attrs.find(attr_name_) == attrs.end()) { LOG(WARNING) << "Op (" << op_name << ") has no default attr:" << attr_name_; conditions_.emplace_back([](const Attribute& attr) { return false; }); diff --git a/paddle/fluid/framework/op_proto_maker.cc b/paddle/fluid/framework/op_proto_maker.cc index 0b9fd0a47e22c..8fbea51584d3c 100644 --- a/paddle/fluid/framework/op_proto_maker.cc +++ b/paddle/fluid/framework/op_proto_maker.cc @@ -66,6 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, op_checker_ = attr_checker; Make(); op_checker_->RecordExplicitCheckerNum(); + op_checker_->InitDefaultAttributeMap(); AddAttr(OpRoleAttrName(), "The role of this operator") .InEnum( diff --git a/paddle/fluid/framework/type_defs.h b/paddle/fluid/framework/type_defs.h index e43cccfe64816..951daea47bde3 100644 --- a/paddle/fluid/framework/type_defs.h +++ b/paddle/fluid/framework/type_defs.h @@ -71,6 +71,7 @@ using DygraphGradOpMakerFN = const imperative::NameVarBaseMap& /*var_base_map_in*/, const imperative::NameVarBaseMap& /*var_base_map_out*/, const framework::AttributeMap& /*attributes*/, + const framework::AttributeMap& /*default attributes*/, const std::map& /*inplace_map*/)>; using InferVarTypeFN = diff --git a/paddle/fluid/imperative/basic_engine.cc b/paddle/fluid/imperative/basic_engine.cc index 7bcc3d6c608c9..84ee1fbe5df96 100644 --- a/paddle/fluid/imperative/basic_engine.cc +++ b/paddle/fluid/imperative/basic_engine.cc @@ -474,10 +474,11 @@ void BasicEngine::Execute() { try { if (tmp_ins_ptr == nullptr) { OpBase::Run(cur_op.InnerOp(), bwd_ins, tmp_outs, cur_op.Attrs(), - cur_op.place()); + cur_op.DefaultAttrsMap(), cur_op.place()); } else { OpBase::Run(cur_op.InnerOp(), *tmp_ins_ptr, tmp_outs, - cur_op.Attrs(), cur_op.place()); + cur_op.Attrs(), cur_op.DefaultAttrsMap(), + cur_op.place()); } } catch (platform::EnforceNotMet& exception) { Clear(); diff --git a/paddle/fluid/imperative/dygraph_grad_maker.h b/paddle/fluid/imperative/dygraph_grad_maker.h index 7fefc9ccc67b5..f1eb8aa62c927 100644 --- a/paddle/fluid/imperative/dygraph_grad_maker.h +++ b/paddle/fluid/imperative/dygraph_grad_maker.h @@ -113,9 +113,18 @@ class GradOpBaseMakerBase { return vec_temp; } + // Only for dygraph + void SetDygraphDefaultAttrsMap(const framework::AttributeMap& default_attrs) { + default_attrs_ = &default_attrs; + } + + const framework::AttributeMap& DefaultAttrsMap() const { + return *default_attrs_; + } + const framework::AttributeMap& Attrs() const { return attrs_; } - const framework::Attribute& GetAttr(const std::string& name) const { + virtual const framework::Attribute& GetAttr(const std::string& name) const { auto it = attrs_.find(name); PADDLE_ENFORCE_EQ( it != attrs_.end(), true, @@ -199,6 +208,7 @@ class GradOpBaseMakerBase { const NameVarBaseMap& var_base_map_in_; const NameVarBaseMap& var_base_map_out_; const framework::AttributeMap& attrs_; + const framework::AttributeMap* default_attrs_; const std::map& inplace_map_; }; @@ -285,6 +295,10 @@ class TracedGradOp { return op_->SetAttrMap(attrs); } + void SetDefaultAttrsMap(const framework::AttributeMap& attrs) { + return op_->SetDefaultAttrsMap(attrs); + } + void SetAttr(const std::string& name, const framework::Attribute& v) { op_->SetAttr(name, v); } diff --git a/paddle/fluid/imperative/execution_context.h b/paddle/fluid/imperative/execution_context.h index 398b1292e2ffe..5446add86788b 100644 --- a/paddle/fluid/imperative/execution_context.h +++ b/paddle/fluid/imperative/execution_context.h @@ -35,11 +35,13 @@ class DygraphExecutionContext : public framework::ExecutionContext { const framework::RuntimeContext& ctx, const NameVarMap& var_base_map_in, const NameVarMap& var_base_map_out, - const framework::AttributeMap& attrs) + const framework::AttributeMap& attrs, + const framework::AttributeMap& default_attrs) : ExecutionContext(op, scope, device_context, ctx), var_base_map_in_(var_base_map_in), var_base_map_out_(var_base_map_out), - attrs_(attrs) {} + attrs_(attrs), + default_attrs_(default_attrs) {} std::string InputName(const std::string& name) const override { auto it = var_base_map_in_.find(name); @@ -92,7 +94,7 @@ class DygraphExecutionContext : public framework::ExecutionContext { } bool HasAttr(const std::string& name) const override { - return attrs_.count(name) != 0; + return attrs_.count(name) != 0 || default_attrs_.count(name) != 0; } const framework::AttributeMap& Attrs() const override { return attrs_; } @@ -100,9 +102,14 @@ class DygraphExecutionContext : public framework::ExecutionContext { const framework::Attribute& GetAttr(const std::string& name) const override { auto it = attrs_.find(name); - PADDLE_ENFORCE_NE( - it, attrs_.end(), - platform::errors::NotFound("can not find [%s] in attrs", name)); + if (it == attrs_.end()) { + it = default_attrs_.find(name); + if (it == default_attrs_.end()) { + PADDLE_THROW(platform::errors::NotFound( + "Can not find [%s] in attributes of op %s.", name, + this->GetOp().Type())); + } + } return it->second; } @@ -192,6 +199,7 @@ class DygraphExecutionContext : public framework::ExecutionContext { const NameVarMap& var_base_map_in_; const NameVarMap& var_base_map_out_; const framework::AttributeMap& attrs_; + const framework::AttributeMap& default_attrs_; }; } // namespace imperative diff --git a/paddle/fluid/imperative/infer_shape_context.h b/paddle/fluid/imperative/infer_shape_context.h index fcd4545a2c82d..7efe1177f5dc7 100644 --- a/paddle/fluid/imperative/infer_shape_context.h +++ b/paddle/fluid/imperative/infer_shape_context.h @@ -35,10 +35,12 @@ class DygraphInferShapeContext : public framework::InferShapeContext { DygraphInferShapeContext(const NameVarMap* in, const NameVarMap* out, const framework::AttributeMap* attr, + const framework::AttributeMap* default_attr, const std::string op_type) : var_base_map_in_(in), var_base_map_out_(out), attrs_(attr), + default_attrs_(default_attr), op_type_(op_type) {} bool HasInput(const std::string& name) const override { @@ -101,7 +103,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext { } framework::AttrReader Attrs() const override { - return framework::AttrReader(*attrs_); + return framework::AttrReader(*attrs_, *default_attrs_); } std::vector Inputs(const std::string& name) const override { @@ -395,6 +397,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext { const NameVarMap* var_base_map_in_; const NameVarMap* var_base_map_out_; const framework::AttributeMap* attrs_; + const framework::AttributeMap* default_attrs_; const std::string op_type_; }; diff --git a/paddle/fluid/imperative/infer_var_type_context.h b/paddle/fluid/imperative/infer_var_type_context.h index f740507fa5086..7defc339f4f81 100644 --- a/paddle/fluid/imperative/infer_var_type_context.h +++ b/paddle/fluid/imperative/infer_var_type_context.h @@ -32,20 +32,28 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { public: RuntimeInferVarTypeContext(const NameVarMap& inputs, const NameVarMap& outputs, - const framework::AttributeMap& attrs_map) + const framework::AttributeMap& attrs_map, + const framework::AttributeMap& default_attrs_map) : InferVarTypeContext(nullptr, nullptr), inputs_(inputs), outputs_(outputs), - attrs_(attrs_map) {} + attrs_(attrs_map), + default_attrs_(default_attrs_map) {} virtual ~RuntimeInferVarTypeContext() {} framework::Attribute GetAttr(const std::string& name) const override { - auto iter = attrs_.find(name); - PADDLE_ENFORCE_EQ( - iter != attrs_.end(), true, - platform::errors::NotFound("Cannot find attribute %s", name)); - return iter->second; + auto it = attrs_.find(name); + + if (it == attrs_.end()) { + it = default_attrs_.find(name); + if (it == default_attrs_.end()) { + PADDLE_THROW(platform::errors::NotFound( + "Can not find [%s] in attributes.", name)); + } + } + + return it->second; } bool HasInput(const std::string& name) const override { @@ -233,6 +241,7 @@ class RuntimeInferVarTypeContext : public framework::InferVarTypeContext { const NameVarMap& inputs_; const NameVarMap& outputs_; const framework::AttributeMap& attrs_; + const framework::AttributeMap& default_attrs_; }; } // namespace imperative diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index a4af3117d3e32..6e28ecd9971ab 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -329,6 +329,7 @@ static void OpBaseRunImpl(const framework::OperatorBase& op, const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs, + const framework::AttributeMap& default_attrs, const platform::Place& place) { auto* op_kernel = dynamic_cast(&op); PADDLE_ENFORCE_NOT_NULL( @@ -336,7 +337,8 @@ static void OpBaseRunImpl(const framework::OperatorBase& op, "Only support operator with kernel in Dygraph mode.")); auto& info = op.Info(); if (info.infer_var_type_) { - RuntimeInferVarTypeContext infer_var_type_ctx(ins, outs, attrs); + RuntimeInferVarTypeContext infer_var_type_ctx(ins, outs, attrs, + default_attrs); info.infer_var_type_(&infer_var_type_ctx); } @@ -369,13 +371,14 @@ static void OpBaseRunImpl(const framework::OperatorBase& op, * after the execution of op, but the original input is directly * overwritten in the previous dynamic graph implemention. */ - auto prepared_op = PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs); + auto prepared_op = + PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs, default_attrs); auto tmp_ins_ptr = PrepareData(*op_kernel, ins, prepared_op.kernel_type()); if (tmp_ins_ptr == nullptr) { - prepared_op.Run(ins, outs, attrs); + prepared_op.Run(ins, outs, attrs, default_attrs); } else { - prepared_op.Run(*tmp_ins_ptr, outs, attrs); + prepared_op.Run(*tmp_ins_ptr, outs, attrs, default_attrs); } VLOG(4) << LayerDebugString(op.Type(), ins, outs); @@ -395,16 +398,18 @@ void OpBase::Run(const framework::OperatorBase& op, const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs, + const framework::AttributeMap& default_attrs, const platform::Place& place) { - OpBaseRunImpl(op, ins, outs, attrs, place); + OpBaseRunImpl(op, ins, outs, attrs, default_attrs, place); } void OpBase::Run(const framework::OperatorBase& op, const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs, + const framework::AttributeMap& default_attrs, const platform::Place& place) { - OpBaseRunImpl(op, ins, outs, attrs, place); + OpBaseRunImpl(op, ins, outs, attrs, default_attrs, place); } void ClearNoNeedBufferInputs(OpBase* op) { @@ -446,15 +451,15 @@ void ClearNoNeedBufferInputs(OpBase* op) { std::shared_ptr CreateGradOpNode( const framework::OperatorBase& op, const NameVarBaseMap& ins, const NameVarBaseMap& outs, const framework::AttributeMap& attrs, - const platform::Place& place, + const framework::AttributeMap& default_attrs, const platform::Place& place, const std::map& inplace_map) { const auto& info = op.Info(); if (!info.dygraph_grad_op_maker_) { return nullptr; } - auto grad_node = - info.dygraph_grad_op_maker_(op.Type(), ins, outs, attrs, inplace_map); + auto grad_node = info.dygraph_grad_op_maker_(op.Type(), ins, outs, attrs, + default_attrs, inplace_map); if (grad_node && !grad_node->empty()) { for (auto& grad_op : *grad_node) { grad_op.SetId(OpBase::GenerateUniqueId()); diff --git a/paddle/fluid/imperative/layer.h b/paddle/fluid/imperative/layer.h index bbede47e36429..56e16ba199707 100644 --- a/paddle/fluid/imperative/layer.h +++ b/paddle/fluid/imperative/layer.h @@ -108,7 +108,7 @@ class VarBase { void ClearGradVarBase() { grad_var_ = nullptr; } - void SetGradVarBase(VarBase& grad_var) { + void SetGradVarBase(const VarBase& grad_var) { MutableGradVarBase()->CopyFrom(grad_var, true); } @@ -283,7 +283,7 @@ class Layer { std::shared_ptr CreateGradOpNode( const framework::OperatorBase& op, const NameVarBaseMap& ins, const NameVarBaseMap& outs, const framework::AttributeMap& attrs, - const platform::Place& place, + const framework::AttributeMap& default_attrs, const platform::Place& place, const std::map& inplace_map); void ClearNoNeedBufferInputs(OpBase* op); diff --git a/paddle/fluid/imperative/op_base.h b/paddle/fluid/imperative/op_base.h index 0164ff9313cdf..acb125a82925d 100644 --- a/paddle/fluid/imperative/op_base.h +++ b/paddle/fluid/imperative/op_base.h @@ -50,6 +50,10 @@ class OpBase { const framework::AttributeMap& Attrs() const { return attrs_; } + const framework::AttributeMap& DefaultAttrsMap() const { + return *default_attrs_; + } + const framework::OpInfo& Info() const { PADDLE_ENFORCE_NOT_NULL(op_, platform::errors::PreconditionNotMet( "OpBase::Info() should be called after " @@ -99,6 +103,10 @@ class OpBase { void SetAttrMap(const framework::AttributeMap& attrs) { attrs_ = attrs; } + void SetDefaultAttrsMap(const framework::AttributeMap& default_attrs) { + default_attrs_ = &default_attrs; + } + void SetAttr(const std::string& name, const framework::Attribute& v) { attrs_[name] = v; } @@ -110,14 +118,23 @@ class OpBase { const framework::AttributeMap& Attrs() { return attrs_; } - bool HasAttr(const std::string& name) const { return attrs_.count(name) > 0; } + const framework::AttributeMap& DefaultAttrsMap() { return *default_attrs_; } + + bool HasAttr(const std::string& name) const { + return attrs_.count(name) > 0 || default_attrs_->count(name) > 0; + } const framework::Attribute& GetAttr(const std::string& name) const { auto it = attrs_.find(name); - PADDLE_ENFORCE_NE( - it, attrs_.end(), - platform::errors::NotFound("can not find attribute [%s]", name)); - return it->second; + if (it != attrs_.end()) { + return it->second; + } else { + auto it_default = default_attrs_->find(name); + PADDLE_ENFORCE_NE( + it_default, default_attrs_->end(), + platform::errors::NotFound("can not find attribute [%s]", name)); + return it_default->second; + } } template @@ -156,12 +173,14 @@ class OpBase { const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs, + const framework::AttributeMap& default_attrs, const platform::Place& place); static void Run(const framework::OperatorBase& op, const NameVarMap& ins, const NameVarMap& outs, const framework::AttributeMap& attrs, + const framework::AttributeMap& default_attrs, const platform::Place& place); private: @@ -174,6 +193,7 @@ class OpBase { NameVarMap ins_; NameVarMap outs_; framework::AttributeMap attrs_; + const framework::AttributeMap* default_attrs_; std::unique_ptr op_; platform::Place place_; size_t id_{-1UL}; diff --git a/paddle/fluid/imperative/partial_grad_engine.cc b/paddle/fluid/imperative/partial_grad_engine.cc index 3da3a05ed1071..d905b1350821c 100644 --- a/paddle/fluid/imperative/partial_grad_engine.cc +++ b/paddle/fluid/imperative/partial_grad_engine.cc @@ -884,11 +884,13 @@ void PartialGradTask::RunEachOp(OpBase *op) { } // Run op - OpBase::Run(op->InnerOp(), tmp_ins, tmp_outs, op->Attrs(), op->place()); + OpBase::Run(op->InnerOp(), tmp_ins, tmp_outs, op->Attrs(), + op->DefaultAttrsMap(), op->place()); if (create_graph_) { - auto double_grad_node = CreateGradOpNode(op->InnerOp(), tmp_ins, tmp_outs, - op->Attrs(), op->place(), {}); + auto double_grad_node = + CreateGradOpNode(op->InnerOp(), tmp_ins, tmp_outs, op->Attrs(), + op->DefaultAttrsMap(), op->place(), {}); PADDLE_ENFORCE_NOT_NULL( double_grad_node, platform::errors::NotFound("The Op %s doesn't have any grad op. If you " diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 4a42751b1c4d5..6bdb042ebd557 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -91,7 +91,8 @@ PreparedOp PrepareImpl(const NameVarMap& ins, const NameVarMap& outs, const framework::OperatorWithKernel& op, const platform::Place& place, - const framework::AttributeMap& attrs) { + const framework::AttributeMap& attrs, + const framework::AttributeMap& default_attrs) { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); @@ -108,9 +109,9 @@ PreparedOp PrepareImpl(const NameVarMap& ins, #endif // 1. get expected kernel key - auto expected_kernel_key = - op.GetExpectedKernelType(DygraphExecutionContext( - op, framework::Scope(), *dev_ctx, ctx, ins, outs, attrs)); + auto expected_kernel_key = op.GetExpectedKernelType( + DygraphExecutionContext(op, framework::Scope(), *dev_ctx, ctx, + ins, outs, attrs, default_attrs)); VLOG(3) << "expected_kernel_key:" << expected_kernel_key; // 2. check if op[type] has kernel registered. @@ -148,16 +149,19 @@ PreparedOp PreparedOp::Prepare(const NameVarMap& ins, const NameVarMap& outs, const framework::OperatorWithKernel& op, const platform::Place& place, - const framework::AttributeMap& attrs) { - return PrepareImpl(ins, outs, op, place, attrs); + const framework::AttributeMap& attrs, + const framework::AttributeMap& default_attrs) { + return PrepareImpl(ins, outs, op, place, attrs, default_attrs); } PreparedOp PreparedOp::Prepare(const NameVarMap& ins, const NameVarMap& outs, const framework::OperatorWithKernel& op, const platform::Place& place, - const framework::AttributeMap& attrs) { - return PrepareImpl(ins, outs, op, place, attrs); + const framework::AttributeMap& attrs, + const framework::AttributeMap& default_attrs) { + return PrepareImpl(ins, outs, op, place, attrs, + default_attrs); } template @@ -166,17 +170,18 @@ static void PreparedOpRunImpl( const framework::OpKernelType& kernel_type, const framework::OperatorWithKernel::OpKernelFunc& func, platform::DeviceContext* dev_ctx, const NameVarMap& ins, - const NameVarMap& outs, const framework::AttributeMap& attrs) { + const NameVarMap& outs, const framework::AttributeMap& attrs, + const framework::AttributeMap& default_attrs) { // TODO(zjl): remove scope in dygraph framework::Scope scope; DygraphInferShapeContext infer_shape_ctx(&ins, &outs, &attrs, - op.Type()); + &default_attrs, op.Type()); static_cast(op).InferShape( &infer_shape_ctx); func(DygraphExecutionContext(op, scope, *dev_ctx, ctx, ins, outs, - attrs)); + attrs, default_attrs)); if (FLAGS_check_nan_inf) { framework::details::CheckOpHasNanOrInfInDygraph( @@ -202,16 +207,18 @@ static void PreparedOpRunImpl( void PreparedOp::Run(const NameVarMap& ins, const NameVarMap& outs, - const framework::AttributeMap& attrs) { + const framework::AttributeMap& attrs, + const framework::AttributeMap& default_attrs) { PreparedOpRunImpl(op_, ctx_, kernel_type_, func_, dev_ctx_, ins, - outs, attrs); + outs, attrs, default_attrs); } void PreparedOp::Run(const NameVarMap& ins, const NameVarMap& outs, - const framework::AttributeMap& attrs) { + const framework::AttributeMap& attrs, + const framework::AttributeMap& default_attrs) { PreparedOpRunImpl(op_, ctx_, kernel_type_, func_, dev_ctx_, - ins, outs, attrs); + ins, outs, attrs, default_attrs); } } // namespace imperative diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index 1f6be5483be30..53f876c498cd0 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -151,20 +151,24 @@ class PreparedOp { const NameVarMap& outs, const framework::OperatorWithKernel& op, const platform::Place& place, - const framework::AttributeMap& attrs); + const framework::AttributeMap& attrs, + const framework::AttributeMap& default_attrs); static PreparedOp Prepare(const NameVarMap& ins, const NameVarMap& outs, const framework::OperatorWithKernel& op, const platform::Place& place, - const framework::AttributeMap& attrs); + const framework::AttributeMap& attrs, + const framework::AttributeMap& default_attrs); void Run(const NameVarMap& in, const NameVarMap& out, - const framework::AttributeMap& attrs); + const framework::AttributeMap& attrs, + const framework::AttributeMap& default_attrs); void Run(const NameVarMap& ins, const NameVarMap& outs, - const framework::AttributeMap& attrs); + const framework::AttributeMap& attrs, + const framework::AttributeMap& default_attrs); const framework::OpKernelType& kernel_type() const { return kernel_type_; } diff --git a/paddle/fluid/imperative/tests/test_layer.cc b/paddle/fluid/imperative/tests/test_layer.cc index 4a30ffb7e3d01..064f47f54979a 100644 --- a/paddle/fluid/imperative/tests/test_layer.cc +++ b/paddle/fluid/imperative/tests/test_layer.cc @@ -43,10 +43,12 @@ template class TestRuntimeInferVarTypeContext : public RuntimeInferVarTypeContext { public: - TestRuntimeInferVarTypeContext(const NameVarMap& inputs, - const NameVarMap& outputs, - const framework::AttributeMap& attrs_map) - : RuntimeInferVarTypeContext(inputs, outputs, attrs_map) {} + TestRuntimeInferVarTypeContext( + const NameVarMap& inputs, const NameVarMap& outputs, + const framework::AttributeMap& attrs_map, + const framework::AttributeMap& default_attrs_map) + : RuntimeInferVarTypeContext(inputs, outputs, attrs_map, + default_attrs_map) {} bool HasVar(const std::string& name) const { return RuntimeInferVarTypeContext::HasVar(name); @@ -125,7 +127,7 @@ TEST(test_layer, test_runtime_context) { auto* ctx = new imperative::TestRuntimeInferVarTypeContext( - ins, outs, attrs); + ins, outs, attrs, {}); ASSERT_TRUE(ctx->HasInput("X")); ASSERT_TRUE(ctx->HasOutput("Out")); @@ -358,7 +360,7 @@ TEST(test_layer, test_dygraph_execution_context) { framework::Scope scope; DygraphExecutionContext dy_exe_context( - *(op.get()), scope, *dev_ctx, ctx, ins, outs, concat_att_map); + *(op.get()), scope, *dev_ctx, ctx, ins, outs, concat_att_map, {}); ASSERT_EQ(dy_exe_context.InputSize("X"), 1u); ASSERT_EQ(dy_exe_context.InputName("X"), "vin"); @@ -386,7 +388,7 @@ TEST(test_layer, test_dygraph_infershape_context) { concat_att_map["axis"] = 1; DygraphInferShapeContext infer_shape_ctx( - &ins, &outs, &concat_att_map, "dummy"); + &ins, &outs, &concat_att_map, {}, "dummy"); bool have_x = infer_shape_ctx.HasOutputs("Out"); ASSERT_EQ(have_x, true); diff --git a/paddle/fluid/imperative/tests/test_prepare_op.cc b/paddle/fluid/imperative/tests/test_prepare_op.cc index 7d6882a4ee7d0..5e269d74044d2 100644 --- a/paddle/fluid/imperative/tests/test_prepare_op.cc +++ b/paddle/fluid/imperative/tests/test_prepare_op.cc @@ -93,7 +93,7 @@ TEST(test_prepare_op, test_prepare_op) { ASSERT_NO_FATAL_FAILURE(PreparedOp preparedOp = PreparedOp::Prepare( ins, outs, dynamic_cast(*op), - place, split_attr_map)); + place, split_attr_map, {})); } const framework::Tensor* GetTensorFromVar(const framework::Variable& var); @@ -144,7 +144,7 @@ TEST(test_prepare_op, test_prepare_data) { // test if it can be transformed to GPU place auto prepared_op = PreparedOp::Prepare( ins, outs, dynamic_cast(*op), gpu_place, - attr_map); + attr_map, {}); PrepareData( dynamic_cast(*op), ins, prepared_op.kernel_type()); @@ -193,7 +193,7 @@ void TestPrepareDataSamePlace(framework::AttributeMap attr_map) { // test if it never transferred on GPU place auto prepared_op = PreparedOp::Prepare( ins, outs, dynamic_cast(*op), cpu_place, - attr_map); + attr_map, {}); PrepareData( dynamic_cast(*op), ins, prepared_op.kernel_type()); diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 41ad70e5a5741..367f948ef63b2 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -154,9 +154,14 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, const auto& op_info = op->Info(); auto* attr_checker = op_info.Checker(); if (attr_checker) { - attr_checker->Check(&attrs, true); + attr_checker->Check(&attrs, true, /*only_check_exist_value=*/true); } + static paddle::framework::AttributeMap empty_attrs_map = {}; + const paddle::framework::AttributeMap& default_attrs = + attr_checker == nullptr ? empty_attrs_map + : attr_checker->GetDefaultAttrMap(); + NameVarBaseMap new_ins = ins; if (enable_autocast_) { VLOG(5) << "Auto mixed precision run operator: " << type; @@ -181,7 +186,7 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, #endif } - OpBase::Run(*op, new_ins, outs, attrs, place); + OpBase::Run(*op, new_ins, outs, attrs, default_attrs, place); } catch (platform::EnforceNotMet& exception) { framework::AppendErrorOpHint(type, &exception); throw std::move(exception); @@ -204,7 +209,8 @@ void Tracer::TraceOp(const std::string& type, const NameVarBaseMap& ins, } if (ComputeRequiredGrad(new_ins, outs, trace_backward)) { - CreateGradOpNode(*op, new_ins, outs, attrs, place, inplace_map); + CreateGradOpNode(*op, new_ins, outs, attrs, default_attrs, place, + inplace_map); } else { VLOG(3) << "No Grad to track for Op: " << type; } diff --git a/paddle/fluid/operators/test_common_infer_shape_functions.cc b/paddle/fluid/operators/test_common_infer_shape_functions.cc index ca8f6ce84fc57..60eeb66ae7d1e 100644 --- a/paddle/fluid/operators/test_common_infer_shape_functions.cc +++ b/paddle/fluid/operators/test_common_infer_shape_functions.cc @@ -48,7 +48,7 @@ class DygraphInferShapeTest { void SetOpType(const std::string& op_type) { op_type_ = op_type; } void Run(std::function infer_shape) { imperative::DygraphInferShapeContext ctx( - &ins_, &outs_, &attrs_, op_type_); + &ins_, &outs_, &attrs_, {}, op_type_); infer_shape(&ctx); for (const auto& pair : expected_dims_) { auto out = outs_[pair.first][0]; diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 86084297c4ae6..67f004e61cbfd 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1308,7 +1308,7 @@ All parameter, weight, gradient are variables in Paddle. if (info != nullptr) { if (info->HasOpProtoAndChecker()) { auto op_checker = info->Checker(); - res = op_checker->GetAttrsDefaultValuesMap(); + res = op_checker->GetDefaultAttrsMap(); } } return res;