diff --git a/py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py b/py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py index e34fcd280b..431eaf4469 100644 --- a/py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py +++ b/py/torch_tensorrt/fx/test/passes/test_fix_reshape_batch_dim.py @@ -2,6 +2,7 @@ import logging from copy import deepcopy +from packaging import version import torch import torch.fx as fx @@ -42,6 +43,9 @@ def forward(self, x, y): %reshape : [num_users=1] = call_function[target=torch_tensorrt.fx.tracer.acc_tracer.acc_ops.reshape](args = (), kwargs = {input: %y, acc_out_ty: ((%getitem_1, -1, 3), None, None, None, None, None, None)}) return reshape """ + if version.parse(torch.__version__) < version.parse("2.1.0.dev20230620"): + expected_graph = expected_graph.replace("num_users", "#users") + assert ( str(mod_fixed.graph).strip() == expected_graph.strip() ), f"Unexpected fixed graph. \nActual: {str(mod_fixed.graph)} \nExpected: {expected_graph}" diff --git a/py/torch_tensorrt/fx/test/passes/test_remove_duplicate_output_args.py b/py/torch_tensorrt/fx/test/passes/test_remove_duplicate_output_args.py index 1411fdab32..5dc7d8572c 100644 --- a/py/torch_tensorrt/fx/test/passes/test_remove_duplicate_output_args.py +++ b/py/torch_tensorrt/fx/test/passes/test_remove_duplicate_output_args.py @@ -1,6 +1,8 @@ # Owner(s): ["oncall: gpu_enablement"] import logging +import torch +from packaging import version import torch.fx as fx import torch.nn as nn @@ -54,6 +56,10 @@ def is_leaf_module(self, m, qn): %add : [num_users=1] = call_function[target=operator.add](args = (%getitem, %getitem_1), kwargs = {}) return add """.strip() + + if version.parse(torch.__version__) < version.parse("2.1.0.dev20230620"): + ttop_graph_expected = ttop_graph_expected.replace("num_users", "#users") + assert ( ttop_graph_expected == ttop_graph_actual ), f"Unexpected ttop graph: {ttop_graph_actual}" @@ -64,6 +70,10 @@ def is_leaf_module(self, m, qn): %x : [num_users=1] = placeholder[target=x] return (x,) """.strip() + + if version.parse(torch.__version__) < version.parse("2.1.0.dev20230620"): + ttop_a_graph_expected = ttop_a_graph_expected.replace("num_users", "#users") + assert ( ttop_a_graph_expected == ttop_a_graph_actual ), f"Unexpected ttop.a graph: {ttop_a_graph_actual}"