Skip to content

Commit

Permalink
[PIR] fix onednn layout transform yaml format (#60680)
Browse files Browse the repository at this point in the history
* fix onednn layout transform yaml format
  • Loading branch information
wanghuancoder authored Jan 11, 2024
1 parent 04ab9a6 commit ec174f3
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -276,25 +276,27 @@ OneDNNPhiKernelInstruction::OneDNNPhiKernelInstruction(
VLOG(6) << "finish process no need buffer";

// Step2: build layout_transform information
if (op_attributes.count("layout_transform_arg")) {
auto layout_transform_arg = op_attributes.at("layout_transform_arg")
.dyn_cast<pir::StrAttribute>()
.AsString();
auto data_layout = op_attributes.at(layout_transform_arg)
.dyn_cast<pir::StrAttribute>()
.AsString();
input_layout_ = common::StringToDataLayout(data_layout);
std::vector<pir::Attribute> layout_transform_inputs_attr =
if (op_attributes.count("data_format_tensors")) {
if (op_attributes.count("data_format")) {
auto data_layout = op_attributes.at("data_format")
.dyn_cast<pir::StrAttribute>()
.AsString();
input_layout_ = common::StringToDataLayout(data_layout);
} else {
input_layout_ = phi::OneDNNContext::tls().get_cur_paddle_data_layout();
}

std::vector<pir::Attribute> data_format_tensors_attr =
op->attributes()
.at("layout_transform_inputs")
.at("data_format_tensors")
.dyn_cast<pir::ArrayAttribute>()
.AsVector();
std::vector<std::string> layout_transform_inputs;
for (auto& attr : layout_transform_inputs_attr) {

for (auto& attr : data_format_tensors_attr) {
auto pair = kernel_context_.InputRangeAt(value_exec_info_->GetIdByName(
attr.dyn_cast<pir::StrAttribute>().AsString()));
for (int i = pair.first; i < pair.second; ++i) {
layout_transform_inputs_.insert(i);
data_format_tensors_.insert(i);
}
}
}
Expand Down Expand Up @@ -333,7 +335,7 @@ void OneDNNPhiKernelInstruction::Run() {

// Handle 'layout_transform' in
// ops_onednn_extra.yaml(GetKernelTypeForVar)
if (layout_transform_inputs_.count(i) &&
if (data_format_tensors_.count(i) &&
input_layout_ != phi::DataLayout::kAnyLayout) {
from_layout = input_layout_;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class OneDNNPhiKernelInstruction : public InstructionBase {

const ValueExecutionInfo* value_exec_info_; // not owned

std::set<int> layout_transform_inputs_{};
std::set<int> data_format_tensors_{};
phi::DataLayout input_layout_{phi::DataLayout::kAnyLayout};
std::map<std::string, phi::Attribute> extra_attr_{};
std::map<std::string, std::vector<std::string>> inputs_{};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,25 +202,27 @@ OneDNNLegacyKernelInstruction::OneDNNLegacyKernelInstruction(
VLOG(6) << "finish process no need buffer";

// Step2: build layout_transform information
if (op_attributes.count("layout_transform_arg")) {
auto layout_transform_arg = op_attributes.at("layout_transform_arg")
.dyn_cast<pir::StrAttribute>()
.AsString();
auto data_layout = op_attributes.at(layout_transform_arg)
.dyn_cast<pir::StrAttribute>()
.AsString();
input_layout_ = common::StringToDataLayout(data_layout);
std::vector<pir::Attribute> layout_transform_inputs_attr =
if (op_attributes.count("data_format_tensors")) {
if (op_attributes.count("data_format")) {
auto data_layout = op_attributes.at("data_format")
.dyn_cast<pir::StrAttribute>()
.AsString();
input_layout_ = common::StringToDataLayout(data_layout);
} else {
input_layout_ = phi::OneDNNContext::tls().get_cur_paddle_data_layout();
}

std::vector<pir::Attribute> data_format_tensors_attr =
op->attributes()
.at("layout_transform_inputs")
.at("data_format_tensors")
.dyn_cast<pir::ArrayAttribute>()
.AsVector();
std::vector<std::string> layout_transform_inputs;
auto& op_normalizer = paddle::translator::OpNameNormalizer::instance();
std::string fluid_op_name = yaml_info_parser.GetOriginOpName();
for (auto& attr : layout_transform_inputs_attr) {
for (auto& attr : data_format_tensors_attr) {
auto input_name = attr.dyn_cast<pir::StrAttribute>().AsString();
layout_transform_inputs_.insert(
data_format_tensors_.insert(
op_normalizer.GetLegacyArgName(fluid_op_name, input_name));
}
}
Expand Down Expand Up @@ -249,7 +251,7 @@ void OneDNNLegacyKernelInstruction::Run() {

// Handle 'layout_transform' in
// ops_onednn_extra.yaml(GetKernelTypeForVar)
if (layout_transform_inputs_.count(*input_name) &&
if (data_format_tensors_.count(*input_name) &&
input_layout_ != phi::DataLayout::kAnyLayout) {
from_layout = input_layout_;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class OneDNNLegacyKernelInstruction : public InstructionBase {

const ValueExecutionInfo* value_exec_info_; // not owned

std::set<std::string> layout_transform_inputs_{};
std::set<std::string> data_format_tensors_{};
phi::DataLayout input_layout_{phi::DataLayout::kAnyLayout};
};

Expand Down
41 changes: 19 additions & 22 deletions paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from op_kerneltype_gen import gen_kernel_type_for_var_str
from op_member_func_gen import gen_op_get_inputs_outputs_str
from op_verify_gen import gen_verify_func_str
from ops_onednn_extra_parser import parse_extra_args, parse_layout_transform
from ops_onednn_extra_parser import parse_data_format_tensors, parse_extra_args
from parse_kernel_key_gen import gen_parse_kernel_key_str
from vjp_interface_black_list import vjp_interface_black_list

Expand Down Expand Up @@ -233,7 +233,7 @@ class {TEST_API} {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{
std::vector<paddle::dialect::OpInputInfo> inputs = {{ {inputs} }};
std::vector<paddle::dialect::OpAttributeInfo> attributes = {{ {attributes} }};
std::vector<paddle::dialect::OpOutputInfo> outputs = {{ {outputs} }};
paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, "{kernel_func}", {{"{kernel_param}"}}, {{{kernel_key_dtype}}}, {{{kernel_key_backend}}}, {{{inplace}}}, {{{view}}}, {{{extra_args}}}, "{layout_transform_arg}", {{{layout_transform_inputs}}}, {is_onednn_only}, {dynamic_fallback});
paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, "{kernel_func}", {{"{kernel_param}"}}, {{{kernel_key_dtype}}}, {{{kernel_key_backend}}}, {{{inplace}}}, {{{view}}}, {{{extra_args}}}, {{{data_format_tensors}}}, {is_onednn_only}, {dynamic_fallback});
return std::make_tuple(inputs, attributes, outputs, run_time_info, "{origin_op_name}");
}}
"""
Expand Down Expand Up @@ -490,12 +490,14 @@ def __init__(self, op_yaml_item, op_compat_item):
# OneDNN info
if "extra_args" in self.op_yaml_item:
self.onednn_extra_args = self.op_yaml_item["extra_args"]
self.onednn_layout_transform = self.op_yaml_item["layout_transform"]
self.onednn_data_format_tensors = self.op_yaml_item[
"data_format_tensors"
]
self.is_onednn_only = self.op_yaml_item["is_onednn_only"]
self.dynamic_fallback = self.op_yaml_item["dynamic_fallback"]
else:
self.onednn_extra_args = []
self.onednn_layout_transform = None
self.onednn_data_format_tensors = None
self.is_onednn_only = False
self.dynamic_fallback = False

Expand Down Expand Up @@ -1616,18 +1618,12 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):
extra_args = '"' + '", "'.join(args_name) + '"'
else:
extra_args = ""
if op_info.onednn_layout_transform is None:
layout_transform_arg, layout_transform_inputs = (
"",
"",
)
if op_info.onednn_data_format_tensors is None:
data_format_tensors = ""
else:
(
layout_transform_arg,
layout_transform_inputs,
) = op_info.onednn_layout_transform
layout_transform_inputs = (
'"' + '", "'.join(layout_transform_inputs) + '"'
data_format_tensors = op_info.onednn_data_format_tensors
data_format_tensors = (
'"' + '", "'.join(data_format_tensors) + '"'
)

op_info_func_str = OP_INFO_ONEDNN_TEMPLATE.format(
Expand All @@ -1645,8 +1641,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):
view=view_str,
origin_op_name=op_info.op_yaml_item['name'],
extra_args=extra_args,
layout_transform_arg=layout_transform_arg,
layout_transform_inputs=layout_transform_inputs,
data_format_tensors=data_format_tensors,
is_onednn_only="true"
if op_info.is_onednn_only
else "false",
Expand Down Expand Up @@ -1864,12 +1859,12 @@ def OpGenerator(
item = {}
item["is_onednn_only"] = False
item["extra_args"] = parse_extra_args(op_name, op['extra_args'])
if 'layout_transform' in op:
item["layout_transform"] = parse_layout_transform(
op_name, op['layout_transform']
if 'data_format_tensors' in op:
item["data_format_tensors"] = parse_data_format_tensors(
op_name, op['data_format_tensors']
)
else:
item["layout_transform"] = None
item["data_format_tensors"] = None
if 'dynamic_fallback' in op:
item["dynamic_fallback"] = op['dynamic_fallback']
else:
Expand Down Expand Up @@ -1924,7 +1919,9 @@ def OpGenerator(
onednn_item = ops_onednn_extra_map[op['name']]
op["is_onednn_only"] = onednn_item["is_onednn_only"]
op["extra_args"] = onednn_item["extra_args"]
op["layout_transform"] = onednn_item["layout_transform"]
op["data_format_tensors"] = onednn_item[
"data_format_tensors"
]
op["dynamic_fallback"] = onednn_item["dynamic_fallback"]
op["attrs"] = op["attrs"] + onednn_item["attrs"]
else:
Expand Down
12 changes: 5 additions & 7 deletions paddle/fluid/pir/dialect/op_generator/ops_onednn_extra_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import re
from typing import Any, Dict, List, Tuple
from typing import Dict, List, Tuple


def parse_plain_list(s: str, sep=",") -> List[str]:
Expand Down Expand Up @@ -76,11 +76,9 @@ def parse_extra_args(op_name: str, arguments: str) -> List:
return attrs


def parse_layout_transform(
op_name: str, layout_transform: Dict[str, Any]
def parse_data_format_tensors(
op_name: str, data_format_tensors: str
) -> Tuple[str, List]:
if layout_transform is None:
if data_format_tensors is None:
return "", []
return layout_transform["arg_name"], parse_plain_list(
layout_transform["tensors"]
)
return parse_plain_list(data_format_tensors)
16 changes: 4 additions & 12 deletions paddle/fluid/pir/dialect/operator/ir/ops_onednn_extra.yaml
Original file line number Diff line number Diff line change
@@ -1,27 +1,19 @@

- op : conv2d
extra_args : bool is_test=false
layout_transform :
arg_name: data_format
tensors: input
data_format_tensors : input

- op : conv2d_grad
extra_args : bool is_test=false
layout_transform :
arg_name: data_format
tensors: input, out_grad
data_format_tensors : input, out_grad

- op : lrn
extra_args : bool is_test=false
layout_transform :
arg_name: data_format
tensors: x
data_format_tensors : x

- op : lrn_grad
extra_args : bool is_test=false
layout_transform :
arg_name: data_format
tensors: x, out, mid_out, out_grad
data_format_tensors : x, out, mid_out, out_grad

# - op : matmul
# extra_args : str mkldnn_data_type="float32"
Expand Down
9 changes: 3 additions & 6 deletions paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ struct OpRunTimeInfo {
std::vector<std::pair<std::string, std::string>> inplace;
std::vector<std::pair<std::string, std::string>> view;
std::vector<std::string> extra_args;
std::string layout_transform_arg;
std::vector<std::string> layout_transform_inputs;
std::vector<std::string> data_format_tensors;
bool is_onednn_only;
bool dynamic_fallback;

Expand All @@ -108,8 +107,7 @@ struct OpRunTimeInfo {
const std::vector<std::pair<std::string, std::string>>& inplace,
const std::vector<std::pair<std::string, std::string>>& view,
const std::vector<std::string>& extra_args = {},
const std::string& layout_transform_arg = "",
const std::vector<std::string>& layout_transform_inputs = {},
const std::vector<std::string>& data_format_tensors = {},
bool is_onednn_only = false,
bool dynamic_fallback = false)
: infer_meta_func(infer_meta_func),
Expand All @@ -121,8 +119,7 @@ struct OpRunTimeInfo {
inplace(inplace),
view(view),
extra_args(extra_args),
layout_transform_arg(layout_transform_arg),
layout_transform_inputs(layout_transform_inputs),
data_format_tensors(data_format_tensors),
is_onednn_only(is_onednn_only),
dynamic_fallback(dynamic_fallback) {}
};
Expand Down
15 changes: 5 additions & 10 deletions paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2090,18 +2090,13 @@ pir::Operation* BuildKernelOp(
op_attribute.emplace(
"extra_args",
pir::ArrayAttribute::get(pir::IrContext::Instance(), extra_args));
op_attribute.emplace(
"layout_transform_arg",
pir::StrAttribute::get(
ctx, op_info_parser->OpRuntimeInfo().layout_transform_arg));
std::vector<pir::Attribute> layout_transform_inputs;
for (auto& input :
op_info_parser->OpRuntimeInfo().layout_transform_inputs) {
layout_transform_inputs.push_back(pir::StrAttribute::get(ctx, input));
std::vector<pir::Attribute> data_format_tensors;
for (auto& input : op_info_parser->OpRuntimeInfo().data_format_tensors) {
data_format_tensors.push_back(pir::StrAttribute::get(ctx, input));
}
op_attribute.emplace("layout_transform_inputs",
op_attribute.emplace("data_format_tensors",
pir::ArrayAttribute::get(pir::IrContext::Instance(),
layout_transform_inputs));
data_format_tensors));
op_attribute.emplace(
"is_onednn_only",
pir::BoolAttribute::get(
Expand Down

0 comments on commit ec174f3

Please sign in to comment.