Skip to content

Commit

Permalink
[NewIR] register set_value in new ir (PaddlePaddle#56436)
Browse files Browse the repository at this point in the history
* register set_value in new ir

* fix

* register set_value_grad

* fix

* fix

* remove debug info

* add unittest

* fix

* fix

* fix

* fix

* fix

* resolve comments
  • Loading branch information
kangguangli authored and BeingGod committed Sep 9, 2023
1 parent 957d59d commit 75669e3
Show file tree
Hide file tree
Showing 14 changed files with 431 additions and 69 deletions.
5 changes: 5 additions & 0 deletions paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once

#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute_storage.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h"
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/phi/common/scalar.h"
Expand Down Expand Up @@ -49,6 +50,10 @@ class ScalarAttribute : public ir::Attribute {
(val.type_id() == ir::StrAttribute::type_id());
}

static ir::Attribute get(ir::IrContext *ctx, phi::Scalar scalar) {
return TransToIrAttribute(scalar, ctx);
}

phi::Scalar data();
};

Expand Down
70 changes: 69 additions & 1 deletion paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,6 @@
view: null
backward: null


- name: shadow_feed
inputs:
- typename: Tensor
Expand All @@ -355,3 +354,72 @@
force_backend: null
inplace: null
backward: null

- name : set_value
inputs:
- {typename: Tensor, name: x, optional: false, no_need_buffer: false, data_transform: {} }
attrs:
- {typename: 'int64_t[]', name: starts}
- {typename: 'int64_t[]', name: ends}
- {typename: 'int64_t[]', name: steps}
- {typename: 'int64_t[]', name: axes}
- {typename: 'int64_t[]', name: decrease_axes}
- {typename: 'int64_t[]', name: none_axes}
- {typename: 'int64_t[]', name: shape}
- {typename: 'Scalar[]', name: values}
outputs:
- {typename: Tensor, name: out, optional: false, intermediate: false}
infer_meta:
func: SetValueInferMeta
param: [x]
kernel:
func: [set_value]
param: [x, starts, ends, steps, axes, decrease_axes, none_axes, shape, values]
inplace: {out: x}
backward: set_value_grad

- name : set_value_with_tensor
inputs:
- {typename: Tensor, name: x, optional: false, no_need_buffer: false, data_transform: {} }
- {typename: Tensor, name: values, optional: false, no_need_buffer: false, data_transform: {} }
attrs:
- {typename: 'int64_t[]', name: starts}
- {typename: 'int64_t[]', name: ends}
- {typename: 'int64_t[]', name: steps}
- {typename: 'int64_t[]', name: axes}
- {typename: 'int64_t[]', name: decrease_axes}
- {typename: 'int64_t[]', name: none_axes}
outputs:
- {typename: Tensor, name: out, optional: false, intermediate: false}
infer_meta:
func: SetValueInferMeta
param: [x]
kernel:
func: [set_value_with_tensor]
param: [x, values, starts, ends, steps, axes, decrease_axes, none_axes]
inplace: {out: x}
backward: set_value_grad


- name : set_value_grad
inputs:
- {typename: Tensor, name: out_grad, optional: false, no_need_buffer: false, data_transform: {} }
- {typename: Tensor, name: values, optional: false, no_need_buffer: false, data_transform: {} }
attrs:
- {typename: 'int64_t[]', name: starts}
- {typename: 'int64_t[]', name: ends}
- {typename: 'int64_t[]', name: steps}
- {typename: 'int64_t[]', name: axes}
- {typename: 'int64_t[]', name: decrease_axes}
- {typename: 'int64_t[]', name: none_axes}
outputs:
- {typename: Tensor, name: x_grad, optional: false, intermediate: false}
- {typename: Tensor, name: values_grad, optional: false, intermediate: false}
infer_meta:
func: SetValueGradInferMeta
param: [out_grad, values]
kernel:
func: [set_value_grad]
param: [out_grad, starts, ends, steps, axes, decrease_axes, none_axes]
inplace: null
backward: null
1 change: 1 addition & 0 deletions paddle/fluid/ir/dialect/paddle_dialect/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h"

