-
Notifications
You must be signed in to change notification settings - Fork 169
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into aqt_refactor
- Loading branch information
Showing
27 changed files
with
2,592 additions
and
63 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import torch | ||
import pandas as pd | ||
from torchao.utils import benchmark_torch_function_in_microseconds | ||
from torchao.ops import marlin_qqq_gemm | ||
from torchao.quantization.marlin_qqq import marlin_qqq_workspace, pack_to_marlin_qqq | ||
from tqdm import tqdm | ||
|
||
|
||
def get_problem(m, n, k, groupsize=-1): | ||
if groupsize == -1: | ||
groupsize = k | ||
dev = torch.device("cuda") | ||
A_ref = torch.randn((m, k), dtype=torch.half, device=dev) | ||
B_ref = torch.randn((k, n), dtype=torch.half, device=dev) | ||
|
||
A = torch.randint(-128, 127, (m, k), dtype=torch.int8, device=dev) | ||
B = torch.randint(low=-(2**31), high=2**31, size=(k, n), device=dev) | ||
s_tok = torch.ones((m, 1), dtype=torch.float, device=dev) | ||
if groupsize == k: | ||
s_group = torch.tensor([], dtype=torch.half, device=dev) | ||
else: | ||
s_group = torch.ones((k // groupsize, n), dtype=torch.half, device=dev) | ||
s_channel = torch.ones((1, n), dtype=torch.float, device=dev) | ||
B, s_group, s_channel = pack_to_marlin_qqq( | ||
B, s_group, s_channel, num_bits=4, group_size=group_size | ||
) | ||
qqq_workspace = marlin_qqq_workspace(n) | ||
return A, B, A_ref, B_ref, s_tok, s_channel, s_group, qqq_workspace | ||
|
||
|
||
def benchmark(m: int, k: int, n: int, group_size: int): | ||
A, B, A_ref, B_ref, s_tok, s_channel, s_group, qqq_workspace = get_problem( | ||
m, n, k, group_size | ||
) | ||
|
||
fp16_time = benchmark_torch_function_in_microseconds(torch.matmul, A_ref, B_ref) | ||
marlin_qqq_w4a8_time = benchmark_torch_function_in_microseconds( | ||
marlin_qqq_gemm, A, B, s_tok, s_channel, s_group, qqq_workspace, m, n, k | ||
) | ||
|
||
return { | ||
"m": m, | ||
"k": k, | ||
"n": n, | ||
"group_size": group_size, | ||
"fp16_latency (ms)": fp16_time, | ||
"marlin_qqq_w4a8_latency (ms)": marlin_qqq_w4a8_time, | ||
"speedup (d/s)": fp16_time / marlin_qqq_w4a8_time, | ||
} | ||
|
||
|
||
if __name__ == "__main__": | ||
k_vals = (8192, 8192, 8192, 28672) | ||
n_vals = (8192, 10240, 57344, 8192) | ||
|
||
results = [] | ||
for group_size in tqdm([-1, 128]): | ||
for m in tqdm([1 << i for i in range(10)]): | ||
for n, k in zip(n_vals, k_vals): | ||
results.append(benchmark(m, k, n, group_size)) | ||
|
||
df = pd.DataFrame(results) | ||
df.to_csv("marlin_qqq_w4a8_llm_benchmark_results.csv", index=False) | ||
print(df.to_markdown(index=False)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
python scripts/download.py --repo_id meta-llama/Llama-2-7b-chat-hf | ||
python scripts/download.py --repo_id meta-llama/Meta-Llama-3-8B | ||
python scripts/download.py --repo_id meta-llama/Meta-Llama-3.1-8B | ||
python scripts/download.py --repo_id meta-llama/Llama-3.2-3B | ||
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-2-7b-chat-hf | ||
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3-8B | ||
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3.1-8B | ||
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-3.2-3B |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
import copy | ||
|
||
import pytest | ||
import torch | ||
from torch import nn | ||
from torch.testing._internal.common_utils import TestCase, run_tests | ||
|
||
from torchao.dtypes import MarlinQQQLayout | ||
from torchao.quantization.marlin_qqq import ( | ||
pack_to_marlin_qqq, | ||
unpack_from_marlin_qqq, | ||
) | ||
from torchao.quantization.quant_api import ( | ||
int8_dynamic_activation_int4_weight, | ||
quantize_, | ||
) | ||
from torchao.quantization.quant_primitives import ( | ||
MappingType, | ||
choose_qparams_and_quantize_affine_qqq, | ||
) | ||
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 | ||
|
||
|
||
class MarlinQQQ(TestCase): | ||
def setUp(self): | ||
super().setUp() | ||
torch.manual_seed(0) | ||
|
||
self.input = torch.randn((64, 32, 8192), dtype=torch.float16, device="cuda") | ||
self.model = ( | ||
nn.Sequential( | ||
nn.Linear(8192, 21504), | ||
nn.Linear(21504, 8192), | ||
nn.ReLU(), | ||
nn.Linear(8192, 21504), | ||
nn.Linear(21504, 8192), | ||
) | ||
.half() | ||
.cuda() | ||
) | ||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") | ||
def test_marlin_qqq(self): | ||
output_ref = self.model(self.input) | ||
for group_size in [-1, 128]: | ||
modelq = copy.deepcopy(self.model) | ||
quantize_( | ||
modelq, | ||
int8_dynamic_activation_int4_weight( | ||
group_size=group_size, | ||
mapping_type=MappingType.SYMMETRIC, | ||
act_mapping_type=MappingType.SYMMETRIC, | ||
layout=MarlinQQQLayout(), | ||
), | ||
) | ||
output = modelq(self.input) | ||
|
||
assert torch.allclose( | ||
output, output_ref, atol=1e-1 | ||
), "Results are not close" | ||
|
||
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+") | ||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") | ||
def test_marlin_qqq_compile(self): | ||
model_copy = copy.deepcopy(self.model) | ||
model_copy.forward = torch.compile(model_copy.forward, fullgraph=True) | ||
output_ref = model_copy(self.input) | ||
|
||
for group_size in [-1, 128]: | ||
modelq = copy.deepcopy(self.model) | ||
quantize_( | ||
modelq, | ||
int8_dynamic_activation_int4_weight( | ||
group_size=group_size, | ||
mapping_type=MappingType.SYMMETRIC, | ||
act_mapping_type=MappingType.SYMMETRIC, | ||
layout=MarlinQQQLayout(), | ||
), | ||
) | ||
modelq.forward = torch.compile(modelq.forward, fullgraph=True) | ||
output = modelq(self.input) | ||
|
||
assert torch.allclose( | ||
output, output_ref, atol=1e-1 | ||
), "Results are not close" | ||
|
||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") | ||
def test_pack_unpack_equivalence(self): | ||
num_bits = 4 | ||
shape = (11008, 4096) | ||
|
||
w = torch.rand(shape, dtype=torch.float16, device="cuda") | ||
|
||
for group_size in [-1, 128]: | ||
# Quantize weights | ||
q_w, s_group, s_channel, _ = choose_qparams_and_quantize_affine_qqq( | ||
w, num_bits, group_size | ||
) | ||
|
||
q_w = q_w.t() | ||
s_group = s_group.t() | ||
s_channel = s_channel.t() | ||
|
||
# Test pack/unpack equivalence | ||
q_w_comp, packed_s_group, packed_s_channel = pack_to_marlin_qqq( | ||
q_w, s_group, s_channel, num_bits, group_size | ||
) | ||
unpacked_q_w, unpacked_s_group, unpacked_s_channel = unpack_from_marlin_qqq( | ||
q_w_comp, | ||
packed_s_group, | ||
packed_s_channel, | ||
q_w.shape, | ||
num_bits, | ||
group_size, | ||
) | ||
|
||
assert torch.equal( | ||
q_w, unpacked_q_w | ||
), "Unpacked weights do not match original weights" | ||
assert torch.equal( | ||
s_channel, unpacked_s_channel | ||
), "Unpacked s_channel do not match original s_channel" | ||
assert torch.equal( | ||
s_group, unpacked_s_group | ||
), "Unpacked s_group do not match original s_group" | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.