diff --git a/benchmarks/intmm.py b/benchmarks/intmm.py index 819ada35e..17f4bf00a 100644 --- a/benchmarks/intmm.py +++ b/benchmarks/intmm.py @@ -6,7 +6,7 @@ import pathlib import torch -from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 +from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4, TORCH_VERSION_AFTER_2_2 # Check if CUDA is available, if not, exit the script @@ -14,13 +14,13 @@ print("CUDA is not available. Exiting the script.") sys.exit(0) -if not TORCH_VERSION_AFTER_2_4: - print("torch version must be 2.4 or higher") +if not TORCH_VERSION_AFTER_2_2: + print("torch version must be 2.2 or higher") sys.exit(0) import torch.nn.functional as F import torch.utils.benchmark as benchmark -from torchao.kernel.intmm_triton import int_matmul, int_scaled_matmul +from torchao.kernel.intmm import int_matmul, int_scaled_matmul torch._dynamo.config.cache_size_limit = 128 torch._dynamo.config.accumulated_cache_size_limit = 128