Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR】modify comment of 57281 #57354

Merged
merged 1 commit into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions paddle/fluid/pir/dialect/op_generator/api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@

OP_RESULT = 'pir::OpResult'
VECTOR_TYPE = 'pir::VectorType'
INTARRAY_ATTRIBUTE = "paddle::dialect::IntArrayAttribute"


def get_op_class_name(op_name):
Expand Down Expand Up @@ -133,7 +134,7 @@ def _gen_api_inputs(self, op_info):
return ', '.join(ret)

def _gen_api_attrs(
self, op_info, with_default, is_mutable_attr, is_vector_mutable_sttr
self, op_info, with_default, is_mutable_attr, is_vector_mutable_attr
):
name_list = op_info.attribute_name_list
type_list = op_info.attribute_build_arg_type_list
Expand All @@ -149,8 +150,8 @@ def _gen_api_attrs(
if is_mutable_attr and name in mutable_name_list:
if (
mutable_type_list[mutable_name_list.index(name)][0]
== "paddle::dialect::IntArrayAttribute"
and is_vector_mutable_sttr
== INTARRAY_ATTRIBUTE
and is_vector_mutable_attr
):
mutable_attr.append(f'std::vector<{OP_RESULT}> {name}')
else:
Expand Down Expand Up @@ -231,7 +232,7 @@ def _gen_h_file(self, op_info_items, namespaces, h_file_path):
declare_str += self._gen_one_declare(
op_info, op_name, True, False
)
if "paddle::dialect::IntArrayAttribute" in {
if INTARRAY_ATTRIBUTE in {
type[0] for type in op_info.mutable_attribute_type_list
}:
declare_str += self._gen_one_declare(
Expand Down Expand Up @@ -267,10 +268,7 @@ def _gen_in_combine(self, op_info, is_mutable_attr, is_vector_mutable_attr):
type_list = op_info.mutable_attribute_type_list
assert len(name_list) == len(type_list)
for name, type in zip(name_list, type_list):
if (
type[0] == "paddle::dialect::IntArrayAttribute"
and is_vector_mutable_attr
):
if type[0] == INTARRAY_ATTRIBUTE and is_vector_mutable_attr:
op_name = f'{name}_combine_op'
combine_op += COMBINE_OP_TEMPLATE.format(
op_name=op_name, in_name=name
Expand Down Expand Up @@ -400,7 +398,7 @@ def _gen_cpp_file(self, op_info_items, namespaces, cpp_file_path):
impl_str += self._gen_one_impl(
op_info, op_name, True, False
)
if "paddle::dialect::IntArrayAttribute" in {
if INTARRAY_ATTRIBUTE in {
type[0] for type in op_info.mutable_attribute_type_list
}:
impl_str += self._gen_one_impl(
Expand Down
12 changes: 9 additions & 3 deletions paddle/fluid/pir/dialect/op_generator/python_c_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
import argparse
import re

from api_gen import NAMESPACE_TEMPLATE, OP_RESULT, VECTOR_TYPE, CodeGen
from api_gen import (
INTARRAY_ATTRIBUTE,
NAMESPACE_TEMPLATE,
OP_RESULT,
VECTOR_TYPE,
CodeGen,
)

H_FILE_TEMPLATE = """

Expand Down Expand Up @@ -300,7 +306,7 @@ def _gen_cast_attrs(self, op_info, op_name):
mutable_attr_type_list[mutable_attr_name_list.index(name)][
0
]
== "paddle::dialect::IntArrayAttribute"
== INTARRAY_ATTRIBUTE
):
mutable_cast_str = MUTABLE_ATTR_CAST_TEMPLATE.format(
type='std::vector<pir::OpResult>',
Expand Down Expand Up @@ -337,7 +343,7 @@ def _gen_cast_attrs(self, op_info, op_name):
mutable_attr_type_list[mutable_attr_name_list.index(name)][
0
]
== "paddle::dialect::IntArrayAttribute"
== INTARRAY_ATTRIBUTE
):
no_mutable_cast_str += FULL_INT_ARRAY_OP_TEMPLATE.format(
name=name,
Expand Down
13 changes: 13 additions & 0 deletions paddle/fluid/pybind/eager_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,19 @@ bool PyObject_CheckIRVectorOfOpResult(PyObject* obj) {
}
}
return true;
} else if (PyTuple_Check(obj)) {
Py_ssize_t len = PyTuple_Size(obj);
PyObject* item = nullptr;
if (len == 0) {
return false;
}
for (Py_ssize_t i = 0; i < len; i++) {
item = PyTuple_GetItem(obj, i);
if (!PyObject_CheckIROpResult(item)) {
return false;
}
}
return true;
} else {
return false;
}
Expand Down