Skip to content

Commit

Permalink
[NewIR] add stop_gradient attribute for defining op (PaddlePaddle#55235)
Browse files Browse the repository at this point in the history
* add stop_gradient attribute for defining op

* modify by reviews

* fix
  • Loading branch information
kangguangli authored and cqulilujia committed Jul 24, 2023
1 parent ccf7067 commit 131014b
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 14 deletions.
9 changes: 4 additions & 5 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,10 +330,6 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput(

std::set<std::string> yaml_input_set;
for (const auto& info : input_infos) {
if (auto special_handler = this->GetSpecialInputHandlers(info.name)) {
continue;
}

std::string legacy_input_name =
op_normalizer.GetLegacyArgName(op_desc.Type(), info.name);

Expand Down Expand Up @@ -381,7 +377,6 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput(

std::vector<std::string> legacy_input_vars;
// return empty OpResult if this arg is optional and not shown in OpDesc
// TODO(lyk): HasInput doesnot consider variadic attribute
if (op_desc.HasInput(legacy_input_name, true)) {
legacy_input_vars = op_desc.Input(legacy_input_name, true);
}
Expand Down Expand Up @@ -436,6 +431,10 @@ std::vector<ir::OpResult> OpTranscriber::GenerateOperationInput(

// if src type is Tensor
if (!is_vector) {
IR_ENFORCE(legacy_input_vars.size() == 1u,
"Input %s not found when parsing op %s",
info.name,
op_desc.Type());
auto defining_info = (*param_map)[legacy_input_vars[0]];
op_inputs.push_back(defining_info.value);

Expand Down
48 changes: 43 additions & 5 deletions paddle/fluid/ir_adaptor/translator/program_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "paddle/fluid/ir_adaptor/translator/type_translator.h"
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/enforce.h"
Expand All @@ -38,17 +39,19 @@ using ProgramDesc = ::paddle::framework::ProgramDesc;
using BlockDesc = ::paddle::framework::BlockDesc;
using VarDesc = ::paddle::framework::VarDesc;

const std::unordered_set<std::string> ProgramTranslator::no_cast_var_names = {
"feed",
"fetch",
};

constexpr char kAttrStopGradients[] = "stop_gradient";

ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program,
ir::Program* program)
: legacy_program_(legacy_program), program_(program) {
ctx_ = ir::IrContext::Instance();
}

const std::unordered_set<std::string> ProgramTranslator::no_cast_var_names = {
"feed",
"fetch",
};

void ProgramTranslator::Translate() {
PADDLE_ENFORCE_EQ(
legacy_program_->Size(),
Expand All @@ -71,6 +74,11 @@ void ProgramTranslator::Translate() {
const BlockDesc& block = legacy_program_->Block(block_idx);
SetParameterFromSingleBlock(block);
}

for (size_t block_idx = 0; block_idx < legacy_program_->Size(); block_idx++) {
const BlockDesc& block = legacy_program_->Block(block_idx);
SetStopGradientAttributeForAllValue(block);
}
}

inline ir::Operation* InsertGetParamaterOp(ir::IrContext* ctx,
Expand Down Expand Up @@ -198,5 +206,35 @@ void ProgramTranslator::SetParameterFromSingleBlock(const BlockDesc& block) {
}
}

void ProgramTranslator::SetStopGradientAttributeForAllValue(
const BlockDesc& block) {
// Currently we set stop gradient for operation that generated a value
// connected with VarDesc
for (const auto& [var_name, value_info] : param_map_) {
VLOG(10) << "[op translated][stop gradient]" << var_name;
VarDesc* var = block.FindVarRecursive(var_name);
if (var == nullptr) {
continue;
}
ir::OpResult value = value_info.value;
auto* defining_op = value.owner();
VLOG(8) << "[op translated][stop gradient]" << var_name
<< " from: " << defining_op->name();
std::vector<ir::Attribute> stop_gradients;
if (defining_op->HasAttribute(kAttrStopGradients)) {
stop_gradients = defining_op->attribute(kAttrStopGradients)
.dyn_cast<ir::ArrayAttribute>()
.data();
} else {
stop_gradients = std::vector<ir::Attribute>(
defining_op->num_results(), ir::BoolAttribute::get(ctx_, false));
}
stop_gradients[value.GetResultIndex()] =
ir::BoolAttribute::get(ctx_, var->StopGradient());
defining_op->set_attribute(kAttrStopGradients,
ir::ArrayAttribute::get(ctx_, stop_gradients));
}
}

} // namespace translator
} // namespace paddle
3 changes: 2 additions & 1 deletion paddle/fluid/ir_adaptor/translator/program_translator.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,13 @@ class ProgramTranslator {
/// 2. "fetch", the output variable of fetch op
/// However, new feed has no input and new fetch has no output
/// So we don't handle these two vairables when
/// `ExtractParameterFromSingleBlock`
/// `Get/SetParameterFromSingleBlock`
static const std::unordered_set<std::string> no_cast_var_names;

void GetParameterForSingleBlock(const BlockDesc& block);
void InsertOperationToSingleBlock(const BlockDesc& block);
void SetParameterFromSingleBlock(const BlockDesc& block);
void SetStopGradientAttributeForAllValue(const BlockDesc& block);
};

} // namespace translator
Expand Down
5 changes: 5 additions & 0 deletions paddle/ir/core/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,11 @@ std::string Operation::name() const {
return p_name ? p_name : "";
}

Attribute Operation::attribute(const std::string &key) const {
IR_ENFORCE(HasAttribute(key), "operation(%s): no attribute %s", name(), key);
return attributes_.at(key);
}

Region *Operation::GetParentRegion() const {
return parent_ ? parent_->GetParent() : nullptr;
}
Expand Down
8 changes: 7 additions & 1 deletion paddle/ir/core/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,16 @@ class IR_API alignas(8) Operation final {

const AttributeMap &attributes() const { return attributes_; }

void SetAttribute(const std::string &key, Attribute value) {
void set_attribute(const std::string &key, Attribute value) {
attributes_[key] = value;
}

Attribute attribute(const std::string &key) const;

bool HasAttribute(const std::string &key) const {
return attributes_.find(key) != attributes_.end();
}

ir::OpInfo info() const { return info_; }

uint32_t num_results() const { return num_results_; }
Expand Down
4 changes: 2 additions & 2 deletions test/cpp/ir/core/ir_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,6 @@ TEST(op_test, module_op_death) {
EXPECT_EQ(program.module_op().program(), &program);
EXPECT_EQ(program.module_op().ir_context(), ctx);

program.module_op()->SetAttribute("program",
ir::PointerAttribute::get(ctx, &program));
program.module_op()->set_attribute("program",
ir::PointerAttribute::get(ctx, &program));
}

0 comments on commit 131014b

Please sign in to comment.