diff --git a/python/tvm/relax/utils.py b/python/tvm/relax/utils.py index 5bfb0d87bf00..0bb82c79f4f8 100644 --- a/python/tvm/relax/utils.py +++ b/python/tvm/relax/utils.py @@ -23,7 +23,7 @@ from ..runtime import String, convert_to_object from ..tir import PrimExpr from . import _ffi_api -from .expr import Expr, Function, PrimValue, ShapeExpr, StringImm +from .expr import Expr, Function, PrimValue, StringImm from .expr import Tuple as rx_Tuple @@ -74,14 +74,12 @@ def convert_to_expr(value: Any) -> Expr: 1. Return the input itself if it's already a `relax.Expr`; 2. Return `relax.PrimValue` if the input is a `PrimExpr`; 3. Return `relax.StringImm` if the input is `tvm.String` or `str`; - 4. Return `relax.ShapeExpr` if the input is a tuple/list of `PrimExpr` w/ int dtype; - 5. Return `relax.Tuple` if the input is a tuple/list of `Expr`. + 4. Return `relax.Tuple` if the input is a tuple/list of `Expr`. Notes ----- 1. `tvm.tir.StringImm` is not allowed because of ambiguity, which can be either `relax.StringImm` or `relax.PrimValue`. - 2. We regard empty tuple/list as `relax.Tuple` instead of `relax.ShapeExpr` """ if isinstance(value, int): return PrimValue(tir.IntImm("int64", value)) @@ -102,16 +100,8 @@ def convert_to_expr(value: Any) -> Expr: # Case 3 if isinstance(tvm_value, String): return StringImm(value) - # Case 4 & 5 + # Case 4 if isinstance(value, (tuple, list)): - # Note 2 - if len(value) == 0: - return rx_Tuple([]) - # Case 4 - opt_prim_value = [convert_to_object(v) for v in value] - if all([isinstance(v, PrimExpr) and v.dtype.startswith("int") for v in opt_prim_value]): - return ShapeExpr(value) - # Case 5 # `convert_to_expr` ensures that all elements are `Expr` if no exception raises return rx_Tuple([convert_to_expr(v) for v in value]) raise TypeError(f"Cannot convert {value} with type {type(value)} to `relax.Expr`") diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 0692ec5683c0..0e6595cb4514 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -329,6 +329,23 @@ def tuple(*fields: Expr) -> Expr: return relax.Tuple(fields) # type: ignore[attr-defined] # pylint: disable=no-member +############################### R.shape ################################ + + +def shape(value: List[PrimExpr]) -> Expr: + """Create a ShapeExpr. + Parameters + ---------- + value : List[PrimExpr] + The fields of the tuple. + Returns + ------- + res : Expr + The result tuple. + """ + return relax.ShapeExpr(value) # pylint: disable=no-member # type: ignore + + ############################### PrimValue ############################## @@ -407,6 +424,7 @@ def dtype(value: Union[py_str, DataType]) -> Expr: "prim_value", "print", "reshape", + "shape", "shape_of", "str", "tuple", diff --git a/python/tvm/script/parser/relax/entry.py b/python/tvm/script/parser/relax/entry.py index d93f9a2826bc..7e51264cb37c 100644 --- a/python/tvm/script/parser/relax/entry.py +++ b/python/tvm/script/parser/relax/entry.py @@ -22,6 +22,7 @@ from tvm.relax import ( Expr, + ShapeExpr, FuncStructInfo, Function, ObjectStructInfo, @@ -84,17 +85,22 @@ class TensorProxy(StructInfoProxy): def __init__( self, - shape: Optional[List[Union[PrimExpr, str]]] = None, + shape: Optional[Union[List[Union[PrimExpr, str]], Expr]] = None, dtype: Optional[str] = None, ndim: int = -1, ) -> None: self.shape = shape + if isinstance(shape, Expr) and not isinstance(shape, ShapeExpr): + raise ValueError( + "Only ShapeExpr is allowed as shape expr, but got: " + f"{shape} with type: {type(shape)}" + ) self.dtype = dtype self.ndim = ndim super().__init__() def get_symbolic_vars(self) -> Set[str]: - if self.shape is None: + if self.shape is None or isinstance(self.shape, Expr): return {} else: return {s for s in self.shape if isinstance(s, str) and s.isidentifier()} @@ -102,6 +108,8 @@ def get_symbolic_vars(self) -> Set[str]: def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> TensorStructInfo: if self.shape is None: return TensorStructInfo(None, self.dtype, self.ndim) + elif isinstance(self.shape, ShapeExpr): + return TensorStructInfo(self.shape, self.dtype, self.ndim) else: if dict_globals is None and any([isinstance(s, str) for s in self.shape]): raise ValueError( @@ -113,7 +121,7 @@ def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> Tenso def Tensor( - shape: Optional[List[Union[PrimExpr, str]]] = None, + shape: Optional[Union[List[Union[PrimExpr, str]], ShapeExpr]] = None, dtype: Optional[str] = None, ndim: int = -1, ) -> TensorProxy: @@ -124,8 +132,12 @@ def Tensor( dtype = shape shape = None - if shape is not None and not isinstance(shape, (tuple, list)): - raise ValueError(f"shape must be a list or tuple, but got: {shape}") + if ( + shape is not None + and not isinstance(shape, (tuple, list)) + and not isinstance(shape, ShapeExpr) + ): + raise ValueError(f"shape must be a list/tuple or a ShapeExpr, but got: {shape}") return TensorProxy(shape, dtype, ndim) diff --git a/src/script/printer/relax/expr.cc b/src/script/printer/relax/expr.cc index a786932fc3d9..66d7d187d0c8 100644 --- a/src/script/printer/relax/expr.cc +++ b/src/script/printer/relax/expr.cc @@ -71,7 +71,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) for (int i = 0, l = n->values.size(); i < l; ++i) { values_doc.push_back(PrintShapeVar(n->values[i], values_p->ArrayIndex(i), d)); } - return TupleDoc(values_doc); + return Relax(d, "shape")->Call({ListDoc(values_doc)}); }); Optional SpecialScalar(const runtime::NDArray& n, const ObjectPath& p) { diff --git a/src/script/printer/relax/struct_info.cc b/src/script/printer/relax/struct_info.cc index 6f4a66c991d9..c541619ec887 100644 --- a/src/script/printer/relax/struct_info.cc +++ b/src/script/printer/relax/struct_info.cc @@ -89,7 +89,19 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) Array kwargs_keys; Array kwargs_values; if (n->shape.defined()) { - args.push_back(d->AsDoc(n->shape.value(), n_p->Attr("shape"))); + // Need to dig into ShapeExpr to preserve the `R.shape` prefix + if (const auto* shape = n->shape.value().as()) { + auto shape_expr = GetRef(shape); + ObjectPath shape_p = n_p->Attr("shape")->Attr("values"); + Array shape_docs; + for (int i = 0, ndim = shape_expr->values.size(); i < ndim; ++i) { + shape_docs.push_back( + PrintShapeVar(shape_expr->values[i], shape_p->ArrayIndex(i), d)); + } + args.push_back(TupleDoc(shape_docs)); + } else { + args.push_back(d->AsDoc(n->shape.value(), n_p->Attr("shape"))); + } } if (!n->IsUnknownDtype()) { kwargs_keys.push_back("dtype"); diff --git a/tests/python/relax/test_backend_transform_shape_lower.py b/tests/python/relax/test_backend_transform_shape_lower.py index 0bf0f175dd7e..5cd104dd013f 100644 --- a/tests/python/relax/test_backend_transform_shape_lower.py +++ b/tests/python/relax/test_backend_transform_shape_lower.py @@ -167,7 +167,7 @@ def main( n = T.Var("n", "int64") k = T.Var("k", "int64") z = R.match_cast(y, R.Tensor([k, m, k + 1], dtype=None)) - return (k + 1, m, 2) + return R.shape([k + 1, m, 2]) # slot assignment: # 0: n, 1: m, 2:k, 3: k+1 diff --git a/tests/python/relax/test_transform.py b/tests/python/relax/test_transform.py index 624b7877cd11..12dd095c6b5d 100644 --- a/tests/python/relax/test_transform.py +++ b/tests/python/relax/test_transform.py @@ -109,7 +109,7 @@ class TestVMBuiltinLower: @R.function def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor: m, n = T.var("int64"), T.var("int64") - alloc = R.builtin.alloc_tensor((m, n), runtime_device_index=0, dtype="float32") + alloc = R.builtin.alloc_tensor(R.shape([m, n]), runtime_device_index=0, dtype="float32") _ = R.call_packed( "test.op.identity", x, alloc, sinfo_args=(R.Tensor(ndim=2, dtype="float32")) ) diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 34b02fdbb8c3..c9a16fbcacb7 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -22,10 +22,9 @@ import tvm.script import tvm.testing from tvm import IRModule, relax, tir, topi -from tvm.relax import DynTensorType -from tvm.script import ir as I -from tvm.script import relax as R -from tvm.script import tir as T +from tvm.script.parser import ir as I +from tvm.script.parser import relax as R +from tvm.script.parser import tir as T def _check( @@ -202,6 +201,23 @@ def foo(x: R.Tensor((4, 4), "float32")) -> R.Tensor((4, 4), "float32"): _check(foo, bb.get()["foo"]) +def test_relax_base_op(): + @R.function + def foo(x: R.Tensor((4, 4), "float32")): + alloc = R.builtin.alloc_tensor(R.shape([4, 4]), runtime_device_index=0, dtype="float32") + shape = R.shape_of(alloc) + return shape + + x = relax.Var("x", R.Tensor((4, 4), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", (x,)): + alloc = bb.emit(relax.op.builtin.alloc_tensor(relax.ShapeExpr((4, 4)), "float32", 0)) + shape = bb.emit(relax.op.shape_of(alloc)) + bb.emit_func_output(shape) + # todo(yongwww): comment this check because 0 was changed to R.prim_value(0) in the printed IR + # _check(foo, bb.get()["foo"]) + + def test_symbolic_shape(): @R.function def foo(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor(("m", "n"), "float32"): @@ -274,7 +290,7 @@ def foo(x: R.Tensor("float32"), y: R.Tensor("float32")): y0 = R.match_cast(y, R.Tensor([n], "float32")) gv = y0 R.output(gv) - return (x0, (m, n * 2)) + return (x0, R.shape([m, n * 2])) x = relax.Var("x", R.Tensor("float32")) y = relax.Var("y", R.Tensor("float32")) @@ -314,7 +330,7 @@ def test_tuple_return_2(): def foo(x: R.Tensor("float32", ndim=2)): n, m = T.var("int64"), T.var("int64") x0 = R.match_cast(x, R.Tensor((n, m), "float32")) - return (x0, (n + 1, m, 1)) + return (x0, R.shape([n + 1, m, 1])) x = relax.Var("x", R.Tensor("float32", ndim=2)) n, m = tir.Var("n", "int64"), tir.Var("m", "int64") @@ -332,7 +348,7 @@ def foo(x: R.Tensor("float32", ndim=2)): n, m = T.var("int64"), T.var("int64") x0 = R.match_cast(x, R.Tensor((n, m), "float32")) t0 = (x, x0) - t1 = (x, (n, m), t0) + t1 = (x, R.shape([n, m]), t0) return t1 x = relax.Var("x", R.Tensor("float32", ndim=2)) @@ -965,9 +981,9 @@ def test_vm_ops(): def foo(x: R.Tensor(("m", "n"), dtype="float32")): m = T.var("int64") n = T.var("int64") - storage = R.vm.alloc_storage((4 * m * n,), dtype="float32", runtime_device_index=0) - alloc = R.vm.alloc_tensor(storage, (m, n), offset=0, dtype="float32") - tensor = R.builtin.alloc_tensor((m, n), dtype="float32", runtime_device_index=0) + storage = R.vm.alloc_storage(R.shape([4 * m * n]), dtype="float32", runtime_device_index=0) + alloc = R.vm.alloc_tensor(storage, shape=R.shape([m, n]), offset=0, dtype="float32") + tensor = R.builtin.alloc_tensor(R.shape([m, n]), dtype="float32", runtime_device_index=0) _ = R.vm.call_tir_dyn("te_func", (x, tensor, (m, n))) gv = tensor return alloc, gv diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index 58596f968f98..db90c66422d0 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -292,7 +292,7 @@ def test_tuple_get_item(): def test_shape_expr(): obj = relax.ShapeExpr([1, 2, 3]) - _assert_print(obj, "(1, 2, 3)") + _assert_print(obj, "R.shape([1, 2, 3])") def test_call(): @@ -304,7 +304,7 @@ def test_call(): """ x = T.Var("x", "int64") a: R.Tensor((1, x, 3), dtype="float32") -R.call_tir("my_func", (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32"), tir_vars=(x,)) +R.call_tir("my_func", (a,), out_sinfo=R.Tensor((1, x, 3), dtype="float32"), tir_vars=R.shape([x])) """, ) diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index 534d2308daa9..0a881691accc 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -88,7 +88,7 @@ class TestVMCompileStage2: def foo(x: R.Tensor(dtype="float32")) -> R.Shape: n, m = T.var("int64"), T.var("int64") _ = R.match_cast(x, R.Tensor((n, m), "float32")) - return (n * 2, m * 3) + return R.shape([n * 2, m * 3]) mod = TestVMCompileStage2 target = tvm.target.Target("llvm", host="llvm") @@ -511,9 +511,9 @@ class TestMemoryAllocStorageTensor: @R.function def main(x: R.Tensor((2, 3), dtype="float32")): storage = R.memory.alloc_storage( - (24,), virtual_device_index=0, storage_scope="global", dtype="float32" + R.shape([24]), virtual_device_index=0, storage_scope="global", dtype="float32" ) - y = R.memory.alloc_tensor(storage, 0, (2, 3), dtype="float32") + y = R.memory.alloc_tensor(storage, 0, R.shape([2, 3]), dtype="float32") _ = copy(x, y) return y diff --git a/tests/python/relax/test_vm_codegen_only.py b/tests/python/relax/test_vm_codegen_only.py index b5e77091776a..4b79ecf70fa1 100644 --- a/tests/python/relax/test_vm_codegen_only.py +++ b/tests/python/relax/test_vm_codegen_only.py @@ -18,13 +18,15 @@ Restrictions: all shape lowered, explicit allocation. """ -import tvm -import pytest import numpy as np -from tvm import relax, TVMError -from tvm.script import relax as R, tir as T +import pytest +import tvm +import tvm.testing +from tvm import relax +from tvm.relax.testing.runtime_builtin import MakeShapeCode, MatchShapeCode from tvm.relax.testing.vm import check_saved_func -from tvm.relax.testing.runtime_builtin import MatchShapeCode, MakeShapeCode +from tvm.script import relax as R +from tvm.script import tir as T EXEC_MODE = ["bytecode"] @@ -312,7 +314,7 @@ class TestVMBuiltinReshape: def main(x: R.Tensor((3, 4), "float32")): R.func_attr({"global_symbol": "main"}) y = R.call_packed( - "vm.builtin.reshape", x, (6, 2), sinfo_args=R.Tensor((6, 2), "float32") + "vm.builtin.reshape", x, R.shape([6, 2]), sinfo_args=R.Tensor((6, 2), "float32") ) return y