Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【NewIR】modify static_op_function vector mutable attr bug (Split op ) #57218

Closed
Closed
Show file tree
Hide file tree
Changes from 66 commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
2c0166c
add reference of lbfgs
xiaoguoguo626807 Aug 11, 2023
37883b2
add reference of lbfgs
xiaoguoguo626807 Aug 11, 2023
185d30b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Aug 18, 2023
92e5303
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Aug 24, 2023
221b70c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Aug 25, 2023
aba6f0e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Aug 28, 2023
bce9b3b
tmp
xiaoguoguo626807 Aug 30, 2023
c2341a5
fix conflict
xiaoguoguo626807 Aug 30, 2023
4d30fdd
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Aug 31, 2023
c8f5864
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 1, 2023
805ceb2
split gen modify
xiaoguoguo626807 Sep 1, 2023
008a8b2
fix conflict
xiaoguoguo626807 Sep 4, 2023
c619c4e
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 4, 2023
23b37ed
add split
zhangbo9674 Sep 4, 2023
9640576
Merge branch 'develop', commit 'refs/pull/56924/head' of https://gith…
xiaoguoguo626807 Sep 4, 2023
000a8db
fix bug
zhangbo9674 Sep 4, 2023
03ec67b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zhangbo9674 Sep 4, 2023
06def9f
fix bug
zhangbo9674 Sep 4, 2023
d9f20d0
test split
xiaoguoguo626807 Sep 5, 2023
c027bbf
Merge branch 'develop', commit 'refs/pull/56924/head' of https://gith…
xiaoguoguo626807 Sep 5, 2023
9df5940
add meta tensor
zhangbo9674 Sep 5, 2023
3924ce3
refine code
zhangbo9674 Sep 5, 2023
d3805c7
fix bug
zhangbo9674 Sep 5, 2023
4405f3c
fix bug
zhangbo9674 Sep 5, 2023
a2fa7be
fix comflict
xiaoguoguo626807 Sep 5, 2023
fa23c5d
Merge branch 'develop', commit 'refs/pull/56973/head' of https://gith…
xiaoguoguo626807 Sep 5, 2023
7273556
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 5, 2023
24115d0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 6, 2023
b8e98c5
Call _C_ops.sum in new ir
0x45f Sep 6, 2023
a516693
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 6, 2023
237c493
fix conflict
xiaoguoguo626807 Sep 6, 2023
fff986c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 6, 2023
49a0678
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 6, 2023
9d3c92b
modify concat kernel choose
xiaoguoguo626807 Sep 6, 2023
addc342
modify ci
xiaoguoguo626807 Sep 7, 2023
697c11f
modify sum zero_dim optest
xiaoguoguo626807 Sep 7, 2023
ecc3d21
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 7, 2023
1bc9720
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 7, 2023
0848775
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
0x45f Sep 7, 2023
56c8666
modify split_with_num api
xiaoguoguo626807 Sep 7, 2023
f29b9d6
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 7, 2023
3c1c6aa
fix conflict
xiaoguoguo626807 Sep 7, 2023
372f43f
fix conflict
xiaoguoguo626807 Sep 7, 2023
8338414
modify split -1
xiaoguoguo626807 Sep 8, 2023
4d2b8c7
fix conflict
xiaoguoguo626807 Sep 8, 2023
7fac313
modify split test
xiaoguoguo626807 Sep 11, 2023
224bc4a
fix conflict
xiaoguoguo626807 Sep 11, 2023
a6a9ee4
fix bug
xiaoguoguo626807 Sep 11, 2023
b10214f
xxx
xiaoguoguo626807 Sep 11, 2023
c3cbb6d
fix conflict
xiaoguoguo626807 Sep 11, 2023
050b58c
delete extra modify
xiaoguoguo626807 Sep 11, 2023
0bd42a8
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 11, 2023
690b2bf
Merge branch 'develop', commit 'refs/pull/56873/head' of https://gith…
xiaoguoguo626807 Sep 12, 2023
c5c045c
modify ci shape error
xiaoguoguo626807 Sep 12, 2023
f613ec0
fix conflict
xiaoguoguo626807 Sep 12, 2023
5a8af83
fix conflict
xiaoguoguo626807 Sep 12, 2023
ebfeae8
fix conflict
xiaoguoguo626807 Sep 12, 2023
b6b5b2b
modify ci
xiaoguoguo626807 Sep 12, 2023
4bd1769
modify ci
xiaoguoguo626807 Sep 12, 2023
00ed7d0
modify ci
xiaoguoguo626807 Sep 12, 2023
7e3b684
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 13, 2023
73f1556
modit mode
xiaoguoguo626807 Sep 13, 2023
dbcad3a
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 13, 2023
c34c876
modify ir_backward
xiaoguoguo626807 Sep 13, 2023
24dc472
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 13, 2023
78d52a6
modify ir_backward
xiaoguoguo626807 Sep 13, 2023
29cc298
fix conflict
xiaoguoguo626807 Sep 14, 2023
b6832b2
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
xiaoguoguo626807 Sep 14, 2023
4315201
comment modify
xiaoguoguo626807 Sep 14, 2023
810aff3
modify eager_utils
xiaoguoguo626807 Sep 14, 2023
f2c41b0
modify ci
xiaoguoguo626807 Sep 14, 2023
b560789
fix conflict
xiaoguoguo626807 Sep 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 88 additions & 18 deletions paddle/fluid/pir/dialect/op_generator/api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,19 +132,29 @@ def _gen_api_inputs(self, op_info):
ret.append(f'{self._type_map[type]} {name}')
return ', '.join(ret)

