diff --git a/bench/kernels/test_int_mm.py b/bench/kernels/test_int_mm.py index 688a28a8..b818ca08 100644 --- a/bench/kernels/test_int_mm.py +++ b/bench/kernels/test_int_mm.py @@ -24,7 +24,7 @@ def main(): device = torch.device(args.device) def get_int_matmul(device): - if device.type == ("cuda"): + if device.type == ("cuda") or device.type == ("cpu"): return torch._int_mm return torch.matmul diff --git a/examples/nlp/text-generation/quantize_causal_lm_model.py b/examples/nlp/text-generation/quantize_causal_lm_model.py index 294535df..7fd0d35f 100644 --- a/examples/nlp/text-generation/quantize_causal_lm_model.py +++ b/examples/nlp/text-generation/quantize_causal_lm_model.py @@ -61,6 +61,13 @@ def main(): parser.add_argument("--max_new_tokens", type=int, default=20, help="The maximum number of tokens to generate.") parser.add_argument("--batch_size", type=int, default=32, help="The batch_size for evaluation (and calibration).") parser.add_argument("--validation_batch", type=int, default=4, help="The number of batch to use for calibration.") + parser.add_argument( + "--load_dtype", + type=str, + default="float16", + choices=["float16", "float32", "bfloat16"], + help="Precision to load the initial model", + ) parser.add_argument( "--weights", type=str, @@ -96,7 +103,12 @@ def main(): else: device = torch.device(args.device) - model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_dtype = ( + torch.float16 + if args.load_dtype == "float16" + else torch.bfloat16 if args.load_dtype == "bfloat16" else torch.float32 + ) + model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch_dtype, low_cpu_mem_usage=True).to( device ) tokenizer = AutoTokenizer.from_pretrained(args.model) diff --git a/quanto/tensor/ops.py b/quanto/tensor/ops.py index 368bd7a2..e8362253 100644 --- a/quanto/tensor/ops.py +++ b/quanto/tensor/ops.py @@ -3,6 +3,7 @@ from typing import Callable, List import torch +from packaging import version from .core import dtype_info from .qtensor import QTensor, qfallback @@ -181,7 +182,13 @@ def mm(op, input, other): n, m = input.shape p = other.shape[-1] if ( - input.device.type == "cuda" + ( + input.device.type == "cuda" + or ( + input.device.type == "cpu" + and version.parse(torch.__version__).release >= version.parse("2.4.0").release + ) + ) and input.qtype == qint8 and other.qtype == qint8 and n > 16