-
Notifications
You must be signed in to change notification settings - Fork 169
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fused HQQ Quantization Gemm #153
Changes from all commits
0d41668
497d8db
8db9f51
2a18357
f11a59f
2e76839
19c43c2
e0f3781
be718d6
793994b
5c585b9
c80b239
48a153f
b3a9ab8
bbe9083
38dbc3e
d89fa74
cd68d38
582fd8b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
|
||
try: | ||
import triton | ||
import hqq | ||
if int(triton.__version__.split(".")[0]) < 3: | ||
raise "triton >= 3.0.0 is required to run this test" | ||
except ImportError: | ||
raise "triton and hqq required to run this benchmark" | ||
|
||
import torch | ||
from io import StringIO | ||
|
||
import pandas as pd | ||
from hqq.core.quantize import HQQLinear, BaseQuantizeConfig | ||
from torchao.prototype.hqq.hqq_tinygemm_linear import HQQLinearTorchWeightOnlyInt4 | ||
from torchao.prototype.hqq import triton_mixed_mm, pack_2xint4 | ||
|
||
from triton.testing import do_bench | ||
|
||
|
||
BASE_QUANT_CONFIG = { | ||
"optimize": True, | ||
"view_as_float": False, | ||
"nbits": 4, | ||
"bitpack": False, | ||
"axis": 1, | ||
} | ||
|
||
|
||
def bench_custom_kernel(x, W_q, scales, zeros, group_size, kernel_type="max_autotune", fp8_fast_accum=False): | ||
packed_w = pack_2xint4(W_q.T) | ||
|
||
def fn(): | ||
_ = triton_mixed_mm( | ||
x, | ||
packed_w, | ||
scales.T, | ||
zeros.T, | ||
group_size=group_size, | ||
fp8_fast_accum=fp8_fast_accum, | ||
kernel_type=kernel_type, | ||
) | ||
|
||
t = do_bench(fn) | ||
return t | ||
|
||
|
||
def bench_hqq(x, hqq_linear: HQQLinear): | ||
def fn(): | ||
_ = hqq_linear.forward(x) | ||
|
||
t = do_bench(fn) | ||
return t | ||
|
||
|
||
def run_benchmark(shape, group_size, dtype, axis=1, quant_dtype=torch.uint8): | ||
qcfg = { | ||
**BASE_QUANT_CONFIG, | ||
**dict(group_size=group_size, axis=axis), | ||
} | ||
M, N, K = shape | ||
|
||
x = torch.randn(M, K, dtype=dtype, device="cuda") | ||
linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda") | ||
|
||
quant_config = BaseQuantizeConfig( | ||
quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False | ||
) | ||
quant_config.update({"weight_quant_params": qcfg}) | ||
|
||
hqq_linear = HQQLinear(linear, quant_config, compute_dtype=dtype, del_orig=False) | ||
|
||
# Reference | ||
ref_time = bench_hqq(x, hqq_linear) | ||
|
||
# Custom kernel | ||
W_q, meta = hqq_linear.W_q, hqq_linear.meta | ||
scales, zeros = meta["scale"], meta["zero"] | ||
|
||
W_q = ( | ||
W_q.reshape(meta["shape"]) | ||
if quant_config["weight_quant_params"]["bitpack"] == False | ||
else W_q | ||
) | ||
W_q = W_q.to(dtype=quant_dtype) | ||
scales = scales.reshape(N, -1) | ||
zeros = zeros.reshape(N, -1) | ||
tt_time = bench_custom_kernel(x, W_q, scales, zeros, group_size) | ||
|
||
if dtype == torch.bfloat16: | ||
_ = quant_config["weight_quant_params"].pop("bitpack") | ||
hqq_int4mm = HQQLinearTorchWeightOnlyInt4( | ||
linear, quant_config, compute_dtype=dtype, del_orig=False | ||
) | ||
int4_time = bench_hqq(x, hqq_int4mm) | ||
|
||
print(f"{shape=} {group_size=} {dtype=}:") | ||
|
||
print( | ||
f"Ref: {ref_time:.4f}", | ||
f"Triton: {tt_time:.4f}", | ||
f"Torch int4mm: {int4_time:.4f}" | ||
if dtype == torch.bfloat16 | ||
else "", | ||
) | ||
print() | ||
return ref_time, tt_time, int4_time if dtype == torch.bfloat16 else None | ||
|
||
|
||
SHAPES = [ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cpuhrsch I guess these shapes are fine for now but are there some specific shapes we're more interested in tracking on an ongoing basis if so I wish we could just make them part of our benchmark or test utilities |
||
[16, 4096, 4096], | ||
[32, 4096, 4096], | ||
[128, 4096, 4096], | ||
[256, 4096, 4096], | ||
[512, 4096, 4096], | ||
[1024, 4096, 4096], | ||
] | ||
|
||
DTYPES = [torch.bfloat16] # , torch.float16] | ||
GROUP_SIZES = [128] | ||
|
||
|
||
HEADERS = [ | ||
"M", | ||
"N", | ||
"K", | ||
"group_size", | ||
"dtype", | ||
"ref", | ||
"triton", | ||
"tinygemm", | ||
] | ||
data = [] | ||
|
||
if __name__ == "__main__": | ||
print(torch.cuda.get_device_properties(0)) | ||
|
||
for shape in SHAPES: | ||
for group_size in GROUP_SIZES: | ||
for dtype in DTYPES: | ||
timings = run_benchmark(shape, group_size, dtype) | ||
data.append((*shape, group_size, dtype, *timings)) | ||
|
||
output = StringIO() | ||
df = pd.DataFrame(data, columns=HEADERS) | ||
df.to_csv(output, index=False) | ||
print(output.getvalue()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# Skip entire test if triton is not available, otherwise CI failure | ||
import pytest | ||
try: | ||
import triton | ||
import hqq | ||
if int(triton.__version__.split(".")[0]) < 3: | ||
pytest.skip("triton >= 3.0.0 is required to run this test", allow_module_level=True) | ||
except ImportError: | ||
pytest.skip("triton and hqq required to run this test", allow_module_level=True) | ||
|
||
import itertools | ||
import torch | ||
|
||
from hqq.core.quantize import HQQLinear, BaseQuantizeConfig | ||
from torchao.prototype.hqq import triton_mixed_mm, pack_2xint4 | ||
|
||
|
||
#Test configs | ||
SHAPES = [ | ||
[16, 128, 128], | ||
[16, 4096, 4096], | ||
] | ||
|
||
DTYPES = [torch.bfloat16, torch.float16] | ||
GROUP_SIZES = [64, 128] | ||
AXES = [1] #Only axis = 1 supported | ||
TRANSPOSED = [True] | ||
TRITON_KERNEL_TYPE = ["compute_bound"] #["max_autotune", "compute_bound"] | ||
|
||
TEST_CONFIGS = list(itertools.product(SHAPES, GROUP_SIZES, AXES, DTYPES, TRANSPOSED, TRITON_KERNEL_TYPE)) | ||
|
||
BASE_QUANT_CONFIG = { | ||
"optimize": True, | ||
"view_as_float": False, | ||
"nbits": 4, | ||
"bitpack": False, | ||
"axis": 1, | ||
} | ||
|
||
|
||
def check(expected, actual, msg="", max_diff=1e-3, verbose=False): | ||
passed = torch.allclose(expected, actual, atol=max_diff, rtol=max_diff) | ||
if verbose: | ||
max_err = (expected - actual).abs().max() | ||
if not passed: | ||
print(f"{msg}: Failed! Max error: {max_err}") | ||
else: | ||
print(f"{msg}: Passed! Max error: {max_err}") | ||
|
||
return passed | ||
|
||
def _arg_to_id(arg): | ||
if isinstance(arg, list): | ||
return "x".join([str(x) for x in arg]) | ||
return str(arg) | ||
|
||
@pytest.mark.parametrize("shape, group_size, axis, dtype, transposed, kernel_type", TEST_CONFIGS, ids=_arg_to_id) | ||
def test_mixed_mm(shape, group_size, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8): | ||
qcfg = { | ||
**BASE_QUANT_CONFIG, | ||
**dict(group_size=group_size, axis=axis), | ||
} | ||
M, N, K = shape | ||
|
||
linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda") | ||
|
||
quant_config = BaseQuantizeConfig( | ||
quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False | ||
) | ||
quant_config.update({"weight_quant_params": qcfg}) | ||
hqq_linear = HQQLinear(linear, quant_config, compute_dtype=dtype, del_orig=False) | ||
W_q, meta = hqq_linear.W_q, hqq_linear.meta | ||
W_q = W_q.to(dtype=quant_dtype) | ||
W_q = ( | ||
W_q.reshape(meta["shape"]) | ||
if quant_config["weight_quant_params"]["bitpack"] == False | ||
else W_q | ||
) | ||
W_dq = hqq_linear.dequantize() | ||
|
||
scales, zeros = meta["scale"], meta["zero"] | ||
scales = scales.reshape(N, -1) | ||
zeros = zeros.reshape(N, -1) | ||
|
||
if transposed: | ||
x = torch.randn(M, N, dtype=dtype, device="cuda") | ||
hqq_out = x @ W_dq | ||
|
||
#Pack uint8 W_q, then run fused dequant matmul | ||
packed_w = pack_2xint4(W_q) | ||
tt_out = triton_mixed_mm( | ||
x, packed_w, scales, zeros, transposed=True, group_size=group_size, fp8_fast_accum=False, kernel_type=kernel_type | ||
) | ||
else: | ||
x = torch.randn(M, K, dtype=dtype, device="cuda") | ||
hqq_out = x @ W_dq.T | ||
|
||
packed_w = pack_2xint4(W_q.T) | ||
tt_out = triton_mixed_mm( | ||
x, packed_w, scales.T, zeros.T, transposed=False, group_size=group_size, fp8_fast_accum=False, kernel_type=kernel_type | ||
) | ||
|
||
assert check(hqq_out, tt_out, max_diff=1e-2 if dtype == torch.bfloat16 else 1e-3) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
## Fused `int4 / fp16` Quant Matmul | ||
|
||
Fused kernel that combines asymmetric dequantization and gemm. Useful primarily for compute-bound (M > 16) scenarios and not for memory-bound / inference scenarios. | ||
|
||
The kernel fuses two ops: | ||
|
||
- Dequantization: upcasts `u4 / s4` weights to `float16 / bfloat16`, followed by groupwise scaling and shifting by scales / zeropoints | ||
- GEMM: matmul on dequantized weights and activations. | ||
|
||
Tested and benchmarked for `HQQ` but could theoretically be used for any asymmetric quantization scheme. | ||
|
||
> **NOTE**: Benchmark below is only indicative of performance on consumer-grade `Ampere` GPUs (`A6000` specifically). When tested on `H100`, the performance is on par / marginally worse than native / compiled `torch`. | ||
> The intended use is thus for fine-tuning / training models on non-datacenter GPUs (`80 <= compute capability < 90`). If interested in optimizing the kernel for other architectures, please drop a note in the CUDA-MODE Discord channel. | ||
|
||
### Implementation Details | ||
|
||
- Bitpacking is simple row interleave, no need for extensive preprocessing (e.g., `tinygemm` or `fastertransformer`) | ||
- Tested for `float16 / bfloat16` activations, scales, and zeros | ||
- Autotuned for both compute-bound and memory-bound configs | ||
- Assumes operand B of the `gemm` is is the quantized type. | ||
- Requires quantization along `in-features`, i.e., the `K` dimension, or `axis=1`, of `torch.linear.weight`. | ||
- Implementation handles both transposed and non-tranposed quantized weights, useful for forward / backward training passes. | ||
|
||
### Performance | ||
|
||
Initial benchmarking (on `A6000`) demonstrates promising results, scaling well for compute-bound workloads: | ||
|
||
| | M | N | K | group_size | dtype | hqq_ref | triton | tinygemm | | ||
| --- | ---- | ---- | ---- | ---------- | -------------- | ------- | ------ | -------- | | ||
| 0 | 16 | 4096 | 4096 | 128 | torch.bfloat16 | 0.2675 | 0.0633 | 0.0382 | | ||
| 1 | 32 | 4096 | 4096 | 128 | torch.bfloat16 | 0.2669 | 0.0704 | 0.0649 | | ||
| 2 | 128 | 4096 | 4096 | 128 | torch.bfloat16 | 0.2689 | 0.0960 | 0.2523 | | ||
| 3 | 256 | 4096 | 4096 | 128 | torch.bfloat16 | 0.3268 | 0.1355 | 0.5192 | | ||
| 4 | 512 | 4096 | 4096 | 128 | torch.bfloat16 | 0.3628 | 0.2369 | 1.0892 | | ||
| 5 | 1024 | 4096 | 4096 | 128 | torch.bfloat16 | 0.5133 | 0.4753 | 2.2016 | | ||
|
||
- Times are in `ms`, see `benchmarks/benchmark_hqq.py`. | ||
- `hqq_ref` is the base `HQQ_Linear` [module](https://github.com/mobiusml/hqq/blob/6d50eee4bcdd99cc10716f1297c5b2803d2b6da4/hqq/core/quantize.py#L349) that is unfused (dequantization followed by call to torch.matmul). | ||
- `tinygemm` calls `torch.ops.aten._weight_int4pack_mm`. Implementation is a custom HQQLinear layer that wraps the preprocessing necessary for this kernel, adapted from a benchmark script posted by @mobicham from `CUDA-mode` Discord discussions. | ||
|
||
GPU details: | ||
|
||
``` | ||
_CudaDeviceProperties(name='NVIDIA RTX A6000', major=8, minor=6, total_memory=48676MB, multi_processor_count=84) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. once we figure out the installation issues I'll check to see if results repro on an H100 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. apologies meant There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When you run on H100, can you run once with |
||
``` | ||
|
||
### NOTE | ||
|
||
This implementation requires **`triton >= 3.0.0`**. | ||
|
||
- Running tests / benchmarks requires installation of `hqq`: | ||
|
||
``` | ||
pip install hqq | ||
``` |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .mixed_mm import triton_mixed_mm, pack_2xint4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you print your pip list here? I'm having a lot of trouble finding which triton version you used
I tried the nightlies shown on the openai repo and I also tried make triton from inside pytorch repo and keep getting errors like
To make testing easier assume we'll be using https://github.com/pytorch/pytorch/blob/main/Makefile#L35