Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NewIR] register set_value in new ir #56436

Merged
merged 17 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#pragma once

#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute_storage.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h"
#include "paddle/ir/core/attribute.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/phi/common/scalar.h"
Expand Down Expand Up @@ -49,6 +50,10 @@ class ScalarAttribute : public ir::Attribute {
(val.type_id() == ir::StrAttribute::type_id());
}

static ir::Attribute get(ir::IrContext *ctx, phi::Scalar scalar) {
return TransToIrAttribute(scalar, ctx);
}

phi::Scalar data();
};

Expand Down
70 changes: 69 additions & 1 deletion paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,6 @@
view: null
backward: null


- name: shadow_feed
inputs:
- typename: Tensor
Expand All @@ -355,3 +354,72 @@
force_backend: null
inplace: null
backward: null

- name : set_value
inputs:
- {typename: Tensor, name: x, optional: false, no_need_buffer: false, data_transform: {} }
attrs:
- {typename: 'int64_t[]', name: starts}
- {typename: 'int64_t[]', name: ends}
- {typename: 'int64_t[]', name: steps}
- {typename: 'int64_t[]', name: axes}
- {typename: 'int64_t[]', name: decrease_axes}
- {typename: 'int64_t[]', name: none_axes}
- {typename: 'int64_t[]', name: shape}
- {typename: 'Scalar[]', name: values}
outputs:
- {typename: Tensor, name: out, optional: false, intermediate: false}
infer_meta:
func: SetValueInferMeta
param: [x]
kernel:
func: [set_value]
param: [x, starts, ends, steps, axes, decrease_axes, none_axes, shape, values]
inplace: null
kangguangli marked this conversation as resolved.
Show resolved Hide resolved
backward: set_value_grad
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set_value_grad的参数和这个好像对不上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是对不上,它们的kernel也不一致


- name : set_value_with_tensor
inputs:
- {typename: Tensor, name: x, optional: false, no_need_buffer: false, data_transform: {} }
- {typename: Tensor, name: values, optional: false, no_need_buffer: false, data_transform: {} }
attrs:
- {typename: 'int64_t[]', name: starts}
- {typename: 'int64_t[]', name: ends}
- {typename: 'int64_t[]', name: steps}
- {typename: 'int64_t[]', name: axes}
- {typename: 'int64_t[]', name: decrease_axes}
- {typename: 'int64_t[]', name: none_axes}
outputs:
- {typename: Tensor, name: out, optional: false, intermediate: false}
infer_meta:
func: SetValueInferMeta
param: [x]
kernel:
func: [set_value_with_tensor]
param: [x, values, starts, ends, steps, axes, decrease_axes, none_axes]
inplace: null
backward: set_value_grad


- name : set_value_grad
inputs:
- {typename: Tensor, name: out_grad, optional: false, no_need_buffer: false, data_transform: {} }
- {typename: Tensor, name: values, optional: false, no_need_buffer: false, data_transform: {} }
attrs:
- {typename: 'int64_t[]', name: starts}
- {typename: 'int64_t[]', name: ends}
- {typename: 'int64_t[]', name: steps}
- {typename: 'int64_t[]', name: axes}
- {typename: 'int64_t[]', name: decrease_axes}
- {typename: 'int64_t[]', name: none_axes}
outputs:
- {typename: Tensor, name: x_grad, optional: false, intermediate: false}
- {typename: Tensor, name: values_grad, optional: false, intermediate: false}
infer_meta:
func: SetValueGradInferMeta
param: [out_grad, values]
kernel:
func: [set_value_grad]
param: [out_grad, starts, ends, steps, axes, decrease_axes, none_axes]
inplace: null
backward: null
1 change: 1 addition & 0 deletions paddle/fluid/ir/dialect/paddle_dialect/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h"

