diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index c05858fd887e..5e38d2ff6c16 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -447,24 +447,34 @@ def create_input_vars( """Create relax input vars.""" parameters_buffers_constants = OrderedDict() user_inputs = OrderedDict() + torch_symbol_to_relax_var: Dict[str, tvm.tir.Var] = {} + for spec in exported_program.graph_signature.input_specs: name_hint = spec.arg.name if spec.kind is torch.export.graph_signature.InputKind.CONSTANT_TENSOR: - shape = exported_program.tensor_constants[spec.target].shape + torch_shape = exported_program.tensor_constants[spec.target].shape torch_dtype = exported_program.tensor_constants[spec.target].dtype elif spec.kind is torch.export.graph_signature.InputKind.USER_INPUT: for node in exported_program.graph.find_nodes(op="placeholder", target=spec.target): if node.name == name_hint and "tensor_meta" in node.meta: - shape = node.meta["tensor_meta"].shape + torch_shape = node.meta["tensor_meta"].shape torch_dtype = node.meta["tensor_meta"].dtype break else: # PARAMETER or BUFFER - shape = exported_program.state_dict[spec.target].shape + torch_shape = exported_program.state_dict[spec.target].shape torch_dtype = exported_program.state_dict[spec.target].dtype + # TODO(mshr-h): Support range constraints + relax_shape = [ + torch_symbol_to_relax_var.setdefault(str(s), tvm.tir.SizeVar(str(s), "int64")) + if isinstance(s, torch.SymInt) + else s + for s in torch_shape + ] dtype = self._convert_data_type(torch_dtype) - relax_var = relax.Var(name_hint, relax.TensorStructInfo(shape, dtype)) + + relax_var = relax.Var(name_hint, relax.TensorStructInfo(relax_shape, dtype)) if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT: user_inputs[name_hint] = relax_var else: diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index dd4ead9e593e..a3c939fcb64e 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -32,8 +32,8 @@ torch_version = torch.__version__ -def verify_model(torch_model, example_args, binding, expected): - exported_program = export(torch_model, args=example_args) +def verify_model(torch_model, example_args, binding, expected, dynamic_shapes=None): + exported_program = export(torch_model, args=example_args, dynamic_shapes=dynamic_shapes) mod = from_exported_program(exported_program) binding = {k: tvm.nd.array(v) for k, v in binding.items()} @@ -3961,5 +3961,32 @@ def main( verify_model(Topk(), example_args, {}, Expected) +def test_dynamic_shape(): + class DynamicModel(torch.nn.Module): + def forward(self, x1, x2): + return torch.ops.aten.add.Tensor(x1, x2) + + B = tvm.tir.SizeVar("BatchSize", dtype="int64") + + @tvm.script.ir_module + class Expected: + @R.function + def main( + lhs: R.Tensor((B, 4), dtype="float32"), + rhs: R.Tensor((B, 4), dtype="float32"), + ) -> R.Tuple(R.Tensor((B, 4), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((B, 4), dtype="float32") = R.add(lhs, rhs) + gv: R.Tuple(R.Tensor((B, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(2, 4), torch.randn(2, 4)) + batch = torch.export.Dim("batch") + dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}} + + verify_model(DynamicModel(), example_args, {}, Expected, dynamic_shapes=dynamic_shapes) + + if __name__ == "__main__": tvm.testing.main()