diff --git a/python/tvm/relax/parser.py b/python/tvm/relax/parser.py index 594ce3b7d7e6..a757971098f8 100644 --- a/python/tvm/relax/parser.py +++ b/python/tvm/relax/parser.py @@ -871,16 +871,25 @@ def parse_call(self, expr: ast.Call) -> Union[tir.PrimExpr, rx.Expr]: self.report_error(f"unsupported function in call: {op}", expr.func_name.span) # parse call attributes if applicable - if isinstance(op, rx.ExternFunc) or (isinstance(op, tvm.ir.Op) and op.attrs_type_key != ""): - attrs_type_key = "DictAttrs" if isinstance(op, rx.ExternFunc) else op.attrs_type_key - kwargs = {} - for key, val in expr.keyword_params.items(): - assert isinstance(key, ast.Constant) and isinstance(key.value, str) - # TODO(@altanh): might need separate attribute parsing eventually - kwargs[key.value] = self.transform_expr(val) - attrs = tvm.ir.attrs.make_node(attrs_type_key, **kwargs) + kwargs = {} + for key, val in expr.keyword_params.items(): + assert isinstance(key, ast.Constant) and isinstance(key.value, str) + # TODO(@altanh): might need separate attribute parsing eventually + kwargs[key.value] = self.transform_expr(val) + + is_default = False + if "attrs_type_key" in kwargs: + attrs_type_key = kwargs["attrs_type_key"] + kwargs.pop("attrs_type_key") + elif isinstance(op, tvm.ir.Op) and op.attrs_type_key != "": + attrs_type_key = op.attrs_type_key else: - attrs = None + attrs_type_key = "DictAttrs" + is_default = True + + attrs = None + if kwargs or not is_default: + attrs = tvm.ir.attrs.make_node(attrs_type_key, **kwargs) return relay.Call(op, args, attrs=attrs, span=self.to_tvm_span(expr.span)) diff --git a/src/printer/relax_script_printer.cc b/src/printer/relax_script_printer.cc index a7359613c28e..dcf27f0d18b0 100644 --- a/src/printer/relax_script_printer.cc +++ b/src/printer/relax_script_printer.cc @@ -194,6 +194,9 @@ Doc RelaxScriptPrinter::VisitNode_(const relay::CallNode* op) { doc << "(" << Doc::Concat(args, Doc::Text(", ")); std::vector attrs = PrintAttrs(op->attrs); + if (op->attrs.defined()) { + attrs.push_back(Doc::Text("attrs_type_key=") << Doc::StrLiteral(op->attrs->GetTypeKey())); + } if (!attrs.empty()) { doc << ", " << Doc::Concat(attrs); } diff --git a/tests/python/relax/test_parser.py b/tests/python/relax/test_parser.py index 93392281ee43..2c503ba488d9 100644 --- a/tests/python/relax/test_parser.py +++ b/tests/python/relax/test_parser.py @@ -435,10 +435,13 @@ def test_call_packed(): def f(x: Tensor[(3, 3), "float32"]): # test that we can intro dim vars z: Tensor[(n, m), "float32"] = relax.call_packed("contrib.my_matmul", x, x, mp=False) + w = relax.call_packed( + "contrib.my_shape_of", x, dtype="int32", attrs_type_key="relay.attrs.ShapeOfAttrs" + ) return z x = f.params[0] - (z_bind,) = f.body.blocks[0].bindings + (z_bind, w_bind) = f.body.blocks[0].bindings check_tensor_var(z_bind.var, ("n", "m"), "float32") assert isinstance(z_bind.value.op, rx.ExternFunc) @@ -446,6 +449,8 @@ def f(x: Tensor[(3, 3), "float32"]): assert "mp" in z_bind.value.attrs and z_bind.value.attrs["mp"] == False assert structural_equal(z_bind.value.args, [x, x]) + assert isinstance(w_bind.value.attrs, relay.op.op_attrs.ShapeOfAttrs) + def test_primexpr_arithmetic(): @rx.script diff --git a/tests/python/relax/test_printer.py b/tests/python/relax/test_printer.py index aa6b3fd8c2b4..5dc9527035e4 100644 --- a/tests/python/relax/test_printer.py +++ b/tests/python/relax/test_printer.py @@ -145,6 +145,9 @@ def test_call_packed(): def foo(x: Tensor[(3, 3), "float32"]): # test that we can intro dim vars z: Tensor[(n, m), "float32"] = relax.call_packed("contrib.my_matmul", x, x, mp=False) + w = relax.call_packed( + "contrib.my_shape_of", x, dtype="int32", attrs_type_key="relay.attrs.ShapeOfAttrs" + ) return z check_roundtrip(foo)