diff --git a/benchmarks/benchmark_marlin_qqq.py b/benchmarks/benchmark_marlin_qqq.py new file mode 100644 index 0000000000..295d089680 --- /dev/null +++ b/benchmarks/benchmark_marlin_qqq.py @@ -0,0 +1,64 @@ +import torch +import pandas as pd +from torchao.utils import benchmark_torch_function_in_microseconds +from torchao.ops import marlin_qqq_gemm +from torchao.quantization.marlin_qqq import marlin_qqq_workspace, pack_to_marlin_qqq +from tqdm import tqdm + + +def get_problem(m, n, k, groupsize=-1): + if groupsize == -1: + groupsize = k + dev = torch.device("cuda") + A_ref = torch.randn((m, k), dtype=torch.half, device=dev) + B_ref = torch.randn((k, n), dtype=torch.half, device=dev) + + A = torch.randint(-128, 127, (m, k), dtype=torch.int8, device=dev) + B = torch.randint(low=-(2**31), high=2**31, size=(k, n), device=dev) + s_tok = torch.ones((m, 1), dtype=torch.float, device=dev) + if groupsize == k: + s_group = torch.tensor([], dtype=torch.half, device=dev) + else: + s_group = torch.ones((k // groupsize, n), dtype=torch.half, device=dev) + s_channel = torch.ones((1, n), dtype=torch.float, device=dev) + B, s_group, s_channel = pack_to_marlin_qqq( + B, s_group, s_channel, num_bits=4, group_size=group_size + ) + qqq_workspace = marlin_qqq_workspace(n) + return A, B, A_ref, B_ref, s_tok, s_channel, s_group, qqq_workspace + + +def benchmark(m: int, k: int, n: int, group_size: int): + A, B, A_ref, B_ref, s_tok, s_channel, s_group, qqq_workspace = get_problem( + m, n, k, group_size + ) + + fp16_time = benchmark_torch_function_in_microseconds(torch.matmul, A_ref, B_ref) + marlin_qqq_w4a8_time = benchmark_torch_function_in_microseconds( + marlin_qqq_gemm, A, B, s_tok, s_channel, s_group, qqq_workspace, m, n, k + ) + + return { + "m": m, + "k": k, + "n": n, + "group_size": group_size, + "fp16_latency (ms)": fp16_time, + "marlin_qqq_w4a8_latency (ms)": marlin_qqq_w4a8_time, + "speedup (d/s)": fp16_time / marlin_qqq_w4a8_time, + } + + +if __name__ == "__main__": + k_vals = (8192, 8192, 8192, 28672) + n_vals = (8192, 10240, 57344, 8192) + + results = [] + for group_size in tqdm([-1, 128]): + for m in tqdm([1 << i for i in range(10)]): + for n, k in zip(n_vals, k_vals): + results.append(benchmark(m, k, n, group_size)) + + df = pd.DataFrame(results) + df.to_csv("marlin_qqq_w4a8_llm_benchmark_results.csv", index=False) + print(df.to_markdown(index=False)) diff --git a/scripts/convert_hf_checkpoint.py b/scripts/convert_hf_checkpoint.py index 3098c818bb..11d425ceb2 100644 --- a/scripts/convert_hf_checkpoint.py +++ b/scripts/convert_hf_checkpoint.py @@ -85,6 +85,10 @@ def permute(w, n_head): else: state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True) merged_result.update(state_dict) + + if config.tie_word_embeddings: + merged_result["lm_head.weight"] = merged_result["model.embed_tokens.weight"].clone() + final_result = {} for key, value in merged_result.items(): if "layers" in key: @@ -112,7 +116,7 @@ def permute(w, n_head): del final_result[key.replace("wq", "wv")] print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}") torch.save(final_result, checkpoint_dir / "model.pth") - if 'llama-3-' in model_name.lower() or 'llama-3.1-' in model_name.lower(): + if any([x in model_name.lower() for x in ["llama-3-", "llama-3.1-", "llama-3.2-"]]): if 'llama-3.1-405b' in model_name.lower(): original_dir = checkpoint_dir / "original" / "mp16" else: diff --git a/scripts/prepare.sh b/scripts/prepare.sh index 8799388939..db426e3b11 100644 --- a/scripts/prepare.sh +++ b/scripts/prepare.sh @@ -1,6 +1,8 @@ python scripts/download.py --repo_id meta-llama/Llama-2-7b-chat-hf python scripts/download.py --repo_id meta-llama/Meta-Llama-3-8B python scripts/download.py --repo_id meta-llama/Meta-Llama-3.1-8B +python scripts/download.py --repo_id meta-llama/Llama-3.2-3B python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-2-7b-chat-hf python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3-8B python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3.1-8B +python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-3.2-3B diff --git a/test/dtypes/test_floatx.py b/test/dtypes/test_floatx.py index 13add69a0a..aa55164716 100644 --- a/test/dtypes/test_floatx.py +++ b/test/dtypes/test_floatx.py @@ -10,12 +10,11 @@ run_tests, ) from torchao.dtypes.floatx import ( - FloatxTensorCoreAQTTensorImpl, FloatxTensorCoreLayout, to_scaled_tc_floatx, from_scaled_tc_floatx, ) -from torchao.dtypes.floatx.floatx_tensor_core_layout import _pack_tc_floatx, _pack_tc_fp6 +from torchao.dtypes.floatx.floatx_tensor_core_layout import _pack_tc_floatx, _pack_tc_fp6, FloatxTensorCoreAQTTensorImpl from torchao.prototype.custom_fp_utils import _f32_to_floatx_unpacked, _floatx_unpacked_to_f32 from torchao.quantization import ( quantize_, diff --git a/test/dtypes/test_uintx.py b/test/dtypes/test_uintx.py index bb754135db..f4823c4d3b 100644 --- a/test/dtypes/test_uintx.py +++ b/test/dtypes/test_uintx.py @@ -4,7 +4,7 @@ import torch -from torchao.dtypes.uintx import to_uintx +from torchao.dtypes.uintx.uintx_layout import to_uintx from torchao.quantization.quant_api import quantize_, uintx_weight_only from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_3, diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 7bc5a37882..a0ea96baae 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -632,7 +632,7 @@ def test_pad_inner_dim(self, base_dtype, use_fast_accum): with pytest.raises( RuntimeError, match=re.escape( - "Expected trailing dimension of mat1 to be divisible by 16 but got mat1 shape: (16x41)." + "Expected trailing dimension of mat1 to be divisible by 16 but got mat1 shape: (16x41" ), ): a_fp8 @ b_fp8 diff --git a/test/quantization/test_marlin_qqq.py b/test/quantization/test_marlin_qqq.py new file mode 100644 index 0000000000..c020b958f1 --- /dev/null +++ b/test/quantization/test_marlin_qqq.py @@ -0,0 +1,129 @@ +import copy + +import pytest +import torch +from torch import nn +from torch.testing._internal.common_utils import TestCase, run_tests + +from torchao.dtypes import MarlinQQQLayout +from torchao.quantization.marlin_qqq import ( + pack_to_marlin_qqq, + unpack_from_marlin_qqq, +) +from torchao.quantization.quant_api import ( + int8_dynamic_activation_int4_weight, + quantize_, +) +from torchao.quantization.quant_primitives import ( + MappingType, + choose_qparams_and_quantize_affine_qqq, +) +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 + + +class MarlinQQQ(TestCase): + def setUp(self): + super().setUp() + torch.manual_seed(0) + + self.input = torch.randn((64, 32, 8192), dtype=torch.float16, device="cuda") + self.model = ( + nn.Sequential( + nn.Linear(8192, 21504), + nn.Linear(21504, 8192), + nn.ReLU(), + nn.Linear(8192, 21504), + nn.Linear(21504, 8192), + ) + .half() + .cuda() + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + def test_marlin_qqq(self): + output_ref = self.model(self.input) + for group_size in [-1, 128]: + modelq = copy.deepcopy(self.model) + quantize_( + modelq, + int8_dynamic_activation_int4_weight( + group_size=group_size, + mapping_type=MappingType.SYMMETRIC, + act_mapping_type=MappingType.SYMMETRIC, + layout=MarlinQQQLayout(), + ), + ) + output = modelq(self.input) + + assert torch.allclose( + output, output_ref, atol=1e-1 + ), "Results are not close" + + @pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+") + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + def test_marlin_qqq_compile(self): + model_copy = copy.deepcopy(self.model) + model_copy.forward = torch.compile(model_copy.forward, fullgraph=True) + output_ref = model_copy(self.input) + + for group_size in [-1, 128]: + modelq = copy.deepcopy(self.model) + quantize_( + modelq, + int8_dynamic_activation_int4_weight( + group_size=group_size, + mapping_type=MappingType.SYMMETRIC, + act_mapping_type=MappingType.SYMMETRIC, + layout=MarlinQQQLayout(), + ), + ) + modelq.forward = torch.compile(modelq.forward, fullgraph=True) + output = modelq(self.input) + + assert torch.allclose( + output, output_ref, atol=1e-1 + ), "Results are not close" + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available") + def test_pack_unpack_equivalence(self): + num_bits = 4 + shape = (11008, 4096) + + w = torch.rand(shape, dtype=torch.float16, device="cuda") + + for group_size in [-1, 128]: + # Quantize weights + q_w, s_group, s_channel, _ = choose_qparams_and_quantize_affine_qqq( + w, num_bits, group_size + ) + + q_w = q_w.t() + s_group = s_group.t() + s_channel = s_channel.t() + + # Test pack/unpack equivalence + q_w_comp, packed_s_group, packed_s_channel = pack_to_marlin_qqq( + q_w, s_group, s_channel, num_bits, group_size + ) + unpacked_q_w, unpacked_s_group, unpacked_s_channel = unpack_from_marlin_qqq( + q_w_comp, + packed_s_group, + packed_s_channel, + q_w.shape, + num_bits, + group_size, + ) + + assert torch.equal( + q_w, unpacked_q_w + ), "Unpacked weights do not match original weights" + assert torch.equal( + s_channel, unpacked_s_channel + ), "Unpacked s_channel do not match original s_channel" + assert torch.equal( + s_group, unpacked_s_group + ), "Unpacked s_group do not match original s_group" + + +if __name__ == "__main__": + run_tests() diff --git a/test/test_ops.py b/test/test_ops.py index 7802fdeaeb..4d8104c25b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -13,6 +13,11 @@ from torchao.utils import is_fbcode, TORCH_VERSION_AT_LEAST_2_5, compute_max_diff from torchao.dtypes.floatx import from_scaled_tc_floatx from torchao.sparsity.marlin import marlin_24_workspace, pack_to_marlin_24, inject_24 +from torchao.quantization.marlin_qqq import ( + marlin_qqq_workspace, + pack_to_marlin_qqq, +) +from torchao.quantization.quant_primitives import choose_qparams_and_quantize_affine_qqq import pytest if is_fbcode(): @@ -426,5 +431,109 @@ def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_facto ) +MARLIN_QQQ_BATCH_SIZE = [1, 4, 8, 16, 32, 64] +MARLIN_QQQ_K_CHUNKS = [128] +MARLIN_QQQ_N_CHUNKS = [64, 128, 256] +MNK_FACTORS = [ + (1, 1, 1), + (1, 4, 8), + (1, 7, 5), + (13, 17, 67), + (26, 37, 13), + (67, 13, 11), +] +MARLIN_QQQ_SUPPORTED_NUM_BITS = [4] +MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128] + +MARLIN_TEST_PARAMS = list( + itertools.product( + MARLIN_QQQ_BATCH_SIZE, + MARLIN_QQQ_K_CHUNKS, + MARLIN_QQQ_N_CHUNKS, + MARLIN_QQQ_SUPPORTED_NUM_BITS, + MARLIN_QQQ_SUPPORTED_GROUP_SIZES, + MNK_FACTORS, + ) +) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize( + "batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors", + MARLIN_TEST_PARAMS, + ids=str, +) +def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors): + int8_traits = torch.iinfo(torch.int8) + m_factor, n_factor, k_factor = mnk_factors + + size_m = m_factor + size_k = k_chunk * k_factor + size_n = n_chunk * n_factor + + a_input = torch.randn( + (batch_size, size_m, size_k), dtype=torch.float16, device="cuda" + ) + b_weight = torch.rand((size_n, size_k), dtype=torch.float16, device="cuda") + + # Reshape input into 2D tensor + input_2d = a_input.view(-1, a_input.shape[-1]) + a_input_in, a_input_out = input_2d.shape + + # Quantize activations + s_a = ( + input_2d.abs() + .max(dim=-1, keepdim=True)[0] + .div(int8_traits.max) + .to(torch.float32) + ) + q_a = ( + (input_2d / s_a).round().clamp(int8_traits.min, int8_traits.max).to(torch.int8) + ) + + # Quantize weights + q_w, s_group, s_channel, w_ref = choose_qparams_and_quantize_affine_qqq( + b_weight, num_bits, group_size + ) + q_w = q_w.t() + s_group = s_group.t() + s_channel = s_channel.t() + w_ref = w_ref.t() + marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = pack_to_marlin_qqq( + q_w, s_group, s_channel, num_bits, group_size + ) + + workspace = marlin_qqq_workspace(size_n) + + # Obtains reference output + output_ref = torch.matmul(q_a.half() * s_a.half(), w_ref) + output_ref = output_ref.reshape(a_input.shape[:-1] + (size_n,)) + + fn_inputs = ( + q_a, + marlin_qqq_q_w, + s_a, + marlin_qqq_s_channel, + marlin_qqq_s_group, + workspace, + a_input_in, + size_n, + a_input_out, + ) + output = torchao.ops.marlin_qqq_gemm(*fn_inputs) + output = output.reshape(a_input.shape[:-1] + (size_n,)) + + max_diff = compute_max_diff(output, output_ref) + assert max_diff < 0.04 + + # Performs opcheck + test_utils = ["test_schema", "test_autograd_registration", "test_faketensor"] + opcheck( + torch.ops.torchao.marlin_qqq_gemm, + fn_inputs, + test_utils=test_utils, + ) + + if __name__ == "__main__": run_tests() diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 973e0ba9a5..1efa6b04b3 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -14,6 +14,7 @@ import torch._dynamo.config import torch._inductor.config from torchao.utils import get_model_size_in_bytes +from torchao.quantization.quant_primitives import MappingType from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 def device_sync(device): @@ -211,6 +212,7 @@ def main( int8_weight_only, int8_dynamic_activation_int8_weight, int4_weight_only, + int8_dynamic_activation_int4_weight, fpx_weight_only, uintx_weight_only, autoquant, @@ -235,8 +237,20 @@ def main( assert group_size in [32,64,128,256], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}" quantize_(model, int4_weight_only(group_size=group_size)) if "marlin" in quantization: - from torchao.dtypes import MarlinSparseLayout - quantize_(model, int4_weight_only(layout=MarlinSparseLayout())) + if "qqq" in quantization: + from torchao.dtypes import MarlinQQQLayout + quantize_( + model, + int8_dynamic_activation_int4_weight( + group_size=128, + mapping_type=MappingType.SYMMETRIC, + act_mapping_type=MappingType.SYMMETRIC, + layout=MarlinQQQLayout(), + ), + ) + else: + from torchao.dtypes import MarlinSparseLayout + quantize_(model, int4_weight_only(layout=MarlinSparseLayout())) if "fp6" in quantization: quantize_(model, fpx_weight_only(3, 2)) if "embed-int8wo" in quantization: @@ -474,7 +488,7 @@ def callback(x): help=( 'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-, int4wo--hqq, autoquant, ' +'autoquant-int4, autoquant-float8, uintx--, uintx---hqq, sparse-marlin, spinquant, ' - +'embed-int8wo' + +'embed-int8wo, marlin_qqq' ) ) parser.add_argument("--calibration_limit", type=int, default=10, help="Number of calibration examples") diff --git a/torchao/_models/llama/model.py b/torchao/_models/llama/model.py index de1f311979..74cad30cbd 100644 --- a/torchao/_models/llama/model.py +++ b/torchao/_models/llama/model.py @@ -35,6 +35,7 @@ class ModelArgs: rope_base: float = 10000 norm_eps: float = 1e-5 use_scaled_rope: bool = False + tie_word_embeddings: bool = False def __post_init__(self): if self.n_local_heads == -1: @@ -79,6 +80,9 @@ def from_name(cls, name: str): "Llama-3.1-405B": dict(block_size=131072, n_layer=126, n_head=128, n_local_heads=8, dim=16384, intermediate_size=53248, vocab_size=128256, rope_base=500000, use_scaled_rope=True ), + "Llama-3.2-3B": dict(block_size=131072, n_layer=28, n_head=24, n_local_heads=8, dim=3072, intermediate_size=8192, vocab_size=128256, rope_base=500000, + use_scaled_rope=True, tie_word_embeddings=True + ), } # this is a model specific variable that controls whether index_put is used for the kv_cache update, diff --git a/torchao/csrc/cuda/marlin_qqq/base.h b/torchao/csrc/cuda/marlin_qqq/base.h new file mode 100644 index 0000000000..d184b65cce --- /dev/null +++ b/torchao/csrc/cuda/marlin_qqq/base.h @@ -0,0 +1,36 @@ +/* + * Modified by HandH1998 + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace torchao { + +constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } + +// Instances of `Vec` are used to organize groups of >>registers<<, as needed +// for instance as inputs to tensor core operations. Consequently, all +// corresponding index accesses must be compile-time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee +// this. +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { return elems[i]; } +}; + +} // namespace torchao diff --git a/torchao/csrc/cuda/marlin_qqq/marlin_qqq_kernel.cu b/torchao/csrc/cuda/marlin_qqq/marlin_qqq_kernel.cu new file mode 100644 index 0000000000..7380f9aff2 --- /dev/null +++ b/torchao/csrc/cuda/marlin_qqq/marlin_qqq_kernel.cu @@ -0,0 +1,1248 @@ +/* + * Adapted from + * https://github.com/IST-DASLab/marlin/blob/master/marlin/marlin_cuda_kernel.cu + * https://github.com/IST-DASLab/marlin/blob/master/marlin/marlin_cuda.cpp + * Modified by HandH1998 + * Copyright (C) 2024 HandH1998 + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include +#include + +#include + +#include "base.h" +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + #include "mem.h" +#endif + +template +inline std::string str(T x) { + return std::to_string(x); +} + +namespace torchao { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +using I4 = Vec; +// Matrix fragments for tensor core instructions; their precise layout is +// documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-integer-type +using FragA = Vec; +using FragB = Vec; +using FragC = Vec; +using FragS_GROUP = Vec; // weight per-group quantization scales +using FragS_CHANNEL = + Vec; // weight per-channel quantization scales or activaton + // per-token quantization scales + +// NOTE(HandH1998): cp.async.cg only support BYTES = 16, however, +// cp.async.ca can support BYTES = 4, 8, 16; +// as s_tok's shape is equal to prob_m, we need set s_tok to float type, +// and cp_size = 1 float, i.e., 4 BYTES +// Asynchronous global->shared copy for activation quantizaton scales s_tok +__device__ inline void cp_async1(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 4; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.ca.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +// m16n8k16 tensor core mma instruction with int8 inputs and int32 +// output/accumulation. +__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, + FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + int* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.satfinite.s32.s8.s8.s32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=r"(c[0]), "=r"(c[1]), "=r"(c[2]), "=r"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1]), "r"(c[2]), + "r"(c[3])); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in int8 tensor core layout. +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" + : "=r"(a[0]), "=r"(a[1]) + : "r"(smem)); +} + +inline __device__ half2 float2_to_half2(float2 f) { + uint32_t res; + // NOTE(HandH1998): h0,h1 should be uint16_t, not half + uint16_t h0, h1; + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(h0) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(h1) : "f"(f.y)); + asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(res) : "h"(h0), "h"(h1)); + return reinterpret_cast(res); +} + +inline __device__ float int32_to_float(int h) { + float res; + asm volatile("cvt.rn.f32.s32 %0, %1;\n" : "=f"(res) : "r"(h)); + return res; +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 int8 values +// for weight per channel dequant. +__device__ inline FragB dequant_per_channel(int q) { + static constexpr int MASK = 0xf0f0f0f0; + FragB frag_b; + frag_b[0] = (q & MASK); + return frag_b; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 int8 values +// for weight per group dequant. +__device__ inline FragB dequant_per_group(int q, FragS_GROUP& frag_s, int i) { + static constexpr uint32_t LO = 0x000f000f; + static constexpr uint32_t HI = 0x00f000f0; + static constexpr uint32_t EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + uint32_t t0 = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + uint32_t t1 = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + static constexpr uint32_t SUB = 0x64086408; + static constexpr uint32_t MUL = 0x2c002c00; + static constexpr uint32_t ADD = 0xd480d480; + *reinterpret_cast(&t0) = __hsub2( + *reinterpret_cast(&t0), *reinterpret_cast(&SUB)); + *reinterpret_cast(&t1) = __hfma2( + *reinterpret_cast(&t1), *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + + uint16_t s = reinterpret_cast(&frag_s)[i]; + uint32_t double_s; + // pack 2xfp16 to half2 + asm volatile("mov.b32 %0, {%1, %2};\n" : "=r"(double_s) : "h"(s), "h"(s)); + // dequant and convert 4 half to 4 uint8 (be placed at the low 8 bits of 4 + // half, respectively) + static constexpr uint32_t MAGIC_NUM = 0x64806480; + *reinterpret_cast(&t0) = __hfma2( + *reinterpret_cast(&t0), *reinterpret_cast(&double_s), + *reinterpret_cast(&MAGIC_NUM)); + *reinterpret_cast(&t1) = __hfma2( + *reinterpret_cast(&t1), *reinterpret_cast(&double_s), + *reinterpret_cast(&MAGIC_NUM)); + // take out the 4 uint8 from 4 half, then convert them to 4 int8 and pack 4 + // int8 into 1 uint32 + FragB frag_b; + uint32_t uint8s; + static constexpr uint32_t MASK_0246 = 0x6420; + static constexpr uint32_t UINT8s_TO_INT8s_MASK = 0x80808080; + asm volatile("prmt.b32 %0,%1,%2,%3;\n" + : "=r"(uint8s) + : "r"(t0), "r"(t1), "n"(MASK_0246)); + frag_b[0] = (uint8s ^ UINT8s_TO_INT8s_MASK); + return frag_b; +} + +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin_QQQ( + const int4* __restrict__ A, // int8 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // int32 global_reduce buffer of shape + // (max_par*16*4)xn, as int8 tensor core's output is + // int32 dtype + int4* __restrict__ D, // fp16 output buffer of shape mxn + const float* __restrict__ s_tok, // fp32 activation per-token quantization + // scales of shape mx1 + const int4* __restrict__ s_ch, // fp32 weight per-channel quantization + // scales of shape 1xn + const int4* __restrict__ s_group, // fp16 weight per-group quantization + // scales of shape (k/groupsize)xn, when + // group_blocks=-1, it should be nullptr + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts in + // the middle of group. + if constexpr (group_blocks != -1) + iters = (group_blocks / thread_k_blocks) * + ceildiv(iters, (group_blocks / thread_k_blocks)); + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 16; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 4; + D += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + s_tok += (slice_col_par / n_tiles) * 16 * thread_m_blocks; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 16; + C += 16 * thread_m_blocks * prob_n / 4; + D += 16 * thread_m_blocks * prob_n / 8; + s_tok += 16 * thread_m_blocks; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + int a_gl_stride = prob_k / 16; // stride of the A matrix in global memory + // We typically use `constexpr` to indicate that this value is a compile-time + // constant + constexpr int a_sh_stride = + 16 * thread_k_blocks / 16; // stride of an A matrix tile in shared memory + constexpr int a_gl_rd_delta_o = + 16 * thread_k_blocks / + 16; // delta between subsequent A tiles in global memory + int a_gl_rd_delta_i = + a_gl_stride * + (threads / a_gl_rd_delta_o); // between subsequent accesses within a tile + constexpr int a_sh_wr_delta = + a_sh_stride * + (threads / a_gl_rd_delta_o); // between shared memory writes + constexpr int a_sh_rd_delta_o = + 1 * ((threads / 32) / + (thread_n_blocks / 4)); // between shared memory tile reads + constexpr int a_sh_rd_delta_i = + a_sh_stride * 16; // within a shared memory tile + constexpr int a_sh_stage = + a_sh_stride * (16 * thread_m_blocks); // overall size of a tile + constexpr int a_sh_wr_iters = + ceildiv(a_sh_stage, + a_sh_wr_delta); // number of shared write iterations for a tile + + int b_gl_stride = 16 * prob_n / 32; + constexpr int b_sh_stride = 32 * thread_n_blocks / 4; + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); + constexpr int b_sh_wr_delta = threads; + constexpr int b_sh_rd_delta = threads; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + constexpr int s_tok_sh_stride = 16 * thread_m_blocks; + + constexpr int s_ch_sh_stride = 16 * thread_n_blocks / 4; + + int s_group_gl_stride = prob_n / 8; + constexpr int s_group_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_group_sh_stage = s_group_sh_stride; + int s_group_gl_rd_delta = s_group_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + // NOTE(HandH1998): int8 input a only need 16 threads to load 16x16 matrix + int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % 16); + a_sh_rd += 1 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = + b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x; + int b_sh_rd = threadIdx.x; + + int s_tok_gl_rd = threadIdx.x; + // NOTE(HandH1998): activation scale s_tok need shuffle to [0, 8, 1, 9, 2, 10, + // 3, 11, 4, 12, 5, 13, 6, 14, 7, 15] for example, 0, 8 row scales serve for + // thread 0, 1, 2, 3. For more details, refer to mma operand A layout as + // s_tok's size is not fixed, we can not shuffle before inference we shuffle + // it when fetching s_tok from global memory to shared memory, that's why + // s_tok_sh_wr is like this + int s_tok_sh_wr = + (threadIdx.x / 16) * 16 + (threadIdx.x % 8) * 2 + (threadIdx.x % 16) / 8; + int s_tok_sh_rd = (threadIdx.x % 32) / 4; + bool s_tok_sh_wr_pred = threadIdx.x < prob_m; + + int s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x; + int s_ch_sh_wr = threadIdx.x; + int s_ch_sh_rd = 16 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + 2 * ((threadIdx.x % 32) % 4); + bool s_ch_sh_wr_pred = threadIdx.x < s_ch_sh_stride; + + int s_group_gl_rd, s_group_sh_wr, s_group_sh_rd; + bool s_group_sh_wr_pred; + if constexpr (group_blocks != -1) { + s_group_gl_rd = + s_group_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_group_sh_stride * slice_col + threadIdx.x; + s_group_sh_wr = threadIdx.x; + // NOTE(HandH1998): s_group_sh_rd is related to mma output C + s_group_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + s_group_sh_wr_pred = threadIdx.x < s_group_sh_stride; + } + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + // NOTE(HandH1998): stages need >= 4, otherwise, sh_s_tok = sh + max(stages * + // a_sh_stage + stages * b_sh_stage, 4 * stages * a_sh_stage) + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_s_tok = sh_b + (stages * b_sh_stage); + int4* sh_s_ch = sh_s_tok + s_tok_sh_stride; + int4* sh_s_group = sh_s_ch + s_ch_sh_stride; + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2]; + FragC frag_c[thread_m_blocks][4][2]; + FragS_GROUP frag_s_group[2][4]; + FragS_CHANNEL frag_s_tok[thread_m_blocks]; + FragS_CHANNEL frag_s_ch[2][4]; + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); + B_ptr[i] += b_gl_rd_delta_o; + } + // Only fetch scales if this tile starts a new group + if constexpr (group_blocks != -1) { + if (pipe % (group_blocks / thread_k_blocks) == 0) { + int4* sh_s_group_stage = sh_s_group + s_group_sh_stage * pipe; + if (s_group_sh_wr_pred) + cp_async4(&sh_s_group_stage[s_group_sh_wr], + &s_group[s_group_gl_rd]); + s_group_gl_rd += s_group_gl_rd_delta; + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + // It may seem inefficient that we reload the groups for every sub-tile; + // however, this does not seem to be a significant bottleneck, while some + // theoretically better attempts have lead to bad instruction ordering by + // the compiler and correspondingly a noticeable drop in performance. + if constexpr (group_blocks != -1) { + int4* sh_s_group_stage = + sh_s_group + + s_group_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s_group[k % 2])[0] = + sh_s_group_stage[s_group_sh_rd]; + } + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + frag_b_quant[k % 2] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + int b_quant = frag_b_quant[k % 2][j]; + // int b_quant_shift = b_quant << 4; + FragB frag_b0, frag_b1; + // If there are no groups, we can just scale the final output once and can + // avoid doing so for each weight. + if constexpr (group_blocks != -1) { + int b_quant_shift = b_quant >> 8; + frag_b0 = dequant_per_group(b_quant, frag_s_group[k % 2][j], 0); + frag_b1 = dequant_per_group(b_quant_shift, frag_s_group[k % 2][j], 1); + } else { + int b_quant_shift = b_quant << 4; + frag_b0 = dequant_per_channel(b_quant); + frag_b1 = dequant_per_channel(b_quant_shift); + } + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride; + constexpr int red_sh_stride = b_sh_stride * 4 * 2; + constexpr int red_sh_delta = b_sh_stride; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + + (threadIdx.x % b_sh_stride); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + int* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + int* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + int* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + // global_reduce works on INT32 elements, which are the results of INT8 GEMM. + // This is why we need another INT32 maxtrix `C` to reduce instead of the + // original half matrix `D`. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 4; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 8 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 8 * (threadIdx.x / 32) + (threadIdx.x % 4) * 2; + c_gl_wr += (4 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads * 2; + int c_sh_wr = 2 * threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + cp_async4_pred( + &sh[c_sh_wr + c_sh_wr_delta * i + 1], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2) + 1], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { + if (!first) { + int4 d_red1 = sh[c_sh_wr + i * c_sh_wr_delta]; + int4 d_red2 = sh[c_sh_wr + i * c_sh_wr_delta + 1]; + #pragma unroll + for (int j = 0; j < 4; j++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += + reinterpret_cast(&d_red1)[j]; + } + #pragma unroll + for (int j = 0; j < 4; j++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * (j + 4) + (i % 4)] += + reinterpret_cast(&d_red2)[j]; + } + } + if (!last) { + int4 d1, d2; + #pragma unroll + for (int j = 0; j < 4; j++) { + reinterpret_cast(&d1)[j] = reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]; + } + #pragma unroll + for (int j = 0; j < 4; j++) { + reinterpret_cast(&d2)[j] = reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * (j + 4) + (i % 4)]; + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = + d1; + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2) + + 1] = d2; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int d_gl_stride = prob_n / 8; + constexpr int d_sh_stride = 2 * thread_n_blocks + 1; + int d_gl_wr_delta = d_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int d_sh_rd_delta = + d_sh_stride * (threads / (2 * thread_n_blocks)); + + int d_gl_wr = d_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + d_gl_wr += (2 * thread_n_blocks) * slice_col; + int d_sh_wr = + (4 * d_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + d_sh_wr += 32 * (threadIdx.x / 32); + int d_sh_rd = d_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + int d_gl_wr_end = d_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, int c0, int c1, float a_s, FragS_CHANNEL& w_s) { + float2 deq_res; + deq_res.x = int32_to_float(c0) * w_s[0] * a_s; + deq_res.y = int32_to_float(c1) * w_s[1] * a_s; + ((half2*)sh)[idx] = float2_to_half2(deq_res); + }; + + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = d_sh_wr + 8 * j; + write(wr + (4 * d_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s_tok[i][0], + frag_s_ch[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * d_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s_tok[i][1], + frag_s_ch[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * d_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s_tok[i][0], + frag_s_ch[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * d_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s_tok[i][1], + frag_s_ch[j / 2][2 * (j % 2) + 1]); + } + d_sh_wr += 16 * (4 * d_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (d_gl_wr < d_gl_wr_end) { + D[d_gl_wr] = sh[d_sh_rd]; + d_gl_wr += d_gl_wr_delta; + d_sh_rd += d_sh_rd_delta; + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + #pragma unroll + for (int i = 0; i < stages - 1; i++) fetch_to_shared(i, i, i < slice_iters); + zero_accums(); + wait_for_stage(); + fetch_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + }; + start_pipes(); + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines have + // even length meaning that the next iteration will always start at index 0. + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) break; + } + a_gl_rd += a_gl_rd_delta_o * stages; + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if (last) { + if (s_tok_sh_wr_pred) { + cp_async1(&sh_s_tok[s_tok_sh_wr], &s_tok[s_tok_gl_rd]); + } + if (s_ch_sh_wr_pred) { + cp_async4(&sh_s_ch[s_ch_sh_wr], &s_ch[s_ch_gl_rd]); + } + cp_async_fence(); + } + thread_block_reduce(); + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + frag_s_tok[i][0] = + *reinterpret_cast(&sh_s_tok[16 * i + 2 * s_tok_sh_rd]); + frag_s_tok[i][1] = *reinterpret_cast( + &sh_s_tok[16 * i + 2 * s_tok_sh_rd + 1]); + } + reinterpret_cast(&frag_s_ch)[0] = sh_s_ch[s_ch_sh_rd + 0]; + reinterpret_cast(&frag_s_ch)[1] = sh_s_ch[s_ch_sh_rd + 1]; + reinterpret_cast(&frag_s_ch)[2] = sh_s_ch[s_ch_sh_rd + 8]; + reinterpret_cast(&frag_s_ch)[3] = sh_s_ch[s_ch_sh_rd + 9]; + } + } + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + s_group_gl_rd = s_group_sh_stride * slice_col + threadIdx.x; + s_ch_gl_rd = s_ch_sh_stride * slice_col + threadIdx.x; + start_pipes(); + } + } + } +} + +#else + +template shared + // fetch pipeline + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void Marlin_QQQ( + const int4* __restrict__ A, // int8 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // int32 global_reduce buffer of shape + // (max_par*16*4)xn, as int8 tensor core's output is + // int32 dtype + int4* __restrict__ D, // fp16 output buffer of shape mxn + const float* __restrict__ s_tok, // fp32 activation per-token quantization + // scales of shape mx1 + const int4* __restrict__ s_ch, // fp32 weight per-channel quantization + // scales of shape 1xn + const int4* __restrict__ s_group, // fp16 weight per-group quantization + // scales of shape (k/groupsize)xn, when + // group_blocks=-1, it should be nullptr + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks // extra global storage for barrier synchronization +) { + // Marlin is not implemented yet for SM < 8.0 + TORCH_CHECK_NOT_IMPLEMENTED( + false, "marlin_qqq_gemm(..) requires CUDA_ARCH >= 8.0"); + return; +} + +#endif + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +const int USER_THREADS = + 256; // Note: This is only used with user-provided thread_k/n +const int STAGES = 4; // 4 pipeline stages fit into shared memory + +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; + +static constexpr int tile_size = 16; +static constexpr int max_par = 16; + +static constexpr int pack_factor_4bit = + 8; // We have 8 4-bit vals inside a 32 bit + +#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + GROUP_BLOCKS, NUM_THREADS) \ + else if (thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute(Marlin_QQQ, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + max_shared_mem); \ + Marlin_QQQ \ + <<>>( \ + A_ptr, B_ptr, C_ptr, D_ptr, s_tok_ptr, s_ch_ptr, s_group_ptr, \ + prob_m, prob_n, prob_k, locks); \ + } + +typedef struct { + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, // Default + {128, 64, 128}, // Reduce N 2X, same K + {64, 256, 256}, // Reduce K 2X, increase N 2X + {64, 128, 128}, // Reduce K 2X, same N +}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, // Default + {128, 128, 256}, // Reduce N 2X, increase K 2X + {64, 128, 128}, // Reduce N 2X, same K + {128, 64, 128}, // Reduce N 4X, increase K 2X +}; + +bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, + int prob_k) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || + th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // thread_k can be only 128 or 64 (because it must be less than groupsize + // which is 128) + if (th_config.thread_k != 128 && th_config.thread_k != 64) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + return true; +} + +thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { + if (prob_m <= 16) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { + return th_config; + } + } + + } else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { + return th_config; + } + } + } + + return thread_config_t{-1, -1, -1}; +} + +#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) \ + __CALL_IF(4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \ + __CALL_IF(4, N_BLOCKS, K_BLOCKS, 8, NUM_THREADS) + +void marlin_qqq_cuda(const void* A, const void* B, void* C, void* D, + void* s_tok, void* s_ch, void* s_group, int prob_m, + int prob_n, int prob_k, void* workspace, + int groupsize = -1, int dev = 0, cudaStream_t stream = 0, + int thread_k = -1, int thread_n = -1, int sms = -1, + int max_par = 16) { + int tot_m = prob_m; + int tot_m_blocks = ceildiv(tot_m, 16); + int pad = 16 * tot_m_blocks - tot_m; + + if (sms == -1) + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + // Set thread config + thread_config_t th_config; + if (thread_k != -1 && thread_n != -1) { + // User-defined config + th_config = thread_config_t{thread_k, thread_n, USER_THREADS}; + } else { + // Auto config + th_config = determine_thread_config(prob_m, prob_n, prob_k); + } + + if (!is_valid_config(th_config, prob_m, prob_n, prob_k)) { + throw std::runtime_error( + "Invalid thread config: thread_k = " + str(th_config.thread_k) + + ", thread_n = " + str(th_config.thread_n) + + ", num_threads = " + str(th_config.num_threads) + " for MKN = [" + + str(prob_m) + ", " + str(prob_k) + ", " + str(prob_n) + "]"); + } + + int num_threads = th_config.num_threads; + thread_k = th_config.thread_k; + thread_n = th_config.thread_n; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + int group_blocks = (groupsize == -1) ? -1 : groupsize / 16; + int blocks = sms; + + if (prob_m == 0 || prob_n == 0 || prob_k == 0) { + return; + } + + TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, + " is not divisible by thread_n = ", thread_n); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + if (group_blocks != -1) { + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } + + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + int4* C_ptr = (int4*)C; + int4* D_ptr = (int4*)D; + const float* s_tok_ptr = (const float*)s_tok; + const int4* s_ch_ptr = (const int4*)s_ch; + const int4* s_group_ptr = (const int4*)s_group; + + int* locks = (int*)workspace; + + for (int i = 0; i < tot_m_blocks; i += 4) { + int thread_m_blocks = tot_m_blocks - i; + prob_m = tot_m - 16 * i; + int par = 1; + if (thread_m_blocks > 4) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * thread_m_blocks - pad) / 64; + if (par > max_par) par = max_par; + prob_m = 64 * par; + i += 4 * (par - 1); + thread_m_blocks = 4; + } + + // For compilation speed, we only define the kernel configurations that have + // seemed useful (in terms of performance) in our testing, however many more + // are, in principle, possible. + if (false) { + } + CALL_IF(8, 8, 256) + CALL_IF(16, 4, 256) + CALL_IF(8, 4, 128) + CALL_IF(4, 8, 128) + else { + throw std::runtime_error("Unsupported shapes: MKN = [" + str(prob_m) + + ", " + str(prob_k) + ", " + str(prob_n) + "]" + + ", groupsize = " + str(groupsize) + + ", thread_m_blocks = " + str(thread_m_blocks) + + ", thread_n_blocks = " + str(thread_n_blocks) + + ", thread_k_blocks = " + str(thread_k_blocks)); + } + + A_ptr += 16 * thread_m_blocks * (prob_k / 16) * par; + D_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; + s_tok_ptr += 16 * thread_m_blocks * par; + } +} + +torch::Tensor marlin_qqq_gemm(torch::Tensor const& a, + torch::Tensor const& b_q_weight, + torch::Tensor const& s_tok, + torch::Tensor const& s_ch, + torch::Tensor const& s_group, + torch::Tensor& workspace, int64_t size_m, + int64_t size_n, int64_t size_k) { + // Verify M + TORCH_CHECK(size_m == a.size(0), + "Shape mismatch: a.size(0) = " + str(a.size(0)) + + ", size_m = " + str(size_m)); + TORCH_CHECK(size_m == s_tok.numel(), + "Shape mismatch: s_tok.numel() = " + str(s_tok.numel()) + + ", size_m = " + str(size_m)); + + // Verify K + TORCH_CHECK(size_k == a.size(1), + "Shape mismatch: a.size(1) = " + str(a.size(1)) + + ", size_k = " + str(size_k)); + TORCH_CHECK(size_k % tile_size == 0, + "size_k = " + str(size_k) + + " is not divisible by tile_size = " + str(tile_size)); + TORCH_CHECK( + (size_k / tile_size) == b_q_weight.size(0), + "Shape mismatch: b_q_weight.size(0) = " + str(b_q_weight.size(0)) + + ", size_k = " + str(size_k) + ", tile_size = " + str(tile_size)); + + int groupsize = (s_group.numel() == 0) ? -1 : size_k / s_group.size(0); + // Verify groupsize + TORCH_CHECK(groupsize == -1 || groupsize == 128, + "Unexpected groupsize = " + str(groupsize)); + + // Verify N + TORCH_CHECK(s_ch.numel() == size_n, + "Shape mismatch: s_ch.numel() = " + str(s_ch.numel()) + + ", size_n = " + str(size_n)); + TORCH_CHECK(b_q_weight.size(1) % tile_size == 0, + "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + + " is not divisible by tile_size = " + str(tile_size)); + if (groupsize != -1) { + TORCH_CHECK(s_group.size(1) == size_n, + "Shape mismatch: s_group.size(1) = " + str(s_group.size(1)) + + ", size_n = " + str(size_n)); + TORCH_CHECK( + size_k % s_group.size(0) == 0, + "size_k = " + str(size_k) + + ", is not divisible by s_group.size(0) = " + str(s_group.size(0))); + } + + int actual_size_n = (b_q_weight.size(1) / tile_size) * pack_factor_4bit; + TORCH_CHECK(size_n == actual_size_n, + "Shape mismatch: size_n = " + str(size_n) + + ", actual_size_n = " + str(actual_size_n)); + + // Verify A device and strides + TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); + TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + + // Verify B device and strides + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + + // Verify s_tok device, strides and dtype + TORCH_CHECK(s_tok.device().is_cuda(), "s_tok is not on GPU"); + TORCH_CHECK(s_tok.is_contiguous(), "s_tok is not contiguous"); + TORCH_CHECK(s_tok.dtype() == torch::kFloat32, "s_tok's dtype is not float32"); + + // Verify s_ch device, strides and dtype + TORCH_CHECK(s_ch.device().is_cuda(), "s_ch is not on GPU"); + TORCH_CHECK(s_ch.is_contiguous(), "s_ch is not contiguous"); + TORCH_CHECK(s_ch.dtype() == torch::kFloat32, "s_ch's dtype is not float32"); + + // Verify s_group device, strides and dtype + TORCH_CHECK(s_group.device().is_cuda(), "s_group is not on GPU"); + TORCH_CHECK(s_group.is_contiguous(), "s_group is not contiguous"); + TORCH_CHECK(s_group.dtype() == torch::kFloat16, + "s_group's dtype is not float16"); + + // Verify workspace size + TORCH_CHECK(size_n % min_thread_n == 0, + "size_n = " + str(size_n) + + ", is not divisible by min_thread_n = " + str(min_thread_n)); + int min_workspace_size = (size_n / min_thread_n) * max_par; + TORCH_CHECK(workspace.numel() >= min_workspace_size, + "workspace.numel = " + str(workspace.numel()) + + " is below min_workspace_size = " + str(min_workspace_size)); + + // Alloc C matrix + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options_c = torch::TensorOptions().dtype(torch::kInt).device(a.device()); + torch::Tensor c = torch::empty({max_par * 64, size_n}, options_c); + + // Alloc D matrix + auto options_d = + torch::TensorOptions().dtype(torch::kFloat16).device(a.device()); + torch::Tensor d = torch::empty({size_m, size_n}, options_d); + + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_k = -1; + // thread_n: `n` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_n = -1; + // sms: number of SMs to use for the kernel (can usually be left as auto -1) + int sms = -1; + + int dev = a.get_device(); + marlin_qqq_cuda( + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), d.data_ptr(), + s_tok.data_ptr(), s_ch.data_ptr(), s_group.data_ptr(), size_m, size_n, + size_k, workspace.data_ptr(), groupsize, dev, + at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par); + + return d; +} + +TORCH_LIBRARY_IMPL(torchao, CUDA, m) { + m.impl("torchao::marlin_qqq_gemm", &marlin_qqq_gemm); +} + +} // namespace torchao diff --git a/torchao/csrc/cuda/marlin_qqq/mem.h b/torchao/csrc/cuda/marlin_qqq/mem.h new file mode 100644 index 0000000000..db88e2bc40 --- /dev/null +++ b/torchao/csrc/cuda/marlin_qqq/mem.h @@ -0,0 +1,91 @@ +/* + * Modified by HandH1998 + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace torchao { +// Predicated asynchronous global->shared copy; used for inputs A where we apply +// predication to handle batchsizes that are not multiples of 16. +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +// Asynchronous global->shared copy +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +// Async copy fence. +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +// Wait until at most `n` async copy stages are still pending. +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} +} // namespace torchao diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index a41fd83408..be1708be98 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -3,31 +3,26 @@ # from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor from .affine_quantized_tensor import ( AffineQuantizedTensor, + MarlinQQQTensor, to_affine_quantized_floatx, to_affine_quantized_floatx_static, # experimental, will be merged into floatx in the future to_affine_quantized_fpx, to_affine_quantized_intx, to_affine_quantized_intx_static, + to_marlinqqq_quantized_intx, ) from .floatx import ( - Float8AQTTensorImpl, Float8Layout, ) from .nf4tensor import NF4Tensor, to_nf4 from .uintx import ( - _BIT_WIDTH_TO_DTYPE, - _DTYPE_TO_BIT_WIDTH, BlockSparseLayout, + MarlinQQQLayout, MarlinSparseLayout, - PlainAQTTensorImpl, SemiSparseLayout, TensorCoreTiledLayout, - UInt4Tensor, - UintxAQTTensorImpl, UintxLayout, - UintxTensor, - to_uintx, ) from .utils import ( Layout, @@ -39,29 +34,22 @@ __all__ = [ "NF4Tensor", "to_nf4", - "UInt4Tensor", "AffineQuantizedTensor", "to_affine_quantized_intx", "to_affine_quantized_intx_static", "to_affine_quantized_fpx", "to_affine_quantized_floatx", "to_affine_quantized_floatx_static", + "to_marlinqqq_quantized_intx", "Layout", "PlainLayout", "SemiSparseLayout", "TensorCoreTiledLayout", "Float8Layout", - "Float8AQTTensorImpl", "MarlinSparseLayout", - "PlainAQTTensorImpl", "affine_quantized_tensor_ops", "BlockSparseLayout", - "to_uintx", - "UintxTensor", "UintxLayout", - "UintxAQTTensorImpl", - "_DTYPE_TO_BIT_WIDTH", - "_BIT_WIDTH_TO_DTYPE", - "Uint4Tensor", - "PlainAQTTensorImpl", + "MarlinQQQTensor", + "MarlinQQQLayout", ] diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 4bbb87ecee..6d39aaf4fe 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -16,8 +16,10 @@ choose_qparams_affine, choose_qparams_affine_floatx, choose_qparams_and_quantize_affine_hqq, + choose_qparams_and_quantize_affine_qqq, dequantize_affine, dequantize_affine_floatx, + dequantize_affine_qqq, quantize_affine, quantize_affine_floatx, ) @@ -445,22 +447,70 @@ def _apply_fn_to_data(self, fn): # 2 - we're given non-floats - quantizing long to int8 is crazy +class MarlinQQQTensor(AffineQuantizedTensor): + """ + MarlinQQQ quantized tensor subclass which inherits AffineQuantizedTensor class. + + To see what happens during choose_qparams_and_quantize_affine_qqq, quantization and dequantization for marlin qqq quantization, + please checkout https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_primitives.py + and check the two quant primitive ops: choose_qparams_and_quantize_affine_qqq and dequantize_affine_qqq + """ + + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + if output_dtype is None: + output_dtype = self.dtype + + int_data, s_group, s_channel = self.tensor_impl.get_plain() + nbits = int(math.log2(self.quant_max - self.quant_min + 1)) + group_size = max(self.block_size) + return dequantize_affine_qqq( + int_data, s_group, s_channel, nbits, group_size, output_dtype + ) + + @classmethod + def from_hp_to_intx( + cls, + input_float: torch.Tensor, + block_size: Tuple[int, ...], + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT, + _layout: Optional[Layout] = None, + ): + original_shape = input_float.shape + input_float = _layout.pre_process(input_float) + nbits = int(math.log2(quant_max - quant_min + 1)) + group_size = max(block_size) + data, s_group, s_channel, _ = choose_qparams_and_quantize_affine_qqq( + input_float, nbits, group_size + ) + data = _layout.post_process(data) + tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) + tensor_impl = tensor_impl_ctr(data, s_group, s_channel, _layout) + return cls( + tensor_impl, + block_size, + original_shape, + quant_min, + quant_max, + zero_point_domain, + dtype=input_float.dtype, + ) + + ###################################################### # Layout and TensorImpl Subclass Registration # ###################################################### register_layout = AffineQuantizedTensor.register_layout get_tensor_impl_constructor = AffineQuantizedTensor.get_tensor_impl_constructor -##################################################### -# torch functional and aten operator implementation # -##################################################### - to_affine_quantized_intx = AffineQuantizedTensor.from_hp_to_intx to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx to_affine_quantized_floatx_static = AffineQuantizedTensor.from_hp_to_floatx_static # experimental will be merged in to floatx to_affine_quantized_fpx = AffineQuantizedTensor.from_hp_to_fpx +to_marlinqqq_quantized_intx = MarlinQQQTensor.from_hp_to_intx if TORCH_VERSION_AT_LEAST_2_5: # Allow a model with AffineQuantizedTensor weights to be loaded with `weights_only=True` diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index fe7644d922..c4c1e0ca37 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -23,6 +23,10 @@ _linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl, ) +from torchao.dtypes.uintx.marlin_qqq_layout import ( + _linear_int8_act_int4_weight_marlin_qqq_check, + _linear_int8_act_int4_weight_marlin_qqq_impl, +) from torchao.dtypes.uintx.marlin_sparse_layout import ( _linear_fp_act_int4_weight_sparse_marlin_check, _linear_fp_act_int4_weight_sparse_marlin_impl, @@ -129,6 +133,10 @@ def _register_aqt_quantized_linear_dispatches(): _linear_fp_act_int4_weight_sparse_marlin_check, _linear_fp_act_int4_weight_sparse_marlin_impl, ), + ( + _linear_int8_act_int4_weight_marlin_qqq_check, + _linear_int8_act_int4_weight_marlin_qqq_impl, + ), ]: register_aqt_quantized_linear_dispatch(dispatch_condition, impl) diff --git a/torchao/dtypes/floatx/__init__.py b/torchao/dtypes/floatx/__init__.py index 6e22186d7f..3f0a1ccd5c 100644 --- a/torchao/dtypes/floatx/__init__.py +++ b/torchao/dtypes/floatx/__init__.py @@ -1,7 +1,5 @@ -from .float8_layout import Float8AQTTensorImpl, Float8Layout +from .float8_layout import Float8Layout from .floatx_tensor_core_layout import ( - _SPLIT_K_MAP, - FloatxTensorCoreAQTTensorImpl, FloatxTensorCoreLayout, from_scaled_tc_floatx, to_scaled_tc_floatx, @@ -9,10 +7,7 @@ __all__ = [ "FloatxTensorCoreLayout", - "FloatxTensorCoreAQTTensorImpl", "to_scaled_tc_floatx", "from_scaled_tc_floatx", - "_SPLIT_K_MAP", - "Float8AQTTensorImpl", "Float8Layout", ] diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index 1d0d22c0d4..a6059f93a3 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -1,41 +1,27 @@ from .block_sparse_layout import ( BlockSparseLayout, ) +from .marlin_qqq_layout import ( + MarlinQQQLayout, +) from .marlin_sparse_layout import ( MarlinSparseLayout, ) -from .plain_layout import ( - PlainAQTTensorImpl, -) from .semi_sparse_layout import ( SemiSparseLayout, ) from .tensor_core_tiled_layout import ( TensorCoreTiledLayout, ) -from .uint4_layout import ( - UInt4Tensor, -) from .uintx_layout import ( - _BIT_WIDTH_TO_DTYPE, - _DTYPE_TO_BIT_WIDTH, - UintxAQTTensorImpl, UintxLayout, - UintxTensor, - to_uintx, ) __all__ = [ - "UintxTensor", "UintxLayout", - "UintxAQTTensorImpl", - "to_uintx", - "_DTYPE_TO_BIT_WIDTH", - "_BIT_WIDTH_TO_DTYPE", - "UInt4Tensor", - "PlainAQTTensorImpl", "BlockSparseLayout", "MarlinSparseLayout", "SemiSparseLayout", "TensorCoreTiledLayout", + "MarlinQQQLayout", ] diff --git a/torchao/dtypes/uintx/marlin_qqq_layout.py b/torchao/dtypes/uintx/marlin_qqq_layout.py new file mode 100644 index 0000000000..c3b2a78394 --- /dev/null +++ b/torchao/dtypes/uintx/marlin_qqq_layout.py @@ -0,0 +1,281 @@ +import logging +from dataclasses import dataclass + +import torch +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.uintx.plain_layout import ( + _aqt_is_int8_reduced_range, +) +from torchao.dtypes.utils import AQTTensorImpl, Layout + +logger = logging.getLogger(__name__) + +aten = torch.ops.aten + + +@dataclass(frozen=True) +class MarlinQQQLayout(Layout): + pass + + +@register_layout(MarlinQQQLayout) +class MarlinQQQAQTTensorImpl(AQTTensorImpl): + """ + TensorImpl storage class for sparse_qqq layout for affine quantized tensor. + + Can only be used with 4 bits quantization for now. + + Original marlin documentation and information: + https://github.com/IST-DASLab/marlin/tree/master + + Marlin qqq information: + https://github.com/HandH1998/QQQ/tree/main + https://arxiv.org/pdf/2406.09904 + + fields: + original_shape (torch.Size): the original shape of the tensor. used to unpack the tensor to the original shape + group_size (int): the group size used to pack the tensor + num_bits (int): the number of bits used to quantize the tensor + """ + + @staticmethod + def __new__( + cls, + int_data: torch.Tensor, + s_group: torch.Tensor, + s_channel: torch.Tensor, + _layout: Layout, + original_shape: torch.Size, + group_size: int, + num_bits: int, + ): + kwargs = {} + kwargs["device"] = int_data.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout + ) + kwargs["dtype"] = int_data.dtype + kwargs["requires_grad"] = False + shape = int_data.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + int_data: torch.Tensor, + s_group: torch.Tensor, + s_channel: torch.Tensor, + _layout: Layout, + original_shape: torch.Size, + group_size: int, + num_bits: int, + ): + self.int_data = int_data + self.s_group = s_group + self.s_channel = s_channel + self._layout = _layout + self.original_shape = original_shape + self.group_size = group_size + self.num_bits = num_bits + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + raise NotImplementedError( + f"MarlinQQQAQTTensorImpl dispatch: attempting to run {func}, this is not supported" + ) + + def __tensor_flatten__(self): + return ["int_data", "s_group", "s_channel"], [ + self._layout, + self.original_shape, + self.group_size, + self.num_bits, + ] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + int_data = tensor_data_dict["int_data"] + s_group = tensor_data_dict["s_group"] + s_channel = tensor_data_dict["s_channel"] + _layout, original_shape, group_size, num_bits = tensor_attributes + return cls( + int_data, s_group, s_channel, _layout, original_shape, group_size, num_bits + ) + + def get_plain(self): + from torchao.quantization.marlin_qqq import ( + unpack_from_marlin_qqq, + ) # avoid circular import + + int_data_expanded, s_group_expanded, s_channel_expanded = ( + unpack_from_marlin_qqq( + self.int_data, + self.s_group, + self.s_channel, + self.original_shape, + self.num_bits, + self.group_size, + ) + ) + int_data_expanded_t = int_data_expanded.t() + s_group_expanded_t = s_group_expanded.t() + s_channel_expanded_t = s_channel_expanded.t() + return int_data_expanded_t, s_group_expanded_t, s_channel_expanded_t + + @classmethod + def from_plain( + cls, + int_data: torch.Tensor, + s_group: torch.Tensor, + s_channel: torch.Tensor, + _layout: Layout, + ): + from torchao.quantization.marlin_qqq import ( + const, + pack_to_marlin_qqq, + ) # avoid circular import + + assert isinstance(_layout, MarlinQQQLayout) + + # Linear layers are (in_features, out_features) but the int_data that is reaching this point + # is (out_features, in_features). We need to transpose it to match the expected shape in the marlin code. + q_w = int_data.t() + s_group_t = s_group.t() + s_channel_t = s_channel.t() + + if not torch.cuda.get_device_capability()[0] >= 8: + raise ValueError( + f"Can not use Marlin QQQ int4*int8 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel." + ) + + if q_w.dtype != torch.int32: + raise ValueError("Only `torch.int32` weights are supported.") + + in_features, out_features = q_w.shape + # (thread_k, thread_n) + thread_config = [(64, 256), (128, 128), (128, 64), (64, 128)] + if not any( + [ + in_features % thread_k == 0 and out_features % thread_n == 0 + for thread_k, thread_n in thread_config + ] + ): + raise ValueError( + "Not supported `in_features`: {} and `out_features`: {}.".format( + in_features, out_features + ) + ) + + num_bits = 4 if torch.max(q_w) - torch.min(q_w) < 16 else -1 + if num_bits not in [4]: + raise ValueError(f"Only {[4]} bits are supported, got {num_bits}.") + + if s_group.numel() == 0: + group_size = -1 + else: + group_size = in_features // s_group_t.shape[0] + assert ( + group_size <= in_features + ), "Group size must be less than or equal to in_features." + + if group_size not in const.SUPPORTED_GROUP_SIZES: + raise ValueError( + f"Only {const.SUPPORTED_GROUP_SIZES} group sizes are supported, got {group_size}." + ) + + # Compress quantized weight to marlin format + marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = pack_to_marlin_qqq( + q_w, s_group_t, s_channel_t, num_bits, group_size + ) + + return cls( + marlin_qqq_q_w, + marlin_qqq_s_group, + marlin_qqq_s_channel, + _layout, + q_w.shape, + group_size, + num_bits, + ) + + def get_layout(self) -> Layout: + return self._layout + + def _apply_fn_to_data(self, fn): + self.int_data = fn(self.int_data) + self.s_group = fn(self.s_group) + self.s_channel = fn(self.s_channel) + return self + + +def _linear_int8_act_int4_weight_marlin_qqq_check(input_tensor, weight_tensor, bias): + return ( + isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_int8_reduced_range(input_tensor) + and input_tensor.dtype == torch.float16 + and input_tensor.tensor_impl.scale.dtype == torch.float32 + and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 + and isinstance(weight_tensor, AffineQuantizedTensor) + and weight_tensor.tensor_impl.dtype == torch.int32 + and len(weight_tensor.shape) == 2 + and isinstance(weight_tensor._layout, MarlinQQQLayout) + ) + + +def _linear_int8_act_int4_weight_marlin_qqq_impl(input_tensor, weight_tensor, bias): + from torchao.ops import marlin_qqq_gemm + from torchao.quantization.marlin_qqq import marlin_qqq_workspace + + assert isinstance(input_tensor, AffineQuantizedTensor) + assert isinstance(weight_tensor, AffineQuantizedTensor) + + input = input_tensor.tensor_impl.int_data + input_scale = input_tensor.tensor_impl.scale + + w_int4 = weight_tensor.tensor_impl.int_data + s_group = weight_tensor.tensor_impl.s_group + s_channel = weight_tensor.tensor_impl.s_channel + original_shape = weight_tensor.tensor_impl.original_shape + + # Folds batch dimension into the first dimension + input_2d = input.view(-1, input.shape[-1]) + input_scale = input_scale.view(1, -1) + + size_m = input_2d.shape[0] + size_n = s_channel.shape[1] + size_k = input_2d.shape[1] + workspace_qqq = marlin_qqq_workspace(original_shape[1]) + + out = marlin_qqq_gemm( + input_2d, + w_int4, + input_scale, + s_channel, + s_group, + workspace_qqq, + size_m, + size_n, + size_k, + ) + + # Unfold the batch dimension + out = out.reshape(input.shape[:-1] + (s_channel.shape[1],)) + + if bias is not None: + out += bias.to(out.dtype) + return out diff --git a/torchao/ops.py b/torchao/ops.py index fa8ad7fe89..9713f68eb2 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -9,6 +9,7 @@ lib.define("unpack_tensor_core_tiled_layout(Tensor packed_w, int inner_k_tiles) -> Tensor") lib.define("dequantize_tensor_core_tiled_layout(Tensor packed_w, Tensor scales_and_zeros, int group_size, int inner_k_tiles) -> Tensor") lib.define("marlin_24_gemm(Tensor x, Tensor weight_marlin, Tensor meta, Tensor s, Tensor workspace, int bits, int size_m, int size_n, int size_k) -> Tensor") +lib.define("marlin_qqq_gemm(Tensor x, Tensor weight_marlin, Tensor s_tok, Tensor s_ch, Tensor s_group, Tensor workspace, int size_m, int size_n, int size_k) -> Tensor") def register_custom_op(name): @@ -275,3 +276,145 @@ def _( torch._check(workspace.numel() >= min_workspace_size, lambda: f"workspace.numel = {workspace.numel()} is below min_workspace_size = {min_workspace_size}") return torch.empty((x.size(0), s.size(1)), dtype=x.dtype, device=x.device) + + +def marlin_qqq_gemm( + x: Tensor, + weight_marlin: Tensor, + s_tok: Tensor, + s_ch: Tensor, + s_group: Tensor, + workspace: Tensor, + size_m: int, + size_n: int, + size_k: int, +) -> Tensor: + """ + Marlin for W4A8 mixed precision matrix multiplication. + See https://arxiv.org/pdf/2406.09904 for more details. + Reference: https://github.com/HandH1998/QQQ/tree/main + Args: + x: `torch.int8` input matrix of shape `(m, k)` in standard row-major layout. + weight_marlin: `torch.int32` weight matrix of original shape `(k, n)` in the specified format. + s_tok: `torch.float32` activation per-token quantization scales of shape `(m, 1)`. + s_ch: `torch.float32` weight per-channel quantization scales of shape `(1, n)`. + s_group: `torch.half` weight per-group quantization scales of shape `(m / groupsize, n)`, it should be empty when group_size != -1. + workspace: `torch.int32` tensor with at least `n / 128 * max_par` entries that are all zero. + size_m: number of rows in input matrix. + size_n: number of columns in weight matrix. + size_k: number of columns in input matrix. + Returns: + `torch.half` out matrix of shape `(m, n)` in standard row-major layout. + """ + return torch.ops.torchao.marlin_qqq_gemm.default( + x, weight_marlin, s_tok, s_ch, s_group, workspace, size_m, size_n, size_k + ) + + +@register_custom_op("torchao::marlin_qqq_gemm") +def _( + x: Tensor, + weight_marlin: Tensor, + s_tok: Tensor, + s_ch: Tensor, + s_group: Tensor, + workspace: Tensor, + size_m: int, + size_n: int, + size_k: int, +) -> Tensor: + TILE_SIZE = 16 + MIN_THREAD_N = 64 + MAX_PARALLELISM = 16 + PACK_FACTOR = 32 // 4 + + # Verify M + torch._check( + size_m == x.size(0), + lambda: f"Shape mismatch: x.size(0) = {x.size(0)}, size_m = {size_m}", + ) + torch._check( + size_m == s_tok.numel(), + lambda: f"Shape mismatch: s_tok.numel() = {s_tok.numel()}, size_m = {size_m}", + ) + + # Verify K + torch._check( + size_k == x.size(1), + lambda: f"Shape mismatch: x.size(1) = {x.size(1)}, size_k = {size_k}", + ) + torch._check( + size_k % TILE_SIZE == 0, + lambda: f"size_k = {size_k} is not divisible by tile_size = {TILE_SIZE}", + ) + torch._check( + (size_k // TILE_SIZE) == weight_marlin.size(0), + lambda: f"Shape mismatch: weight_marlin.size(0) = {weight_marlin.size(0)}, size_k = {size_k}, tile_size = {TILE_SIZE}", + ) + + # Verify groupsize + groupsize = -1 if s_group.numel() == 0 else size_k // s_group.size(0) + torch._check(groupsize in [-1, 128], lambda: f"Unexpected groupsize = {groupsize}") + + # Verify N + torch._check( + s_ch.numel() == size_n, + lambda: f"Shape mismatch: s_ch.numel() = {s_ch.numel()}, size_n = {size_n}", + ) + torch._check( + weight_marlin.size(1) % TILE_SIZE == 0, + lambda: f"weight_marlin.size(1) = {weight_marlin.size(1)} is not divisible by tile_size = {TILE_SIZE}", + ) + if groupsize != -1: + torch._check( + s_group.size(1) == size_n, + lambda: f"Shape mismatch: s_group.size(1) = {s_group.size(1)}, size_n = {size_n}", + ) + torch._check( + size_k % s_group.size(0) == 0, + lambda: f"size_k = {size_k} is not divisible by s_group.size(0) = {s_group.size(0)}", + ) + + actual_size_n = (weight_marlin.size(1) // TILE_SIZE) * PACK_FACTOR + torch._check( + size_n == actual_size_n, + lambda: f"Shape mismatch: size_n = {size_n}, actual_size_n = {actual_size_n}", + ) + + # Verify A device and strides + torch._check(x.is_cuda, lambda: "x is not on GPU") + torch._check(x.is_contiguous(), lambda: "x is not contiguous") + + # Verify B device and strides + torch._check(weight_marlin.is_cuda, lambda: "weight_marlin is not on GPU") + torch._check( + weight_marlin.is_contiguous(), lambda: "weight_marlin is not contiguous" + ) + + # Verify s_tok device, strides and dtype + torch._check(s_tok.is_cuda, lambda: "s_tok is not on GPU") + torch._check(s_tok.is_contiguous(), lambda: "s_tok is not contiguous") + torch._check(s_tok.dtype == torch.float32, lambda: "s_tok's dtype is not float32") + + # Verify s_ch device, strides and dtype + torch._check(s_ch.is_cuda, lambda: "s_ch is not on GPU") + torch._check(s_ch.is_contiguous(), lambda: "s_ch is not contiguous") + torch._check(s_ch.dtype == torch.float32, lambda: "s_ch's dtype is not float32") + + # Verify s_group device, strides and dtype + torch._check(s_group.is_cuda, lambda: "s_group is not on GPU") + torch._check(s_group.is_contiguous(), lambda: "s_group is not contiguous") + torch._check(s_group.dtype == torch.float16, "s_group's dtype is not float16") + + # Verify workspace size + torch._check( + size_n % MIN_THREAD_N == 0, + lambda: f"size_n = {size_n} is not divisible by min_thread_n = {MIN_THREAD_N}", + ) + min_workspace_size = (size_n // MIN_THREAD_N) * MAX_PARALLELISM + torch._check( + workspace.numel() >= min_workspace_size, + lambda: f"workspace.numel() = {workspace.numel()} is below min_workspace_size = {min_workspace_size}", + ) + + return torch.empty((size_m, size_n), dtype=torch.float16, device=x.device) diff --git a/torchao/prototype/awq/api.py b/torchao/prototype/awq/api.py index 0a26ab98d3..d0f3ebc0d6 100644 --- a/torchao/prototype/awq/api.py +++ b/torchao/prototype/awq/api.py @@ -9,7 +9,7 @@ ) from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter -from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayout +from torchao.dtypes.uintx.uintx_layout import _DTYPE_TO_BIT_WIDTH, UintxLayout from torchao.dtypes import( to_affine_quantized_intx, TensorCoreTiledLayout, diff --git a/torchao/prototype/awq/core.py b/torchao/prototype/awq/core.py index 034d73639e..89f615e9ea 100644 --- a/torchao/prototype/awq/core.py +++ b/torchao/prototype/awq/core.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayout +from torchao.dtypes.uintx.uintx_layout import _DTYPE_TO_BIT_WIDTH, UintxLayout from torchao.dtypes import to_affine_quantized_intx from torchao.quantization.granularity import Granularity from torchao.quantization.quant_primitives import ( diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 074a916d49..90e898debd 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -309,6 +309,16 @@ Sparse-Marlin 2:4 is an optimized GPU kernel that extends the Mixed Auto-Regress More details can be found [here](../sparsity/README.md) +### Marlin QQQ + +Marlin QQQ is an optimized GPU kernel that supports W4A8 mixed precision GEMM. For more details about Marlin QQQ, please refer to [paper](https://arxiv.org/pdf/2406.09904). + +| Model | Technique | Tokens/Second | Memory Bandwidth (GB/s) | Peak Memory (GB) | Model Size (GB) | +| ----------- | ----------------------- | ------------- | ----------------------- | ---------------- | --------------- | +| Llama-2-7B | Base (float16) | 112.45 | 1486.00 | 13.93 | 13.21 | +| | w4a8 | 197.45 | 653.50 | 4.79 | 3.31 | +| | w4a8-g128 | 187.62 | 640.32 | 4.82 | 3.41 | + ### UINTx Quantization We're trying to develop kernels for low bit quantization for intx quantization formats. While the current performance is not ideal, we're hoping to continue to iterate on these kernels to improve their performance. diff --git a/torchao/quantization/marlin_qqq/README.md b/torchao/quantization/marlin_qqq/README.md new file mode 100644 index 0000000000..f3bbecaa2a --- /dev/null +++ b/torchao/quantization/marlin_qqq/README.md @@ -0,0 +1,14 @@ +# Marlin QQQ + +Marlin QQQ kernel is now compatible with GPUs for sm80 and above. +Marlin QQQ kernel and Marlin kernel mainly have the following differences: +1. Marlin QQQ kernel supports W4A8 mixed precision GEMM using INT8 Tensor Core, while the original Marlin kernel supports W4A16 mixed precision GEMM using FP16 Tensor Core. +2. Because the mma instruction requires that the data types of weight and activation be consistent, type conversion is required. Marlin QQQ needs to convert INT4 weight to INT8, while Marlin needs to convert INT4 weight to FP16. +3. Similar to W8A8, Marlin QQQ needs to dequant to FP16 before writing the final result because the calculation result is accumulated in INT32, while Marlin does not need this processing. + +For more details about Marlin QQQ, please refer to [paper](https://arxiv.org/pdf/2406.09904). + +Marlin QQQ implementation adapted from the two below sources: + +* [QQQ](https://github.com/HandH1998/QQQ/tree/main) +* [vllm](https://github.com/vllm-project/vllm/tree/main) diff --git a/torchao/quantization/marlin_qqq/__init__.py b/torchao/quantization/marlin_qqq/__init__.py new file mode 100644 index 0000000000..2cf98cc378 --- /dev/null +++ b/torchao/quantization/marlin_qqq/__init__.py @@ -0,0 +1,282 @@ +from typing import Tuple + +import torch + +from torchao.quantization.granularity import ( + PerAxis, + PerGroup, +) +from torchao.quantization.marlin_qqq.utils import ( + const, + get_pack_factor, + get_qqq_scale_perms, + get_qqq_scale_reverse_perms, + get_qqq_weight_perm, + get_qqq_weight_reverse_perm, + marlin_permute_weights, + reverse_marlin_permute_weights, +) + +__all__ = [ + "marlin_qqq_workspace", + "pack_to_marlin_qqq", + "unpack_from_marlin_qqq", +] + + +def marlin_qqq_workspace( + out_features: int, + min_thread_n: int = const.MIN_THREAD_N, + max_parallel: int = const.MAX_PARALLEL, +) -> torch.Tensor: + """Creates a workspace for marlin qqq. The workspace is used to coordinate the locks + during the execution of the kernel. + + Args: + out_features (int): The number of output features. + min_thread_n (int, optional): The minimum number of threads per block. Defaults to `MARLIN_QQQ_MIN_THREAD_N`. + max_parallel (int, optional): The maximum number of parallel threads. Defaults to `MARLIN_QQQ_MAX_PARALLEL`. + Returns: + torch.Tensor: The workspace tensor fully initialized with zeros. + """ + assert ( + out_features % min_thread_n == 0 + ), f"out_features = {out_features}, min_thread_n = {min_thread_n}" + max_workspace_size = (out_features // min_thread_n) * max_parallel + return torch.zeros(max_workspace_size, dtype=torch.int, device="cuda") + + +def pack_to_marlin_qqq( + q_w: torch.Tensor, + s_group: torch.Tensor, + s_channel: torch.Tensor, + num_bits: int, + group_size: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Pack the quantized weights and scales to the marlin format. + + Args: + q_w (torch.Tensor): The quantized weight. + s_group (torch.Tensor): The per-group quantization scale. + s_channel (torch.Tensor): The per-channel quantization scale. + num_bits (int): The number of bits used for quantization. + group_size (int): The group size of quantization. + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The packed quantized weight in marlin format, the packed per-group scale in marlin format, and the packed per-channel scale in marlin format. + """ + in_features, out_features = q_w.shape + + assert num_bits == 4, "Marlin QQQ only supports 4-bit for now." + + # Reformat to marlin + marlin_qqq_q_w = _to_marlin_weights( + q_w, in_features, out_features, num_bits, group_size + ) + marlin_qqq_s_group, marlin_qqq_s_channel = _to_marlin_scales( + s_group, s_channel, in_features, out_features, num_bits, group_size + ) + + return marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel + + +def _to_marlin_weights( + q_w: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, + group_size: int, +) -> torch.Tensor: + """Converts a quantized weight tensor to the marlin format. + + Args: + q_w (torch.Tensor): The quantized weight. + size_k (int): The number of input features. + size_n (int): The number of output features. + num_bits (int): The number of bits used for quantization. + group_size (int): The group size of quantization. + Returns: + torch.Tensor: The packed quantized weight in marlin format. + """ + if group_size == -1: + group_size = size_k + granularity = PerAxis(1) if group_size == size_k else PerGroup(group_size) + # Permute + perm = get_qqq_weight_perm(num_bits, granularity) + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + # q_w is torch.uint32 originally, but torch does not support lshift_cuda or lshift_cpu, we have to + # convert it to torch.int64 + q_w = q_w.to(torch.int64) + q_packed = torch.zeros( + (q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=torch.int64, + device=q_w.device, + ) + + if group_size == size_k: + for i in range(pack_factor): + q_packed |= (q_w[:, i::pack_factor] & 0xF) << (num_bits * i) + else: + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << (num_bits * i) + + q_packed = q_packed.to(torch.int32).to(orig_device) + return q_packed + + +def _to_marlin_scales( + s_group: torch.Tensor, + s_channel: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, + group_size: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Converts the per-group scale and the per-channel scale to the format necessary for marlin. + + Args: + s_group (torch.Tensor): The per-group quantization scale. + s_channel (torch.Tensor): The per-channel quantization scale. + size_k (int): The number of input features. + size_n (int): The number of output features. + num_bits (int): The number of bits used for quantization. + group_size (int): The group size of quantization. + Returns: + Tuple[torch.Tensor, torch.Tensor]: The scale tensors in the marlin format. + """ + if group_size == -1: + group_size = size_k + scale_perm, scale_perm_single = get_qqq_scale_perms(num_bits) + if group_size < size_k: + s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] + s_channel = s_channel.reshape((-1, len(scale_perm_single)))[ + :, scale_perm_single + ] + s_group = s_group.reshape((-1, size_n)).contiguous() + else: + s_channel = s_channel.reshape((-1, len(scale_perm_single)))[ + :, scale_perm_single + ] + s_channel = s_channel.reshape((-1, size_n)).contiguous() + + return s_group, s_channel + + +def unpack_from_marlin_qqq( + q_w: torch.Tensor, + s_group: torch.Tensor, + s_channel: torch.Tensor, + original_shape: torch.Size, + num_bits: int, + group_size: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Unpacks the quantized weights and scales from the marlin format. + Args: + q_w (torch.Tensor): The packed quantized weights. + s_group (torch.Tensor): The per-group quantization scale. + s_channel (torch.Tensor): The per-channel quantization scale. + original_shape (torch.Size): The original shape of the weight tensor. + num_bits (int): The number of bits used for quantization. + group_size (int): The group size of quantization. + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The unpacked quantized weights and scales. + """ + in_features, out_features = original_shape + + assert num_bits == 4, "Marlin QQQ only supports 4-bit for now." + + # Unpacks the scales + unpacked_s_group, unpacked_s_channel = _from_marlin_scales( + s_group, s_channel, in_features, out_features, num_bits, group_size + ) + + # Unpacks the weights + unpacked_q_w = _from_marlin_weights( + q_w, in_features, out_features, num_bits, group_size + ) + + return unpacked_q_w, unpacked_s_group, unpacked_s_channel + + +def _from_marlin_weights( + q_w: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, + group_size: int, +) -> torch.Tensor: + """Converts a weight tensor in the marlin format to a regular format. + Args: + q_w (torch.Tensor): The packed quantized weights. + size_k (int): The number of input features. + size_n (int): The number of output features. + num_bits (int): The number of bits used for quantization. + group_size (int): The group size of quantization. + Returns: + torch.Tensor: The unpacked quantized weights. + """ + if group_size == -1: + group_size = size_k + granularity = PerAxis(1) if group_size == size_k else PerGroup(group_size) + # Permute + perm = get_qqq_weight_reverse_perm(num_bits, granularity) + + orig_device = q_w.device + + wf = ( + torch.tensor(list(range(0, 32, num_bits)), dtype=torch.int32) + .unsqueeze(0) + .to(orig_device) + ) + # unpack weight + weight = torch.bitwise_right_shift( + torch.unsqueeze(q_w, 2).expand(-1, -1, 32 // num_bits), + wf.unsqueeze(0), + ) + weight = torch.bitwise_and(weight, (2**num_bits) - 1) + weight = weight.reshape(weight.shape[0], weight.shape[1] * weight.shape[2]) + q_w_comp = reverse_marlin_permute_weights(weight, size_k, size_n, perm) + + return q_w_comp + + +def _from_marlin_scales( + s_group: torch.Tensor, + s_channel: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, + group_size: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Converts the quantization scales from the marlin format to their original format. + Args: + s_group (torch.Tensor): The per-group quantization scale in marlin format. + s_channel (torch.Tensor): The per-channel quantization scale in marlin format. + size_k (int): The number of input features. + size_n (int): The number of output features. + num_bits (int): The number of bits used for quantization. + group_size (int): The group size of quantization. + Returns: + Tuple[torch.Tensor, torch.Tensor]: The per-group quantization scale in the original format and + the per-channel quantization scale in the original format. + """ + if group_size == -1: + group_size = size_k + scale_perm, scale_perm_single = get_qqq_scale_reverse_perms(num_bits) + if group_size < size_k: + s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] + s_channel = s_channel.reshape((-1, len(scale_perm_single)))[ + :, scale_perm_single + ] + s_group = s_group.reshape((-1, size_n)).contiguous() + else: + s_channel = s_channel.reshape((-1, len(scale_perm_single)))[ + :, scale_perm_single + ] + s_channel = s_channel.reshape((-1, size_n)).contiguous() + + return s_group, s_channel diff --git a/torchao/quantization/marlin_qqq/utils.py b/torchao/quantization/marlin_qqq/utils.py new file mode 100644 index 0000000000..e0bcc0d853 --- /dev/null +++ b/torchao/quantization/marlin_qqq/utils.py @@ -0,0 +1,193 @@ +from dataclasses import dataclass, field +from typing import List, Tuple + +import numpy +import torch + +from torchao.quantization.granularity import ( + Granularity, + PerAxis, +) + + +@dataclass(frozen=True) +class MarlinQQQConstants: + TILE: int = 16 + MIN_THREAD_N: int = 64 + MAX_PARALLEL: int = 16 + + SUPPORTED_NUM_BITS: List[int] = field(default_factory=lambda: [4]) + SUPPORTED_GROUP_SIZES: List[int] = field(default_factory=lambda: [-1, 128]) + + +const = MarlinQQQConstants() + + +def get_pack_factor(num_bits: int) -> int: + """Compute the packing factor for a given number of bits. + + Args: + num_bits (int): Number of bits to pack. + Returns: + int: The packing factor. + """ + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + + +def marlin_permute_weights( + q_w: torch.Tensor, + size_k: int, + size_n: int, + perm: torch.Tensor, + tile: int = const.TILE, +) -> torch.Tensor: + """Permute weights to 16x64 Marlin tiles. + + Args: + q_w (torch.Tensor): Quantized weights. + size_k (int): Number of input features. + size_n (int): Number of output features. + perm (torch.Tensor): The computed permutation tensor to be applied. + tile (int, optional): Tile size. Defaults to `TILE`. + Returns: + torch.Tensor: Weight tensor permuted to Marlin tiles. + """ + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + + +def reverse_marlin_permute_weights( + q_w_unpacked: torch.Tensor, + size_k: int, + size_n: int, + reverse_perm: torch.Tensor, + tile: int = const.TILE, +) -> torch.Tensor: + """Reverse permute weights from 16x64 Marlin tiles. + Args: + q_w_unpacked (torch.Tensor): Unpacked quantized weights. + size_k (int): Number of input features. + size_n (int): Number of output features. + reverse_perm (torch.Tensor): The computed reverse permutation tensor to be applied. + tile (int, optional): Tile size. Defaults to `TILE`. + Returns: + torch.Tensor: Weight tensor reverse permuted from Marlin tiles. + """ + + assert (q_w_unpacked.shape[0], size_n) == ( + size_k // tile, + q_w_unpacked.shape[1] // tile, + ) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Reverse permute weights to original shape + q_w_comp = q_w_unpacked.reshape((-1, reverse_perm.numel()))[ + :, reverse_perm + ].reshape(q_w_unpacked.shape) + q_w_comp = q_w_comp.reshape((size_k // tile, size_n // tile, tile, tile)) + q_w_comp = q_w_comp.permute((0, 2, 1, 3)) + q_w_comp = q_w_comp.reshape((size_k, size_n)) + + return q_w_comp + + +# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 +def get_qqq_weight_perm(num_bits: int, granularity: Granularity) -> torch.Tensor: + """Precompute permutations for the marlin weight shuffling. + + Args: + num_bits (int): Number of bits to pack. + granularity (Granularity): The weight quantization granularity. + Returns: + torch.Tensor: The weight permutation tensor. + """ + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 4 * (i % 4), + 4 * (i % 4) + 1, + 4 * (i % 4) + 2, + 4 * (i % 4) + 3, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm_list) + + if num_bits == 4: + if isinstance(granularity, PerAxis): + interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) + else: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + else: + raise Exception("num_bits must be 4, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def get_qqq_scale_perms(num_bits: int) -> Tuple[List[int], List[int]]: + """Precompute permutations for the marlin scale shuffling. + Args: + num_bits (int): Number of bits to pack. + Returns: + Tuple[List[int], List[int]]: Scale permutation list and + scale permutation list for a single group. + """ + if num_bits != 4: + raise Exception("num_bits must be 4, got {}".format(num_bits)) + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +def get_qqq_weight_reverse_perm( + num_bits: int, granularity: Granularity +) -> torch.Tensor: + """Reverse permutation for Marlin weight shuffling from `get_qqq_weight_perm`. + Args: + num_bits (int): Number of bits to pack. + granularity (Granularity): The weight quantization granularity. + Returns: + torch.Tensor: The reversed weight permutation tensor. + """ + perm = get_qqq_weight_perm(num_bits, granularity) + perm = perm.argsort() + + return perm + + +def get_qqq_scale_reverse_perms(num_bits: int) -> Tuple[torch.Tensor, torch.Tensor]: + """Reverse permutation for Marlin scale shuffling from `get_qqq_scale_perms`. + Args: + num_bits (int): Number of bits to pack. + Returns: + Tuple[List[int], List[int]]: The reversed scale permutation list and + the reversed scale permutation list for a single group. + """ + scale_perm, scale_perm_single = get_qqq_scale_perms(num_bits) + scale_perm = torch.tensor(scale_perm).argsort() + scale_perm_single = torch.tensor(scale_perm_single).argsort() + + return scale_perm, scale_perm_single diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 07e5a269e6..0e6ebdc7e0 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -28,15 +28,17 @@ from torchao.dtypes import ( AffineQuantizedTensor, Float8Layout, + MarlinQQQLayout, MarlinSparseLayout, PlainLayout, SemiSparseLayout, TensorCoreTiledLayout, + UintxLayout, to_affine_quantized_floatx, to_affine_quantized_floatx_static, to_affine_quantized_intx, + to_marlinqqq_quantized_intx, ) -from torchao.dtypes.uintx.uintx_layout import UintxLayout from torchao.float8.inference import Float8MMConfig from torchao.quantization.linear_activation_weight_observed_tensor import ( LinearActivationWeightObservedTensor, @@ -525,10 +527,35 @@ def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor: ) +def _int8_symm_per_token_quant(x: torch.Tensor) -> torch.Tensor: + mapping_type = MappingType.SYMMETRIC + target_dtype = torch.int8 + eps = 1e-5 + quant_min = -127 + quant_max = 127 + + return to_affine_quantized_intx( + x, + mapping_type, + _get_per_token_block_size(x), + target_dtype, + eps=eps, + quant_min=quant_min, + quant_max=quant_max, + scale_dtype=torch.float32, + ) + + def apply_int8_dynamic_activation_int4_weight_quant( - weight, group_size=32, mapping_type=MappingType.SYMMETRIC + weight, + group_size=32, + layout=PlainLayout(), + mapping_type=MappingType.SYMMETRIC, + act_mapping_type=MappingType.ASYMMETRIC, ): """This is defined here instead of local function to support serialization""" + if group_size is None or group_size == -1: + group_size = weight.shape[-1] if weight.shape[-1] % group_size != 0: return weight @@ -540,17 +567,37 @@ def apply_int8_dynamic_activation_int4_weight_quant( quant_max = 7 # input settings - input_quant_func = _int8_asymm_per_token_quant + if act_mapping_type == MappingType.ASYMMETRIC: + input_quant_func = _int8_asymm_per_token_quant + elif act_mapping_type == MappingType.SYMMETRIC: + input_quant_func = _int8_symm_per_token_quant + else: + assert False, f"Unsupported activation mapping type: {act_mapping_type}" - weight = to_affine_quantized_intx( - weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps - ) + if isinstance(layout, MarlinQQQLayout): + weight = to_marlinqqq_quantized_intx( + weight, block_size, quant_min, quant_max, _layout=layout + ) + else: + weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + _layout=layout, + ) weight = to_linear_activation_quantized(weight, input_quant_func) return weight def int8_dynamic_activation_int4_weight( - group_size=32, mapping_type=MappingType.SYMMETRIC + group_size=32, + layout=PlainLayout(), + mapping_type=MappingType.SYMMETRIC, + act_mapping_type=MappingType.ASYMMETRIC, ): """Applies int8 dynamic per token asymmetric activation quantization and int4 per group weight symmetric quantization to linear This is used to produce a model for executorch backend, but currently executorch did not @@ -559,11 +606,16 @@ def int8_dynamic_activation_int4_weight( Args: `group_size`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained + `layout`: layout type for quantized weight tensor, only supports `PlainLayout()` and `MarlinQQQLayout()` for now + `mapping_type`: quantization type for weight, controls the weight quantization is symmetric or asymmetric + `act_mapping_type`: quantization type for activation, controls the activation quantization is symmetric or asymmetric """ return _get_linear_subclass_inserter( apply_int8_dynamic_activation_int4_weight_quant, group_size=group_size, + layout=layout, mapping_type=mapping_type, + act_mapping_type=act_mapping_type, ) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 665b3f7464..37aa609b9b 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -33,6 +33,8 @@ "fake_quantize_affine", "fake_quantize_affine_cachemask", "choose_qparams_and_quantize_affine_hqq", + "choose_qparams_and_quantize_affine_qqq", + "dequantize_affine_qqq", "MappingType", "ZeroPointDomain", "TorchAODType", @@ -916,6 +918,119 @@ def _choose_qparams_affine( return scale.to(dtype=scale_dtype), zero_point +def choose_qparams_and_quantize_affine_qqq( + w: torch.Tensor, + num_bits: int, + group_size: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert num_bits == 4, f"Unsupported num_bits = {num_bits}" + size_n, size_k = w.shape + assert group_size in [-1, 128, size_k], f"Unsupported groupsize = {group_size}" + orig_device = w.device + if group_size == -1: + group_size = size_k + + if group_size < size_k: + # Reshape to [-1, group_size] + w = w.reshape((-1, group_size)) + + max_q_val = 2**num_bits - 1 + half_q_val = (max_q_val + 1) // 2 + + # Compute scale for each group + s_group = torch.amax(torch.abs(w), -1, keepdim=True) + s_group *= 2 / max_q_val # 2 => symmetric + + # Quantize + q_w = torch.round(w / s_group).int() + q_w += half_q_val + q_w = torch.clamp(q_w, 0, max_q_val) + # Compute ref (dequantized) + w_ref = (q_w - half_q_val).half() * s_group + + # Restore original shapes + def reshape_w(w): + w = w.reshape((size_n, size_k)).contiguous() + return w + + q_w = reshape_w(q_w) + w_ref = reshape_w(w_ref) + + # Compute int8 quantization scale for each channel + s_channel = torch.amax(torch.abs(w_ref), -1, keepdim=True) + s_channel /= 127.0 + t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) + w_ref = t_int8.half() * s_channel + s_channel = s_channel.reshape(-1, 1).to(dtype=torch.float) + + # Fuse scales + s_group = (s_group.reshape(size_n, -1).contiguous() / s_channel).to( + dtype=torch.half + ) + else: + max_q_val = 2 ** (num_bits - 1) - 1 + + # Compute scale for each channel + s_channel = torch.amax(torch.abs(w), -1, keepdim=True) + s_channel /= max_q_val + + # Quantize + q_w = torch.round(w / s_channel).int() + q_w = torch.clamp(q_w, -max_q_val, max_q_val) + # Compute ref (dequantized) + w_ref = q_w.half() * s_channel + + s_group = torch.tensor([], dtype=torch.half, device=orig_device) + # div 2 ** (8 - self.bits)) to offset right shift in unpacking + s_channel /= 2 ** (8 - num_bits) + s_channel = s_channel.reshape(size_n, -1).contiguous().to(torch.float) + + return q_w, s_group, s_channel, w_ref + + +def dequantize_affine_qqq( + w: torch.Tensor, + s_group: torch.Tensor, + s_channel: torch.Tensor, + num_bits: int, + group_size: int, + output_dtype: Optional[torch.dtype] = None, +): + assert num_bits == 4, f"Unsupported num_bits = {num_bits}" + size_n, size_k = w.shape + assert group_size in [-1, 128, size_k], f"Unsupported groupsize = {group_size}" + if group_size == -1: + group_size = size_k + + if group_size < size_k: + # Reshape to [-1, group_size] + w = w.reshape((-1, group_size)) + + max_q_val = 2**num_bits - 1 + half_q_val = (max_q_val + 1) // 2 + + s_group = s_group * s_channel.half() + w_dq = (w - half_q_val).half() * s_group.reshape(-1, 1) + + # Restore original shapes + def reshape_w(w): + w = w.reshape((size_n, size_k)).contiguous() + return w + + w_dq = reshape_w(w_dq) + + else: + s_channel = s_channel * (2 ** (8 - num_bits)) + w_dq = w.half() * s_channel + + if output_dtype is None: + w_dq = w_dq.to(torch.float16) + else: + w_dq = w_dq.to(output_dtype) + + return w_dq + + # HQQ ############################################################################ # Shrinking operator (proximal operator for the lp norm)