Skip to content

Commit

Permalink
Fused HQQ Quantization Gemm (pytorch#153)
Browse files Browse the repository at this point in the history
* add test / benchmark

* add kernels

* update readme

* more readme edits

* edit readme

* add transpose test

* transpose test pass

* refactor test

* add checks for CI

* add more comments for transpose kernel

* remove import in test

* clean up benchmark

* fix test import order

* minor README edits

* additional readme edits

* update readme

* update readme

* add note about cudamode

---------

Co-authored-by: Mark Saroufim <marksaroufim@meta.com>
  • Loading branch information
jeromeku and msaroufim authored Apr 25, 2024
1 parent a512fa3 commit af62e4c
Show file tree
Hide file tree
Showing 7 changed files with 1,013 additions and 0 deletions.
147 changes: 147 additions & 0 deletions benchmarks/benchmark_hqq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@

try:
import triton
import hqq
if int(triton.__version__.split(".")[0]) < 3:
raise "triton >= 3.0.0 is required to run this test"
except ImportError:
raise "triton and hqq required to run this benchmark"

import torch
from io import StringIO

import pandas as pd
from hqq.core.quantize import HQQLinear, BaseQuantizeConfig
from torchao.prototype.hqq.hqq_tinygemm_linear import HQQLinearTorchWeightOnlyInt4
from torchao.prototype.hqq import triton_mixed_mm, pack_2xint4

from triton.testing import do_bench


BASE_QUANT_CONFIG = {
"optimize": True,
"view_as_float": False,
"nbits": 4,
"bitpack": False,
"axis": 1,
}


def bench_custom_kernel(x, W_q, scales, zeros, group_size, kernel_type="max_autotune", fp8_fast_accum=False):
packed_w = pack_2xint4(W_q.T)

def fn():
_ = triton_mixed_mm(
x,
packed_w,
scales.T,
zeros.T,
group_size=group_size,
fp8_fast_accum=fp8_fast_accum,
kernel_type=kernel_type,
)

t = do_bench(fn)
return t


def bench_hqq(x, hqq_linear: HQQLinear):
def fn():
_ = hqq_linear.forward(x)

t = do_bench(fn)
return t


def run_benchmark(shape, group_size, dtype, axis=1, quant_dtype=torch.uint8):
qcfg = {
**BASE_QUANT_CONFIG,
**dict(group_size=group_size, axis=axis),
}
M, N, K = shape

x = torch.randn(M, K, dtype=dtype, device="cuda")
linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda")

quant_config = BaseQuantizeConfig(
quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False
)
quant_config.update({"weight_quant_params": qcfg})

hqq_linear = HQQLinear(linear, quant_config, compute_dtype=dtype, del_orig=False)

# Reference
ref_time = bench_hqq(x, hqq_linear)

# Custom kernel
W_q, meta = hqq_linear.W_q, hqq_linear.meta
scales, zeros = meta["scale"], meta["zero"]

W_q = (
W_q.reshape(meta["shape"])
if quant_config["weight_quant_params"]["bitpack"] == False
else W_q
)
W_q = W_q.to(dtype=quant_dtype)
scales = scales.reshape(N, -1)
zeros = zeros.reshape(N, -1)
tt_time = bench_custom_kernel(x, W_q, scales, zeros, group_size)

if dtype == torch.bfloat16:
_ = quant_config["weight_quant_params"].pop("bitpack")
hqq_int4mm = HQQLinearTorchWeightOnlyInt4(
linear, quant_config, compute_dtype=dtype, del_orig=False
)
int4_time = bench_hqq(x, hqq_int4mm)

print(f"{shape=} {group_size=} {dtype=}:")

print(
f"Ref: {ref_time:.4f}",
f"Triton: {tt_time:.4f}",
f"Torch int4mm: {int4_time:.4f}"
if dtype == torch.bfloat16
else "",
)
print()
return ref_time, tt_time, int4_time if dtype == torch.bfloat16 else None


SHAPES = [
[16, 4096, 4096],
[32, 4096, 4096],
[128, 4096, 4096],
[256, 4096, 4096],
[512, 4096, 4096],
[1024, 4096, 4096],
]

DTYPES = [torch.bfloat16] # , torch.float16]
GROUP_SIZES = [128]


HEADERS = [
"M",
"N",
"K",
"group_size",
"dtype",
"ref",
"triton",
"tinygemm",
]
data = []

if __name__ == "__main__":
print(torch.cuda.get_device_properties(0))

for shape in SHAPES:
for group_size in GROUP_SIZES:
for dtype in DTYPES:
timings = run_benchmark(shape, group_size, dtype)
data.append((*shape, group_size, dtype, *timings))

output = StringIO()
df = pd.DataFrame(data, columns=HEADERS)
df.to_csv(output, index=False)
print(output.getvalue())
104 changes: 104 additions & 0 deletions test/hqq/test_triton_mm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Skip entire test if triton is not available, otherwise CI failure
import pytest
try:
import triton
import hqq
if int(triton.__version__.split(".")[0]) < 3:
pytest.skip("triton >= 3.0.0 is required to run this test", allow_module_level=True)
except ImportError:
pytest.skip("triton and hqq required to run this test", allow_module_level=True)

import itertools
import torch

from hqq.core.quantize import HQQLinear, BaseQuantizeConfig
from torchao.prototype.hqq import triton_mixed_mm, pack_2xint4


#Test configs
SHAPES = [
[16, 128, 128],
[16, 4096, 4096],
]

DTYPES = [torch.bfloat16, torch.float16]
GROUP_SIZES = [64, 128]
AXES = [1] #Only axis = 1 supported
TRANSPOSED = [True]
TRITON_KERNEL_TYPE = ["compute_bound"] #["max_autotune", "compute_bound"]

TEST_CONFIGS = list(itertools.product(SHAPES, GROUP_SIZES, AXES, DTYPES, TRANSPOSED, TRITON_KERNEL_TYPE))

BASE_QUANT_CONFIG = {
"optimize": True,
"view_as_float": False,
"nbits": 4,
"bitpack": False,
"axis": 1,
}


def check(expected, actual, msg="", max_diff=1e-3, verbose=False):
passed = torch.allclose(expected, actual, atol=max_diff, rtol=max_diff)
if verbose:
max_err = (expected - actual).abs().max()
if not passed:
print(f"{msg}: Failed! Max error: {max_err}")
else:
print(f"{msg}: Passed! Max error: {max_err}")

return passed

def _arg_to_id(arg):
if isinstance(arg, list):
return "x".join([str(x) for x in arg])
return str(arg)

@pytest.mark.parametrize("shape, group_size, axis, dtype, transposed, kernel_type", TEST_CONFIGS, ids=_arg_to_id)
def test_mixed_mm(shape, group_size, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8):
qcfg = {
**BASE_QUANT_CONFIG,
**dict(group_size=group_size, axis=axis),
}
M, N, K = shape

linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda")

quant_config = BaseQuantizeConfig(
quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False
)
quant_config.update({"weight_quant_params": qcfg})
hqq_linear = HQQLinear(linear, quant_config, compute_dtype=dtype, del_orig=False)
W_q, meta = hqq_linear.W_q, hqq_linear.meta
W_q = W_q.to(dtype=quant_dtype)
W_q = (
W_q.reshape(meta["shape"])
if quant_config["weight_quant_params"]["bitpack"] == False
else W_q
)
W_dq = hqq_linear.dequantize()

scales, zeros = meta["scale"], meta["zero"]
scales = scales.reshape(N, -1)
zeros = zeros.reshape(N, -1)

if transposed:
x = torch.randn(M, N, dtype=dtype, device="cuda")
hqq_out = x @ W_dq

#Pack uint8 W_q, then run fused dequant matmul
packed_w = pack_2xint4(W_q)
tt_out = triton_mixed_mm(
x, packed_w, scales, zeros, transposed=True, group_size=group_size, fp8_fast_accum=False, kernel_type=kernel_type
)
else:
x = torch.randn(M, K, dtype=dtype, device="cuda")
hqq_out = x @ W_dq.T

packed_w = pack_2xint4(W_q.T)
tt_out = triton_mixed_mm(
x, packed_w, scales.T, zeros.T, transposed=False, group_size=group_size, fp8_fast_accum=False, kernel_type=kernel_type
)

assert check(hqq_out, tt_out, max_diff=1e-2 if dtype == torch.bfloat16 else 1e-3)

55 changes: 55 additions & 0 deletions torchao/prototype/hqq/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
## Fused `int4 / fp16` Quant Matmul

Fused kernel that combines asymmetric dequantization and gemm. Useful primarily for compute-bound (M > 16) scenarios and not for memory-bound / inference scenarios.

The kernel fuses two ops:

- Dequantization: upcasts `u4 / s4` weights to `float16 / bfloat16`, followed by groupwise scaling and shifting by scales / zeropoints
- GEMM: matmul on dequantized weights and activations.

Tested and benchmarked for `HQQ` but could theoretically be used for any asymmetric quantization scheme.

> **NOTE**: Benchmark below is only indicative of performance on consumer-grade `Ampere` GPUs (`A6000` specifically). When tested on `H100`, the performance is on par / marginally worse than native / compiled `torch`.
> The intended use is thus for fine-tuning / training models on non-datacenter GPUs (`80 <= compute capability < 90`). If interested in optimizing the kernel for other architectures, please drop a note in the CUDA-MODE Discord channel.
### Implementation Details

- Bitpacking is simple row interleave, no need for extensive preprocessing (e.g., `tinygemm` or `fastertransformer`)
- Tested for `float16 / bfloat16` activations, scales, and zeros
- Autotuned for both compute-bound and memory-bound configs
- Assumes operand B of the `gemm` is is the quantized type.
- Requires quantization along `in-features`, i.e., the `K` dimension, or `axis=1`, of `torch.linear.weight`.
- Implementation handles both transposed and non-tranposed quantized weights, useful for forward / backward training passes.

### Performance

Initial benchmarking (on `A6000`) demonstrates promising results, scaling well for compute-bound workloads:

| | M | N | K | group_size | dtype | hqq_ref | triton | tinygemm |
| --- | ---- | ---- | ---- | ---------- | -------------- | ------- | ------ | -------- |
| 0 | 16 | 4096 | 4096 | 128 | torch.bfloat16 | 0.2675 | 0.0633 | 0.0382 |
| 1 | 32 | 4096 | 4096 | 128 | torch.bfloat16 | 0.2669 | 0.0704 | 0.0649 |
| 2 | 128 | 4096 | 4096 | 128 | torch.bfloat16 | 0.2689 | 0.0960 | 0.2523 |
| 3 | 256 | 4096 | 4096 | 128 | torch.bfloat16 | 0.3268 | 0.1355 | 0.5192 |
| 4 | 512 | 4096 | 4096 | 128 | torch.bfloat16 | 0.3628 | 0.2369 | 1.0892 |
| 5 | 1024 | 4096 | 4096 | 128 | torch.bfloat16 | 0.5133 | 0.4753 | 2.2016 |

- Times are in `ms`, see `benchmarks/benchmark_hqq.py`.
- `hqq_ref` is the base `HQQ_Linear` [module](https://github.com/mobiusml/hqq/blob/6d50eee4bcdd99cc10716f1297c5b2803d2b6da4/hqq/core/quantize.py#L349) that is unfused (dequantization followed by call to torch.matmul).
- `tinygemm` calls `torch.ops.aten._weight_int4pack_mm`. Implementation is a custom HQQLinear layer that wraps the preprocessing necessary for this kernel, adapted from a benchmark script posted by @mobicham from `CUDA-mode` Discord discussions.

GPU details:

```
_CudaDeviceProperties(name='NVIDIA RTX A6000', major=8, minor=6, total_memory=48676MB, multi_processor_count=84)
```

### NOTE

This implementation requires **`triton >= 3.0.0`**.

- Running tests / benchmarks requires installation of `hqq`:

```
pip install hqq
```
1 change: 1 addition & 0 deletions torchao/prototype/hqq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .mixed_mm import triton_mixed_mm, pack_2xint4
Loading

0 comments on commit af62e4c

Please sign in to comment.