Skip to content

Commit

Permalink
Merge branch 'main' into aqt_refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
jainapurva committed Nov 14, 2024
2 parents 9c531e9 + 5effac5 commit 0315a14
Show file tree
Hide file tree
Showing 27 changed files with 2,592 additions and 63 deletions.
64 changes: 64 additions & 0 deletions benchmarks/benchmark_marlin_qqq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch
import pandas as pd
from torchao.utils import benchmark_torch_function_in_microseconds
from torchao.ops import marlin_qqq_gemm
from torchao.quantization.marlin_qqq import marlin_qqq_workspace, pack_to_marlin_qqq
from tqdm import tqdm


def get_problem(m, n, k, groupsize=-1):
if groupsize == -1:
groupsize = k
dev = torch.device("cuda")
A_ref = torch.randn((m, k), dtype=torch.half, device=dev)
B_ref = torch.randn((k, n), dtype=torch.half, device=dev)

A = torch.randint(-128, 127, (m, k), dtype=torch.int8, device=dev)
B = torch.randint(low=-(2**31), high=2**31, size=(k, n), device=dev)
s_tok = torch.ones((m, 1), dtype=torch.float, device=dev)
if groupsize == k:
s_group = torch.tensor([], dtype=torch.half, device=dev)
else:
s_group = torch.ones((k // groupsize, n), dtype=torch.half, device=dev)
s_channel = torch.ones((1, n), dtype=torch.float, device=dev)
B, s_group, s_channel = pack_to_marlin_qqq(
B, s_group, s_channel, num_bits=4, group_size=group_size
)
qqq_workspace = marlin_qqq_workspace(n)
return A, B, A_ref, B_ref, s_tok, s_channel, s_group, qqq_workspace


def benchmark(m: int, k: int, n: int, group_size: int):
A, B, A_ref, B_ref, s_tok, s_channel, s_group, qqq_workspace = get_problem(
m, n, k, group_size
)

fp16_time = benchmark_torch_function_in_microseconds(torch.matmul, A_ref, B_ref)
marlin_qqq_w4a8_time = benchmark_torch_function_in_microseconds(
marlin_qqq_gemm, A, B, s_tok, s_channel, s_group, qqq_workspace, m, n, k
)

return {
"m": m,
"k": k,
"n": n,
"group_size": group_size,
"fp16_latency (ms)": fp16_time,
"marlin_qqq_w4a8_latency (ms)": marlin_qqq_w4a8_time,
"speedup (d/s)": fp16_time / marlin_qqq_w4a8_time,
}


if __name__ == "__main__":
k_vals = (8192, 8192, 8192, 28672)
n_vals = (8192, 10240, 57344, 8192)

results = []
for group_size in tqdm([-1, 128]):
for m in tqdm([1 << i for i in range(10)]):
for n, k in zip(n_vals, k_vals):
results.append(benchmark(m, k, n, group_size))

df = pd.DataFrame(results)
df.to_csv("marlin_qqq_w4a8_llm_benchmark_results.csv", index=False)
print(df.to_markdown(index=False))
6 changes: 5 additions & 1 deletion scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def permute(w, n_head):
else:
state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
merged_result.update(state_dict)

if config.tie_word_embeddings:
merged_result["lm_head.weight"] = merged_result["model.embed_tokens.weight"].clone()

final_result = {}
for key, value in merged_result.items():
if "layers" in key:
Expand Down Expand Up @@ -112,7 +116,7 @@ def permute(w, n_head):
del final_result[key.replace("wq", "wv")]
print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
torch.save(final_result, checkpoint_dir / "model.pth")
if 'llama-3-' in model_name.lower() or 'llama-3.1-' in model_name.lower():
if any([x in model_name.lower() for x in ["llama-3-", "llama-3.1-", "llama-3.2-"]]):
if 'llama-3.1-405b' in model_name.lower():
original_dir = checkpoint_dir / "original" / "mp16"
else:
Expand Down
2 changes: 2 additions & 0 deletions scripts/prepare.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
python scripts/download.py --repo_id meta-llama/Llama-2-7b-chat-hf
python scripts/download.py --repo_id meta-llama/Meta-Llama-3-8B
python scripts/download.py --repo_id meta-llama/Meta-Llama-3.1-8B
python scripts/download.py --repo_id meta-llama/Llama-3.2-3B
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-2-7b-chat-hf
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3-8B
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3.1-8B
python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-3.2-3B
3 changes: 1 addition & 2 deletions test/dtypes/test_floatx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,11 @@
run_tests,
)
from torchao.dtypes.floatx import (
FloatxTensorCoreAQTTensorImpl,
FloatxTensorCoreLayout,
to_scaled_tc_floatx,
from_scaled_tc_floatx,
)
from torchao.dtypes.floatx.floatx_tensor_core_layout import _pack_tc_floatx, _pack_tc_fp6
from torchao.dtypes.floatx.floatx_tensor_core_layout import _pack_tc_floatx, _pack_tc_fp6, FloatxTensorCoreAQTTensorImpl
from torchao.prototype.custom_fp_utils import _f32_to_floatx_unpacked, _floatx_unpacked_to_f32
from torchao.quantization import (
quantize_,
Expand Down
2 changes: 1 addition & 1 deletion test/dtypes/test_uintx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from torchao.dtypes.uintx import to_uintx
from torchao.dtypes.uintx.uintx_layout import to_uintx
from torchao.quantization.quant_api import quantize_, uintx_weight_only
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
Expand Down
2 changes: 1 addition & 1 deletion test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum):
with pytest.raises(
RuntimeError,
match=re.escape(
"Expected trailing dimension of mat1 to be divisible by 16 but got mat1 shape: (16x41)."
"Expected trailing dimension of mat1 to be divisible by 16 but got mat1 shape: (16x41"
),
):
a_fp8 @ b_fp8
Expand Down
129 changes: 129 additions & 0 deletions test/quantization/test_marlin_qqq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import copy

import pytest
import torch
from torch import nn
from torch.testing._internal.common_utils import TestCase, run_tests

from torchao.dtypes import MarlinQQQLayout
from torchao.quantization.marlin_qqq import (
pack_to_marlin_qqq,
unpack_from_marlin_qqq,
)
from torchao.quantization.quant_api import (
int8_dynamic_activation_int4_weight,
quantize_,
)
from torchao.quantization.quant_primitives import (
MappingType,
choose_qparams_and_quantize_affine_qqq,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5


class MarlinQQQ(TestCase):
def setUp(self):
super().setUp()
torch.manual_seed(0)

self.input = torch.randn((64, 32, 8192), dtype=torch.float16, device="cuda")
self.model = (
nn.Sequential(
nn.Linear(8192, 21504),
nn.Linear(21504, 8192),
nn.ReLU(),
nn.Linear(8192, 21504),
nn.Linear(21504, 8192),
)
.half()
.cuda()
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_marlin_qqq(self):
output_ref = self.model(self.input)
for group_size in [-1, 128]:
modelq = copy.deepcopy(self.model)
quantize_(
modelq,
int8_dynamic_activation_int4_weight(
group_size=group_size,
mapping_type=MappingType.SYMMETRIC,
act_mapping_type=MappingType.SYMMETRIC,
layout=MarlinQQQLayout(),
),
)
output = modelq(self.input)

assert torch.allclose(
output, output_ref, atol=1e-1
), "Results are not close"

@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_marlin_qqq_compile(self):
model_copy = copy.deepcopy(self.model)
model_copy.forward = torch.compile(model_copy.forward, fullgraph=True)
output_ref = model_copy(self.input)

for group_size in [-1, 128]:
modelq = copy.deepcopy(self.model)
quantize_(
modelq,
int8_dynamic_activation_int4_weight(
group_size=group_size,
mapping_type=MappingType.SYMMETRIC,
act_mapping_type=MappingType.SYMMETRIC,
layout=MarlinQQQLayout(),
),
)
modelq.forward = torch.compile(modelq.forward, fullgraph=True)
output = modelq(self.input)

assert torch.allclose(
output, output_ref, atol=1e-1
), "Results are not close"

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_pack_unpack_equivalence(self):
num_bits = 4
shape = (11008, 4096)

w = torch.rand(shape, dtype=torch.float16, device="cuda")

for group_size in [-1, 128]:
# Quantize weights
q_w, s_group, s_channel, _ = choose_qparams_and_quantize_affine_qqq(
w, num_bits, group_size
)

q_w = q_w.t()
s_group = s_group.t()
s_channel = s_channel.t()

# Test pack/unpack equivalence
q_w_comp, packed_s_group, packed_s_channel = pack_to_marlin_qqq(
q_w, s_group, s_channel, num_bits, group_size
)
unpacked_q_w, unpacked_s_group, unpacked_s_channel = unpack_from_marlin_qqq(
q_w_comp,
packed_s_group,
packed_s_channel,
q_w.shape,
num_bits,
group_size,
)

assert torch.equal(
q_w, unpacked_q_w
), "Unpacked weights do not match original weights"
assert torch.equal(
s_channel, unpacked_s_channel
), "Unpacked s_channel do not match original s_channel"
assert torch.equal(
s_group, unpacked_s_group
), "Unpacked s_group do not match original s_group"


if __name__ == "__main__":
run_tests()
109 changes: 109 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
from torchao.utils import is_fbcode, TORCH_VERSION_AT_LEAST_2_5, compute_max_diff
from torchao.dtypes.floatx import from_scaled_tc_floatx
from torchao.sparsity.marlin import marlin_24_workspace, pack_to_marlin_24, inject_24
from torchao.quantization.marlin_qqq import (
marlin_qqq_workspace,
pack_to_marlin_qqq,
)
from torchao.quantization.quant_primitives import choose_qparams_and_quantize_affine_qqq
import pytest

if is_fbcode():
Expand Down Expand Up @@ -426,5 +431,109 @@ def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_facto
)


MARLIN_QQQ_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
MARLIN_QQQ_K_CHUNKS = [128]
MARLIN_QQQ_N_CHUNKS = [64, 128, 256]
MNK_FACTORS = [
(1, 1, 1),
(1, 4, 8),
(1, 7, 5),
(13, 17, 67),
(26, 37, 13),
(67, 13, 11),
]
MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128]

MARLIN_TEST_PARAMS = list(
itertools.product(
MARLIN_QQQ_BATCH_SIZE,
MARLIN_QQQ_K_CHUNKS,
MARLIN_QQQ_N_CHUNKS,
MARLIN_QQQ_SUPPORTED_NUM_BITS,
MARLIN_QQQ_SUPPORTED_GROUP_SIZES,
MNK_FACTORS,
)
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize(
"batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors",
MARLIN_TEST_PARAMS,
ids=str,
)
def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors):
int8_traits = torch.iinfo(torch.int8)
m_factor, n_factor, k_factor = mnk_factors

size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor

a_input = torch.randn(
(batch_size, size_m, size_k), dtype=torch.float16, device="cuda"
)
b_weight = torch.rand((size_n, size_k), dtype=torch.float16, device="cuda")

# Reshape input into 2D tensor
input_2d = a_input.view(-1, a_input.shape[-1])
a_input_in, a_input_out = input_2d.shape

# Quantize activations
s_a = (
input_2d.abs()
.max(dim=-1, keepdim=True)[0]
.div(int8_traits.max)
.to(torch.float32)
)
q_a = (
(input_2d / s_a).round().clamp(int8_traits.min, int8_traits.max).to(torch.int8)
)

# Quantize weights
q_w, s_group, s_channel, w_ref = choose_qparams_and_quantize_affine_qqq(
b_weight, num_bits, group_size
)
q_w = q_w.t()
s_group = s_group.t()
s_channel = s_channel.t()
w_ref = w_ref.t()
marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = pack_to_marlin_qqq(
q_w, s_group, s_channel, num_bits, group_size
)

workspace = marlin_qqq_workspace(size_n)

# Obtains reference output
output_ref = torch.matmul(q_a.half() * s_a.half(), w_ref)
output_ref = output_ref.reshape(a_input.shape[:-1] + (size_n,))

fn_inputs = (
q_a,
marlin_qqq_q_w,
s_a,
marlin_qqq_s_channel,
marlin_qqq_s_group,
workspace,
a_input_in,
size_n,
a_input_out,
)
output = torchao.ops.marlin_qqq_gemm(*fn_inputs)
output = output.reshape(a_input.shape[:-1] + (size_n,))

max_diff = compute_max_diff(output, output_ref)
assert max_diff < 0.04

# Performs opcheck
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor"]
opcheck(
torch.ops.torchao.marlin_qqq_gemm,
fn_inputs,
test_utils=test_utils,
)


if __name__ == "__main__":
run_tests()
Loading

0 comments on commit 0315a14

Please sign in to comment.