diff --git a/README.md b/README.md index 736915463f..90fb105599 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,7 @@ swap_linear_with_semi_sparse_linear(model, {"seq.0": SemiSparseLinear}) * [MX](torchao/prototype/mx_formats) implementing training and inference support with tensors using the [OCP MX spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) data types, which can be described as groupwise scaled float8/float6/float4/int8, with the scales being constrained to powers of two. This work is prototype as the hardware support is not available yet. * [nf4](torchao/dtypes/nf4tensor.py) which was used to [implement QLoRA](https://github.com/pytorch/torchtune/blob/main/docs/source/tutorials/qlora_finetune.rst) one of the most popular finetuning algorithms without writing custom Triton or CUDA code. Accessible talk [here](https://x.com/HamelHusain/status/1800315287574847701) -* [fp6](torchao/prototype/fp6_llm/) for 2x faster inference over fp16 with an easy to use wrapper api `convert_fp6_llm(model)` +* [fp6](torchao/prototype/quant_llm/) for 2x faster inference over fp16 with an easy to use API `quantize(model, fp6_llm_weight_only())` ## Composability @@ -104,7 +104,7 @@ python setup.py install * [GaLore](torchao/prototype/galore/) a drop for the Adam Optimizer that allows you to finetune llama 7b on a single 4090 card with up to 70% speedups relative to eager PyTorch * [DoRA](torchao/prototype/dora) a newer replacement for QLoRA with more promising convergence characteristics * [Fused int4/fp16 Quant Matmul](torchao/prototype/hqq) which is particularly useful for compute bound kernels showing 4x speedups over tinygemm for larger batch sizes such as 512 -* [gau-nernst](https://github.com/gau-nernst) fp6 kernels that are 4x faster than fp16 [torchao/prototype/fp6_llm](torchao/prototype/fp6_llm) +* [gau-nernst](https://github.com/gau-nernst) fp6 kernels that are 4x faster than fp16 [torchao/prototype/quant_llm](torchao/prototype/quant_llm) * [vayuda](https://github.com/vayuda) with generic bitpacking kernels that were code generated using pure PyTorch [prototype/common](torchao/prototype/common) * [andreaskopf](https://github.com/andreaskoepf) and [melvinebenezer](https://github.com/melvinebenezer) with [1 bit LLMs](torchao/prototype/dtypes) Bitnet 1.58 bitpacked into uint2 and fully code-generated with torch.compile diff --git a/benchmarks/benchmark_fp6_llm.py b/benchmarks/benchmark_fp6_llm.py index ae17764e68..b6b99c0ebe 100644 --- a/benchmarks/benchmark_fp6_llm.py +++ b/benchmarks/benchmark_fp6_llm.py @@ -1,25 +1,24 @@ import torch -from torch import nn -from torchao.prototype.fp6_llm.fp6_llm import Fp6LlmLinear, from_tc_float6_e3m2 -from torch.utils.benchmark import Timer import pandas as pd +import torch.nn.functional as F +from torchao.prototype.quant_llm import QuantLlmLinearWeight +from torchao.utils import benchmark_torch_function_in_microseconds from tqdm import tqdm def benchmark(m: int, k: int, n: int): - fp6_weight = torch.randint(256, size=(n, k * 3 // 4), dtype=torch.uint8, device="cuda") - scales = torch.rand(n, dtype=torch.half, device="cuda") + 0.5 - fp6_linear = Fp6LlmLinear(fp6_weight, scales) + fp6_data = torch.randint(256, size=(n, k * 3 // 4), dtype=torch.uint8, device="cuda") + scale = torch.rand(n, dtype=torch.half, device="cuda") + 0.5 + fp6_weight = QuantLlmLinearWeight(fp6_data, scale, 3, 2) - fp16_linear = nn.Linear(k, n, bias=True, dtype=torch.half, device="cuda") - fp16_linear.weight.data = from_tc_float6_e3m2(fp6_weight, dtype=torch.half) * scales[:, None] + fp16_weight = fp6_weight.dequantize(torch.half) fp16_act = torch.randn(m, k, dtype=torch.half, device="cuda") - fp6_output = fp6_linear(fp16_act) - fp16_output = fp16_linear(fp16_act) + fp6_output = F.linear(fp16_act, fp6_weight) + fp16_output = F.linear(fp16_act, fp16_weight) - fp6_measurement = Timer(stmt="fp6_linear(fp16_act)", globals=locals()).blocked_autorange() - fp16_measurement = Timer(stmt="fp16_linear(fp16_act)", globals=locals()).blocked_autorange() + fp6_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp6_weight) + fp16_time = benchmark_torch_function_in_microseconds(F.linear, fp16_act, fp16_weight) # follow https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/tests/python/kernel_test.py # doesn't seem to be the right way to check for correctness @@ -29,9 +28,9 @@ def benchmark(m: int, k: int, n: int): "m": m, "k": k, "n": n, - "fp6_latency (ms)": fp6_measurement.median * 1000, - "fp16_latency (ms)": fp16_measurement.median * 1000, - "speedup (d/s)": fp16_measurement.median / fp6_measurement.median, + "fp6_latency (ms)": fp6_time, + "fp16_latency (ms)": fp16_time, + "speedup (d/s)": fp16_time / fp6_time, "correct": correct, } diff --git a/test/prototype/test_fp6_llm.py b/test/prototype/test_fp6_llm.py deleted file mode 100644 index 9ee3faae4a..0000000000 --- a/test/prototype/test_fp6_llm.py +++ /dev/null @@ -1,106 +0,0 @@ -import pytest -import torch -from torch import nn -from torch.testing._internal.common_utils import ( - TestCase, - instantiate_parametrized_tests, - parametrize, - run_tests, -) -from torchao.prototype.fp6_llm.fp6_llm import ( - to_tc_float6_e3m2, - from_tc_float6_e3m2, - _to_tc_float6_e3m2_ref, - Fp6LlmLinear, - convert_fp6_llm, -) -from torchao.prototype.mx_formats.custom_cast import f6_e3m2_unpacked_to_f32, f32_to_f6_e3m2_unpacked - - -_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) - - -class TestFp6LlmLinear(TestCase): - @parametrize("device", _DEVICES) - def test_to_tc_float6_e3m2_correctness(self, device): - x = torch.randn(256, 64, device=device) - - expected = _to_tc_float6_e3m2_ref(x) - actual = to_tc_float6_e3m2(x) - torch.testing.assert_close(actual, expected) - - @parametrize("device", _DEVICES) - def test_to_tc_float6_e3m2_compile(self, device): - x = torch.randn(256, 64, device=device) - - expected = to_tc_float6_e3m2(x) - actual = torch.compile(to_tc_float6_e3m2, fullgraph=True)(x) - torch.testing.assert_close(actual, expected) - - @parametrize("device", _DEVICES) - def test_from_tc_float6_e3m2_correctness(self, device): - x = torch.randn(256, 64, device=device) - - # quantize and dequantize so that the values are exactly representable in FP6 - x = f6_e3m2_unpacked_to_f32(f32_to_f6_e3m2_unpacked(x)) - - actual = from_tc_float6_e3m2(to_tc_float6_e3m2(x)) - torch.testing.assert_close(actual, x) - - @parametrize("device", _DEVICES) - def test_from_tc_float6_e3m2_compile(self, device): - M, N = 256, 64 - x = torch.randint(256, size=(M, N * 3 // 4), dtype=torch.uint8, device=device) - - expected = from_tc_float6_e3m2(x) - actual = torch.compile(from_tc_float6_e3m2, fullgraph=True)(x) - torch.testing.assert_close(actual, expected) - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @parametrize("leading_dims", [(4,), (2, 4)]) - @parametrize("bias", [False, True]) - def test_fp6_llm_linear_forward(self, bias, leading_dims): - OC, IC = 256, 64 - device = "cuda" - - linear = torch.nn.Linear(IC, OC, bias=bias, device=device) - fp6_linear = Fp6LlmLinear.from_float(linear) - assert (fp6_linear.bias is not None) == bias - - x = torch.randn(*leading_dims, IC, device=device, dtype=torch.half) - fp6_linear(x) - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - @parametrize("bias", [False, True]) - def test_fp6_llm_linear_compile(self, bias): - N, OC, IC = 4, 256, 64 - device = "cuda" - - linear = torch.nn.Linear(IC, OC, bias=bias, device=device) - fp6_linear = Fp6LlmLinear.from_float(linear) - - x = torch.randn(N, IC, device=device, dtype=torch.half) - expected = fp6_linear(x) - actual = torch.compile(fp6_linear, fullgraph=True)(x) - torch.testing.assert_close(actual, expected) - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") - def test_convert_fp6_llm(self): - device = "cuda" - model = nn.Sequential(nn.Linear(64, 256, bias=False), nn.Linear(256, 256)).to(device) - convert_fp6_llm(model) - - assert isinstance(model[0], Fp6LlmLinear) - assert model[0].bias is None - assert isinstance(model[1], Fp6LlmLinear) - assert model[1].bias is not None - - x = torch.randn(4, 64, device=device) - model(x) - - -instantiate_parametrized_tests(TestFp6LlmLinear) - - -if __name__ == "__main__": - run_tests() diff --git a/test/prototype/test_quant_llm.py b/test/prototype/test_quant_llm.py new file mode 100644 index 0000000000..77eac6f69d --- /dev/null +++ b/test/prototype/test_quant_llm.py @@ -0,0 +1,106 @@ +import copy + +import pytest +import torch +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, + run_tests, +) +from torchao.prototype.quant_llm import ( + QuantLlmLinearWeight, + quant_llm_fpx_weight_only, + to_scaled_tc_fpx, + from_scaled_tc_fpx, +) +from torchao.prototype.quant_llm.quant_llm import _pack_tc_fpx, _pack_tc_fp6 +from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32 +from torchao.quantization.quant_api import quantize + + +_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) +_FPx_DTYPES = [(3, 2), (2, 2)] + + +class TestQuantLlmLinearWeight(TestCase): + @parametrize("device", _DEVICES) + def test_pack_tc_fp6_correctness(self, device): + x = torch.randint(256, size=(256, 64), dtype=torch.uint8, device=device) + + expected = _pack_tc_fpx(x, 6) + actual = _pack_tc_fp6(x) + torch.testing.assert_close(actual, expected) + + @parametrize("ebits,mbits", _FPx_DTYPES) + @parametrize("device", _DEVICES) + def test_to_scaled_tc_fpx_compile(self, ebits, mbits, device): + x = torch.randn(256, 64, device=device) + + expected = to_scaled_tc_fpx(x, ebits, mbits) + actual = torch.compile(to_scaled_tc_fpx, fullgraph=True)(x, ebits, mbits) + torch.testing.assert_close(actual, expected) + + @parametrize("ebits,mbits", _FPx_DTYPES) + @parametrize("device", _DEVICES) + def test_from_tc_fpx_correctness(self, ebits, mbits, device): + x = torch.randn(256, 64, device=device) * 100 + + # quantize and dequantize so that the values are exactly representable in FPx + x = _fpx_unpacked_to_f32(_f32_to_fpx_unpacked(x, ebits, mbits), ebits, mbits) + + tc_fpx, scale = to_scaled_tc_fpx(x, ebits, mbits) + actual = from_scaled_tc_fpx(tc_fpx, ebits, mbits, scale=scale) + torch.testing.assert_close(actual, x) + + @parametrize("ebits,mbits", _FPx_DTYPES) + @parametrize("device", _DEVICES) + def test_from_scaled_tc_fpx_compile(self, ebits, mbits, device): + M, N = 256, 64 + nbits = 1 + ebits + mbits + x = torch.randint(256, size=(M, N // 8 * nbits), dtype=torch.uint8, device=device) + scale = torch.randn(M, device=device) + + expected = from_scaled_tc_fpx(x, ebits, mbits, scale) + actual = torch.compile(from_scaled_tc_fpx, fullgraph=True)(x, ebits, mbits, scale) + torch.testing.assert_close(actual, expected) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @parametrize("ebits,mbits", _FPx_DTYPES) + @parametrize("leading_dims", [(4,), (2, 4)]) + @parametrize("bias", [False, True]) + def test_quant_llm_linear_weight(self, ebits, mbits, bias, leading_dims): + OC, IC = 256, 64 + device = "cuda" + + fp16_weight = torch.randn(OC, IC, device=device, dtype=torch.half) + fp16_bias = torch.randn(OC, device=device, dtype=torch.half) if bias else None + + fpx_weight = QuantLlmLinearWeight.from_float(fp16_weight, ebits, mbits) + + x = torch.randn(*leading_dims, IC, device=device, dtype=torch.half) + out = torch.nn.functional.linear(x, fpx_weight, fp16_bias) + assert out.shape == leading_dims + (OC,) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @parametrize("ebits,mbits", _FPx_DTYPES) + @parametrize("bias", [False, True]) + def test_quant_llm_quantize(self, ebits, mbits, bias): + N, OC, IC = 4, 256, 64 + device = "cuda" + + linear = torch.nn.Linear(IC, OC, bias=bias, device=device) + fpx_linear = copy.deepcopy(linear) + quantize(fpx_linear, quant_llm_fpx_weight_only(ebits, mbits)) + + x = torch.randn(N, IC, device=device, dtype=torch.half) + expected = fpx_linear(x) + actual = torch.compile(fpx_linear, fullgraph=True)(x) + torch.testing.assert_close(actual, expected) + + +instantiate_parametrized_tests(TestQuantLlmLinearWeight) + + +if __name__ == "__main__": + run_tests() diff --git a/test/test_ops.py b/test/test_ops.py index 920b32c5f2..28e7437b66 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1,60 +1,69 @@ import torch -from torch.testing._internal.common_utils import TestCase, IS_FBCODE +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, + run_tests, +) from torch.testing._internal.optests import opcheck -import torchao -from torchao.prototype.fp6_llm.fp6_llm import from_tc_float6_e3m2 -import unittest -from parameterized import parameterized +from torchao.utils import is_fbcode +from torchao.prototype.quant_llm import from_scaled_tc_fpx import pytest +if is_fbcode(): + pytest.skip("Skipping the test in fbcode since we don't have TARGET file for kernels") + try: import torchao.ops except RuntimeError: pytest.skip("torchao.ops not available") -# torch.testing._internal.optests.generate_tests.OpCheckError: opcheck(op, ...): -# test_faketensor failed with module 'torch' has no attribute '_custom_ops' (scroll up for stack trace) -@pytest.mark.filterwarnings("ignore:create_unbacked_symint is deprecated, please use new_dynamic_size instead:UserWarning") -@unittest.skipIf(IS_FBCODE, "Skipping the test in fbcode since we don't have TARGET file for kernels") class TestOps(TestCase): - def _create_fp6_inputs(self, BS: int, OC: int, IC: int, device): - # Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t. - fp6_weight = torch.randint(4294967295, (OC, IC // 16 * 3)).to(torch.int) - fp16_scale = torch.rand(OC).half() + 0.5 - fp16_activation = torch.rand(BS, IC).half() + 0.5 - return fp6_weight.to(device), fp16_scale.to(device), fp16_activation.to(device) - - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - def test_fp6_llm_linear(self): + def _create_fpx_inputs(self, ebits: int, mbits: int, BS: int, OC: int, IC: int, device): + # Randomly initialize each byte + nbits = 1 + ebits + mbits + fpx_weight = torch.randint(256, (OC, IC // 8 * nbits), dtype=torch.uint8) + scale = torch.rand(OC).half() + 0.5 + fp16_act = torch.rand(BS, IC).half() + 0.5 + return fpx_weight.to(device), scale.to(device), fp16_act.to(device) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @parametrize("ebits,mbits", [(3, 2), (2, 2)]) + def test_quant_llm_linear(self, ebits, mbits): BS = 2 OC = 256 IC = 256 splitK = 1 - fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda") + fpx_weight, scale, fp16_act = self._create_fpx_inputs(ebits, mbits, BS, OC, IC, "cuda") # smoke test - torchao.ops.fp6_llm_linear(fp16_activation, fp6_weight, fp16_scale, splitK) + torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK) # comprehensive testing test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] - opcheck(torch.ops.torchao.fp6_llm_linear, (fp16_activation, fp6_weight, fp16_scale, splitK), test_utils=test_utils) + opcheck(torch.ops.torchao.quant_llm_linear, (ebits, mbits, fp16_act, fpx_weight, scale, splitK), test_utils=test_utils) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @parametrize("BS,OC,IC,splitK", [(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) + @parametrize("ebits,mbits", [(3, 2), (2, 2)]) + def test_quant_llm_linear_correctness(self, ebits, mbits, BS, OC, IC, splitK): + # adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/tests/python/kernel_test_fpx.py + fpx_weight, scale, fp16_act = self._create_fpx_inputs(ebits, mbits, BS, OC, IC, "cuda") + + results_fpx = torchao.ops.quant_llm_linear(ebits, mbits, fp16_act, fpx_weight, scale, splitK) - # adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py - @parameterized.expand([(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - def test_fp6_llm_linear_correctness(self, BS, OC, IC, splitK): - fp6_weight, fp16_scale, fp16_activation = self._create_fp6_inputs(BS, OC, IC, "cuda") + fp16_weight = from_scaled_tc_fpx(fpx_weight, ebits, mbits, scale).half() + results_fp16 = fp16_act @ fp16_weight.T - results_fp6 = torchao.ops.fp6_llm_linear(fp16_activation, fp6_weight, fp16_scale, splitK) + error = (results_fpx - results_fp16).abs().mean() + gt = results_fp16.abs().mean() + relative_error = error / gt + assert relative_error < 1e-3 - fp16_weight = from_tc_float6_e3m2(fp6_weight.view(torch.uint8), dtype=torch.float16) * fp16_scale[:, None] - results_fp16 = fp16_activation @ fp16_weight.T - error = (results_fp6 - results_fp16).abs() - relative_error = error / results_fp16.abs() - assert relative_error.mean() < 1e-2 +instantiate_parametrized_tests(TestOps) if __name__ == "__main__": - unittest.main() + run_tests() diff --git a/torchao/csrc/cuda/fp6_llm/configs.h b/torchao/csrc/cuda/fp6_llm/configs.h index 0a642fc805..60f6745048 100644 --- a/torchao/csrc/cuda/fp6_llm/configs.h +++ b/torchao/csrc/cuda/fp6_llm/configs.h @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. // -// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/configs.h +// This file is copied from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/configs.h #ifndef CONFIGS_H #define CONFIGS_H @@ -63,28 +63,6 @@ struct TilingConfig { static constexpr int SMEM_SIZE_C_TILE = TILE_N * (TILE_M + PADDING_BYTES_16) * 4; // sizeof(float)=4 }; -/************************ General Config for FP6-LLM **********************/ -#define WEIGHT_FRAG1_BIT_WIDTH 2 -#define WEIGHT_FRAG2_BIT_WIDTH 4 -#define WEIGHT_BIT_WIDTH (WEIGHT_FRAG1_BIT_WIDTH+WEIGHT_FRAG2_BIT_WIDTH) // 6 -//#define QUANT_GROUP_SIZE_DIVIDED_BY_64 4 // QuantGroupSize: 4*64 = 256 -/*************************** 64*64 Weghts of A WARP *************************/ -#define WEIGHT_PER_UNIT (WARP_M*WARP_K) // 64*64 -#define SMEM_SIZE_IN_BYTES_PER_WARP_A1 (WEIGHT_PER_UNIT*WEIGHT_FRAG1_BIT_WIDTH/8) // 1024 Bytes #doubleBuffer not takedn into consideration -#define SMEM_SIZE_IN_BYTES_PER_WARP_A2 (WEIGHT_PER_UNIT*WEIGHT_FRAG2_BIT_WIDTH/8) // 2048 Bytes #doubleBuffer not takedn into consideration -#define SMEM_SIZE_A1_TILE (SMEM_SIZE_IN_BYTES_PER_WARP_A1*4*PIPELINE_LEVEL_GMEM) // #WARP=4, #Trible-Buffer for 3-level pipeline for A = 12 KB; double buffer for 2-level pipeline A= 8 KB. -#define SMEM_SIZE_A2_TILE (SMEM_SIZE_IN_BYTES_PER_WARP_A2*4*PIPELINE_LEVEL_GMEM) // #WARP=4, #Trible-Buffer for 3-level pipeline for A = 24 KB; double buffer for 2-level pipeline A= 16 KB. -/******************** Gloabl Memory Layout For QUANTIZED DATA ******************/ -#define NUM_INT4_PER_UNIT_2BIT_FRAG (WEIGHT_PER_UNIT*WEIGHT_FRAG1_BIT_WIDTH/128) // 64 -#define NUM_INT4_PER_UNIT_4BIT_FRAG (WEIGHT_PER_UNIT*WEIGHT_FRAG2_BIT_WIDTH/128) // 128 -/******************** Register Allocation For QUANTIZED DATA ******************/ -#define WEIGHT_PER_THREAD (WEIGHT_PER_UNIT/WARP_SIZE) // 128 -#define REG_PER_THREAD_2BIT_FRAG (WEIGHT_PER_THREAD/REG_BIT_WIDTH*2) // 8 -#define REG_PER_THREAD_4BIT_FRAG (WEIGHT_PER_THREAD/REG_BIT_WIDTH*4) // 16 -/******************** Register Allocation For QUANT Scales ******************/ -#define WARP_REG_QUANT_SCALE 4 // 8 rows per thread -> 8 FP16 scales -> 4 registers -#define WARP_REG_QUANT_SCALE_DISTRIBUTED 1 // T0-T3, T4-T7, ..., T28-T31 share the same scales, using shfl to get all the scales for each thread - #endif // CONFIGS_H diff --git a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu index 8db5d44303..1d44acde08 100644 --- a/torchao/csrc/cuda/fp6_llm/fp6_linear.cu +++ b/torchao/csrc/cuda/fp6_llm/fp6_linear.cu @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. // -// This file is adapted from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/fp6_linear.cu +// This file is adapted from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/fp6_linear.cu #include "kernel_matmul.cuh" #include "kernel_reduction.cuh" @@ -20,7 +20,7 @@ #include #include -template +template static void Kernel_Ex(cudaStream_t stream, const uint4 *Weight, const half *Scales, @@ -37,8 +37,8 @@ static void Kernel_Ex(cudaStream_t stream, printf("M: %d, N: %d, K: %d, SplitK: %d\n", M_Global, N_Global, K_Global, Split_K); printf("TILE_M: %d, TILE_K: %d, TILE_N: %d\n", TilingConfig::TILE_M, TilingConfig::TILE_K, TilingConfig::TILE_N); #endif - static size_t SHMEM_SZ = max(TilingConfig::SMEM_SIZE_B_TILE+SMEM_SIZE_A1_TILE+SMEM_SIZE_A2_TILE, TilingConfig::SMEM_SIZE_C_TILE); - cudaFuncSetAttribute(QUANT_GEMM_Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SHMEM_SZ); + static size_t SHMEM_SZ = max(TilingConfig::SMEM_SIZE_B_TILE+SMEM_SIZE_PER_TB_A_TILE, TilingConfig::SMEM_SIZE_C_TILE); + cudaFuncSetAttribute(QUANT_GEMM_Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, SHMEM_SZ); size_t dimN = (N_Global-1) / TilingConfig::TILE_N + 1; size_t dimM = M_Global * Split_K / TilingConfig::TILE_M; dim3 GridDim(dimN, dimM, 1); @@ -49,14 +49,12 @@ static void Kernel_Ex(cudaStream_t stream, GridDim.x, GridDim.y, GridDim.z, BlockDim.x, BlockDim.y, BlockDim.z, SHMEM_SZ); printf("\n"); #endif - QUANT_GEMM_Kernel<<>> + QUANT_GEMM_Kernel<<>> (Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); } -/* - * - */ -cudaError_t fp6_linear_kernel(cudaStream_t stream, +template +cudaError_t fpx_linear_kernel(cudaStream_t stream, const uint4 *Weight, const half *Scales, const half *B, @@ -82,30 +80,30 @@ cudaError_t fp6_linear_kernel(cudaStream_t stream, if (Split_K == 1) { switch (N_PowerOf2) { - case 8: Kernel_Ex, half>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; - case 16: Kernel_Ex, half>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; - case 32: Kernel_Ex, half>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; - case 64: Kernel_Ex, half>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; - case 128: Kernel_Ex, half>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 8: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 16: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 32: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 64: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + case 128: Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; default: if (N_PowerOf2 % 128 != 0) { printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); return cudaErrorUnknown; } - Kernel_Ex, half>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; + Kernel_Ex, half, EXPONENT, MANTISSA>(stream, Weight, Scales, B, C, M_Global, N_Global, K_Global, Split_K); break; } } else { switch (N_PowerOf2) { - case 8: Kernel_Ex, float>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; - case 16: Kernel_Ex, float>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; - case 32: Kernel_Ex, float>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; - case 64: Kernel_Ex, float>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; - case 128: Kernel_Ex, float>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 8: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 16: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 32: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 64: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + case 128: Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; default: if (N_PowerOf2 % 128 != 0) { printf("FP6LLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2); return cudaErrorUnknown; } - Kernel_Ex, float>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; + Kernel_Ex, float, EXPONENT, MANTISSA>(stream, Weight, Scales, B, Reduction_Workspace, M_Global, N_Global, K_Global, Split_K); break; } // Reduction for SplitK dim3 GridDim((M_Global * N_Global) / REDUCTION_ELEMENT_PER_THREADBLOCK, 1, 1); @@ -118,11 +116,13 @@ cudaError_t fp6_linear_kernel(cudaStream_t stream, #include #include +#include #include namespace torchao { +// MODIFICATION NOTE: dtype of _weights is changed to uint8 /* -Computes FP6-FP16 GEMM (PyTorch interface). +Computes FPx-FP16 GEMM (PyTorch interface). [Mathmatical Formula] Standard definition of linear layer: Out = In * trans(W), where In, Out, and W are stored in row-major. @@ -130,28 +130,32 @@ After Equivalent transformation : trans(Out) = W * trans(In). Note that we [Inputs] _in_feats: tensor of shape [B, IC]; // half - _weights: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. + _weights: int tensor of shape [OC, IC // 8 * x]; // x UINT8 words contains 8 FPx weights. _scales: tensor of shape [OC]; // half splitK: spliting the MatMul problem along K dimension for higher GPU utilization, default 1. [Outputs] _out_feats: tensor of shape [B, OC]; // half */ -torch::Tensor fp6_linear_forward_cuda(torch::Tensor _in_feats, - torch::Tensor _weights, - torch::Tensor _scales, - int64_t splitK=1) +torch::Tensor fp_eXmY_linear_forward_cuda( + int64_t EXPONENT, + int64_t MANTISSA, + torch::Tensor _in_feats, + torch::Tensor _weights, + torch::Tensor _scales, + int64_t splitK=1) { + const int64_t NBITS = 1 + EXPONENT + MANTISSA; int num_in_feats = _in_feats.size(0); int num_in_channels = _in_feats.size(1); int num_out_channels = _weights.size(0); - TORCH_CHECK(num_in_channels%64 == 0, "Expected in_features to be a multiple of 64, but received ", num_in_channels); - TORCH_CHECK((num_in_channels/16*3) == _weights.size(1)); // Making sure the K dimension is matched. + TORCH_CHECK(num_in_channels % 64 == 0, "Expected in_features to be a multiple of 64, but received ", num_in_channels); + TORCH_CHECK((num_in_channels / 8 * NBITS) == _weights.size(1)); // Making sure the K dimension is matched. // int M = num_out_channels; int K = num_in_channels; int N = num_in_feats; // Input Tensors - auto weight = reinterpret_cast(_weights.data_ptr()); // weights is [OC, IC] but in FP6. + auto weight = reinterpret_cast(_weights.data_ptr()); // weights is [OC, IC] but in FP6. auto in_feats = reinterpret_cast(_in_feats.data_ptr()); auto scales = reinterpret_cast(_scales.data_ptr()); // Output Tensors @@ -162,23 +166,37 @@ torch::Tensor fp6_linear_forward_cuda(torch::Tensor _in_feats, options = torch::TensorOptions().dtype(torch::kFloat32).device(_in_feats.device()); at::Tensor _workspace = torch::empty({splitK, num_in_feats, num_out_channels}, options); auto Reduction_Workspace = reinterpret_cast(_workspace.data_ptr()); // Reduction_Workspace_Size = Split_K * M_Global * N_Global * sizeof(fp32) - - fp6_linear_kernel(0, // Using default stream here. - weight, - scales, - in_feats, - out_feats, - M, - N, - K, - Reduction_Workspace, - splitK); + + // MODIFICATION NOTE: use at::cuda::getCurrentCUDAStream() instead of default stream (0) + // this fixes problem with CUDA graphs when used with torch.compile() + auto stream = at::cuda::getCurrentCUDAStream(); + + // officially supported in Quant-LLM + if (EXPONENT == 3 && MANTISSA == 2) + fpx_linear_kernel<3, 2>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + else if (EXPONENT == 2 && MANTISSA == 2) + fpx_linear_kernel<2, 2>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + + // experimental + else if (EXPONENT == 2 && MANTISSA == 3) + fpx_linear_kernel<2, 3>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + else if (EXPONENT == 3 && MANTISSA == 1) + fpx_linear_kernel<3, 1>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + // else if (EXPONENT == 2 && MANTISSA == 1) + // fpx_linear_kernel<2, 1>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + // else if (EXPONENT == 3 && MANTISSA == 0) + // fpx_linear_kernel<3, 0>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + // else if (EXPONENT == 2 && MANTISSA == 0) + // fpx_linear_kernel<2, 0>(stream, weight, scales, in_feats, out_feats, M, N, K, Reduction_Workspace, splitK); + + else + TORCH_CHECK(false, "FP", NBITS, " E", EXPONENT, "M", MANTISSA, " is not supported."); return _out_feats; } TORCH_LIBRARY_IMPL(torchao, CUDA, m) { - m.impl("torchao::fp6_llm_linear", &fp6_linear_forward_cuda); + m.impl("torchao::quant_llm_linear", &fp_eXmY_linear_forward_cuda); } } // namespace torchao diff --git a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh index ed11fc8517..f2c137828d 100644 --- a/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh +++ b/torchao/csrc/cuda/fp6_llm/kernel_matmul.cuh @@ -12,36 +12,59 @@ // See the License for the specific language governing permissions and // limitations under the License. // -// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/kernel_matmul.cuh +// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/kernel_matmul.cuh #include "configs.h" #include "utils_gmem.cuh" #include "utils_core.cuh" +/************************** Bitwidth of Weight Segments ************************/ +#define BIT_WIDTH_1 1 +#define BIT_WIDTH_2 2 +#define BIT_WIDTH_4 4 +/*************************** 64*64 Weghts of Weight Matrix *********************/ +#define WEIGHT_PER_WARP (WARP_M*WARP_K) // 64*64 = 4096 +#define SMEM_SIZE_PER_WARP_1BIT (WEIGHT_PER_WARP*BIT_WIDTH_1/8) // 512 Bytes, doubleBuffer not taken into consideration +#define SMEM_SIZE_PER_WARP_2BIT (WEIGHT_PER_WARP*BIT_WIDTH_2/8) // 1024 Bytes, doubleBuffer not taken into consideration +#define SMEM_SIZE_PER_WARP_4BIT (WEIGHT_PER_WARP*BIT_WIDTH_4/8) // 2048 Bytes, doubleBuffer not taken into consideration +#define SMEM_SIZE_PER_TB_1BIT (SMEM_SIZE_PER_WARP_1BIT*TilingConfig::BLOCK_WARPS*PIPELINE_LEVEL_GMEM) // #WARP=4; Trible-Buffer for 3-level pipeline for A = 6 KB; double buffer for 2-level pipeline A= 4 KB. +#define SMEM_SIZE_PER_TB_2BIT (SMEM_SIZE_PER_WARP_2BIT*TilingConfig::BLOCK_WARPS*PIPELINE_LEVEL_GMEM) // #WARP=4; Trible-Buffer for 3-level pipeline for A = 12 KB; double buffer for 2-level pipeline A= 8 KB. +#define SMEM_SIZE_PER_TB_4BIT (SMEM_SIZE_PER_WARP_4BIT*TilingConfig::BLOCK_WARPS*PIPELINE_LEVEL_GMEM) // #WARP=4; Trible-Buffer for 3-level pipeline for A = 24 KB; double buffer for 2-level pipeline A= 16 KB. +#define SMEM_SIZE_PER_TB_A_TILE (SMEM_SIZE_PER_TB_1BIT+SMEM_SIZE_PER_TB_2BIT+SMEM_SIZE_PER_TB_4BIT) // used in fp6_linear.cu, Kernel_Ex(). +/******************** Gloabl Memory Layout For QUANTIZED DATA *******************/ +#define NUM_INT4_PER_WARP_1BIT (WEIGHT_PER_WARP*BIT_WIDTH_1/128) // 32 +#define NUM_INT4_PER_WARP_2BIT (WEIGHT_PER_WARP*BIT_WIDTH_2/128) // 64 +#define NUM_INT4_PER_WARP_4BIT (WEIGHT_PER_WARP*BIT_WIDTH_4/128) // 128 + /* * C = A*B * A: row major with ahead-of-time layout transformation, FP6 * B: col major, FP16 * C: col major, FP16 */ - template + template __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, const half *B, OutputDataType* C, const size_t M_Global, const size_t N_Global, const size_t K_Global, - int Split_K) + int Split_K) { #ifdef DEBUG_MODE assert(K_Global%TilingConfig::TILE_K==0); assert(M_Global%TilingConfig::TILE_M==0); assert( gridDim.y == Split_K * (M_Global/TilingConfig::TILE_M)); #endif - // 2+4 weight split - const uint4* Weight1 = Weight; - const uint4* Weight2 = Weight1 + M_Global*K_Global*2/128; + // 1+2+4 weight split + constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA; + constexpr int USE_SEG_1BIT = BIT_WIDTH & 1; + constexpr int USE_SEG_2BIT = BIT_WIDTH & 2; + constexpr int USE_SEG_4BIT = BIT_WIDTH & 4; + const uint4* Weight_1bit = Weight; + const uint4* Weight_2bit = Weight_1bit + (USE_SEG_1BIT ? M_Global*K_Global*BIT_WIDTH_1/128 : 0); + const uint4* Weight_4bit = Weight_2bit + (USE_SEG_2BIT ? M_Global*K_Global*BIT_WIDTH_2/128 : 0); // Dynamic shared memory for FP16 A tiles, 128 Bytes aligned extern __shared__ __align__(128) half smem[]; - half (*smem_array)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = reinterpret_cast ( smem + (SMEM_SIZE_A1_TILE+SMEM_SIZE_A2_TILE)/2 ); // Dynamic shared memory for FP16 B tiles + half (*smem_array)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = reinterpret_cast ( smem + SMEM_SIZE_PER_TB_A_TILE/2 ); // Dynamic shared memory for FP16 B tiles __shared__ half QuantScales[64*TilingConfig::BLOCK_WARPS]; // static shared memory for quantization scales, 64 row per warp * 4 warps = 512 Bytes // Thread Block Mapping, considering SplitK const size_t BatchID = blockIdx.y / (M_Global/TilingConfig::TILE_M); @@ -54,38 +77,48 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, const size_t AverageNumBlock_K = NumBlock_K/Split_K; const size_t ExtraNumBlock_K = NumBlock_K - AverageNumBlock_K * Split_K; size_t NumIter = AverageNumBlock_K; - if(BatchID(smem); - uint32_t* AFrag_4BIT_SPTR = AFrag_2BIT_SPTR+SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*TilingConfig::BLOCK_WARPS*PIPELINE_LEVEL_GMEM; // 8 buffers including double buffers, 12 for trible buffers + uint32_t* AFrag_1BIT_SPTR = reinterpret_cast(smem); + uint32_t* AFrag_2BIT_SPTR = AFrag_1BIT_SPTR + SMEM_SIZE_PER_TB_1BIT/4; + uint32_t* AFrag_4BIT_SPTR = AFrag_2BIT_SPTR + SMEM_SIZE_PER_TB_2BIT/4; // 8 buffers including double buffers, 12 for trible buffers // StartSPTR for each WARP - AFrag_2BIT_SPTR += warpId * SMEM_SIZE_IN_BYTES_PER_WARP_A1/4; - AFrag_4BIT_SPTR += warpId * SMEM_SIZE_IN_BYTES_PER_WARP_A2/4; + AFrag_1BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_1BIT/4; + AFrag_2BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_2BIT/4; + AFrag_4BIT_SPTR += warpId * SMEM_SIZE_PER_WARP_4BIT/4; // Pre-fetch of A tile for(int i=0; i(AFrag_2BIT_SPTR+i*SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*4, WARP_StartGPTR_A1); - CopyFromGlobalToShared_A(AFrag_4BIT_SPTR+i*SMEM_SIZE_IN_BYTES_PER_WARP_A2/4*4, WARP_StartGPTR_A2); - WARP_StartGPTR_A1 += SMEM_SIZE_IN_BYTES_PER_WARP_A1/16; - WARP_StartGPTR_A2 += SMEM_SIZE_IN_BYTES_PER_WARP_A2/16; + if(USE_SEG_1BIT) CopyFromGlobalToShared_A(AFrag_1BIT_SPTR+i*SMEM_SIZE_PER_WARP_1BIT/4*4, WARP_StartGPTR_A_1BIT); + if(USE_SEG_2BIT) CopyFromGlobalToShared_A(AFrag_2BIT_SPTR+i*SMEM_SIZE_PER_WARP_2BIT/4*4, WARP_StartGPTR_A_2BIT); + if(USE_SEG_4BIT) CopyFromGlobalToShared_A(AFrag_4BIT_SPTR+i*SMEM_SIZE_PER_WARP_4BIT/4*4, WARP_StartGPTR_A_4BIT); + WARP_StartGPTR_A_1BIT += SMEM_SIZE_PER_WARP_1BIT/16; + WARP_StartGPTR_A_2BIT += SMEM_SIZE_PER_WARP_2BIT/16; + WARP_StartGPTR_A_4BIT += SMEM_SIZE_PER_WARP_4BIT/16; } // Global Memory Address for Matrix A (QuantScale) ///////////////////////////////////////////////////////////////////// const half* TB_StartGPTR_A_Scale = Scales + (y*TilingConfig::BLOCK_ROW_WARPS) * 64; @@ -100,10 +133,8 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight, const half* Scales, // Register Allocation for A,B, and C, Initilazed to Zeros ///////////////////////////////////////////////////////////////////// constexpr int NumRegSets_a = WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA block constexpr int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS==1) ? 1 : TilingConfig::WARP_COL_MMA_TENSORS/2; // 1 set = 4 registers, containing a 16*16 MMA block -#ifdef PIPELINE_LEVEL_SMEM uint32_t a [NumRegSets_a * PIPELINE_LEVEL_SMEM][4]; // double/Trible buffer is used // Registers to store decompressed FP6 uint32_t b [NumRegSets_b * PIPELINE_LEVEL_SMEM][4]; // double/Triple buffer is used // Register to store FP16 B matrix (a slice) -#endif float c[NumRegSets_a * NumRegSets_b][REG_PER_THREAD_C_TENSOR_16_16]; for(int i=0; i(a, b, AFrag_2BIT_SPTR, AFrag_4BIT_SPTR, smem_array, Scales_RPTR); -#endif + initialize_mma_slice(a, b, AFrag_1BIT_SPTR, AFrag_2BIT_SPTR, AFrag_4BIT_SPTR, smem_array, Scales_RPTR); // The outer loop. ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// #pragma unroll(1) for (size_t tile_id_k = 0; tile_id_k < NumIter; tile_id_k++) { // Trible-Buffer for A Tile - uint32_t* __restrict__ read_SPTR_Frag1 = AFrag_2BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 - uint32_t* __restrict__ read_SPTR_Frag2 = AFrag_4BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A2/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 -#ifdef PIPELINE_LEVEL_SMEM - uint32_t* __restrict__ read2_SPTR_Frag1 = AFrag_2BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*4; - uint32_t* __restrict__ read2_SPTR_Frag2 = AFrag_4BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A2/4*4; -#endif - uint32_t* __restrict__ write_SPTR_Frag1 = AFrag_2BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A1/4*4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 - uint32_t* __restrict__ write_SPTR_Frag2 = AFrag_4BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_IN_BYTES_PER_WARP_A2/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + uint32_t* __restrict__ read_SPTR_Frag_1bit = AFrag_1BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_1BIT/4*4; // 512 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + uint32_t* __restrict__ read_SPTR_Frag_2bit = AFrag_2BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_2BIT/4*4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + uint32_t* __restrict__ read_SPTR_Frag_4bit = AFrag_4BIT_SPTR + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_4BIT/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + uint32_t* __restrict__ read2_SPTR_Frag_1bit = AFrag_1BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_1BIT/4*4; + uint32_t* __restrict__ read2_SPTR_Frag_2bit = AFrag_2BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_2BIT/4*4; + uint32_t* __restrict__ read2_SPTR_Frag_4bit = AFrag_4BIT_SPTR + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_4BIT/4*4; + uint32_t* __restrict__ write_SPTR_Frag_1bit = AFrag_1BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_1BIT/4*4; // 512 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + uint32_t* __restrict__ write_SPTR_Frag_2bit = AFrag_2BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_2BIT/4*4; // 1024 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 + uint32_t* __restrict__ write_SPTR_Frag_4bit = AFrag_4BIT_SPTR + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * SMEM_SIZE_PER_WARP_4BIT/4*4; // 2048 (1)*4: 4 WARPs; (2)/4: int*+1 = char*+16 // Trible-Buffer for B Tile // MODIFICATION NOTE: to support MSVC, half __restrict__ (*read_SPTR ) is changed to below. similarly for read2_SPTR and write_SPTR. half (* __restrict__ read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+0) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; -#ifdef PIPELINE_LEVEL_SMEM half (* __restrict__ read2_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+1) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; -#endif half (* __restrict__ write_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] = smem_array + ((tile_id_k+(PIPELINE_LEVEL_GMEM-1)) % PIPELINE_LEVEL_GMEM) * TilingConfig::TILE_N; // bool GlobalCopy = (tile_id_k+PIPELINE_LEVEL_GMEM-1) < NumIter; - // Copying A tile from Global to Register, Bypassing L1, using double-buffer - CopyFromGlobalToShared_A(write_SPTR_Frag1, WARP_StartGPTR_A1, GlobalCopy); - CopyFromGlobalToShared_A(write_SPTR_Frag2, WARP_StartGPTR_A2, GlobalCopy); + // Copying A tile from Global to Register, Bypassing L1, using double-buffer + if(USE_SEG_1BIT) CopyFromGlobalToShared_A(write_SPTR_Frag_1bit, WARP_StartGPTR_A_1BIT, GlobalCopy); + if(USE_SEG_2BIT) CopyFromGlobalToShared_A(write_SPTR_Frag_2bit, WARP_StartGPTR_A_2BIT, GlobalCopy); + if(USE_SEG_4BIT) CopyFromGlobalToShared_A(write_SPTR_Frag_4bit, WARP_StartGPTR_A_4BIT, GlobalCopy); // copying B tile from GlobalMemory to SharedMemory CopyFromGlobalToShared (write_SPTR, BTile_GPTR, K_Global, NumColumnToCopy, GlobalCopy); cp_async_group_commit(); - #ifdef PIPELINE_LEVEL_SMEM - core_mma_slice(c, a, b, read_SPTR_Frag1, read_SPTR_Frag2, read_SPTR, Scales_RPTR, 1); // read_SPTR_Frag1, read_SPTR_Frag2 are different for each WARP; read_SPTR is shared among WARPs - core_mma_slice(c, a, b, read_SPTR_Frag1, read_SPTR_Frag2, read_SPTR, Scales_RPTR, 2); - core_mma_slice(c, a, b, read_SPTR_Frag1, read_SPTR_Frag2, read_SPTR, Scales_RPTR, 3); + core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 1); // read_SPTR_Frag_2bit, read_SPTR_Frag_4bit are different for each WARP; read_SPTR is shared among WARPs + core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 2); + core_mma_slice(c, a, b, read_SPTR_Frag_1bit, read_SPTR_Frag_2bit, read_SPTR_Frag_4bit, read_SPTR, Scales_RPTR, 3); // Barriers and Synchronizations cp_async_wait_group(); __syncthreads(); - core_mma_slice(c, a, b, read2_SPTR_Frag1, read2_SPTR_Frag2, read2_SPTR, Scales_RPTR, 0); - // Updating global PTRs - WARP_StartGPTR_A1 += SMEM_SIZE_IN_BYTES_PER_WARP_A1/16; // 4KB/16=256 (1)/16: int4*+1 = char*+16 - WARP_StartGPTR_A2 += SMEM_SIZE_IN_BYTES_PER_WARP_A2/16; // 8KB/16=512 (1)/16: int4*+1 = char*+16 - BTile_GPTR += TilingConfig::TILE_K; - #else - PipelinedCoreLoop(c, read_SPTR, read_SPTR_Frag1, read_SPTR_Frag2, Scales_RPTR); // read_SPTR_Frag1, read_SPTR_Frag2 are different for each WARP; read_SPTR is shared among WARPs + core_mma_slice(c, a, b, read2_SPTR_Frag_1bit, read2_SPTR_Frag_2bit, read2_SPTR_Frag_4bit, read2_SPTR, Scales_RPTR, 0); // Updating global PTRs - WARP_StartGPTR_A1 += SMEM_SIZE_IN_BYTES_PER_WARP_A1/16; // 4KB/16=256 (1)/16: int4*+1 = char*+16 - WARP_StartGPTR_A2 += SMEM_SIZE_IN_BYTES_PER_WARP_A2/16; // 8KB/16=512 (1)/16: int4*+1 = char*+16 + WARP_StartGPTR_A_1BIT += SMEM_SIZE_PER_WARP_1BIT/16; // 2KB/16=128 (1)/16: int4*+1 = char*+16 + WARP_StartGPTR_A_2BIT += SMEM_SIZE_PER_WARP_2BIT/16; // 4KB/16=256 (1)/16: int4*+1 = char*+16 + WARP_StartGPTR_A_4BIT += SMEM_SIZE_PER_WARP_4BIT/16; // 8KB/16=512 (1)/16: int4*+1 = char*+16 BTile_GPTR += TilingConfig::TILE_K; - // Barriers and Synchronizations - cp_async_wait_group(); - __syncthreads(); - #endif } ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh index bafdd0b4e3..1658352ee5 100644 --- a/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh +++ b/torchao/csrc/cuda/fp6_llm/ptx_mma.cuh @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. // -// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/ptx_mma.cuh +// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/ptx_mma.cuh /*************************************************************************** * Copyright 2023 The FLash-LLM Authors. All rights reserved. @@ -39,7 +39,6 @@ // MODIFICATION NOTE: to support MSVC // - uint32_t __restrict__ Reg[][4] is changed to uint32_t (* __restrict__ Reg)[4] // - half __restrict__ (*read_SPTR) is changed to half (* __restrict__ read_SPTR) -#ifdef PIPELINE_LEVEL_SMEM template __device__ __forceinline__ void B_FromSharedToReg(uint32_t (* __restrict__ Reg)[4], half (* __restrict__ read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], @@ -75,45 +74,6 @@ __device__ __forceinline__ void B_FromSharedToReg(uint32_t (* __restrict__ Reg)[ } } } -#else -// Debug: Whether ldmatrix.trans is required??? -// B is in column-major -template -__device__ __forceinline__ void B_FromSharedToReg(uint32_t __restrict__ Reg[][4], - half __restrict__ (*read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], - int k_offset) { - #ifdef DEBUG_MODE - static_assert( (TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0) ); - #endif - - const int warpId = threadIdx.x / WARP_SIZE; - int lane_id = threadIdx.x % WARP_SIZE; - int WARP_j = warpId % TilingConfig::BLOCK_COL_WARPS; - int warp_start_col = TilingConfig::WARP_COL_MMA_TENSORS * MMA_8 * WARP_j; // each warp may start from reading warp_start_col'th column of the B tile in shared memory - #ifdef DEBUG_MODE - assert( warp_start_col==0 ); - #endif - - int col = (lane_id%8) + (lane_id/16)*8; - int row = (lane_id%16) / 8 * 8; - uint32_t smem_local_ptr = static_cast(__cvta_generic_to_shared(&read_SPTR[warp_start_col+col][k_offset + row])); - if(TilingConfig::WARP_COL_MMA_TENSORS==1) { - asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" - : "=r"(Reg[0][0]), "=r"(Reg[0][1]) - : "r"(smem_local_ptr)); - } - else { - #pragma unroll - for (int i = 0; i < TilingConfig::WARP_COL_MMA_TENSORS/2; i++) - { - asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" - : "=r"(Reg[i][0]), "=r"(Reg[i][1]), "=r"(Reg[i][2]), "=r"(Reg[i][3]) - : "r"(smem_local_ptr)); - smem_local_ptr += 16 * (WARP_K+PADDING_SHARED_MEM_FOR_B_8) * sizeof(half); - } - } -} -#endif // MODIFICATION NOTE: to support MSVC, the function signature is changed from // MMA_FP16_M16N8K16(uint32_t __restrict__ c[], uint32_t __restrict__ *a, uint32_t __restrict__ *b). diff --git a/torchao/csrc/cuda/fp6_llm/utils_core.cuh b/torchao/csrc/cuda/fp6_llm/utils_core.cuh index 07e37d85bc..7a6cd36a46 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_core.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_core.cuh @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. // -// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_core.cuh +// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/utils_core.cuh #ifndef UTILS_CORE_CUH #define UTILS_CORE_CUH @@ -24,7 +24,6 @@ #include "utils_parallel_dequant.cuh" -#ifdef PIPELINE_LEVEL_SMEM template __device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], uint32_t* SPTR, int slice_id) { SPTR += slice_id * (NUM_INT_PER_THREAD*WARP_SIZE); @@ -36,35 +35,50 @@ __device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], u } // MODIFICATION NOTE: to support MSVC, half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below. -template +template __device__ __forceinline__ void initialize_mma_slice(uint32_t (*a)[4], uint32_t (*b)[4], - uint32_t* __restrict__ A1_SPTR_read, - uint32_t* __restrict__ A2_SPTR_read, + uint32_t* __restrict__ A_1BIT_SPTR_read, + uint32_t* __restrict__ A_2BIT_SPTR_read, + uint32_t* __restrict__ A_4BIT_SPTR_read, half (* __restrict__ B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], uint32_t* RPTR_Scales) { + // 1+2+4 weight split + constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA; + constexpr int USE_SEG_1BIT = BIT_WIDTH & 1; + constexpr int USE_SEG_2BIT = BIT_WIDTH & 2; + constexpr int USE_SEG_4BIT = BIT_WIDTH & 4; // Writing registers // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 register per thread; - uint32_t a_1[2]; // NO double buffer - uint32_t a_2[4]; // NO double buffer - CopyFromSharedToRegister_AFrag<2> (a_1, A1_SPTR_read, 0); - CopyFromSharedToRegister_AFrag<4> (a_2, A2_SPTR_read, 0); - Dequant_32FP6_4Way(a, a_1, a_2, RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at register level, dequantizing a slice each time + uint32_t a_1bit[1]; // NO double buffer + uint32_t a_2bit[2]; // NO double buffer + uint32_t a_4bit[4]; // NO double buffer + if(USE_SEG_1BIT) CopyFromSharedToRegister_AFrag<1> (a_1bit, A_1BIT_SPTR_read, 0); + if(USE_SEG_2BIT) CopyFromSharedToRegister_AFrag<2> (a_2bit, A_2BIT_SPTR_read, 0); + if(USE_SEG_4BIT) CopyFromSharedToRegister_AFrag<4> (a_4bit, A_4BIT_SPTR_read, 0); + Dequant_32FP6_4Way(a, a_1bit, a_2bit, a_4bit, RPTR_Scales); // SIMT Dequant: dequantizing FPx to FP16 at register level, dequantizing a slice each time B_FromSharedToReg(b, B_SPTR_read, 0); // Loading B from shared to registers } // MODIFICATION NOTE: to support MSVC, half __restrict__ (*B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8] is changed to below. -template +template __device__ __forceinline__ void core_mma_slice(float c[][REG_PER_THREAD_C_TENSOR_16_16], uint32_t (*a)[4], uint32_t (*b)[4], - uint32_t* __restrict__ A1_SPTR_read, - uint32_t* __restrict__ A2_SPTR_read, + uint32_t* __restrict__ A_1bit_SPTR_read, + uint32_t* __restrict__ A_2bit_SPTR_read, + uint32_t* __restrict__ A_4bit_SPTR_read, half (* __restrict__ B_SPTR_read)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], uint32_t* RPTR_Scales, int slice_id) // writing slice[slice_id] to registers, k=0 -> slice_id=1 for prefetching { + // 1+2+4 weight split + constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA; + constexpr int USE_SEG_1BIT = BIT_WIDTH & 1; + constexpr int USE_SEG_2BIT = BIT_WIDTH & 2; + constexpr int USE_SEG_4BIT = BIT_WIDTH & 4; + #ifdef DEBUG_MODE assert((TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0)); // if WARP_COL_MMA_TENSORS == 1, B tile in registers is padded to a 16*16 MMA block #endif @@ -94,100 +108,18 @@ __device__ __forceinline__ void core_mma_slice(float c[][REG } } } - // Writing registers // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 register per thread; - uint32_t a_1[2]; // NO double buffer - uint32_t a_2[4]; // NO double buffer - CopyFromSharedToRegister_AFrag<2> (a_1, A1_SPTR_read, slice_id); - CopyFromSharedToRegister_AFrag<4> (a_2, A2_SPTR_read, slice_id); - Dequant_32FP6_4Way(a_write, a_1, a_2, RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at register level, dequantizing a slice each time + uint32_t a_1bit[1]; // NO double buffer + uint32_t a_2bit[2]; // NO double buffer + uint32_t a_4bit[4]; // NO double buffer + if(USE_SEG_1BIT) CopyFromSharedToRegister_AFrag<1> (a_1bit, A_1bit_SPTR_read, slice_id); + if(USE_SEG_2BIT) CopyFromSharedToRegister_AFrag<2> (a_2bit, A_2bit_SPTR_read, slice_id); + if(USE_SEG_4BIT) CopyFromSharedToRegister_AFrag<4> (a_4bit, A_4bit_SPTR_read, slice_id); + Dequant_32FP6_4Way(a_write, a_1bit, a_2bit, a_4bit, RPTR_Scales); // SIMT Dequant: dequantizing FP6 to FP16 at register level, dequantizing a slice each time B_FromSharedToReg (b_write, B_SPTR_read, slice_id); // Loading B from shared to registers } -#else -// Old version with naive pipeline design -template -__device__ __forceinline__ void CopyFromSharedToRegister_AFrag(uint32_t Reg[], uint32_t* SPTR) { - int lane_id = threadIdx.x % WARP_SIZE; - #pragma unroll - for(int i=0; i -__device__ __forceinline__ void PipelinedCoreLoop(float c[][REG_PER_THREAD_C_TENSOR_16_16], - half __restrict__ (*read_SPTR)[WARP_K+PADDING_SHARED_MEM_FOR_B_8], - uint32_t* __restrict__ read_SPTR_Frag1, - uint32_t* __restrict__ read_SPTR_Frag2, - uint32_t* RPTR_Scales) -{ - #ifdef DEBUG_MODE - assert((TilingConfig::WARP_COL_MMA_TENSORS==1) || (TilingConfig::WARP_COL_MMA_TENSORS%2==0)); // if WARP_COL_MMA_TENSORS == 1, B tile in registers is padded to a 16*16 MMA block - #endif - const int NumRegSets_a = WARP_ROW_MMA_TENSORS; // 1 set = 4 registers, containing a 16*16 MMA block - const int NumRegSets_b = (TilingConfig::WARP_COL_MMA_TENSORS==1) ? 1 : TilingConfig::WARP_COL_MMA_TENSORS/2; // 1 set = 4 registers, containing a 16*16 MMA block - - // Reigsters to store FP32 results - uint32_t (*c_uint_ptr)[REG_PER_THREAD_C_TENSOR_16_16] = reinterpret_cast(c); - // Registers to store FP6 fragments for a slice (64*16) of A matrix => 32 FP6 per thread => 6 register per thread; - uint32_t a_1[2*2]; // double buffer is used - uint32_t a_2[4*2]; // double buffer is used - // Registers to store decompressed FP6 - uint32_t a [NumRegSets_a * 1][4]; // No double buffer - // Register to store FP16 B matrix (a slice) - uint32_t b [NumRegSets_b * 2][4]; // double buffer is used - - // Overlapped Smem and TC pipeline: pre-loading from shared to registers - CopyFromSharedToRegister_AFrag<2> (a_1, read_SPTR_Frag1); - CopyFromSharedToRegister_AFrag<4> (a_2, read_SPTR_Frag2); - B_FromSharedToReg (b, read_SPTR, 0); - - #pragma unroll - for (int k = 0; k < WARP_K_MMA_TENSORS; k++) { - uint32_t (*b_read)[4] = b; - uint32_t (*b_write)[4] = b; - uint32_t *a_1_read = a_1; - uint32_t *a_1_write = a_1; - uint32_t *a_2_read = a_2; - uint32_t *a_2_write = a_2; - if(k%2==0) { - b_write += NumRegSets_b; - a_1_write += 2; - a_2_write += 4; - } - else { - b_read += NumRegSets_b; - a_1_read += 2; - a_2_read += 4; - } - // data loading - if (k + 1 < WARP_K_MMA_TENSORS) { - // updating SPTR for fragment1 and fragment2 - read_SPTR_Frag1 += 2*WARP_SIZE; - read_SPTR_Frag2 += 4*WARP_SIZE; - CopyFromSharedToRegister_AFrag<2>(a_1_write, read_SPTR_Frag1); - CopyFromSharedToRegister_AFrag<4>(a_2_write, read_SPTR_Frag2); - B_FromSharedToReg(b_write, read_SPTR, (k+1)*MMA_16); - } - // SIMT Dequant + Tensor Core computations - Dequant_32FP6_4Way(a, a_1_read, a_2_read, RPTR_Scales); // Dequantizing FP6 to FP16 at register level, dequantizing a slice each time - #pragma unroll - for (int i = 0; i < WARP_ROW_MMA_TENSORS; i++) { - if(TilingConfig::WARP_COL_MMA_TENSORS==1) - MMA_FP16_M16N8K16( c_uint_ptr[i], a[i], b_read[0] ); - else { - #pragma unroll - for (int j = 0; j < TilingConfig::WARP_COL_MMA_TENSORS/2; j++) { - MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS], a[i], b_read[j] ); - MMA_FP16_M16N8K16( c_uint_ptr[i + j * WARP_ROW_MMA_TENSORS] + 4, a[i], b_read[j] + 2 ); // c+4; b+2 - } - } - } - } -} -#endif // #ifdef PIPELINE_LEVEL_SMEM - template __device__ __forceinline__ void StoreToSharedMemoryFromRegister(float (*smem_CFrag)[TilingConfig::TILE_M + PADDING_SHARED_MEM_FOR_C_4], float c[][REG_PER_THREAD_C_TENSOR_16_16]) diff --git a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh index 48b0f968bb..4c8c39603e 100644 --- a/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh +++ b/torchao/csrc/cuda/fp6_llm/utils_parallel_dequant.cuh @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. // -// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/include/utils_parallel_dequant.cuh +// This file is modified from https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/include/utils_parallel_dequant.cuh // To support MSVC, all instances of u_int32_t are changed to uint32_t. #ifndef UTILS_PARALLELDEQUANT_CUH @@ -27,86 +27,90 @@ * Outputs: R1, R2 * Note: Simplified Exponent calculation is applied. */ -__device__ __forceinline__ void FP6_FP16_Cast_4Way(uint32_t *R1, uint32_t *R2) { - *R2 = *R1 & 0x80808080; - *R1 = *R1 >> 2; - *R1 = *R1 & 0x1f1f1f1f; - *R2 = *R2 | *R1; - *R1 = *R2 & 0x9f009f00; - *R2 = *R2 & 0x009f009f; - *R2 = *R2 << 8; -} - -/* - * Input: R1 - * Outputs: R1, R2 - * Note: Simplified Exponent calculation is NOT applied. - */ -__device__ __forceinline__ void FP6_FP16_Cast_4Way_Naive(uint32_t *R1, uint32_t *R2) { - //*R2 = *R1 & 0x80808080; - *R2 = *R1 & 0xc0c0c0c0; - *R1 = *R1 >> 2; - //*R1 = *R1 & 0x1f1f1f1f; - *R1 = *R1 & 0x0f0f0f0f; - *R2 = *R2 | *R1; +template +__device__ __forceinline__ void FPx_FP16_Cast_4Way(uint32_t *In, uint32_t *Out1, uint32_t *Out2) { + // + constexpr int RIGHT_SHIFT = 5 - EXPONENT; + constexpr int MASK1 = 0x80000000; + constexpr int MASK2 = MASK1 >> EXPONENT + MANTISSA; + constexpr int MASK3 = MASK2 & 0x7fffffff; + constexpr int MASK = MASK3 | MASK3 >> 16; // - //*R1 = *R2 & 0x9f009f00; - //*R2 = *R2 & 0x009f009f; - *R1 = *R2 & 0xcf00cf00; - if( !(*R1 & 0x40000000) && (*R1 & 0x0c000000) ) *R1 = *R1 | 0x30000000; - if( !(*R1 & 0x00004000) && (*R1 & 0x00000c00) ) *R1 = *R1 | 0x00003000; - *R2 = *R2 & 0x00cf00cf; - if( !(*R2 & 0x00400000) && (*R2 & 0x000c0000) ) *R2 = *R2 | 0x00300000; - if( !(*R2 & 0x00000040) && (*R2 & 0x0000000c) ) *R2 = *R2 | 0x00000030; + *Out1 = *In & 0x80008000; + *Out1 |= ( (*In) & MASK ) >> RIGHT_SHIFT; // - *R2 = *R2 << 8; - //*R1 = 0x3c003c00; - //*R2 = 0x3c003c00; + *In = (*In) << 8; + *Out2 = *In & 0x80008000; + *Out2 |= ( (*In) & MASK ) >> RIGHT_SHIFT; } +template __device__ __forceinline__ uint32_t MultScale(uint32_t PackedFP16Pair, half Scale) { + constexpr int BIAS_OFFSET = (int(1) << (5-1)) - (int(1) << (EXPONENT-1)); + constexpr int BIAS = int(1) << BIAS_OFFSET; + // half* FP16_1 = reinterpret_cast(&PackedFP16Pair); half* FP16_2 = FP16_1 + 1; uint32_t output; half* output_half_ptr = reinterpret_cast(&output); - output_half_ptr[0] = __hmul( __hmul(*FP16_1,__float2half(4096.0f)), Scale); - output_half_ptr[1] = __hmul( __hmul(*FP16_2,__float2half(4096.0f)), Scale); + output_half_ptr[0] = __hmul( __hmul(*FP16_1,__float2half(1.0f*BIAS)), Scale); + output_half_ptr[1] = __hmul( __hmul(*FP16_2,__float2half(1.0f*BIAS)), Scale); return output; } // MODIFICATION NOTE: to support MSVC // - u_int32_t __restrict__ Reg[][4] is changed to below. -// - u_int32_t __restrict__ *read_RPTR_Frag1 is changed to below. similarly for read_RPTR_Frag2 +// - u_int32_t __restrict__ *read_RPTR_1bit is changed to below. similarly for read_RPTR_2bit and read_RPTR_4bit +template __device__ __forceinline__ void Dequant_32FP6_4Way(uint32_t (* __restrict__ Reg)[4], - uint32_t * __restrict__ read_RPTR_Frag1, - uint32_t * __restrict__ read_RPTR_Frag2, + uint32_t * __restrict__ read_RPTR_1bit, + uint32_t * __restrict__ read_RPTR_2bit, + uint32_t * __restrict__ read_RPTR_4bit, uint32_t * Scales) { - uint32_t *OutputRegs = reinterpret_cast (Reg); - uint32_t *Frag1_PTR = read_RPTR_Frag1; - uint32_t *Frag2_PTR = read_RPTR_Frag2; - half *Scale_RPTR = reinterpret_cast(Scales); - uint32_t Packed_FP6 = 0; - uint32_t tmp = 0; + // 1+2+4 weight split + constexpr int BIT_WIDTH = 1 + EXPONENT + MANTISSA; + constexpr int USE_SEG_1BIT = BIT_WIDTH & 1; + constexpr int USE_SEG_2BIT = BIT_WIDTH & 2; + constexpr int USE_SEG_4BIT = BIT_WIDTH & 4; + // + uint32_t *OutputRegs = reinterpret_cast (Reg); + uint32_t *Frag_PTR_1bit = read_RPTR_1bit; + uint32_t *Frag_PTR_2bit = read_RPTR_2bit; + uint32_t *Frag_PTR_4bit = read_RPTR_4bit; + half *Scale_RPTR = reinterpret_cast(Scales); // Dequantizing 32 FP6, each Loop dequantizing 4 FP6 #pragma unroll(8) for(int i=0; i<8; i++) { - // Frag1 - Packed_FP6 = (*Frag1_PTR) & 0xc0c0c0c0; - if(i%4==3) Frag1_PTR++; - else (*Frag1_PTR) = (*Frag1_PTR) << 2; - // Frag2 - tmp = (*Frag2_PTR) & 0xf0f0f0f0; - tmp = tmp >> 2; - if(i%2==1) Frag2_PTR++; - else (*Frag2_PTR) = (*Frag2_PTR) << 4; - // Packed_FP6 - Packed_FP6 = Packed_FP6 | tmp; + uint32_t Packed_FP6 = 0; + uint32_t tmp = 0; + // 1bit Frag + if(USE_SEG_1BIT) { + tmp = (*Frag_PTR_1bit) & 0x80808080; + Packed_FP6 |= tmp >> (BIT_WIDTH & 0); + if(i%8==7) Frag_PTR_1bit++; + else (*Frag_PTR_1bit) = (*Frag_PTR_1bit) << 1; + } + // 2bit Frag + if(USE_SEG_2BIT) { + tmp = (*Frag_PTR_2bit) & 0xc0c0c0c0; + Packed_FP6 |= tmp >> (BIT_WIDTH & 1); + if(i%4==3) Frag_PTR_2bit++; + else (*Frag_PTR_2bit) = (*Frag_PTR_2bit) << 2; + } + // 4bit Frag2 + if(USE_SEG_4BIT) { + tmp = (*Frag_PTR_4bit) & 0xf0f0f0f0; + Packed_FP6 |= tmp >> (BIT_WIDTH & 3); + if(i%2==1) Frag_PTR_4bit++; + else (*Frag_PTR_4bit) = (*Frag_PTR_4bit) << 4; + } // - FP6_FP16_Cast_4Way(&Packed_FP6, &tmp); + uint32_t out1, out2; + FPx_FP16_Cast_4Way(&Packed_FP6, &out1, &out2); // - *OutputRegs = MultScale(Packed_FP6, Scale_RPTR[0] ); // Muliply FP16 scales + *OutputRegs = MultScale(out1, Scale_RPTR[0] ); // Muliply FP16 scales OutputRegs += 1; - *OutputRegs = MultScale(tmp, Scale_RPTR[1]); // Muliply FP16 scales + *OutputRegs = MultScale(out2, Scale_RPTR[1]); // Muliply FP16 scales OutputRegs += 1; // Updating offset for FP16 scales for every two iterations if(i%2==1) Scale_RPTR += 2; diff --git a/torchao/csrc/fp6_llm.cpp b/torchao/csrc/fp6_llm.cpp index bd787385c0..861cdbf6db 100644 --- a/torchao/csrc/fp6_llm.cpp +++ b/torchao/csrc/fp6_llm.cpp @@ -4,5 +4,5 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.impl_abstract_pystub("torchao.ops"); - m.def("fp6_llm_linear(Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor"); + m.def("quant_llm_linear(int EXPONENT, int MANTISSA, Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor"); } diff --git a/torchao/ops.py b/torchao/ops.py index 25cbfb5656..3145812a2f 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -12,34 +12,44 @@ def decorator(func): return decorator -def fp6_llm_linear(_in_feats: Tensor, _weights: Tensor, _scales: Tensor, splitK: int = 1) -> Tensor: +def quant_llm_linear( + EXPONENT: int, + MANTISSA: int, + _in_feats: Tensor, + _weights: Tensor, + _scales: Tensor, + splitK: int = 1, +) -> Tensor: """ - FP6-LLM linear layer A @ W.T. See https://arxiv.org/abs/2401.14112 for more details. + Quant-LLM linear layer A @ W.T. See https://arxiv.org/abs/2401.14112 for more details. Arguments + EXPONENT: number of exponent bits + MANTISSA: number of mantissa bits _in_feats: input activations in FP16 - _weights: packed FP6 weights. See :func:prepack_fp6_weight and :func:fp16_to_fp6 + _weights: packed FPx weights _scales: scale splitK: split K Returns output of linear layer """ - return torch.ops.torchao.fp6_llm_linear.default(_in_feats, _weights, _scales, splitK) + return torch.ops.torchao.quant_llm_linear.default(EXPONENT, MANTISSA, _in_feats, _weights, _scales, splitK) -@register_custom_op("torchao::fp6_llm_linear") -def _(_in_feats, _weights, _scales, splitK = 1): +@register_custom_op("torchao::quant_llm_linear") +def _(EXPONENT, MANTISSA, _in_feats, _weights, _scales, splitK = 1): torch._check(_in_feats.dim() == 2, lambda: f"input should be a 2d tensor, got {_in_feats.dim()}D") torch._check(_in_feats.dtype is torch.float16, lambda: f"weight must be FP16, got {_in_feats.dtype}") torch._check(_weights.dim() == 2, lambda: f"weight should be a 2d tensor, got {_weights.dim()}D") - torch._check(_weights.dtype is torch.int32, lambda: f"weight must be INT32, got {_weights.dtype}") + torch._check(_weights.dtype is torch.uint8, lambda: f"weight must be UINT8, got {_weights.dtype}") torch._check(_scales.dim() == 1, lambda: f"scale should be a 2d tensor, got {_scales.dim()}D") torch._check(_scales.dtype is torch.float16, lambda: f"scale must be FP16, got {_scales.dtype}") BS, IC = _in_feats.shape OC, _ = _weights.shape - torch._check(IC / 16 * 3 == _weights.shape[1], lambda: "Dimensions mismatched") + N_BITS = 1 + EXPONENT + MANTISSA + torch._check(IC // 8 * N_BITS == _weights.shape[1], lambda: "Dimensions mismatched") torch._check(OC == _scales.shape[0], lambda: "Dimensions mismatched") return _in_feats.new_empty((BS, OC)) diff --git a/torchao/prototype/README.md b/torchao/prototype/README.md index 633099368a..65968ad3e5 100644 --- a/torchao/prototype/README.md +++ b/torchao/prototype/README.md @@ -9,7 +9,7 @@ - `galore` - fused kernels for memory-efficient pre-training / fine-tuning per the [GaLore algorithm](https://arxiv.org/abs/2403.03507) - `galore/kernels` - `triton` kernels that fuse various steps of the `GaLore` algorithm - `galore/docs` - implementation notes and discussion of issues faced in kernel design. -- [`fp6_llm`](fp6_llm) - FP16 x FP6 mixed matmul kernel per [FP6-LLM](https://arxiv.org/abs/2401.14112) +- [`quant_llm`](quant_llm) - FP16 x FPx mixed matmul kernel per [FP6-LLM](https://arxiv.org/abs/2401.14112) #### Roadmap diff --git a/torchao/prototype/custom_fp_utils.py b/torchao/prototype/custom_fp_utils.py index 1a3e9e34cb..3af11f1710 100644 --- a/torchao/prototype/custom_fp_utils.py +++ b/torchao/prototype/custom_fp_utils.py @@ -216,7 +216,9 @@ def _fpx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor: exp_biased_f32 = (denormal_exp_biased - left_shift) << MBITS_F32 # we can update this in-place since the values won't overlap - mantissa_lp_int32[mantissa_lp_int32 == mantissa_cmp] = exp_biased_f32 | mantissa_f32 + # torch.compile() may complain unsupported operand type(s) for |: 'SymInt' and 'int' + # thus we use + instead of | here + mantissa_lp_int32[mantissa_lp_int32 == mantissa_cmp] = exp_biased_f32 + mantissa_f32 result = torch.where(denormal_mask, mantissa_lp_int32, result) diff --git a/torchao/prototype/fp6_llm/README.md b/torchao/prototype/fp6_llm/README.md deleted file mode 100644 index 767785275b..0000000000 --- a/torchao/prototype/fp6_llm/README.md +++ /dev/null @@ -1,44 +0,0 @@ -# FP6-LLM - -This is a FP16 x FP6 mixed matmul kernel optimized for io bound workloads per [FP6-LLM](https://arxiv.org/abs/2401.14112). The actual CUDA kernel is located under [csrc/cuda/fp6_llm/](../../csrc/cuda/fp6_llm/). This module provides helper functions to quantize FP32 weights to FP6 and facility to convert existing models to FP6. - -## Usage - -```python -from torchao.prototype.fp6_llm import convert_fp6_llm - -model = ... -convert_fp6_llm(model) # convert model in-place, replacing nn.Linear modules with Fp6LlmLinear - -# fully compatible with torch.compile() -model.compile(mode="max-autotune", fullgraph=True) -``` - -It's also possible to pre-process the weight and call the kernel directly. - -```python -import torch -from torchao.prototype.fp6_llm import to_scaled_tc_float6_e3m2 -from torchao.ops import fp6_llm_linear - -fp32_weight = torch.randn(1024, 512).cuda() - -# pre-process the weight. this will quantize the weight to FP6 and pack it in a special -# layout for tensor cores. refer to paper for more details. -fp6_weight, scales = to_scaled_tc_float6_e3m2(fp32_weight) - -fp16_act = torch.randn(1, 512).cuda().half() -outputs = fp6_llm_linear(fp16_act, fp6_weight, scales) # shape (1, 1024) -``` - -## TODO - -- [ ] Compile CUDA kernel for Windows -- [ ] Merge FP5 from upstream - -## Credits - -Credits to FP6-LLM authors - -- Paper: https://arxiv.org/abs/2401.14112 -- Code: https://github.com/usyd-fsalab/fp6_llm diff --git a/torchao/prototype/fp6_llm/__init__.py b/torchao/prototype/fp6_llm/__init__.py deleted file mode 100644 index d1a46339bd..0000000000 --- a/torchao/prototype/fp6_llm/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .fp6_llm import Fp6LlmLinear, convert_fp6_llm, to_scaled_tc_float6_e3m2 diff --git a/torchao/prototype/fp6_llm/fp6_llm.py b/torchao/prototype/fp6_llm/fp6_llm.py deleted file mode 100644 index 570ea13546..0000000000 --- a/torchao/prototype/fp6_llm/fp6_llm.py +++ /dev/null @@ -1,307 +0,0 @@ -import math -from typing import List, Optional, Tuple - -import torch -from torch import nn, Tensor -from torchao.prototype.mx_formats.custom_cast import f32_to_f6_e3m2_unpacked, f6_e3m2_unpacked_to_f32 -from torchao.prototype.mx_formats.constants import F6_E3M2_MAX -from torchao.ops import fp6_llm_linear - - -def _pack_2bit(x: Tensor) -> Tensor: - return (x[..., ::4] << 6) | (x[..., 1::4] << 4) | (x[..., 2::4] << 2) | x[..., 3::4] - - -def _unpack_2bit(x: Tensor) -> Tensor: - return torch.stack([x >> 6, (x >> 4) & 0b11, (x >> 2) & 0b11, x & 0b11], dim=-1).flatten(-2) - - -def _pack_4bit(x: Tensor) -> Tensor: - return (x[..., ::2] << 4) | x[..., 1::2] - - -def _unpack_4bit(x: Tensor) -> Tensor: - return torch.stack([x >> 4, x & 0b1111], dim=-1).flatten(-2) - - -# this is a literal adaptation of FP6-LLM ahead-of-time bit-level pre-packing -# https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_prepacking.h -def _to_tc_float6_e3m2_ref(tensor: Tensor) -> Tensor: - assert tensor.ndim == 2 - M, N = tensor.shape - assert (M % 64 == 0) and (N % 64 == 0) - - tensor_fp6 = f32_to_f6_e3m2_unpacked(tensor.float()) - - # Pass 1 from original code - tensor_fp6 = tensor_fp6.view(M // 64, 4, 2, 8, N // 16, 2, 8) - tensor_fp6 = tensor_fp6.permute(0, 4, 1, 5, 2, 3, 6) - tensor_fp6 = tensor_fp6.reshape(-1, 32, 2) - tensor_fp6 = tensor_fp6.permute(1, 0, 2) - tensor_fp6 = tensor_fp6.flatten() - - tensor_2bit = _pack_2bit((tensor_fp6 >> 4) & 0b11) - tensor_4bit = _pack_4bit(tensor_fp6 & 0b1111) - - # Pass 2 from original code - tensor_2bit = tensor_2bit.view(32, -1, 4).permute(1, 0, 2).flip(2) - tensor_4bit = tensor_4bit.view(32, -1, 4).permute(1, 0, 2).flip(2) - - # Pass 3 from original code - # BitInterleaving_2bit - # the 1st and 3rd permutations are needed because the author unpacks/packs the values from/to uint32 - # while we still unpack/pack the values from/to uint8 - tensor_2bit = _unpack_2bit(tensor_2bit).view(-1, 16) - tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]] - tensor_2bit = tensor_2bit[:, [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14]] - tensor_2bit = tensor_2bit[:, [12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3]] - tensor_2bit = _pack_2bit(tensor_2bit).view(-1) - - # BitInterleaving_4bit - # the 1st and 3rd permutations are needed because the author unpacks/packs the values from/to uint32 - # while we still unpack/pack the values from/to uint8 - tensor_4bit = _unpack_4bit(tensor_4bit).view(-1, 8) - tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]] - tensor_4bit = tensor_4bit[:, [1, 5, 3, 7, 0, 4, 2, 6]] - tensor_4bit = tensor_4bit[:, [4, 5, 6, 7, 0, 1, 2, 3]] - tensor_4bit = _pack_4bit(tensor_4bit).view(-1) - - return torch.cat([tensor_2bit, tensor_4bit], dim=0).view(M, -1) - - -# more optimized version of _to_tc_float6_e3m2_original() by merging ops -# https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_prepacking.h -def to_tc_float6_e3m2(tensor: Tensor) -> Tensor: - assert tensor.ndim == 2 - M, N = tensor.shape - assert (M % 64 == 0) and (N % 64 == 0) - - tensor_fp6 = f32_to_f6_e3m2_unpacked(tensor.float()) - tensor_fp6 = tensor_fp6.view(M // 64, 2, 2, 2, 8, N // 16, 2, 8) - tensor_fp6 = tensor_fp6.flip(3) - - tensor_2bit = (tensor_fp6 >> 4) & 0b11 - tensor_2bit = tensor_2bit.permute(0, 5, 1, 4, 7, 3, 2, 6) - tensor_2bit = _pack_2bit(tensor_2bit.flatten()) - - tensor_4bit = tensor_fp6 & 0b1111 - tensor_4bit = tensor_4bit.permute(0, 5, 1, 2, 4, 7, 3, 6) - tensor_4bit = _pack_4bit(tensor_4bit.flatten()) - - return torch.cat([tensor_2bit, tensor_4bit], dim=0).view(M, -1) - - -def to_scaled_tc_float6_e3m2(tensor: Tensor) -> Tuple[Tensor, Tensor]: - scale = F6_E3M2_MAX / tensor.abs().amax(1).clamp(min=1e-12) - tc_fp6_tensor = to_tc_float6_e3m2(tensor * scale.view(-1, 1)) - return tc_fp6_tensor, scale.reciprocal().half() - - -def from_tc_float6_e3m2(tensor: Tensor, dtype: torch.dtype = torch.float32) -> Tensor: - assert tensor.ndim == 2 and tensor.dtype == torch.uint8 - M = tensor.shape[0] - N = tensor.shape[1] // 3 * 4 - assert (M % 64 == 0) and (N % 64 == 0) - size_2bit = M * N // 4 - size_4bit = M * N // 2 - tensor = tensor.view(-1).view(torch.uint8) - assert tensor.numel() == size_2bit + size_4bit - - tensor_2bit, tensor_4bit = tensor.split([size_2bit, size_4bit]) - - tensor_2bit = _unpack_2bit(tensor_2bit) - tensor_2bit = tensor_2bit.view(M // 64, N // 16, 2, 8, 8, 2, 2, 2) - tensor_2bit = tensor_2bit.permute(0, 2, 6, 5, 3, 1, 7, 4) - - tensor_4bit = _unpack_4bit(tensor_4bit) - tensor_4bit = tensor_4bit.view(M // 64, N // 16, 2, 2, 8, 8, 2, 2) - tensor_4bit = tensor_4bit.permute(0, 2, 3, 6, 4, 1, 7, 5) - - tensor_fp6 = (tensor_2bit << 4) | tensor_4bit - tensor_fp6 = tensor_fp6.flip(3).reshape(M, N) - return f6_e3m2_unpacked_to_f32(tensor_fp6).to(dtype) - - -# https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py -_SPLIT_K_MAP = [ - { # tokens: [1, 64] - 3072: 18, - 4096: 13, - 5120: 10, - 6144: 9, - 8192: 6, - 10240: 5, - 14336: 7, - 28672: 7, - 57344: 7 - }, - { # tokens: [65:128] - 3072: 9, - 4096: 6, - 5120: 5, - 6144: 9, - 8192: 3, - 10240: 5, - 14336: 7, - 28672: 7, - 57344: 6 - }, - { # tokens: [129:192] - 3072: 6, - 4096: 4, - 5120: 7, - 6144: 3, - 8192: 2, - 10240: 5, - 14336: 5, - 28672: 5, - 57344: 4 - }, - { # tokens: [193:256] - 3072: 9, - 4096: 3, - 5120: 5, - 6144: 2, - 8192: 5, - 10240: 4, - 14336: 8, - 28672: 6, - 57344: 4 - }, - { # tokens: [257:320] - 3072: 7, - 4096: 5, - 5120: 2, - 6144: 5, - 8192: 4, - 10240: 1, - 14336: 3, - 28672: 3, - 57344: 4 - }, - { # tokens: [321:384] - 3072: 3, - 4096: 2, - 5120: 5, - 6144: 3, - 8192: 1, - 10240: 8, - 14336: 3, - 28672: 4, - 57344: 3 - }, - { # tokens: [385:448] - 3072: 5, - 4096: 7, - 5120: 3, - 6144: 5, - 8192: 7, - 10240: 3, - 14336: 1, - 28672: 1, - 57344: 3 - }, - { # tokens: [449:512] - 3072: 2, - 4096: 5, - 5120: 4, - 6144: 1, - 8192: 5, - 10240: 2, - 14336: 6, - 28672: 4, - 57344: 1 - }, - { # tokens: [513:576] - 3072: 2, - 4096: 3, - 5120: 1, - 6144: 1, - 8192: 3, - 10240: 3, - 14336: 3, - 28672: 1, - 57344: 1 - }, - { # tokens: [577:640] - 3072: 5, - 4096: 4, - 5120: 1, - 6144: 4, - 8192: 2, - 10240: 1, - 14336: 1, - 28672: 1, - 57344: 1 - }, - { # tokens: [641:704] - 3072: 3, - 4096: 1, - 5120: 2, - 6144: 2, - 8192: 1, - 10240: 2, - 14336: 1, - 28672: 1, - 57344: 1 - }, - { # tokens: [705:768] - 3072: 3, - 4096: 1, - 5120: 3, - 6144: 2, - 8192: 1, - 10240: 1, - 14336: 1, - 28672: 1, - 57344: 1 - } -] - - -class Fp6LlmLinear(nn.Module): - """FP6-LLM Linear layer as described in https://arxiv.org/pdf/2401.14112. - """ - - def __init__(self, weight: Tensor, scales: Tensor, bias: Optional[Tensor] = None) -> None: - super().__init__() - self.register_buffer("weight", weight.view(torch.int32)) - self.register_buffer("scales", scales) - self.register_buffer("bias", bias) - self.out_features = weight.shape[0] - self.in_features = weight.shape[1] // 3 * 4 - - def forward(self, x: Tensor) -> Tensor: - splitK = self.get_split_k(math.prod(x.shape[:-1]), self.out_features) - out = fp6_llm_linear(x.view(-1, self.in_features).half(), self.weight, self.scales, splitK=splitK) - if self.bias is not None: - out = out + self.bias - return out.view(*x.shape[:-1], self.out_features).to(x.dtype) - - @staticmethod - def get_split_k(bsize: int, out_dim: int) -> int: - # https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py - return _SPLIT_K_MAP[(bsize - 1) // 64].get(out_dim, 1) if bsize <= 768 else 1 - - @classmethod - def from_float(cls, linear: nn.Linear): - assert (linear.in_features % 64 == 0) and (linear.out_features % 256 == 0) - - fp6_weight, scale = to_scaled_tc_float6_e3m2(linear.weight.detach()) - bias = linear.bias.detach().half() if linear.bias is not None else None - return cls(fp6_weight, scale, bias) - - def extra_repr(self) -> str: - return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}' - - -def convert_fp6_llm(model: nn.Module, skip_fqn_list: Optional[List[str]] = None, cur_fqn: str = "") -> None: - for name, child in model.named_children(): - new_fqn = name if cur_fqn == "" else f"{cur_fqn}.{name}" - - if ((skip_fqn_list is None) or (new_fqn not in skip_fqn_list)) and (isinstance(child, nn.Linear)): - if (child.in_features % 64 == 0) and (child.out_features % 256 == 0): - new_child = Fp6LlmLinear.from_float(child) - setattr(model, name, new_child) - else: - convert_fp6_llm(child, skip_fqn_list, new_fqn) diff --git a/torchao/prototype/quant_llm/README.md b/torchao/prototype/quant_llm/README.md new file mode 100644 index 0000000000..631df30817 --- /dev/null +++ b/torchao/prototype/quant_llm/README.md @@ -0,0 +1,68 @@ +# Quant-LLM + +This is a FP16 x FPx mixed matmul kernel optimized for io bound workloads per [FP6-LLM](https://arxiv.org/abs/2401.14112). The actual CUDA kernel is located under [csrc/cuda/fp6_llm/](../../csrc/cuda/fp6_llm/). This module provides helper functions to quantize FP32/FP16/BF16 weights to FPx and integration with torchao API. + +## Usage + +```python +from torchao.quantization.quant_api import quantize +from torchao.prototype.quant_llm import fp6_llm_weight_only, quant_llm_fpx_weight_only + +model = ... +model.half() # not necessary, but recommeneded to maintain accuracy +quantize(model, fp6_llm_weight_only()) # convert nn.Lineaer.weight to FP6 E3M2 in-place + +# for generic FPx EyMz where x = 1 + y + z +# quantize(model, quant_llm_fpx_weight_only(2, 2)) # use FP5 E2M2 instead + +# fully compatible with torch.compile() +model.compile(mode="max-autotune", fullgraph=True) +``` + +It's also possible to pre-process the weight and call the kernel directly. + +```python +import torch +from torchao.prototype.quant_llm import to_scaled_tc_fpx +from torchao.ops import quant_llm_linear + +fp32_weight = torch.randn(1024, 512).cuda() +ebits, mbits = 3, 2 + +# pre-process the weight. this will quantize the weight to FP6 and pack it in a special +# layout for tensor cores. refer to paper for more details. +fp6_weight, scales = to_scaled_tc_fpx(fp32_weight, ebits, mbits) + +fp16_act = torch.randn(1, 512).cuda().half() +outputs = quant_llm_linear(ebits, mbits, fp16_act, fp6_weight, scales) # shape (1, 1024) +``` + +**NOTE**: +- Since this kernel's computation dtype is FP16, it is recommended to convert the model to FP16 (instead of BF16) before applying quantization and use FP16 for activations. +- Only FP6 E3M2 and FP5 E2M2 are tested and enabled in the official repo. We additionally enable support for FP6 E2M3 and FP5 E3M1. +- On most hardware, this kernel is faster than FP16 linear for batch size from 1 to 128, and slower for batch size larger than or equal to 256. See https://github.com/usyd-fsalab/fp6_llm/issues/8 for a detailed discussion. See https://github.com/pytorch/ao/pull/223 for some microbenchmark results. + +## End-to-End benchmarks + +Benchmarks are run on a machine with a single 4070Ti SUPER GPU using the scripts in [_models/llama](../../_models/llama). tokens/s is measured using [generate.py](../../_models/llama/generate.py) which generates text in a latency optimized way (batchsize=1). wikitext perplexity is measured using [eval.py](../../_models/llama/eval.py) which uses [lm_eval](https://github.com/EleutherAI/lm-evaluation-harness). The model used is [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf). + +FPx quantization is run with `--precision float16`. The rest uses the default precision of `bfloat16`. + +Quantization | wikitext perplexity | tokens/s +--------------------|---------------------|---------- +INT8 | 12.21 | 87.45 +INT4-256 (tinygemm) | -- | 157.10 +FP6 E3M2 | 12.34 | 106.76 +FP6 E2M3 | 12.23 | 106.77 +FP5 E3M1 | 12.55 | 122.69 +FP5 E2M2 | 12.47 | 122.66 +FP4 E3M0 | 14.58 | 145.55 +FP4 E2M1 | 15.01 | 146.05 +FP3 E2M0 | 74625.18 | 164.49 + +## Credits + +Credits to FP6-LLM authors + +- Paper: https://arxiv.org/abs/2401.14112 +- Code: https://github.com/usyd-fsalab/fp6_llm diff --git a/torchao/prototype/quant_llm/__init__.py b/torchao/prototype/quant_llm/__init__.py new file mode 100644 index 0000000000..4f1479c401 --- /dev/null +++ b/torchao/prototype/quant_llm/__init__.py @@ -0,0 +1 @@ +from .quant_llm import QuantLlmLinearWeight, fp6_llm_weight_only, quant_llm_fpx_weight_only, to_scaled_tc_fpx, from_scaled_tc_fpx diff --git a/torchao/prototype/quant_llm/quant_llm.py b/torchao/prototype/quant_llm/quant_llm.py new file mode 100644 index 0000000000..8e4fae465d --- /dev/null +++ b/torchao/prototype/quant_llm/quant_llm.py @@ -0,0 +1,463 @@ +from functools import reduce +from typing import Tuple + +import torch +from torch import Tensor +from torch.utils._python_dispatch import return_and_correct_aliasing +from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32, _n_ones +from torchao.ops import quant_llm_linear +from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE + + +_ONES_TABLE = [_n_ones(i) for i in range(8)] + + +def _pack(x: Tensor, n_bits: int) -> Tensor: + return reduce(torch.bitwise_or, [x[..., i::(8 // n_bits)] << (8 - (i + 1) * n_bits) for i in range(8 // n_bits)]) + + +def _unpack(x: Tensor, n_bits: int) -> Tensor: + return torch.stack([(x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) for i in range(8 // n_bits)], dim=-1).flatten(-2) + + +# https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h#L87-L116 +def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor: + # the original code unpacks/packs the values from/to uint32 while we unpack/pack the values from/to uint8 + # thus, we need to reverse byte order within a uint32 word. + x = x.reshape(-1, 4).flip(1) + + x = _unpack(x, n_bits) + x = x.view(-1, 4 * (8 // n_bits)) + + if not undo: + bit_order = { + 1: [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31, + 0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30], + 2: [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14], + 4: [1, 5, 3, 7, 0, 4, 2, 6], + }[n_bits] + + else: + # this is inverse of the above, obtained by running + # [v.index(i) for i in range(len(v))] + bit_order = { + 1: [16, 0, 24, 8, 17, 1, 25, 9, 18, 2, 26, 10, 19, 3, 27, 11, + 20, 4, 28, 12, 21, 5, 29, 13, 22, 6, 30, 14, 23, 7, 31, 15], + 2: [8, 0, 12, 4, 9, 1, 13, 5, 10, 2, 14, 6, 11, 3, 15, 7], + 4: [4, 0, 6, 2, 5, 1, 7, 3], + }[n_bits] + + x = x[:, bit_order] + x = _pack(x, n_bits) + + # reverse byte order within a uint32 word again. + x = x.reshape(-1, 4).flip(1) + return x.flatten() + + +# this is a literal adaptation of FP6-LLM ahead-of-time bit-level pre-packing +# https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h +def _pack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor: + assert tensor.ndim == 2, tensor.dtype == torch.uint8 + M, N = tensor.shape + assert (M % 64 == 0) and (N % 64 == 0) + + # Pass 1 from original code + tensor = tensor.view(M // 64, 4, 2, 8, N // 16, 2, 8) + tensor = tensor.permute(0, 4, 1, 5, 2, 3, 6) + tensor = tensor.reshape(-1, 32, 2) + tensor = tensor.permute(1, 0, 2) + tensor = tensor.flatten() + + used_bits = 0 + fragments = [] + + for y in [1, 2, 4]: + if nbits & y: + mask = (1 << y) - 1 + tensor_ybit = (tensor >> (nbits - used_bits - y)) & mask + tensor_ybit = _pack(tensor_ybit, y) + + tensor_ybit = tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2) # Pass 2 from original code + tensor_ybit = _bit_interleave(tensor_ybit.flatten(), y) # Pass 3 from original code + fragments.append(tensor_ybit) + used_bits += y + + return torch.cat(fragments, dim=0).view(M, -1) + + +# more optimized version of _pack_tc_fpx() for FP6 by merging ops +def _pack_tc_fp6(tensor: Tensor) -> Tensor: + assert tensor.ndim == 2, tensor.dtype == torch.uint8 + M, N = tensor.shape + assert (M % 64 == 0) and (N % 64 == 0) + + tensor = tensor.view(M // 64, 2, 2, 2, 8, N // 16, 2, 8) + tensor = tensor.flip(3) + + tensor_2bit = (tensor >> 4) & 0b11 + tensor_2bit = tensor_2bit.permute(0, 5, 1, 4, 7, 3, 2, 6) + tensor_2bit = _pack(tensor_2bit.flatten(), 2) + + tensor_4bit = tensor & 0b1111 + tensor_4bit = tensor_4bit.permute(0, 5, 1, 2, 4, 7, 3, 6) + tensor_4bit = _pack(tensor_4bit.flatten(), 4) + + return torch.cat([tensor_2bit, tensor_4bit], dim=0).view(M, -1) + + +# currently only optimize for TC-FP6 packing +def pack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor: + if nbits == 6: + return _pack_tc_fp6(tensor) + return _pack_tc_fpx(tensor, nbits) + + +def to_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, Tensor]: + # _n_ones() is not compatible with torch.compile() due to << operator + # https://github.com/pytorch/pytorch/issues/119152 + # exp_bias = _n_ones(ebits - 1) + # max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits)) + + # workaround: global lookup table + exp_bias = _ONES_TABLE[ebits - 1] + max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * (_ONES_TABLE[mbits + 1] / (2 ** mbits)) + + tensor = tensor.float() + scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal + tensor_fpx = _f32_to_fpx_unpacked(tensor / scale.view(-1, 1), ebits, mbits) + tensor_tc_fpx = pack_tc_fpx(tensor_fpx, 1 + ebits + mbits) + return tensor_tc_fpx, scale.half() + + +# inverse of _pack_tc_fpx() +def _unpack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor: + assert tensor.ndim == 2 and tensor.dtype == torch.uint8 + M = tensor.shape[0] + size = tensor.numel() + tensor = tensor.flatten() + offset = 0 + used_bits = 0 + + tensor_fpx = None + + for y in [1, 2, 4]: + if nbits & y: + size_ybit = size // nbits * y + tensor_ybit = tensor[offset : offset + size_ybit] + offset += size_ybit + + tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3 + tensor_ybit = tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2) # undo Pass 2 + + tensor_ybit = _unpack(tensor_ybit.flatten(), y) + tensor_ybit = tensor_ybit << (nbits - used_bits - y) + used_bits += y + + if tensor_fpx is None: + tensor_fpx = tensor_ybit + else: + tensor_fpx |= tensor_ybit + + # undo Pass 1 + tensor_fpx = tensor_fpx.view(32, -1, 2).permute(1, 0, 2) + tensor_fpx = tensor_fpx.reshape(M // 64, -1, 4, 2, 2, 8, 8) + tensor_fpx = tensor_fpx.permute(0, 2, 4, 5, 1, 3, 6) + tensor_fpx = tensor_fpx.reshape(M, -1) + return tensor_fpx + + +# more optimized version of _unpack_tc_fpx() for FP6 by merging ops +# inverse of _unpack_tc_fp6() +def _unpack_tc_fp6(tensor: Tensor) -> Tensor: + assert tensor.ndim == 2 and tensor.dtype == torch.uint8 + M = tensor.shape[0] + N = tensor.shape[1] // 3 * 4 + assert (M % 64 == 0) and (N % 64 == 0) + size_2bit = M * N // 4 + size_4bit = M * N // 2 + tensor = tensor.view(-1) + assert tensor.numel() == size_2bit + size_4bit + + tensor_2bit, tensor_4bit = tensor.split([size_2bit, size_4bit]) + + tensor_2bit = _unpack(tensor_2bit, 2) + tensor_2bit = tensor_2bit.view(M // 64, N // 16, 2, 8, 8, 2, 2, 2) + tensor_2bit = tensor_2bit.permute(0, 2, 6, 5, 3, 1, 7, 4) + + tensor_4bit = _unpack(tensor_4bit, 4) + tensor_4bit = tensor_4bit.view(M // 64, N // 16, 2, 2, 8, 8, 2, 2) + tensor_4bit = tensor_4bit.permute(0, 2, 3, 6, 4, 1, 7, 5) + + tensor_fp6 = (tensor_2bit << 4) | tensor_4bit + tensor_fp6 = tensor_fp6.flip(3).reshape(M, N) + return tensor_fp6 + + +def unpack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor: + if nbits == 6: + return _unpack_tc_fp6(tensor) + return _unpack_tc_fpx(tensor, nbits) + + +def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Tensor: + fpx_unpacked = unpack_tc_fpx(tensor, 1 + ebits + mbits) + tensor = _fpx_unpacked_to_f32(fpx_unpacked, ebits, mbits) + if scale is not None: + tensor = tensor * scale.float().view(-1, 1) + return tensor + + +# https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py +_SPLIT_K_MAP = [ + { # tokens: [1, 64] + 3072: 18, + 4096: 13, + 5120: 10, + 6144: 9, + 8192: 6, + 10240: 5, + 14336: 7, + 28672: 7, + 57344: 7 + }, + { # tokens: [65:128] + 3072: 9, + 4096: 6, + 5120: 5, + 6144: 9, + 8192: 3, + 10240: 5, + 14336: 7, + 28672: 7, + 57344: 6 + }, + { # tokens: [129:192] + 3072: 6, + 4096: 4, + 5120: 7, + 6144: 3, + 8192: 2, + 10240: 5, + 14336: 5, + 28672: 5, + 57344: 4 + }, + { # tokens: [193:256] + 3072: 9, + 4096: 3, + 5120: 5, + 6144: 2, + 8192: 5, + 10240: 4, + 14336: 8, + 28672: 6, + 57344: 4 + }, + { # tokens: [257:320] + 3072: 7, + 4096: 5, + 5120: 2, + 6144: 5, + 8192: 4, + 10240: 1, + 14336: 3, + 28672: 3, + 57344: 4 + }, + { # tokens: [321:384] + 3072: 3, + 4096: 2, + 5120: 5, + 6144: 3, + 8192: 1, + 10240: 8, + 14336: 3, + 28672: 4, + 57344: 3 + }, + { # tokens: [385:448] + 3072: 5, + 4096: 7, + 5120: 3, + 6144: 5, + 8192: 7, + 10240: 3, + 14336: 1, + 28672: 1, + 57344: 3 + }, + { # tokens: [449:512] + 3072: 2, + 4096: 5, + 5120: 4, + 6144: 1, + 8192: 5, + 10240: 2, + 14336: 6, + 28672: 4, + 57344: 1 + }, + { # tokens: [513:576] + 3072: 2, + 4096: 3, + 5120: 1, + 6144: 1, + 8192: 3, + 10240: 3, + 14336: 3, + 28672: 1, + 57344: 1 + }, + { # tokens: [577:640] + 3072: 5, + 4096: 4, + 5120: 1, + 6144: 4, + 8192: 2, + 10240: 1, + 14336: 1, + 28672: 1, + 57344: 1 + }, + { # tokens: [641:704] + 3072: 3, + 4096: 1, + 5120: 2, + 6144: 2, + 8192: 1, + 10240: 2, + 14336: 1, + 28672: 1, + 57344: 1 + }, + { # tokens: [705:768] + 3072: 3, + 4096: 1, + 5120: 3, + 6144: 2, + 8192: 1, + 10240: 1, + 14336: 1, + 28672: 1, + 57344: 1 + } +] + + +class QuantLlmLinearWeight(Tensor): + implements = classmethod(_implements) + + @staticmethod + def __new__(cls, fpx_data: Tensor, scale: Tensor, ebits: int, mbits: int): + assert fpx_data.ndim == 2 + assert fpx_data.dtype == torch.uint8 + shape = (fpx_data.shape[0], fpx_data.shape[1] // (1 + ebits + mbits) * 8) + + return Tensor._make_wrapper_subclass( + cls, + shape, + device=fpx_data.device, + requires_grad=False, + ) + + def __init__(self, fpx_data: Tensor, scale: Tensor, ebits: int, mbits: int): + self.fpx_data = fpx_data + self.scale = scale + self.ebits = ebits + self.mbits = mbits + + def __tensor_flatten__(self): + return ["fpx_data", "scale"], [self.ebits, self.mbits] + + @classmethod + def __tensor_unflatten__(cls, tensor_data_dict, tensor_attributes, outer_size=None, outer_stride=None): + return cls(tensor_data_dict["fpx_data"], tensor_data_dict["scale"], *tensor_attributes) + + @classmethod + def from_float(cls, input_float: Tensor, ebits: int, mbits: int): + fpx_data, scale = to_scaled_tc_fpx(input_float, ebits, mbits) + return cls(fpx_data, scale, ebits, mbits) + + def dequantize(self, output_dtype=None): + output_dtype = output_dtype or torch.get_default_dtype() + return from_scaled_tc_fpx(self.fpx_data, self.ebits, self.mbits, self.scale).to(output_dtype) + + def __repr__(self): + dtype = f"fp{1 + self.ebits + self.mbits}_e{self.ebits}m{self.mbits}" + return ( + f"{self.__class__.__name__}(dtype={dtype}, shape={self.shape}, " + f"device={self.device}, requires_grad={self.requires_grad})" + ) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.fpx_data), + fn(self.scale), + self.ebits, + self.mbits, + ) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + kwargs = {} if kwargs is None else kwargs + + if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]: + return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](*args, **kwargs) + + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]: + return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](func, *args, **kwargs) + + raise NotImplementedError(f"{cls.name} dispatch: attempting to run {func}, this is not supported") + + +@QuantLlmLinearWeight.implements(torch.nn.functional.linear) +def _(*args, **kwargs): + act = args[0] + weight = args[1] + bias = args[2] if len(args) >= 3 else None + assert isinstance(weight, QuantLlmLinearWeight) + + out_dim, in_dim = weight.shape + act_reshaped = act.view(-1, in_dim).half() + + # https://github.com/microsoft/DeepSpeed/blob/3a3a6db3332e339cc9fd94efd4982f6d60635a3d/deepspeed/inference/v2/kernels/core_ops/cuda_linear/cuda_linear.py + bsize = act_reshaped.shape[0] + splitK = _SPLIT_K_MAP[(bsize - 1) // 64].get(out_dim, 1) if bsize <= 768 else 1 + + out = quant_llm_linear( + weight.ebits, + weight.mbits, + act_reshaped, + weight.fpx_data, + weight.scale, + splitK=splitK, + ) + + if bias is not None: + out += bias + + return out.view(*act.shape[:-1], out_dim).to(act.dtype) + + +@QuantLlmLinearWeight.implements(torch.ops.aten.detach.default) +def _(func, *args, **kwargs): + return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)) + + +def quant_llm_fpx_weight_only(ebits: int, mbits: int): + def apply_quant_llm(weight: Tensor) -> Tensor: + out_dim, in_dim = weight.shape + if (in_dim % 64 != 0) or (out_dim % 256 != 0): + return weight + return QuantLlmLinearWeight.from_float(weight, ebits, mbits) + return apply_quant_llm + + +def fp6_llm_weight_only(): + return quant_llm_fpx_weight_only(3, 2)