Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions python/tvm/relax/frontend/torch/exported_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,11 +1097,23 @@ def create_convert_map(

def create_input_vars(
self, exported_program: torch.export.ExportedProgram
) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var]]:
) -> Tuple[Dict[str, relax.Var], Dict[str, relax.Var], Dict[str, Tuple[int, int]]]:
"""Create relax input vars."""
parameters_buffers_constants = OrderedDict()
user_inputs = OrderedDict()
torch_symbol_to_relax_var: Dict[str, tvm.tir.Var] = {}
range_constraints = {}

if hasattr(exported_program, "range_constraints"):
for symbol, value_range in exported_program.range_constraints.items():
symbol_name = str(symbol)
if hasattr(value_range, "lower") and hasattr(value_range, "upper"):
try:
lower = int(value_range.lower)
upper = int(value_range.upper)
range_constraints[symbol_name] = (lower, upper)
except (OverflowError, AttributeError, TypeError):
continue

for spec in exported_program.graph_signature.input_specs:
name_hint = spec.arg.name
Expand All @@ -1119,7 +1131,6 @@ def create_input_vars(
torch_shape = exported_program.state_dict[spec.target].shape
torch_dtype = exported_program.state_dict[spec.target].dtype

# TODO(mshr-h): Support range constraints
relax_shape = [
torch_symbol_to_relax_var.setdefault(str(s), tvm.tir.SizeVar(str(s), "int64"))
if isinstance(s, torch.SymInt)
Expand All @@ -1134,7 +1145,7 @@ def create_input_vars(
else:
parameters_buffers_constants[name_hint] = relax_var

return parameters_buffers_constants, user_inputs
return parameters_buffers_constants, user_inputs, range_constraints

def from_exported_program(
self,
Expand All @@ -1147,14 +1158,25 @@ def from_exported_program(
from torch import fx # type: ignore

# Create input variables.
parameter_buffer_constant_vars, user_input_vars = self.create_input_vars(exported_program)
(
parameter_buffer_constant_vars,
user_input_vars,
range_constraints,
) = self.create_input_vars(exported_program)
inputs_vars = user_input_vars.copy()
inputs_vars.update(parameter_buffer_constant_vars)

# Initialize the block builder with a function and a dataflow block.
self.block_builder = relax.BlockBuilder()
func_name = "main"
func_attrs = {"num_input": len(user_input_vars)} if keep_params_as_input else None
if range_constraints:
if func_attrs is None:
func_attrs = {}
tir_var_upper_bound = {
var_name: upper for var_name, (_, upper) in range_constraints.items()
}
func_attrs["tir_var_upper_bound"] = tir_var_upper_bound

nodes: List[fx.Node] = exported_program.graph.nodes

Expand Down
28 changes: 28 additions & 0 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -6527,5 +6527,33 @@ def forward(self, x):
np.testing.assert_allclose(pytorch_output2.numpy(), tvm_output2_np, rtol=1e-4, atol=1e-5)


def test_dynamic_shape_with_range_constraints():
class DynamicModel(torch.nn.Module):
def forward(self, x1, x2):
return torch.ops.aten.add.Tensor(x1, x2)

@I.ir_module
class Expected:
@R.function
def main(
x1: R.Tensor(("s0", 4), dtype="float32"), x2: R.Tensor(("s0", 4), dtype="float32")
) -> R.Tuple(R.Tensor(("s0", 4), dtype="float32")):
s0 = T.int64(is_size_var=True)
R.func_attr({"tir_var_upper_bound": {"s0": 64}})
with R.dataflow():
lv: R.Tensor((s0, 4), dtype="float32") = R.add(x1, x2)
gv: R.Tuple(R.Tensor((s0, 4), dtype="float32")) = (lv,)
R.output(gv)
return gv

example_args = (torch.randn(8, 4), torch.randn(8, 4))
batch = torch.export.Dim("batch", min=1, max=64)
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
exported_program = export(DynamicModel(), args=example_args, dynamic_shapes=dynamic_shapes)

mod = from_exported_program(exported_program, run_ep_decomposition=True)
tvm.ir.assert_structural_equal(mod, Expected)


if __name__ == "__main__":
tvm.testing.main()
Loading