diff --git a/torchao/testing/utils.py b/torchao/testing/utils.py index c73260eef..7fa4ba4a6 100644 --- a/torchao/testing/utils.py +++ b/torchao/testing/utils.py @@ -320,7 +320,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: y_d = dn_dist(up_dist(input_dtensor)) - if TORCH_VERSION_AT_LEAST_2_5: + if not TORCH_VERSION_AT_LEAST_2_5: # Need torch 2.5 to support compiled tensor parallelism return