namespace paddle {
namespace dialect {
Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

// #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type_storage.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_type.h"
Expand Down
5 changes: 3 additions & 2 deletions paddle/fluid/ir/phi_kernel_adaptor/phi_kernel_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "paddle/fluid/ir/dialect/paddle_dialect/utils/op_yaml_info_parser.h"
#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_attribute.h"
#include "paddle/fluid/ir/dialect/paddle_kernel_dialect/ir/kernel_type.h"
#include "paddle/ir/core/type_name.h"
#include "paddle/phi/core/infermeta_utils.h"

#include "glog/logging.h"
Expand Down Expand Up @@ -81,8 +82,8 @@ void BuildPhiContext(ir::Operation* op,
Context* ctx) {
paddle::framework::Scope* inner_scope =
local_scope != nullptr ? local_scope : scope;
VLOG(6) << "BuildPhiContext in scope[" << scope << "] inner_scope["
<< inner_scope << "]";
VLOG(6) << "Build " << get_type_name<Context>() << " in scope[" << scope
<< "] inner_scope[" << inner_scope << "]";

auto attr_map = op->attributes();

Expand Down
16 changes: 13 additions & 3 deletions paddle/fluid/ir_adaptor/translator/attribute_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class AttributeVisitor {
}

virtual ir::Attribute operator()(const std::vector<int64_t>& i64s) {
VLOG(10) << "translating vector<int64>";
VLOG(10) << "translating vector<int64> size: " << i64s.size();
std::vector<ir::Attribute> attrs;
attrs.reserve(i64s.size());
for (const auto& v : i64s) {
Expand All @@ -135,8 +135,13 @@ class AttributeVisitor {
virtual ir::Attribute operator()(
const std::vector<paddle::experimental::Scalar>& ss) {
VLOG(10) << "translating vector<scalar>";
IR_THROW(
"not support translating std::vector<paddle::experimental::Scalar>");
std::vector<ir::Attribute> attrs;
attrs.reserve(ss.size());
for (const auto& v : ss) {
attrs.push_back(dialect::ScalarAttribute::get(ctx, v));
}
VLOG(10) << "translating vector<scalar> Done";
return ir::ArrayAttribute::get(ctx, attrs);
}

virtual ir::Attribute operator()(const paddle::blank& blank) {
Expand Down Expand Up @@ -164,6 +169,11 @@ class Int64ArrayAttributeVisitor : public AttributeVisitor {
}
return ir::ArrayAttribute::get(ctx, attrs);
}

ir::Attribute operator()(const paddle::blank& blank) override {
VLOG(10) << "translating paddle::blank to int64[]";
return ir::ArrayAttribute::get(ctx, {});
}
};

class IntArrayAttributeVisitor : public AttributeVisitor {
Expand Down
75 changes: 50 additions & 25 deletions paddle/fluid/ir_adaptor/translator/op_compat_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include <functional>
#include <optional>
#include <string>
#include <unordered_map>
#include <unordered_set>
Expand Down Expand Up @@ -75,42 +76,66 @@ class OpNameNormalizer {
return op_mutable_attribute_infos.at(op_type).at(arg_name);
}

std::optional<std::string> GetDirectMapping(const std::string& op_type,
const std::string& arg_name) {
if (op_arg_name_mappings.find(op_type) == op_arg_name_mappings.end()) {
return {};
}
auto& arg_mappings = op_arg_name_mappings[op_type];
if (arg_mappings.find(arg_name) == arg_mappings.end()) {
return {};
}
return arg_mappings.at(arg_name);
}

std::optional<std::string> GetGradNameMapping(const std::string& op_type,
const std::string& arg_name) {
std::string target = kPhiGradSuffix;
std::string data = kFluidVarGradSuffix;
kangguangli marked this conversation as resolved.
Show resolved Hide resolved

size_t first_grad_pos = arg_name.find(target);
size_t type_pos = op_type.find(target);
std::string legacy_name = arg_name.substr(0, first_grad_pos);
std::optional<std::string> ret =
this->GetDirectMapping(op_type.substr(0, type_pos), legacy_name);
if (ret) {
legacy_name = ret.value();
}
legacy_name = legacy_name + arg_name.substr(first_grad_pos);
for (size_t pos = 0;
legacy_name.npos != (pos = legacy_name.find(target, pos));
pos += data.length()) {
legacy_name.replace(pos, target.length(), data);
}
return legacy_name;
}

std::string GetLegacyArgName(const std::string& op_type,
const std::string& arg_name) {
if (auto ret = GetDirectMapping(op_type, arg_name)) {
VLOG(10) << "[" << op_type << "] found " << ret.value();
return ret.value();
}

bool is_grad_op = (op_type.find(kPhiGradSuffix) != std::string::npos);
bool is_grad_arg = (arg_name.find(kPhiGradSuffix) != std::string::npos);

if (is_grad_op && is_grad_arg) {
std::string target = kPhiGradSuffix;
std::string data = kFluidVarGradSuffix;

size_t first_grad_pos = arg_name.find(target);
size_t type_pos = op_type.find(target);
std::string legacy_name = this->GetLegacyArgName(
op_type.substr(0, type_pos), arg_name.substr(0, first_grad_pos));
legacy_name += arg_name.substr(first_grad_pos);
for (size_t pos = 0;
legacy_name.npos != (pos = legacy_name.find(target, pos));
pos += data.length()) {
legacy_name.replace(pos, target.length(), data);
if (auto ret = GetGradNameMapping(op_type, arg_name)) {
VLOG(10) << "[" << op_type << "] found " << ret.value();
return ret.value();
}
return legacy_name;
} else if (is_grad_op && !is_grad_arg) {
// backwward op using forward args: like trace_grad using forward input
size_t type_pos = op_type.find(kPhiGradSuffix);
std::string legacy_name =
this->GetLegacyArgName(op_type.substr(0, type_pos), arg_name);

return legacy_name;
}
if (op_arg_name_mappings.find(op_type) == op_arg_name_mappings.end()) {
return arg_name;
}
auto& arg_mappings = op_arg_name_mappings[op_type];
if (arg_mappings.find(arg_name) == arg_mappings.end()) {
return arg_name;
if (auto ret = GetDirectMapping(op_type.substr(0, type_pos), arg_name)) {
VLOG(10) << "[" << op_type << "] found " << ret.value();
return ret.value();
}
}
return arg_mappings.at(arg_name);

VLOG(10) << "[" << op_type << "] not found mapping for " << arg_name;
return arg_name;
}

std::string GetLegacyAttrName(const std::string& op_type,
Expand Down
Loading