def _gen_api_attrs(self, op_info, with_default, is_mutable_attr):
def _gen_api_attrs(
self, op_info, with_default, is_mutable_attr, is_vector_mutable_sttr
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_vector_mutable_sttr -> is_vector_mutable_attr

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_vector_mutable_attr?

):
name_list = op_info.attribute_name_list
type_list = op_info.attribute_build_arg_type_list
default_value_list = op_info.attribute_default_value_list
mutable_name_list = op_info.mutable_attribute_name_list
mutable_type_list = op_info.mutable_attribute_type_list
assert len(name_list) == len(type_list) == len(default_value_list)
no_mutable_attr = []
mutable_attr = []
for name, type, default_value in zip(
name_list, type_list, default_value_list
):
if is_mutable_attr and name in mutable_name_list:
mutable_attr.append(f'{OP_RESULT} {name}')
if (
mutable_type_list[mutable_name_list.index(name)][0]
== "paddle::dialect::IntArrayAttribute"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

发现该pr有大量这样的写法,可以把"paddle::dialect::IntArrayAttribute"抽出来变成全局变量,或者是封装成函数,避免硬编码

and is_vector_mutable_sttr
):
mutable_attr.append(f'std::vector<{OP_RESULT}> {name}')
else:
mutable_attr.append(f'{OP_RESULT} {name}')
continue
if with_default and default_value is not None:
if type in ['float', 'double']:
Expand All @@ -158,9 +168,17 @@ def _gen_api_attrs(self, op_info, with_default, is_mutable_attr):
no_mutable_attr.append(f'{type} {name}')
return ', '.join(mutable_attr + no_mutable_attr)

def _gen_api_args(self, op_info, with_default_attr, is_mutable_attr):
def _gen_api_args(
self,
op_info,
with_default_attr,
is_mutable_attr,
is_vector_mutable_attr,
):
inputs = self._gen_api_inputs(op_info)
attrs = self._gen_api_attrs(op_info, with_default_attr, is_mutable_attr)
attrs = self._gen_api_attrs(
op_info, with_default_attr, is_mutable_attr, is_vector_mutable_attr
)
return (inputs + ', ' + attrs).strip(', ')

def _gen_ret_type(self, op_info):
Expand All @@ -187,11 +205,15 @@ def _gen_ret_type(self, op_info):
elif output_num == 0:
return 'void'

def _gen_one_declare(self, op_info, op_name, is_mutable_attr):
def _gen_one_declare(
self, op_info, op_name, is_mutable_attr, is_vector_mutable_attr
):
return API_DECLARE_TEMPLATE.format(
ret_type=self._gen_ret_type(op_info),
api_name=op_name,
args=self._gen_api_args(op_info, True, is_mutable_attr),
args=self._gen_api_args(
op_info, True, is_mutable_attr, is_vector_mutable_attr
),
)

def _gen_h_file(self, op_info_items, namespaces, h_file_path):
Expand All @@ -202,10 +224,19 @@ def _gen_h_file(self, op_info_items, namespaces, h_file_path):
# is wrong, so temporarily skip the automatic generation of these APIs
if self._need_skip(op_info, op_name):
continue
declare_str += self._gen_one_declare(op_info, op_name, False)
declare_str += self._gen_one_declare(
op_info, op_name, False, False
)
if len(op_info.mutable_attribute_name_list) > 0:
declare_str += self._gen_one_declare(op_info, op_name, True)

declare_str += self._gen_one_declare(
op_info, op_name, True, False
)
if "paddle::dialect::IntArrayAttribute" in {
type[0] for type in op_info.mutable_attribute_type_list
}:
declare_str += self._gen_one_declare(
op_info, op_name, True, True
)
body = declare_str
for namespace in reversed(namespaces):
body = NAMESPACE_TEMPLATE.format(namespace=namespace, body=body)
Expand All @@ -215,7 +246,7 @@ def _gen_h_file(self, op_info_items, namespaces, h_file_path):
# =====================================
# Gen impl functions
# =====================================
def _gen_in_combine(self, op_info):
def _gen_in_combine(self, op_info, is_mutable_attr, is_vector_mutable_attr):
name_list = op_info.input_name_list
type_list = op_info.input_type_list
assert len(name_list) == len(type_list)
Expand All @@ -230,6 +261,24 @@ def _gen_in_combine(self, op_info):
combine_op_list.append(op_name)
else:
combine_op_list.append(None)

