diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 4f3132b8d8f2..2f02f8dfd0dc 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -1097,11 +1097,23 @@ def create_convert_map( def create_input_vars( self, exported_program: torch.export.ExportedProgram - ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var]]: + ) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var], Dict[str, Tuple[int, int]]]: """Create relax input vars.""" parameters_buffers_constants = OrderedDict() user_inputs = OrderedDict() torch_symbol_to_relax_var: Dict[str, tvm.tir.Var] = {} + range_constraints = {} + + if hasattr(exported_program, "range_constraints"): + for symbol, value_range in exported_program.range_constraints.items(): + symbol_name = str(symbol) + if hasattr(value_range, "lower") and hasattr(value_range, "upper"): + try: + lower = int(value_range.lower) + upper = int(value_range.upper) + range_constraints[symbol_name] = (lower, upper) + except (OverflowError, AttributeError, TypeError): + continue for spec in exported_program.graph_signature.input_specs: name_hint = spec.arg.name @@ -1119,7 +1131,6 @@ def create_input_vars( torch_shape = exported_program.state_dict[spec.target].shape torch_dtype = exported_program.state_dict[spec.target].dtype - # TODO(mshr-h): Support range constraints relax_shape = [ torch_symbol_to_relax_var.setdefault(str(s), tvm.tir.SizeVar(str(s), "int64")) if isinstance(s, torch.SymInt) @@ -1134,7 +1145,7 @@ def create_input_vars( else: parameters_buffers_constants[name_hint] = relax_var - return parameters_buffers_constants, user_inputs + return parameters_buffers_constants, user_inputs, range_constraints def from_exported_program( self, @@ -1147,7 +1158,11 @@ def from_exported_program( from torch import fx # type: ignore # Create input variables. - parameter_buffer_constant_vars, user_input_vars = self.create_input_vars(exported_program) + ( + parameter_buffer_constant_vars, + user_input_vars, + range_constraints, + ) = self.create_input_vars(exported_program) inputs_vars = user_input_vars.copy() inputs_vars.update(parameter_buffer_constant_vars) @@ -1155,6 +1170,13 @@ def from_exported_program( self.block_builder = relax.BlockBuilder() func_name = "main" func_attrs = {"num_input": len(user_input_vars)} if keep_params_as_input else None + if range_constraints: + if func_attrs is None: + func_attrs = {} + tir_var_upper_bound = { + var_name: upper for var_name, (_, upper) in range_constraints.items() + } + func_attrs["tir_var_upper_bound"] = tir_var_upper_bound nodes: List[fx.Node] = exported_program.graph.nodes diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 44248c1c59f4..b6df02c132fd 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -6527,5 +6527,33 @@ def forward(self, x): np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, rtol=1e-4, atol=1e-5) +def test_dynamic_shape_with_range_constraints(): + class DynamicModel(torch.nn.Module): + def forward(self, x1, x2): + return torch.ops.aten.add.Tensor(x1, x2) + + @I.ir_module + class Expected: + @R.function + def main( + x1: R.Tensor(("s0", 4), dtype="float32"), x2: R.Tensor(("s0", 4), dtype="float32") + ) -> R.Tuple(R.Tensor(("s0", 4), dtype="float32")): + s0 = T.int64(is_size_var=True) + R.func_attr({"tir_var_upper_bound": {"s0": 64}}) + with R.dataflow(): + lv: R.Tensor((s0, 4), dtype="float32") = R.add(x1, x2) + gv: R.Tuple(R.Tensor((s0, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(8, 4), torch.randn(8, 4)) + batch = torch.export.Dim("batch", min=1, max=64) + dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}} + exported_program = export(DynamicModel(), args=example_args, dynamic_shapes=dynamic_shapes) + + mod = from_exported_program(exported_program, run_ep_decomposition=True) + tvm.ir.assert_structural_equal(mod, Expected) + + if __name__ == "__main__": tvm.testing.main()