Skip to content

Commit

Permalink
feat: enable cpu integer matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
maktukmak authored and dacorvo committed Mar 24, 2024
1 parent 96871c1 commit 9192ef8
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 3 deletions.
2 changes: 1 addition & 1 deletion bench/kernels/test_int_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def main():
device = torch.device(args.device)

def get_int_matmul(device):
if device.type == ("cuda"):
if device.type == ("cuda") or device.type == ("cpu"):
return torch._int_mm
return torch.matmul

Expand Down
14 changes: 13 additions & 1 deletion examples/nlp/text-generation/quantize_causal_lm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ def main():
parser.add_argument("--max_new_tokens", type=int, default=20, help="The maximum number of tokens to generate.")
parser.add_argument("--batch_size", type=int, default=32, help="The batch_size for evaluation (and calibration).")
parser.add_argument("--validation_batch", type=int, default=4, help="The number of batch to use for calibration.")
parser.add_argument(
"--load_dtype",
type=str,
default="float16",
choices=["float16", "float32", "bfloat16"],
help="Precision to load the initial model",
)
parser.add_argument(
"--weights",
type=str,
Expand Down Expand Up @@ -96,7 +103,12 @@ def main():
else:
device = torch.device(args.device)

model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
torch_dtype = (
torch.float16
if args.load_dtype == "float16"
else torch.bfloat16 if args.load_dtype == "bfloat16" else torch.float32
)
model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch_dtype, low_cpu_mem_usage=True).to(
device
)
tokenizer = AutoTokenizer.from_pretrained(args.model)
Expand Down
9 changes: 8 additions & 1 deletion quanto/tensor/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Callable, List

import torch
from packaging import version

from .core import dtype_info
from .qtensor import QTensor, qfallback
Expand Down Expand Up @@ -181,7 +182,13 @@ def mm(op, input, other):
n, m = input.shape
p = other.shape[-1]
if (
input.device.type == "cuda"
(
input.device.type == "cuda"
or (
input.device.type == "cpu"
and version.parse(torch.__version__).release >= version.parse("2.4.0").release
)
)
and input.qtype == qint8
and other.qtype == qint8
and n > 16
Expand Down

0 comments on commit 9192ef8

Please sign in to comment.