if is_mutable_attr:
name_list = op_info.mutable_attribute_name_list
type_list = op_info.mutable_attribute_type_list
assert len(name_list) == len(type_list)
for name, type in zip(name_list, type_list):
if (
type[0] == "paddle::dialect::IntArrayAttribute"
and is_vector_mutable_attr
):
op_name = f'{name}_combine_op'
combine_op += COMBINE_OP_TEMPLATE.format(
op_name=op_name, in_name=name
)
combine_op_list.append(op_name)
else:
combine_op_list.append(None)

return combine_op, combine_op_list

def _gen_compute_op_args(
Expand All @@ -239,15 +288,22 @@ def _gen_compute_op_args(
all_attr_list = op_info.attribute_name_list
no_mutable_attr_list = op_info.non_mutable_attribute_name_list
mutable_attr_list = op_info.mutable_attribute_name_list
assert len(input_name_list) == len(in_combine_op_list)
assert len(input_name_list) + len(mutable_attr_list) == len(
in_combine_op_list
) or len(input_name_list) == len(in_combine_op_list)
ret = []
for input_name, combine_op in zip(input_name_list, in_combine_op_list):
if is_mutable_attr:
name_list = input_name_list + mutable_attr_list
else:
name_list = input_name_list

for input_name, combine_op in zip(name_list, in_combine_op_list):
if combine_op is None:
ret.append(input_name)
else:
ret.append(f'{combine_op}.out()')
if is_mutable_attr:
ret += list(mutable_attr_list + no_mutable_attr_list)
ret += list(no_mutable_attr_list)
else:
ret += list(all_attr_list)
return ', '.join(ret)
Expand Down Expand Up @@ -299,9 +355,13 @@ def _gen_return_result(self, ret_list):
elif len(ret_list) == 0:
return 'return;'

def _gen_one_impl(self, op_info, op_name, is_mutable_attr):
def _gen_one_impl(
self, op_info, op_name, is_mutable_attr, is_vector_mutable_attr
):
ret_type = self._gen_ret_type(op_info)
in_combine, in_combine_op_list = self._gen_in_combine(op_info)
in_combine, in_combine_op_list = self._gen_in_combine(
op_info, is_mutable_attr, is_vector_mutable_attr
)
compute_op, op_inst_name = self._gen_compute_op(
op_info, op_name, in_combine_op_list, is_mutable_attr
)
Expand All @@ -315,7 +375,9 @@ def _gen_one_impl(self, op_info, op_name, is_mutable_attr):
ret = API_IMPL_TEMPLATE.format(
ret_type=ret_type,
api_name=op_name,
args=self._gen_api_args(op_info, False, is_mutable_attr),
args=self._gen_api_args(
op_info, False, is_mutable_attr, is_vector_mutable_attr
),
in_combine=in_combine,
compute_op=compute_op,
out_split=out_split,
Expand All @@ -333,9 +395,17 @@ def _gen_cpp_file(self, op_info_items, namespaces, cpp_file_path):
# is wrong, so temporarily skip the automatic generation of these APIs
if self._need_skip(op_info, op_name):
continue
impl_str += self._gen_one_impl(op_info, op_name, False)
impl_str += self._gen_one_impl(op_info, op_name, False, False)
if len(op_info.mutable_attribute_name_list) > 0:
impl_str += self._gen_one_impl(op_info, op_name, True)
impl_str += self._gen_one_impl(
op_info, op_name, True, False
)
if "paddle::dialect::IntArrayAttribute" in {
type[0] for type in op_info.mutable_attribute_type_list
}:
impl_str += self._gen_one_impl(
op_info, op_name, True, True
)
body = impl_str
for namespace in reversed(namespaces):
body = NAMESPACE_TEMPLATE.format(namespace=namespace, body=body)
Expand Down
17 changes: 15 additions & 2 deletions paddle/fluid/pir/dialect/op_generator/op_build_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,22 @@
# limitations under the License.

# generator build function
_INFERMETA_NEED_META_CONFIG = {'SplitInferMeta'}
_INFERMETA_NEED_META_CONFIG = {
'SplitInferMeta',
'SumInferMeta',
'SplitWithNumInferMeta',
'ConcatInferMeta',
'ReduceIntArrayAxisInferMeta',
}

_PREPARE_DATA_WITH_UNKNOW_ATTRIBUTE = {
'SplitOp',
'SumOp',
'SplitWithNumOp',
'ConcatOp',
'MeanOp',
}

_PREPARE_DATA_WITH_UNKNOW_ATTRIBUTE = {'SplitOp'}

OP_BUILD_TEMPLATE = """
void {op_name}::Build({build_args}) {{
Expand Down
Loading