Skip to content

Commit

Permalink
【Error Message No. 2】paddle/fluid/pir/dialect/op_generator/* (#62773)
Browse files Browse the repository at this point in the history
* fix

* fix
  • Loading branch information
enkilee authored Mar 18, 2024
1 parent e7084fb commit edc3bfc
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 70 deletions.
40 changes: 25 additions & 15 deletions paddle/fluid/pir/dialect/op_generator/op_build_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,36 +838,46 @@ def gen_build_func_str(
)

GET_ATTRIBUTES_FROM_MAP_TEMPLATE = """
IR_ENFORCE(
attributes.find("{attribute_name}") != attributes.end(),
"'{attribute_name}' Attribute is expected for {op_name}. ");
PADDLE_ENFORCE_NE(
attributes.find("{attribute_name}"),
attributes.end(),
phi::errors::InvalidArgument(
"'{attribute_name}' Attribute is expected for {op_name}. "));
{attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast<{attr_ir_type}>().data();
"""
GET_STR_ATTRIBUTES_FROM_MAP_TEMPLATE = """
IR_ENFORCE(
attributes.find("{attribute_name}") != attributes.end(),
"'{attribute_name}' Attribute is expected for {op_name}. ");
PADDLE_ENFORCE_NE(
attributes.find("{attribute_name}"),
attributes.end(),
phi::errors::InvalidArgument(
"'{attribute_name}' Attribute is expected for {op_name}. "));
{attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast<pir::StrAttribute>().AsString();
"""
GET_ARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE = """
IR_ENFORCE(
attributes.find("{attribute_name}") != attributes.end(),
"'{attribute_name}' Attribute is expected for {op_name}. ");
PADDLE_ENFORCE_NE(
attributes.find("{attribute_name}"),
attributes.end(),
phi::errors::InvalidArgument(
"'{attribute_name}' Attribute is expected for {op_name}. "));
{attr_type} {attribute_name};
for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast<pir::ArrayAttribute>().size(); i++) {{
{attribute_name}.push_back(attributes.at("{attribute_name}").dyn_cast<pir::ArrayAttribute>().at(i).dyn_cast<{inner_type}>().{data_name}());
}}
"""
GET_INTARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE = """
IR_ENFORCE(
attributes.find("{attribute_name}") != attributes.end(),
"'{attribute_name}' Attribute is expected for {op_name}. ");
PADDLE_ENFORCE_NE(
attributes.find("{attribute_name}"),
attributes.end(),
phi::errors::InvalidArgument(
"'{attribute_name}' Attribute is expected for {op_name}. "));
{attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast<paddle::dialect::IntArrayAttribute>().data().GetData();
"""
GET_SCALAR_ATTRIBUTE_FROM_MAP_TEMPLATE = """
IR_ENFORCE(
attributes.find("{attribute_name}") != attributes.end(),
"'{attribute_name}' Attribute is expected for {op_name}. ");
PADDLE_ENFORCE_NE(
attributes.find("{attribute_name}"),
attributes.end(),
phi::errors::InvalidArgument(
"'{attribute_name}' Attribute is expected for {op_name}. "));
{attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast<paddle::dialect::ScalarAttribute>().data().to<{attr_type}>();
"""

Expand Down
44 changes: 27 additions & 17 deletions paddle/fluid/pir/dialect/op_generator/op_infermeta_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@
pir::Value {input_name}_ = input_values[{index}]; (void){input_name}_;"""

ENFORCE_INPUT_NUM_TEMPLATE = """
IR_ENFORCE(input_values.size() == {op_input_name_list_size},
"Num of inputs is expected to be {op_input_name_list_size} but got %d.", input_values.size());
PADDLE_ENFORCE_EQ(input_values.size() == {op_input_name_list_size}, true, phi::errors::InvalidArgument(
"Num of inputs is expected to be {op_input_name_list_size} but got %d.", input_values.size()));
"""

GET_INPUT_TYPE_TEMPLATE = """
Expand Down Expand Up @@ -492,36 +492,46 @@ def GetAttributes(
attr_args_is_map,
):
GET_ATTRIBUTES_FROM_MAP_TEMPLATE = """
IR_ENFORCE(
attributes.find("{attribute_name}") != attributes.end(),
"'{attribute_name}' Attribute is expected for {op_name}. ");
PADDLE_ENFORCE_NE(
attributes.find("{attribute_name}"),
attributes.end(),
phi::errors::InvalidArgument(
"'{attribute_name}' Attribute is expected for {op_name}. "));
{attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast<{attr_ir_type}>().data();
"""
GET_STR_ATTRIBUTES_FROM_MAP_TEMPLATE = """
IR_ENFORCE(
attributes.find("{attribute_name}") != attributes.end(),
"'{attribute_name}' Attribute is expected for {op_name}. ");
PADDLE_ENFORCE_NE(
attributes.find("{attribute_name}"),
attributes.end(),
phi::errors::InvalidArgument(
"'{attribute_name}' Attribute is expected for {op_name}. "));
{attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast<pir::StrAttribute>().AsString();
"""
GET_ARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE = """
IR_ENFORCE(
attributes.find("{attribute_name}") != attributes.end(),
"'{attribute_name}' Attribute is expected for {op_name}. ");
PADDLE_ENFORCE_NE(
attributes.find("{attribute_name}"),
attributes.end(),
phi::errors::InvalidArgument(
"'{attribute_name}' Attribute is expected for {op_name}. "));
{attr_type} {attribute_name};
for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast<pir::ArrayAttribute>().size(); i++) {{
{attribute_name}.push_back(attributes.at("{attribute_name}").dyn_cast<pir::ArrayAttribute>().at(i).dyn_cast<{inner_type}>().{data_name}());
}}
"""
GET_INTARRAY_ATTRIBUTE_FROM_MAP_TEMPLATE = """
IR_ENFORCE(
attributes.find("{attribute_name}") != attributes.end(),
"'{attribute_name}' Attribute is expected for {op_name}. ");
PADDLE_ENFORCE_NE(
attributes.find("{attribute_name}"),
attributes.end(),
phi::errors::InvalidArgument(
"'{attribute_name}' Attribute is expected for {op_name}. "));
{attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast<paddle::dialect::IntArrayAttribute>().data().GetData();
"""
GET_SCALAR_ATTRIBUTE_FROM_MAP_TEMPLATE = """
IR_ENFORCE(
attributes.find("{attribute_name}") != attributes.end(),
"'{attribute_name}' Attribute is expected for {op_name}. ");
PADDLE_ENFORCE_NE(
attributes.find("{attribute_name}"),
attributes.end(),
phi::errors::InvalidArgument(
"'{attribute_name}' Attribute is expected for {op_name}. "));
{attr_type} {attribute_name} = attributes.at("{attribute_name}").dyn_cast<paddle::dialect::ScalarAttribute>().data().to<{attr_type}>();
"""

Expand Down
76 changes: 38 additions & 38 deletions paddle/fluid/pir/dialect/op_generator/op_verify_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@
VLOG(4) << "Verifying inputs:";
{{
auto input_size = num_operands();
IR_ENFORCE(input_size == {inputs_size}u,
"The size %d of inputs must be equal to {inputs_size}.", input_size);{inputs_type_check}
PADDLE_ENFORCE_EQ(input_size == {inputs_size}u, true, phi::errors::InvalidArgument(
"The size %d of inputs must be equal to {inputs_size}.", input_size));{inputs_type_check}
}}
VLOG(4) << "Verifying attributes:";
{{{attributes_check}
}}
VLOG(4) << "Verifying outputs:";
{{
auto output_size = num_results();
IR_ENFORCE(output_size == {outputs_size}u,
"The size %d of outputs must be equal to {outputs_size}.", output_size);{outputs_type_check}
PADDLE_ENFORCE_EQ(output_size == {outputs_size}u, true, phi::errors::InvalidArgument(
"The size %d of outputs must be equal to {outputs_size}.", output_size));{outputs_type_check}
}}
VLOG(4) << "End Verifying for: {op_name}.";
}}
Expand All @@ -40,83 +40,83 @@
"""

INPUT_TYPE_CHECK_TEMPLATE = """
IR_ENFORCE((*this)->operand_source({index}).type().isa<{standard}>(),
"Type validation failed for the {index}th input, got %s.", (*this)->operand_source({index}).type());"""
PADDLE_ENFORCE_EQ((*this)->operand_source({index}).type().isa<{standard}>(), true,
phi::errors::InvalidArgument("Type validation failed for the {index}th input, got %s.", (*this)->operand_source({index}).type()));"""
INPUT_VECTORTYPE_CHECK_TEMPLATE = """
if (auto vec_type = (*this)->operand_source({index}).type().dyn_cast<pir::VectorType>()) {{
for (size_t i = 0; i < vec_type.size(); ++i) {{
IR_ENFORCE(vec_type[i].isa<{standard}>(),
"Type validation failed for the {index}th input, got %s.", (*this)->operand_source({index}).type());
PADDLE_ENFORCE_EQ(vec_type[i].isa<{standard}>(), true, phi::errors::InvalidArgument(
"Type validation failed for the {index}th input, got %s.", (*this)->operand_source({index}).type()));
}}
}}
else {{
IR_ENFORCE((*this)->operand_source({index}).type().isa<{standard}>(),
"Type validation failed for the {index}th input, got %s.", (*this)->operand_source({index}).type());
PADDLE_ENFORCE_EQ((*this)->operand_source({index}).type().isa<{standard}>(), true, phi::errors::InvalidArgument(
"Type validation failed for the {index}th input, got %s.", (*this)->operand_source({index}).type()));
}}"""
INPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """
if (auto val = (*this)->operand({index})) {{
IR_ENFORCE(val.type().isa<{standard}>(),
"Type validation failed for the {index}th input, got %s.", (*this)->operand_source({index}).type());
PADDLE_ENFORCE_EQ(val.type().isa<{standard}>(), true, phi::errors::InvalidArgument(
"Type validation failed for the {index}th input, got %s.", (*this)->operand_source({index}).type()));
}}"""
INPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """
if (auto val = (*this)->operand({index})) {{
if (auto vec_type = val.type().dyn_cast<pir::VectorType>()) {{
for (size_t i = 0; i < vec_type.size(); i++) {{
IR_ENFORCE(vec_type[i].isa<{standard}>(),
"Type validation failed for the {index}th input, got %s.", (*this)->operand_source({index}).type());
PADDLE_ENFORCE_EQ(vec_type[i].isa<{standard}>(), true, phi::errors::InvalidArgument(
"Type validation failed for the {index}th input, got %s.", (*this)->operand_source({index}).type()));
}}
}}
else {{
IR_ENFORCE(val.type().isa<{standard}>(),
"Type validation failed for the {index}th input, got %s.", (*this)->operand_source({index}).type());
PADDLE_ENFORCE_EQ(val.type().isa<{standard}>(), true, phi::errors::InvalidArgument(
"Type validation failed for the {index}th input, got %s.", (*this)->operand_source({index}).type()));
}}
}}"""
ATTRIBUTE_CHECK_TEMPLATE = """
IR_ENFORCE(attributes.count("{attribute_name}")>0,
"{attribute_name} does not exist.");
IR_ENFORCE(attributes.at("{attribute_name}").isa<{standard}>(),
"Type of attribute: {attribute_name} is not {standard}.");
PADDLE_ENFORCE_GT(attributes.count("{attribute_name}"), 0, phi::errors::InvalidArgument(
"{attribute_name} does not exist."));
PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").isa<{standard}>(), true, phi::errors::InvalidArgument(
"Type of attribute: {attribute_name} is not {standard}."));
"""
ATTRIBUTE_VECTOR_CHECK_TEMPLATE = """
IR_ENFORCE(attributes.count("{attribute_name}")>0,
"{attribute_name} does not exist.");
IR_ENFORCE(attributes.at("{attribute_name}").isa<pir::ArrayAttribute>(),
"Type of attribute: {attribute_name} is not pir::ArrayAttribute.");
PADDLE_ENFORCE_GT(attributes.count("{attribute_name}"), 0, phi::errors::InvalidArgument(
"{attribute_name} does not exist."));
PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").isa<pir::ArrayAttribute>(), true, phi::errors::InvalidArgument(
"Type of attribute: {attribute_name} is not pir::ArrayAttribute."));
for (size_t i = 0; i < attributes.at("{attribute_name}").dyn_cast<pir::ArrayAttribute>().size(); i++) {{
IR_ENFORCE(attributes.at("{attribute_name}").dyn_cast<pir::ArrayAttribute>().at(i).isa<{standard}>(),
"Type of attribute: {attribute_name} is not right.");
PADDLE_ENFORCE_EQ(attributes.at("{attribute_name}").dyn_cast<pir::ArrayAttribute>().at(i).isa<{standard}>(), true, phi::errors::InvalidArgument(
"Type of attribute: {attribute_name} is not right."));
}}"""
OUTPUT_TYPE_CHECK_TEMPLATE = """
IR_ENFORCE((*this)->result({index}).type().isa<{standard}>(),
"Type validation failed for the {index}th output.");"""
PADDLE_ENFORCE_EQ((*this)->result({index}).type().isa<{standard}>(), true, phi::errors::InvalidArgument(
"Type validation failed for the {index}th output."));"""
OUTPUT_VECTORTYPE_CHECK_TEMPLATE = """
auto output_{index}_type = (*this)->result({index}).type();
if (auto vec_type = output_{index}_type.dyn_cast<pir::VectorType>()) {{
for (size_t i = 0; i < vec_type.size(); i++) {{
IR_ENFORCE(vec_type[i].isa<{standard}>(),
"Type validation failed for the {index}th output.");
PADDLE_ENFORCE_EQ(vec_type[i].isa<{standard}>(), true, phi::errors::InvalidArgument(
"Type validation failed for the {index}th output."));
}}
}}
else {{
IR_ENFORCE(output_{index}_type.isa<{standard}>(),
"Type validation failed for the {index}th output.");
PADDLE_ENFORCE_EQ(output_{index}_type.isa<{standard}>(), true, phi::errors::InvalidArgument(
"Type validation failed for the {index}th output."));
}}"""
OUTPUT_OPTIONAL_TYPE_CHECK_TEMPLATE = """
if (auto output_{index}_type = (*this)->result({index}).type()) {{
IR_ENFORCE(output_{index}_type.isa<{standard}>(),
"Type validation failed for the {index}th output.");
PADDLE_ENFORCE_EQ(output_{index}_type.isa<{standard}>(),true, phi::errors::InvalidArgument(
"Type validation failed for the {index}th output."));
}}"""
OUTPUT_OPTIONAL_VECTORTYPE_CHECK_TEMPLATE = """
if (auto output_{index}_type = (*this)->result({index}).type()) {{
if (auto vec_type = output_{index}_type.dyn_cast<pir::VectorType>()) {{
for (size_t i = 0; i < vec_type.size(); ++i) {{
IR_ENFORCE(vec_type[i].isa<{standard}>(),
"Type validation failed for the {index}th output.");
PADDLE_ENFORCE_EQ(vec_type[i].isa<{standard}>(), true, phi::errors::InvalidArgument(
"Type validation failed for the {index}th output."));
}}
}}
else {{
IR_ENFORCE(output_{index}_type.isa<{standard}>(),
"Type validation failed for the {index}th output.");
PADDLE_ENFORCE_EQ(output_{index}_type.isa<{standard}>(), true, phi::errors::InvalidArgument(
"Type validation failed for the {index}th output."));
}}
}}"""

Expand Down

0 comments on commit edc3bfc

Please sign in to comment.