diff --git a/python/tvm/relax/base_py_module.py b/python/tvm/relax/base_py_module.py index 796ab41a1470..eb34ca4d1522 100644 --- a/python/tvm/relax/base_py_module.py +++ b/python/tvm/relax/base_py_module.py @@ -198,7 +198,7 @@ def call_tir(self, tir_func, args, out_sinfo): ) func = self.compiled_tir_funcs[func_name] - out = self._create_output_tensors(out_sinfo) + out = self._create_output_tensors(out_sinfo, args) tvm_args = self._convert_pytorch_to_tvm(args) tvm_out = self._convert_pytorch_to_tvm(out) @@ -222,12 +222,11 @@ def call_dps_packed(self, func_name: str, args, out_sinfo): ) from error func = self.extern_funcs[func_name] - out = self._create_output_tensors(out_sinfo) + out = self._create_output_tensors(out_sinfo, args) tvm_args = self._convert_pytorch_to_tvm(args) tvm_out = self._convert_pytorch_to_tvm(out) func(*tvm_args, *tvm_out) - result = self._convert_tvm_to_pytorch(tvm_out) - return result[0] if len(result) == 1 else result + return out[0] if len(out) == 1 else out def call_py_func(self, func_name: str, args): """Call a Python function stored in the IRModule's pyfuncs.""" @@ -237,22 +236,71 @@ def call_py_func(self, func_name: str, args): converted_args = self._convert_tvm_to_pytorch(args) return py_func(*converted_args) - def _create_output_tensors(self, out_sinfo): - """Create output PyTorch tensors based on shape and type information.""" + def _create_output_tensors(self, out_sinfo, in_args=None): # pylint: disable=import-outside-toplevel import torch sinfo_list = out_sinfo if isinstance(out_sinfo, list) else [out_sinfo] out_tensors = [] for sinfo in sinfo_list: + if isinstance(sinfo, (tuple, list)) and all( + isinstance(x, (int, np.integer)) for x in sinfo + ): + out_tensors.append(torch.zeros(list(map(int, sinfo)), dtype=torch.float32)) + continue + if hasattr(sinfo, "shape") and hasattr(sinfo, "dtype"): - shape = [int(val) for val in sinfo.shape] + concrete_shape = self._infer_concrete_shape_from_args(sinfo.shape, in_args) torch_dtype = self._convert_tvm_dtype_to_torch(sinfo.dtype) - out_tensors.append(torch.empty(shape, dtype=torch_dtype)) - else: - out_tensors.append(torch.empty((1,), dtype=torch.float32)) + out_tensors.append(torch.zeros(concrete_shape, dtype=torch_dtype)) + continue + + out_tensors.append(torch.zeros((1,), dtype=torch.float32)) return out_tensors + def _infer_concrete_shape_from_args(self, shape, in_args): + + concrete = [] + symbolic_positions = [] + for idx, dim in enumerate(shape): + if isinstance(dim, (int, np.integer)): + concrete.append(int(dim)) + elif isinstance(dim, tir.IntImm): + concrete.append(int(dim.value)) + else: + concrete.append(None) + symbolic_positions.append(idx) + + if not symbolic_positions: + return concrete + + candidates = [] + if in_args is not None: + if not isinstance(in_args, (list, tuple)): + in_args = [in_args] + for obj in in_args: + if hasattr(obj, "shape") and isinstance(obj.shape, (tuple, list)): + try: + candidates.append(tuple(int(x) for x in obj.shape)) + continue + except (ValueError, TypeError): + # Skip objects with invalid shapes + pass + + target_ndim = len(shape) + for cand in candidates: + if len(cand) == target_ndim: + for pos in symbolic_positions: + concrete[pos] = cand[pos] + if all(x is not None for x in concrete): + return concrete + + raise ValueError( + "Cannot infer concrete output shape from symbolic shape and inputs. " + "Please provide a concrete `out_sinfo` (e.g., a tuple/list of ints) " + "or ensure input tensors carry shapes that determine output extents." + ) + def _convert_tvm_dtype_to_torch(self, tvm_dtype: str) -> "torch.dtype": """Convert TVM dtype string to PyTorch dtype.""" # pylint: disable=import-outside-toplevel diff --git a/tests/python/relax/test_base_py_module_symbolic_shape.py b/tests/python/relax/test_base_py_module_symbolic_shape.py new file mode 100644 index 000000000000..aa39fe14bf88 --- /dev/null +++ b/tests/python/relax/test_base_py_module_symbolic_shape.py @@ -0,0 +1,367 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import numpy as np +import pytest + +import tvm +from tvm.ir import IRModule +from tvm.relax.base_py_module import BasePyModule +from tvm import tir, relax +from tvm.script import ir as I, tir as T, relax as R + + +def _make_module(): + return IRModule({}) + + +def test_infer_concrete_shape_from_numpy_input(): + mod = _make_module() + bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") + + n = tir.Var("n", "int64") + m = tir.Var("m", "int64") + sym_shape = [n, m] + + x = np.zeros((3, 4), dtype="float32") + inferred = bpm._infer_concrete_shape_from_args(sym_shape, [x]) + assert inferred == [3, 4] + + +def test_infer_concrete_shape_all_concrete_dims(): + mod = _make_module() + bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") + + shape = [tir.IntImm("int32", 5), 6] + inferred = bpm._infer_concrete_shape_from_args(shape, in_args=[]) + assert inferred == [5, 6] + + +def test_infer_concrete_shape_error_when_uninferrable(): + mod = _make_module() + bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") + + k = tir.Var("k", "int64") + with pytest.raises(ValueError): + bpm._infer_concrete_shape_from_args([k, 8], in_args=[]) + + +@I.ir_module +class AddModuleSymbolic(BasePyModule): + @T.prim_func + def add_tir(var_x: T.handle, var_y: T.handle, var_out: T.handle): + T.func_attr({"global_symbol": "add_tir"}) + n = T.int64() + x = T.match_buffer(var_x, (n,), dtype="float32") + y = T.match_buffer(var_y, (n,), dtype="float32") + out = T.match_buffer(var_out, (n,), dtype="float32") + + for i in T.serial(n): + out[i] = x[i] + y[i] + + @R.function + def main_relax( + x: R.Tensor(("n",), "float32"), y: R.Tensor(("n",), "float32") + ) -> R.Tensor(("n",), "float32"): + return R.add(x, y) + + +def test_base_py_module_relax_symbolic_end_to_end(): + bpm = AddModuleSymbolic(device=tvm.cpu(0), target="llvm") + + a = np.random.randn(5).astype("float32") + b = np.random.randn(5).astype("float32") + out = bpm.main_relax(a, b) + assert isinstance(out, np.ndarray) or hasattr(out, "numpy") + out_np = out if isinstance(out, np.ndarray) else out.numpy() + np.testing.assert_allclose(out_np, a + b, rtol=1e-6, atol=1e-6) + + a7 = np.random.randn(7).astype("float32") + b7 = np.random.randn(7).astype("float32") + out2 = bpm.main_relax(a7, b7) + out2_np = out2 if isinstance(out2, np.ndarray) else out2.numpy() + np.testing.assert_allclose(out2_np, a7 + b7, rtol=1e-6, atol=1e-6) + + +def test_base_py_module_tir_symbolic_end_to_end(): + bpm = AddModuleSymbolic(device=tvm.cpu(0), target="llvm") + + a = np.random.randn(5).astype("float32") + b = np.random.randn(5).astype("float32") + + n = tir.Var("n", "int64") + out_sinfo = relax.TensorStructInfo((n,), "float32") + + out = bpm.call_tir("add_tir", [a, b], out_sinfo) + out_np = out if isinstance(out, np.ndarray) else out.numpy() + np.testing.assert_allclose(out_np, a + b, rtol=1e-6, atol=1e-6) + + +def test_infer_concrete_shape_multiple_symbolic_dims(): + """Test shape inference with multiple symbolic dimensions.""" + mod = _make_module() + bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") + + n = tir.Var("n", "int64") + m = tir.Var("m", "int64") + k = tir.Var("k", "int64") + sym_shape = [n, m, k] + + x = np.zeros((2, 3, 4), dtype="float32") + inferred = bpm._infer_concrete_shape_from_args(sym_shape, [x]) + assert inferred == [2, 3, 4] + + +def test_infer_concrete_shape_mixed_concrete_symbolic(): + """Test shape inference with mixed concrete and symbolic dimensions.""" + mod = _make_module() + bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") + + n = tir.Var("n", "int64") + sym_shape = [n, 5, 10] # First dim is symbolic, others are concrete + + x = np.zeros((3, 5, 10), dtype="float32") + inferred = bpm._infer_concrete_shape_from_args(sym_shape, [x]) + assert inferred == [3, 5, 10] + + +def test_infer_concrete_shape_from_tvm_tensors(): + """Test shape inference from TVM tensors.""" + try: + # Try to create TVM tensor using new API + x_np = np.zeros((3, 4), dtype="float32") + x_tvm = tvm.runtime.tensor(x_np) + + mod = _make_module() + bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") + + n = tir.Var("n", "int64") + m = tir.Var("m", "int64") + sym_shape = [n, m] + + inferred = bpm._infer_concrete_shape_from_args(sym_shape, [x_tvm]) + assert inferred == [3, 4] + except AttributeError: + # Skip if tvm.runtime.tensor is not available + pytest.skip("tvm.runtime.tensor not available") + + +def test_infer_concrete_shape_multiple_inputs(): + """Test shape inference when multiple inputs are available.""" + mod = _make_module() + bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") + + n = tir.Var("n", "int64") + m = tir.Var("m", "int64") + sym_shape = [n, m] + + # Multiple inputs with different shapes - should use first matching one + x1 = np.zeros((2, 3), dtype="float32") + x2 = np.zeros((4, 5), dtype="float32") + inferred = bpm._infer_concrete_shape_from_args(sym_shape, [x1, x2]) + assert inferred == [2, 3] # Should use first input + + +def test_infer_concrete_shape_wrong_ndim(): + """Test shape inference when input has wrong number of dimensions.""" + mod = _make_module() + bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") + + n = tir.Var("n", "int64") + m = tir.Var("m", "int64") + sym_shape = [n, m] # 2D + + x = np.zeros((3,), dtype="float32") # 1D - wrong ndim + with pytest.raises(ValueError, match="Cannot infer concrete output shape"): + bpm._infer_concrete_shape_from_args(sym_shape, [x]) + + +@I.ir_module +class MatrixModuleSymbolic(BasePyModule): + @T.prim_func + def matmul_tir(var_a: T.handle, var_b: T.handle, var_c: T.handle): + T.func_attr({"global_symbol": "matmul_tir"}) + m = T.int64() + n = T.int64() + k = T.int64() + a = T.match_buffer(var_a, (m, k), dtype="float32") + b = T.match_buffer(var_b, (k, n), dtype="float32") + c = T.match_buffer(var_c, (m, n), dtype="float32") + + for i in T.serial(m): + for j in T.serial(n): + c[i, j] = 0.0 + for l in T.serial(k): + c[i, j] = c[i, j] + a[i, l] * b[l, j] + + @R.function + def matmul_relax( + a: R.Tensor(("m", "k"), "float32"), b: R.Tensor(("k", "n"), "float32") + ) -> R.Tensor(("m", "n"), "float32"): + return R.matmul(a, b) + + +def test_base_py_module_multiple_symbolic_dims(): + """Test BasePyModule with multiple symbolic dimensions.""" + bpm = MatrixModuleSymbolic(device=tvm.cpu(0), target="llvm") + + # Test Relax function with multiple symbolic dims + a = np.random.randn(2, 3).astype("float32") + b = np.random.randn(3, 4).astype("float32") + out = bpm.matmul_relax(a, b) + out_np = out if isinstance(out, np.ndarray) else out.numpy() + expected = np.matmul(a, b) + np.testing.assert_allclose(out_np, expected, rtol=1e-6, atol=1e-6) + + # Test TIR function with multiple symbolic dims + # Use concrete shapes for TIR function to avoid constraint issues + out_sinfo = relax.TensorStructInfo((2, 4), "float32") + out_tir = bpm.call_tir("matmul_tir", [a, b], out_sinfo) + out_tir_np = out_tir if isinstance(out_tir, np.ndarray) else out_tir.numpy() + np.testing.assert_allclose(out_tir_np, expected, rtol=1e-6, atol=1e-6) + + +def test_base_py_module_call_dps_packed_symbolic(): + """Test call_dps_packed with symbolic shapes.""" + try: + # Register a simple test function + @tvm.register_global_func("test_add_packed") + def test_add_packed(a, b, out): + """Add two tensors element-wise.""" + a_np = a.numpy() + b_np = b.numpy() + result = a_np + b_np + out[:] = result + + mod = _make_module() + bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") + + a = np.random.randn(5).astype("float32") + b = np.random.randn(5).astype("float32") + + n = tir.Var("n", "int64") + out_sinfo = relax.TensorStructInfo((n,), "float32") + + out = bpm.call_dps_packed("test_add_packed", [a, b], out_sinfo) + out_np = out if isinstance(out, np.ndarray) else out.numpy() + np.testing.assert_allclose(out_np, a + b, rtol=1e-6, atol=1e-6) + + except AttributeError as e: + pytest.skip(f"call_dps_packed test requires register_global_func: {e}") + + +def test_base_py_module_call_dps_packed_multiple_args(): + """Test call_dps_packed with multiple arguments and symbolic shapes.""" + try: + # Register a function that takes multiple arguments + @tvm.register_global_func("test_matmul_packed") + def test_matmul_packed(a, b, out): + """Matrix multiplication.""" + a_np = a.numpy() + b_np = b.numpy() + result = np.matmul(a_np, b_np) + out[:] = result + + mod = _make_module() + bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") + + a = np.random.randn(2, 3).astype("float32") + b = np.random.randn(3, 4).astype("float32") + + out_sinfo = relax.TensorStructInfo((2, 4), "float32") + + out = bpm.call_dps_packed("test_matmul_packed", [a, b], out_sinfo) + out_np = out if isinstance(out, np.ndarray) else out.numpy() + expected = np.matmul(a, b) + np.testing.assert_allclose(out_np, expected, rtol=1e-6, atol=1e-6) + + except AttributeError as e: + pytest.skip(f"call_dps_packed test requires register_global_func: {e}") + + +def test_base_py_module_call_dps_packed_scalar_args(): + """Test call_dps_packed with scalar arguments and symbolic shapes.""" + try: + # Register a function that takes scalar arguments + @tvm.register_global_func("test_add_scalar_packed") + def test_add_scalar_packed(x, scalar, out): + """Add scalar to tensor.""" + x_np = x.numpy() + if hasattr(scalar, "numpy"): + scalar_val = scalar.numpy() + else: + scalar_val = scalar + result = x_np + scalar_val + out[:] = result + + mod = _make_module() + bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") + + x = np.random.randn(4).astype("float32") + scalar = 2.5 + + n = tir.Var("n", "int64") + out_sinfo = relax.TensorStructInfo((n,), "float32") + + out = bpm.call_dps_packed("test_add_scalar_packed", [x, scalar], out_sinfo) + out_np = out if isinstance(out, np.ndarray) else out.numpy() + expected = x + scalar + np.testing.assert_allclose(out_np, expected, rtol=1e-6, atol=1e-6) + + except AttributeError as e: + pytest.skip(f"call_dps_packed test requires register_global_func: {e}") + + +def test_infer_concrete_shape_from_pytorch_tensors(): + """Test shape inference from PyTorch tensors (if available).""" + try: + import torch + except ImportError: + pytest.skip("PyTorch not available") + + mod = _make_module() + bpm = BasePyModule(mod, device=tvm.cpu(0), target="llvm") + + n = tir.Var("n", "int64") + m = tir.Var("m", "int64") + sym_shape = [n, m] + + x_torch = torch.zeros((3, 4), dtype=torch.float32) + inferred = bpm._infer_concrete_shape_from_args(sym_shape, [x_torch]) + assert inferred == [3, 4] + + +def test_base_py_module_relax_with_pytorch_tensors(): + """Test Relax functions with PyTorch tensors and symbolic shapes.""" + try: + import torch + except ImportError: + pytest.skip("PyTorch not available") + + bpm = AddModuleSymbolic(device=tvm.cpu(0), target="llvm") + + a_torch = torch.randn(5, dtype=torch.float32) + b_torch = torch.randn(5, dtype=torch.float32) + + out = bpm.main_relax(a_torch, b_torch) + out_np = out if isinstance(out, np.ndarray) else out.numpy() + expected = a_torch.numpy() + b_torch.numpy() + np.testing.assert_allclose(out_np, expected, rtol=1e-6, atol=1e-6) + + +if __name__ == "__main__": + tvm.testing.main()