namespace paddle {
namespace dialect {
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

// #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type_storage.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_type.h"
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_parser.h"
#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_attribute.h"
#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_type.h"
#include "paddle/ir/core/type_name.h"
#include "paddle/phi/core/infermeta_utils.h"

#include "glog/logging.h"
Expand Down Expand Up @@ -81,8 +82,8 @@ void BuildPhiContext(ir::Operation* op,
Context* ctx) {
paddle::framework::Scope* inner_scope =
local_scope != nullptr ? local_scope : scope;
VLOG(6) << "BuildPhiContext in scope[" << scope << "] inner_scope["
<< inner_scope << "]";
VLOG(6) << "Build " << get_type_name<Context>() << " in scope[" << scope
<< "] inner_scope[" << inner_scope << "]";

auto attr_map = op->attributes();

Expand Down
16 changes: 13 additions & 3 deletions paddle/fluid/ir_adaptor/translator/attribute_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class AttributeVisitor {
}

virtual ir::Attribute operator()(const std::vector<int64_t>& i64s) {
VLOG(10) << "translating vector<int64>";
VLOG(10) << "translating vector<int64> size: " << i64s.size();
std::vector<ir::Attribute> attrs;
attrs.reserve(i64s.size());
for (const auto& v : i64s) {
Expand All @@ -135,8 +135,13 @@ class AttributeVisitor {
virtual ir::Attribute operator()(
const std::vector<paddle::experimental::Scalar>& ss) {
VLOG(10) << "translating vector<scalar>";
IR_THROW(
"not support translating std::vector<paddle::experimental::Scalar>");
std::vector<ir::Attribute> attrs;
attrs.reserve(ss.size());
for (const auto& v : ss) {
attrs.push_back(dialect::ScalarAttribute::get(ctx, v));
}
VLOG(10) << "translating vector<scalar> Done";
return ir::ArrayAttribute::get(ctx, attrs);
}

virtual ir::Attribute operator()(const paddle::blank& blank) {
Expand Down Expand Up @@ -164,6 +169,11 @@ class Int64ArrayAttributeVisitor : public AttributeVisitor {
}
return ir::ArrayAttribute::get(ctx, attrs);
}

ir::Attribute operator()(const paddle::blank& blank) override {
VLOG(10) << "translating paddle::blank to int64[]";
return ir::ArrayAttribute::get(ctx, {});
}
};

class IntArrayAttributeVisitor : public AttributeVisitor {
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/ir_adaptor/translator/op_compat_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ def insert_new_mutable_attributes(
backward_op, op_compat_item["scalar"]
)

# special mapping list
op_arg_name_mappings["set_value_grad"]["values_grad"] = "ValueTensor@GRAD"

op_name_normailzer_template = env.get_template("op_compat_info.cc.j2")
with open(output_source_file, 'wt') as f:
op_compat_definition = op_name_normailzer_template.render(
Expand Down
75 changes: 50 additions & 25 deletions paddle/fluid/ir_adaptor/translator/op_compat_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include <functional>
#include <optional>
#include <string>
#include <unordered_map>
#include <unordered_set>
Expand Down Expand Up @@ -75,42 +76,66 @@ class OpNameNormalizer {
return op_mutable_attribute_infos.at(op_type).at(arg_name);
}

std::optional<std::string> GetDirectMapping(const std::string& op_type,
const std::string& arg_name) {
if (op_arg_name_mappings.find(op_type) == op_arg_name_mappings.end()) {
return {};
}
auto& arg_mappings = op_arg_name_mappings[op_type];
if (arg_mappings.find(arg_name) == arg_mappings.end()) {
return {};
}
return arg_mappings.at(arg_name);
}

std::optional<std::string> GetGradNameMapping(const std::string& op_type,
const std::string& arg_name) {
std::string target = kPhiGradSuffix;
std::string data = kFluidVarGradSuffix;

size_t first_grad_pos = arg_name.find(target);
size_t type_pos = op_type.find(target);
std::string legacy_name = arg_name.substr(0, first_grad_pos);
std::optional<std::string> ret =
this->GetDirectMapping(op_type.substr(0, type_pos), legacy_name);
if (ret) {
legacy_name = ret.value();
}
legacy_name = legacy_name + arg_name.substr(first_grad_pos);
for (size_t pos = 0;
legacy_name.npos != (pos = legacy_name.find(target, pos));
pos += data.length()) {
legacy_name.replace(pos, target.length(), data);
}
return legacy_name;
}

std::string GetLegacyArgName(const std::string& op_type,
const std::string& arg_name) {
if (auto ret = GetDirectMapping(op_type, arg_name)) {
VLOG(10) << "[" << op_type << "] found " << ret.value();
return ret.value();
}

bool is_grad_op = (op_type.find(kPhiGradSuffix) != std::string::npos);
bool is_grad_arg = (arg_name.find(kPhiGradSuffix) != std::string::npos);

if (is_grad_op && is_grad_arg) {
std::string target = kPhiGradSuffix;
std::string data = kFluidVarGradSuffix;

size_t first_grad_pos = arg_name.find(target);
size_t type_pos = op_type.find(target);
std::string legacy_name = this->GetLegacyArgName(
op_type.substr(0, type_pos), arg_name.substr(0, first_grad_pos));
legacy_name += arg_name.substr(first_grad_pos);
for (size_t pos = 0;
legacy_name.npos != (pos = legacy_name.find(target, pos));
pos += data.length()) {
legacy_name.replace(pos, target.length(), data);
if (auto ret = GetGradNameMapping(op_type, arg_name)) {
VLOG(10) << "[" << op_type << "] found " << ret.value();
return ret.value();
}
return legacy_name;
} else if (is_grad_op && !is_grad_arg) {
// backwward op using forward args: like trace_grad using forward input
size_t type_pos = op_type.find(kPhiGradSuffix);
std::string legacy_name =
this->GetLegacyArgName(op_type.substr(0, type_pos), arg_name);

return legacy_name;
}
if (op_arg_name_mappings.find(op_type) == op_arg_name_mappings.end()) {
return arg_name;
}
auto& arg_mappings = op_arg_name_mappings[op_type];
if (arg_mappings.find(arg_name) == arg_mappings.end()) {
return arg_name;
if (auto ret = GetDirectMapping(op_type.substr(0, type_pos), arg_name)) {
VLOG(10) << "[" << op_type << "] found " << ret.value();
return ret.value();
}
}
return arg_mappings.at(arg_name);

VLOG(10) << "[" << op_type << "] not found mapping for " << arg_name;
return arg_name;
}

std::string GetLegacyAttrName(const std::string& op_type,
Expand Down
Loading

0 comments on commit 75669e3

Please sign in to comment.