diff --git a/benchmarks/benchmark_hqq.py b/benchmarks/benchmark_hqq.py new file mode 100644 index 000000000..393481e95 --- /dev/null +++ b/benchmarks/benchmark_hqq.py @@ -0,0 +1,147 @@ + +try: + import triton + import hqq + if int(triton.__version__.split(".")[0]) < 3: + raise "triton >= 3.0.0 is required to run this test" +except ImportError: + raise "triton and hqq required to run this benchmark" + +import torch +from io import StringIO + +import pandas as pd +from hqq.core.quantize import HQQLinear, BaseQuantizeConfig +from torchao.prototype.hqq.hqq_tinygemm_linear import HQQLinearTorchWeightOnlyInt4 +from torchao.prototype.hqq import triton_mixed_mm, pack_2xint4 + +from triton.testing import do_bench + + +BASE_QUANT_CONFIG = { + "optimize": True, + "view_as_float": False, + "nbits": 4, + "bitpack": False, + "axis": 1, +} + + +def bench_custom_kernel(x, W_q, scales, zeros, group_size, kernel_type="max_autotune", fp8_fast_accum=False): + packed_w = pack_2xint4(W_q.T) + + def fn(): + _ = triton_mixed_mm( + x, + packed_w, + scales.T, + zeros.T, + group_size=group_size, + fp8_fast_accum=fp8_fast_accum, + kernel_type=kernel_type, + ) + + t = do_bench(fn) + return t + + +def bench_hqq(x, hqq_linear: HQQLinear): + def fn(): + _ = hqq_linear.forward(x) + + t = do_bench(fn) + return t + + +def run_benchmark(shape, group_size, dtype, axis=1, quant_dtype=torch.uint8): + qcfg = { + **BASE_QUANT_CONFIG, + **dict(group_size=group_size, axis=axis), + } + M, N, K = shape + + x = torch.randn(M, K, dtype=dtype, device="cuda") + linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda") + + quant_config = BaseQuantizeConfig( + quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False + ) + quant_config.update({"weight_quant_params": qcfg}) + + hqq_linear = HQQLinear(linear, quant_config, compute_dtype=dtype, del_orig=False) + + # Reference + ref_time = bench_hqq(x, hqq_linear) + + # Custom kernel + W_q, meta = hqq_linear.W_q, hqq_linear.meta + scales, zeros = meta["scale"], meta["zero"] + + W_q = ( + W_q.reshape(meta["shape"]) + if quant_config["weight_quant_params"]["bitpack"] == False + else W_q + ) + W_q = W_q.to(dtype=quant_dtype) + scales = scales.reshape(N, -1) + zeros = zeros.reshape(N, -1) + tt_time = bench_custom_kernel(x, W_q, scales, zeros, group_size) + + if dtype == torch.bfloat16: + _ = quant_config["weight_quant_params"].pop("bitpack") + hqq_int4mm = HQQLinearTorchWeightOnlyInt4( + linear, quant_config, compute_dtype=dtype, del_orig=False + ) + int4_time = bench_hqq(x, hqq_int4mm) + + print(f"{shape=} {group_size=} {dtype=}:") + + print( + f"Ref: {ref_time:.4f}", + f"Triton: {tt_time:.4f}", + f"Torch int4mm: {int4_time:.4f}" + if dtype == torch.bfloat16 + else "", + ) + print() + return ref_time, tt_time, int4_time if dtype == torch.bfloat16 else None + + +SHAPES = [ + [16, 4096, 4096], + [32, 4096, 4096], + [128, 4096, 4096], + [256, 4096, 4096], + [512, 4096, 4096], + [1024, 4096, 4096], +] + +DTYPES = [torch.bfloat16] # , torch.float16] +GROUP_SIZES = [128] + + +HEADERS = [ + "M", + "N", + "K", + "group_size", + "dtype", + "ref", + "triton", + "tinygemm", +] +data = [] + +if __name__ == "__main__": + print(torch.cuda.get_device_properties(0)) + + for shape in SHAPES: + for group_size in GROUP_SIZES: + for dtype in DTYPES: + timings = run_benchmark(shape, group_size, dtype) + data.append((*shape, group_size, dtype, *timings)) + + output = StringIO() + df = pd.DataFrame(data, columns=HEADERS) + df.to_csv(output, index=False) + print(output.getvalue()) \ No newline at end of file diff --git a/test/hqq/test_triton_mm.py b/test/hqq/test_triton_mm.py new file mode 100644 index 000000000..23f6c60f7 --- /dev/null +++ b/test/hqq/test_triton_mm.py @@ -0,0 +1,104 @@ +# Skip entire test if triton is not available, otherwise CI failure +import pytest +try: + import triton + import hqq + if int(triton.__version__.split(".")[0]) < 3: + pytest.skip("triton >= 3.0.0 is required to run this test", allow_module_level=True) +except ImportError: + pytest.skip("triton and hqq required to run this test", allow_module_level=True) + +import itertools +import torch + +from hqq.core.quantize import HQQLinear, BaseQuantizeConfig +from torchao.prototype.hqq import triton_mixed_mm, pack_2xint4 + + +#Test configs +SHAPES = [ + [16, 128, 128], + [16, 4096, 4096], +] + +DTYPES = [torch.bfloat16, torch.float16] +GROUP_SIZES = [64, 128] +AXES = [1] #Only axis = 1 supported +TRANSPOSED = [True] +TRITON_KERNEL_TYPE = ["compute_bound"] #["max_autotune", "compute_bound"] + +TEST_CONFIGS = list(itertools.product(SHAPES, GROUP_SIZES, AXES, DTYPES, TRANSPOSED, TRITON_KERNEL_TYPE)) + +BASE_QUANT_CONFIG = { + "optimize": True, + "view_as_float": False, + "nbits": 4, + "bitpack": False, + "axis": 1, +} + + +def check(expected, actual, msg="", max_diff=1e-3, verbose=False): + passed = torch.allclose(expected, actual, atol=max_diff, rtol=max_diff) + if verbose: + max_err = (expected - actual).abs().max() + if not passed: + print(f"{msg}: Failed! Max error: {max_err}") + else: + print(f"{msg}: Passed! Max error: {max_err}") + + return passed + +def _arg_to_id(arg): + if isinstance(arg, list): + return "x".join([str(x) for x in arg]) + return str(arg) + +@pytest.mark.parametrize("shape, group_size, axis, dtype, transposed, kernel_type", TEST_CONFIGS, ids=_arg_to_id) +def test_mixed_mm(shape, group_size, axis, dtype, transposed, kernel_type, quant_dtype=torch.uint8): + qcfg = { + **BASE_QUANT_CONFIG, + **dict(group_size=group_size, axis=axis), + } + M, N, K = shape + + linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device="cuda") + + quant_config = BaseQuantizeConfig( + quant_zero=False, quant_scale=False, offload_meta=False, view_as_float=False + ) + quant_config.update({"weight_quant_params": qcfg}) + hqq_linear = HQQLinear(linear, quant_config, compute_dtype=dtype, del_orig=False) + W_q, meta = hqq_linear.W_q, hqq_linear.meta + W_q = W_q.to(dtype=quant_dtype) + W_q = ( + W_q.reshape(meta["shape"]) + if quant_config["weight_quant_params"]["bitpack"] == False + else W_q + ) + W_dq = hqq_linear.dequantize() + + scales, zeros = meta["scale"], meta["zero"] + scales = scales.reshape(N, -1) + zeros = zeros.reshape(N, -1) + + if transposed: + x = torch.randn(M, N, dtype=dtype, device="cuda") + hqq_out = x @ W_dq + + #Pack uint8 W_q, then run fused dequant matmul + packed_w = pack_2xint4(W_q) + tt_out = triton_mixed_mm( + x, packed_w, scales, zeros, transposed=True, group_size=group_size, fp8_fast_accum=False, kernel_type=kernel_type + ) + else: + x = torch.randn(M, K, dtype=dtype, device="cuda") + hqq_out = x @ W_dq.T + + packed_w = pack_2xint4(W_q.T) + tt_out = triton_mixed_mm( + x, packed_w, scales.T, zeros.T, transposed=False, group_size=group_size, fp8_fast_accum=False, kernel_type=kernel_type + ) + + assert check(hqq_out, tt_out, max_diff=1e-2 if dtype == torch.bfloat16 else 1e-3) + diff --git a/torchao/prototype/hqq/README.md b/torchao/prototype/hqq/README.md new file mode 100644 index 000000000..22c40fd24 --- /dev/null +++ b/torchao/prototype/hqq/README.md @@ -0,0 +1,55 @@ +## Fused `int4 / fp16` Quant Matmul + +Fused kernel that combines asymmetric dequantization and gemm. Useful primarily for compute-bound (M > 16) scenarios and not for memory-bound / inference scenarios. + +The kernel fuses two ops: + +- Dequantization: upcasts `u4 / s4` weights to `float16 / bfloat16`, followed by groupwise scaling and shifting by scales / zeropoints +- GEMM: matmul on dequantized weights and activations. + +Tested and benchmarked for `HQQ` but could theoretically be used for any asymmetric quantization scheme. + +> **NOTE**: Benchmark below is only indicative of performance on consumer-grade `Ampere` GPUs (`A6000` specifically). When tested on `H100`, the performance is on par / marginally worse than native / compiled `torch`. +> The intended use is thus for fine-tuning / training models on non-datacenter GPUs (`80 <= compute capability < 90`). If interested in optimizing the kernel for other architectures, please drop a note in the CUDA-MODE Discord channel. + +### Implementation Details + +- Bitpacking is simple row interleave, no need for extensive preprocessing (e.g., `tinygemm` or `fastertransformer`) +- Tested for `float16 / bfloat16` activations, scales, and zeros +- Autotuned for both compute-bound and memory-bound configs +- Assumes operand B of the `gemm` is is the quantized type. +- Requires quantization along `in-features`, i.e., the `K` dimension, or `axis=1`, of `torch.linear.weight`. +- Implementation handles both transposed and non-tranposed quantized weights, useful for forward / backward training passes. + +### Performance + +Initial benchmarking (on `A6000`) demonstrates promising results, scaling well for compute-bound workloads: + +| | M | N | K | group_size | dtype | hqq_ref | triton | tinygemm | +| --- | ---- | ---- | ---- | ---------- | -------------- | ------- | ------ | -------- | +| 0 | 16 | 4096 | 4096 | 128 | torch.bfloat16 | 0.2675 | 0.0633 | 0.0382 | +| 1 | 32 | 4096 | 4096 | 128 | torch.bfloat16 | 0.2669 | 0.0704 | 0.0649 | +| 2 | 128 | 4096 | 4096 | 128 | torch.bfloat16 | 0.2689 | 0.0960 | 0.2523 | +| 3 | 256 | 4096 | 4096 | 128 | torch.bfloat16 | 0.3268 | 0.1355 | 0.5192 | +| 4 | 512 | 4096 | 4096 | 128 | torch.bfloat16 | 0.3628 | 0.2369 | 1.0892 | +| 5 | 1024 | 4096 | 4096 | 128 | torch.bfloat16 | 0.5133 | 0.4753 | 2.2016 | + +- Times are in `ms`, see `benchmarks/benchmark_hqq.py`. +- `hqq_ref` is the base `HQQ_Linear` [module](https://github.com/mobiusml/hqq/blob/6d50eee4bcdd99cc10716f1297c5b2803d2b6da4/hqq/core/quantize.py#L349) that is unfused (dequantization followed by call to torch.matmul). +- `tinygemm` calls `torch.ops.aten._weight_int4pack_mm`. Implementation is a custom HQQLinear layer that wraps the preprocessing necessary for this kernel, adapted from a benchmark script posted by @mobicham from `CUDA-mode` Discord discussions. + +GPU details: + +``` +_CudaDeviceProperties(name='NVIDIA RTX A6000', major=8, minor=6, total_memory=48676MB, multi_processor_count=84) +``` + +### NOTE + +This implementation requires **`triton >= 3.0.0`**. + +- Running tests / benchmarks requires installation of `hqq`: + + ``` + pip install hqq + ``` diff --git a/torchao/prototype/hqq/__init__.py b/torchao/prototype/hqq/__init__.py new file mode 100644 index 000000000..c97591c47 --- /dev/null +++ b/torchao/prototype/hqq/__init__.py @@ -0,0 +1 @@ +from .mixed_mm import triton_mixed_mm, pack_2xint4 \ No newline at end of file diff --git a/torchao/prototype/hqq/hqq_tinygemm_linear.py b/torchao/prototype/hqq/hqq_tinygemm_linear.py new file mode 100644 index 000000000..0c4ae45c6 --- /dev/null +++ b/torchao/prototype/hqq/hqq_tinygemm_linear.py @@ -0,0 +1,256 @@ + +#mobicham's tinygemm hqq eval script +import torch + +device = "cuda" + + +import torch, copy +from torch import nn, Tensor + +from hqq.core.quantize import * +from hqq.core.utils import * + +import torch.nn.functional as F + + +class HQQLinearTorchWeightOnlyInt4(torch.nn.Module): + def __init__( + self, + linear_layer: nn.Module | None, + quant_config: dict, + del_orig: bool = True, + compute_dtype: torch.dtype = torch.bfloat16, + device: str = "cuda", + initialize: bool = True, + inner_k_tiles=8, + padding=True, + ): + super().__init__() + + self.ready = False + self.in_gpu = False + self.bias = None + self.device = device + self.compute_dtype = compute_dtype + self.quant_config = copy.deepcopy(quant_config) + self.del_orig = del_orig + + weight_quant_params = self.quant_config["weight_quant_params"] + self.groupsize = weight_quant_params["group_size"] + self.nbits = weight_quant_params["nbits"] + self.inner_k_tiles = inner_k_tiles + self.padding = padding + + assert self.nbits in [1, 2, 4], "Unsupported nbits" + assert self.groupsize in [None, 32, 64, 128, 256], "Unsupported groupsize" + assert self.inner_k_tiles in [2, 4, 8], "Unsupported tile" + + self.linear_layer = linear_layer + self.compute_dtype = compute_dtype + + if initialize: + self.initialize() + + ###################### Initializers ###################### + def initialize_with_hqq_quants(self, W_q, meta, bias=None): + self.padding = ( + False # Force padding off, a bit tricky to post-pad with grouping + ) + + self.set_shape(meta["shape"]) + self.process_hqq_quants(W_q, meta) + self.bias = bias + self.ready = True + self.in_gpu = True + torch.cuda.empty_cache() + + return self + + def initialize(self): + if self.linear_layer is not None: + W = self.linear_layer.weight.data + self.set_shape(W.shape) + + if self.in_features_diff > 0: + W = F.pad(W, pad=(0, self.in_features_diff), value=0) + + W_q, meta = self.quantize(W, **self.quant_config) + self.process_hqq_quants(W_q, meta) + del W_q, meta + + self.bias = ( + None + if (self.linear_layer.bias is None) + else self.linear_layer.bias.to( + dtype=self.compute_dtype, device=self.device + ) + ) + + if self.del_orig: + del self.linear_layer + + self.ready = True + self.in_gpu = True + torch.cuda.empty_cache() + + return self + + ###################### Quantize/packing ###################### + + def quantize( + self, + W: Tensor, + weight_quant_params: dict, + scale_quant_params=dict | None, + zero_quant_params=dict | None, + offload_meta=False, + ): + W_q, meta = Quantizer.quantize( + W, + **weight_quant_params, + device=self.device, + compute_dtype=self.compute_dtype, + bitpack=False, + ) + + # ToDO: meta quantization + + return W_q, meta + + # TODO: move these to utils + @torch.no_grad() + def reshape_meta_axis1(self, meta_tensor, new_group_size, shape): + meta_tensor = meta_tensor.repeat([1, shape[1]]).reshape(shape) + meta_tensor = torch.mean( + meta_tensor.reshape([-1, new_group_size]), axis=1, keepdim=True + ) + return meta_tensor + + def find_multiple(self, n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + def set_shape(self, shape): + self.shape = shape + self.in_features = shape[1] + self.out_features = shape[0] + + self.origin_in_features = self.in_features + if self.padding: + self.in_features = self.find_multiple(self.in_features, 1024) + + self.in_features_diff = self.in_features - self.origin_in_features + + @torch.no_grad() + def process_hqq_quants(self, W_q, meta): + scales = meta["scale"] + zeros = meta["zero"] + shape = meta["shape"] + + if meta["packing"] is not None: + W_q = Quantizer.unpack[meta["packing"]](W_q) + + if self.groupsize is None: + self.groupsize = 128 + W_q = W_q.reshape([-1, self.groupsize]) + scales = self.reshape_meta_axis1(scales, self.groupsize, shape) + zeros = self.reshape_meta_axis1(zeros, self.groupsize, shape) + + W_q_torch, scales_torch, zeros_torch = self.hqq_quants_to_torch_quants( + W_q=W_q, scales=scales, zeros=zeros, shape=shape, nbits=self.nbits + ) + self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack( + W_q_torch, self.inner_k_tiles + ) + self.scales_and_zeros = self.pack_scales_and_zeros(scales_torch, zeros_torch) + + del W_q_torch, scales_torch, zeros_torch + torch.cuda.empty_cache() + + @torch.no_grad() + def hqq_quants_to_torch_quants( + self, W_q: Tensor, scales: Tensor, zeros: Tensor, shape, nbits=4 + ): + W_q = W_q.to(dtype=self.compute_dtype, device=self.device) + scales = scales.to(dtype=self.compute_dtype, device=self.device) + zeros = zeros.to(dtype=self.compute_dtype, device=self.device) + + max_int = 2**nbits - 1 + min_int = 0 + dump = 2 ** (nbits - 1) + + # HQQ -> torch logic + new_zeros = (scales * dump) - zeros * scales + + min_val = new_zeros - scales * dump + + # group_quantize_tensor_from_qparams + W_r = (W_q - zeros) * scales + + W_q = ( + W_r.sub(min_val) + .div(scales) + .round() + .clamp_(min_int, max_int) + .to(torch.int32) + .reshape(shape) + .contiguous() + ) + + # group_dequantize_tensor_from_qparams + # W_r = W_q*scales + min_val + + scales = scales.contiguous().reshape(shape[0], -1) + new_zeros = new_zeros.contiguous().reshape(shape[0], -1) + + return W_q, scales, new_zeros + + def pack_scales_and_zeros(self, scales, zeros): + assert scales.shape == zeros.shape + assert scales.dtype == torch.bfloat16 + assert zeros.dtype == torch.bfloat16 + return ( + torch.cat( + [ + scales.reshape(scales.size(0), scales.size(1), 1), + zeros.reshape(zeros.size(0), zeros.size(1), 1), + ], + 2, + ) + .transpose(0, 1) + .contiguous() + ) + + ###################### Forward/matmul ###################### + + @torch.jit.ignore() + def matmul(self, x): + origin_x_size = x.size() + x = x.reshape(-1, origin_x_size[-1]) + c = torch.ops.aten._weight_int4pack_mm( + x, self.weight_int4pack, self.groupsize, self.scales_and_zeros + ) + new_shape = origin_x_size[:-1] + (self.out_features,) + c = c.reshape(new_shape) + return c + + # TODO without matmul + def dequantize(self): + return self.matmul( + torch.eye(self.in_features, dtype=self.compute_dtype, device=self.device) + )[: self.origin_in_features].t() + + # TODO: backward + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.to(self.compute_dtype) + + if self.in_features_diff > 0: + x = F.pad(x, pad=(0, self.in_features_diff)) + + out = self.matmul(x) + + if self.bias is not None: + out += self.bias + return out diff --git a/torchao/prototype/hqq/kernels.py b/torchao/prototype/hqq/kernels.py new file mode 100644 index 000000000..077fc9410 --- /dev/null +++ b/torchao/prototype/hqq/kernels.py @@ -0,0 +1,348 @@ +from triton import Config +import triton.language as tl +import triton + +#TODO: add early config prune and estimate_matmul_time to reduce autotuning time +# from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time + + +def get_configs_io_bound(): + configs = [] + for num_stages in [2, 3, 4, 5, 6]: + for block_m in [16, 32]: + for block_k in [32, 64]: + for block_n in [32, 64, 128, 256]: + num_warps = 2 if block_n <= 64 else 4 + configs.append( + Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "SPLIT_K": 1, + }, + num_stages=num_stages, + num_warps=num_warps, + ) + ) + # split_k + for split_k in [2, 4, 8, 16]: + configs.append( + Config( + { + "BLOCK_M": block_m, + "BLOCK_N": block_n, + "BLOCK_K": block_k, + "SPLIT_K": split_k, + }, + num_stages=num_stages, + num_warps=num_warps, + pre_hook=init_to_zero("C"), + ) + ) + return configs + + +def get_configs_compute_bound(): + configs = [ + # basic configs for compute-bound matmuls + Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "SPLIT_K": 1}, + num_stages=5, + num_warps=2, + ), + # good for int8 + Config( + {"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=3, + num_warps=8, + ), + Config( + {"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=4, + num_warps=4, + ), + Config( + {"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "SPLIT_K": 1}, + num_stages=5, + num_warps=2, + ), + ] + return configs + + + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + +MIXED_MM_HEURISTICS = { + "EVEN_K": lambda args: args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0, + "BLOCK_K": lambda args: min(args["BLOCK_K"], args["QGROUP_SIZE"]) if not args["TRANSPOSED"] else args["BLOCK_K"], + "BLOCK_N": lambda args: min(args["BLOCK_N"], args["QGROUP_SIZE"]) if args["TRANSPOSED"] else args["BLOCK_N"], + "SPLIT_K": lambda args: 1 + if args["IS_BFLOAT16"] + else args["SPLIT_K"], # atomic add not supported for bfloat16 +} + + + +@triton.jit +def _mixed_mm_kernel( + # Operands + A, + B, + scales_ptr, + zeros_ptr, + C, + # Matrix dims. + M, + N, + K, + # a, b, c, scales / zeros strides + stride_am, + stride_ak, # + stride_bk, + stride_bn, # + stride_cm, + stride_cn, + stride_scale_k, + stride_scale_n, + # Meta-params + IS_BFLOAT16: tl.constexpr, + QGROUP_SIZE: tl.constexpr, + BLOCK_M: tl.constexpr, # = 32, + BLOCK_N: tl.constexpr, # = 32, + BLOCK_K: tl.constexpr, # = 16, # + SPLIT_K: tl.constexpr, # = 1, + EVEN_K: tl.constexpr, # = True, + TRANSPOSED: tl.constexpr = False, + GROUP_M: tl.constexpr = 8, # 32, + # tl.dot options + acc_dtype: tl.constexpr = tl.float32, + input_precision: tl.constexpr = "ieee", + fp8_fast_accum: tl.constexpr = False, +): + """Mixed matmul kernel + + A has shape (M, K) and is float16, bfloat16, or float32 + + B is i4 / s4 and has shape (K // 2, N) and is packed as uint8 / int8. See `packed_2xint4` for details. + + Scales and zeros are of shape (NUM_GROUPS, N) and are same dtype as A, where NUM_GROUPS = (K // QGROUP_SIZE) + QGROUP_SIZE should be a multiple of BLOCK_K such that a vector of scales / zeros is loaded and broadcasted to block shape + per mainloop iteration. + + NOTE: Assumes that the quantization grouping was done along the K dimension originally (i.e., QGROUP_SIZE consecutive elements + of original weight matrix in the K dimension were grouped together when calculating min / max scaling factors). + """ + + # tl.static_assert(B.dtype.element_ty == tl.int8 or B.dtype.element_ty == tl.uint8) + if not TRANSPOSED: + tl.static_assert(QGROUP_SIZE % BLOCK_K == 0) + else: + tl.static_assert(QGROUP_SIZE % BLOCK_N == 0) + + # Threadblock swizzling + pid = tl.program_id(0) + pid_z = tl.program_id(1) + + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + + width = GROUP_M * grid_n + group_id = pid // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (pid % group_size) + pid_n = (pid % width) // group_size + + rm = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M + rn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N + ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + + rak = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) + + # BLOCK_K for b is effectively BLOCK_K // 2 + rbk = pid_z * BLOCK_K // 2 + tl.arange(0, BLOCK_K // 2) + + A = A + (ram[:, None] * stride_am + rak[None, :] * stride_ak) + B = B + (rbk[:, None] * stride_bk + rbn[None, :] * stride_bn) + + #In the forward pass, we have a K x N matrix + #In the transposed (backward) pass, we have an N x K matrix, where N and K refer to the how the weight was originally quantized + #note that N refers to offsets_scale_k and K refers to offsets_scale_n when it comes to the gemm indexing logic below + + #Grouping is along K, so in the forward pass, each block loads a row vector of BLK_K x BLK_N + #where grouping varies along N, hence the mainloop marches down the K dimension, where + #group idx is given by K // QGROUP_SIZE + + # For the transposed case, we load a column vector of BLK_N x BLK_K + # we march down the N dimension during the mainloop ("K" in gemm) + # Hence blocks now load K // QGROUP_SIZE along pid_n (slow varying) + # while each block now loads column vector of groups along "K" gemm dim on each main loop iteration + # scale offsets is thus a single idx along "N" and range along "K" for the transposed case + + if not TRANSPOSED: + # scale_offset_n = pid_n * stride_scale_n * BLOCK_N + offsets_scale_n = pid_n * stride_scale_n * BLOCK_N + tl.arange(0, BLOCK_N) * stride_scale_n + else: + offsets_scale_n = pid_n * stride_scale_n * BLOCK_N // QGROUP_SIZE + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=acc_dtype) + for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): + if EVEN_K: + a = tl.load(A) + qb = tl.load(B) + else: + k_remaining_a = K - k * (BLOCK_K * SPLIT_K) + k_remaining_b = K - k * (BLOCK_K * SPLIT_K) // 2 # Note the division by 2 + + _0 = tl.zeros((1, 1), dtype=C.dtype.element_ty) + a = tl.load(A, mask=rak[None, :] < k_remaining_a, other=_0) + qb = tl.load(B, mask=rbk[:, None] < k_remaining_b, other=_0) + + if not TRANSPOSED: + scale_offset_k = k * BLOCK_K * SPLIT_K * stride_scale_k // QGROUP_SIZE + else: + scale_offset_k = k * BLOCK_K * SPLIT_K * stride_scale_k + tl.arange(0, BLOCK_K) * stride_scale_k + + scales = tl.load(scales_ptr + offsets_scale_n + scale_offset_k) + zeros = tl.load(zeros_ptr + offsets_scale_n + scale_offset_k) + + # Unpack qweights -- h/t jlebar! + _4_i8 = tl.full((1,), 4, dtype=tl.int8) + qb_lo = (qb << _4_i8) >> _4_i8 + qb_hi = qb >> _4_i8 + + # Upcast to fp16 + # TODO: better bfloat16 conversion? compilation error if direct conversion from int8 to bfloat16 + if IS_BFLOAT16: + dq_b = ( + tl.join( + qb_lo.to(tl.float16).to(A.dtype.element_ty), + qb_hi.to(tl.float16).to(A.dtype.element_ty), + ) + .permute(0, 2, 1) + .reshape(BLOCK_K, BLOCK_N) + ) + else: + dq_b = ( + tl.join( + qb_lo.to(A.dtype.element_ty), + qb_hi.to(A.dtype.element_ty), + ) + .permute(0, 2, 1) + .reshape(BLOCK_K, BLOCK_N) + ) + + # Scale upcasted weights + # Note that we broadcast the scales --> the assumption is that all scales fall within a single QGROUP + # This condition is statically check (see assertions above) + if not TRANSPOSED: + zeros = zeros[None, :] + scales = scales[None, :] + else: + zeros = zeros[:, None] + scales = scales[:, None] + + dq_b = (dq_b - zeros) * scales + + # dq_b = (dq_b - zeros[None, :]) * scales[None, :] + + if fp8_fast_accum: + acc = tl.dot( + a, dq_b, acc, out_dtype=acc_dtype, input_precision=input_precision + ) + else: + acc += tl.dot(a, dq_b, out_dtype=acc_dtype, input_precision=input_precision) + A += BLOCK_K * SPLIT_K * stride_ak + # Advance by half the block size, since each block is unpacked and upcasted into two fp16 values + B += BLOCK_K * SPLIT_K * stride_bk // 2 + + acc = acc.to(C.dtype.element_ty) + + offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + C = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + + if SPLIT_K == 1: + tl.store(C, acc, mask=mask) + else: + tl.atomic_add(C, acc, mask=mask) + + +_mixed_mm = triton.heuristics(MIXED_MM_HEURISTICS)(_mixed_mm_kernel) +mixed_mm_kernel_max_autotune = triton.autotune(configs=get_configs_compute_bound() + get_configs_io_bound(), key=["M", "N", "K"])(_mixed_mm) +mixed_mm_kernel_compute_bound = triton.autotune(configs=get_configs_compute_bound(), key=["M", "N", "K"])(_mixed_mm) diff --git a/torchao/prototype/hqq/mixed_mm.py b/torchao/prototype/hqq/mixed_mm.py new file mode 100644 index 000000000..e3ccaeb46 --- /dev/null +++ b/torchao/prototype/hqq/mixed_mm.py @@ -0,0 +1,102 @@ +import torch +from triton import cdiv +import triton.language as tl +from .kernels import mixed_mm_kernel_compute_bound, mixed_mm_kernel_max_autotune + +#h/t jlebar for the bit packing / unpacking logic (source: Triton Slack thread) +#https://gist.github.com/jlebar/3435b2c00deea53258887ce37231e5e2 +def pack_2xint4(t): + """ + The packing format is such that consecutive rows are packed into a lower / upper bits + E.g., + Original, unpacked B (dtype i8): + [ + [0, 1, 2, 3] + [4, 5, 6, 7] + [8, 9, 10, 11] + [12, 13, 14, 15] + ] + Packed B: + [ + [0|4, 1|5, 2|6, 3|7] + [8|12, 9|13, 10|14, 11|15] + ] + (Note each entry in `Packed B` is shown lsb->msb) + """ + assert t.dtype == torch.int8 or t.dtype == torch.uint8 + t = t.reshape(t.shape[0] // 2, 2, t.shape[1]).permute(1, 0, 2) + return (t[0] & 0xF) | (t[1] << 4) + +def triton_mixed_mm( + a, + b, + scales, + zeros, + group_size, + transposed=False, + acc_dtype=None, + input_precision="ieee", + fp8_fast_accum=False, + kernel_type="compute_bound", +): + device = a.device + # handle non-contiguous inputs if necessary + if a.stride(0) > 1 and a.stride(1) > 1: + a = a.contiguous() + if b.stride(0) > 1 and b.stride(1) > 1: + b = b.contiguous() + # checks constraints + assert a.shape[1] == b.shape[0] * 2, "incompatible dimensions" + assert b.dtype == torch.int8 or b.dtype == torch.uint8, "b must be int8 or uint8" + assert scales.ndim == 2 + assert kernel_type in ["max_autotune", "compute_bound"] + + M, K = a.shape + _, N = b.shape + # N = b.shape[1] if not transposed else b.shape[0] + # assert scales.shape[1] == N if not transposed else scales.shape[0] == N + # assert scales.shape[0] == K // group_size if not transposed else scales.shape[1] == K // group_size + assert scales.dtype == a.dtype + assert scales.shape == zeros.shape + assert zeros.dtype == a.dtype + + # Assumes c is same type as a + c = torch.empty((M, N), device=device, dtype=a.dtype) + if acc_dtype is None: + acc_dtype = tl.float32 + + grid = lambda META: ( + cdiv(M, META["BLOCK_M"]) * cdiv(N, META["BLOCK_N"]), + META["SPLIT_K"], + ) + + if kernel_type == "max_autotune": + kernel = mixed_mm_kernel_max_autotune + else: + kernel = mixed_mm_kernel_compute_bound + + kernel[grid]( + a, + b, + scales, + zeros, + c, + M, + N, + K, # + a.stride(0), + a.stride(1), # + b.stride(0), + b.stride(1), # + c.stride(0), + c.stride(1), + scales.stride(0), + scales.stride(1), + TRANSPOSED=transposed, + IS_BFLOAT16=a.dtype == torch.bfloat16, + QGROUP_SIZE=group_size, + acc_dtype=acc_dtype, + input_precision=input_precision, + fp8_fast_accum=fp8_fast_accum, + ) + return c