From 4ca3985be603e6496da7ec57adf1942c8b32a78e Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sun, 26 May 2024 02:17:18 +0800 Subject: [PATCH] Improve primitives for FP6 quant (#248) --- dev-requirements.txt | 3 + docs/source/api_ref_dtypes.rst | 2 + setup.py | 3 +- test/dtypes/test_float6_e3m2.py | 127 +++++++++ test/test_ops.py | 34 +-- torchao/__init__.py | 13 +- torchao/csrc/cuda/fp6_llm/weight_quant.cu | 69 +---- torchao/csrc/fp6_llm/float6_e3m2.cpp | 319 ++++++++++++++++++++++ torchao/csrc/fp6_llm/fp6_llm.cpp | 8 +- torchao/dtypes/__init__.py | 3 + torchao/dtypes/float6_e3m2.py | 178 ++++++++++++ torchao/ops.py | 37 +-- 12 files changed, 679 insertions(+), 117 deletions(-) create mode 100644 test/dtypes/test_float6_e3m2.py create mode 100644 torchao/csrc/fp6_llm/float6_e3m2.cpp create mode 100644 torchao/dtypes/float6_e3m2.py diff --git a/dev-requirements.txt b/dev-requirements.txt index 6dadb274a..156e8766d 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -12,3 +12,6 @@ pandas # Custom CUDA Extensions ninja + +# for FP6-LLM (can be removed once we remove fp16_to_fp6_original()) +qtorch diff --git a/docs/source/api_ref_dtypes.rst b/docs/source/api_ref_dtypes.rst index 4cb797beb..36c3c9b4e 100644 --- a/docs/source/api_ref_dtypes.rst +++ b/docs/source/api_ref_dtypes.rst @@ -12,6 +12,8 @@ torchao.dtypes to_nf4 UInt4Tensor + to_float6_e3m2 + from_float6_e3m2 .. _NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring diff --git a/setup.py b/setup.py index 5d1f32da2..65ec21e15 100644 --- a/setup.py +++ b/setup.py @@ -46,11 +46,12 @@ def get_extensions(): use_cuda = torch.cuda.is_available() and CUDA_HOME is not None extension = CUDAExtension if use_cuda else CppExtension - extra_link_args = [] + extra_link_args = ["-fopenmp"] extra_compile_args = { "cxx": [ "-O3" if not debug_mode else "-O0", "-fdiagnostics-color=always", + "-fopenmp", ], "nvcc": [ "-O3" if not debug_mode else "-O0", diff --git a/test/dtypes/test_float6_e3m2.py b/test/dtypes/test_float6_e3m2.py new file mode 100644 index 000000000..b82150473 --- /dev/null +++ b/test/dtypes/test_float6_e3m2.py @@ -0,0 +1,127 @@ +import torch +from torch.testing._internal.common_utils import ( + TestCase, + instantiate_parametrized_tests, + parametrize, + run_tests, +) +from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2 + + +_DTYPES = [torch.float32, torch.float16, torch.bfloat16] +_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + + +class TestFp6(TestCase): + + @parametrize("device", _DEVICES) + @parametrize("dtype", _DTYPES) + @parametrize( + "input_output", + [ + (0.0, 0b000000), # exact values + (1.0, 0b001100), # normal numbers + (1.25, 0b001101), + (28.0, 0b011111), # max + (0.1875, 0b000011), # subnormal number + (0.0625, 0b000001), # min + (29.0, 0b011111), # normal round down + (26.0, 0b011110), # normal round to nearest even + (0.1251, 0b000010), # subnormal round down + (0.0314, 0b000001), # subnormal round up + (0.03, 0b000000), # underflow + ], + ) + def test_to_float6_e3m2_no_bit_packing_correctness(self, device, dtype, input_output): + input, output = input_output + input = torch.tensor(input, device=device, dtype=dtype) + assert to_float6_e3m2(input, no_bit_packing=True).item() == output + + @parametrize("device", _DEVICES) + @parametrize("dtype", _DTYPES) + def test_to_float6_e3m2_bit_packing_correctness(self, device, dtype): + x = torch.randn(128, 128, device=device, dtype=dtype) + results_unpacked = to_float6_e3m2(x, no_bit_packing=True) + results_packed = to_float6_e3m2(x) + + val0, val1, val2, val3 = results_unpacked.unflatten(-1, (-1, 4)).unbind(-1) + bits0 = (val0 << 2) | (val1 >> 4) # 0000 0011 + bits1 = (val1 << 4) | (val2 >> 2) # 1111 2222 + bits2 = (val2 << 6) | (val3); # 2233 3333 + + expected_packed = torch.stack([bits0, bits1, bits2], dim=-1).flatten(-2) + assert (results_packed == expected_packed).all() + + @parametrize("device", _DEVICES) + @parametrize("shape", [(), (0,), (10,), (20, 20)]) + def test_to_float6_e3m2_no_bit_packing_shape(self, device, shape): + x = torch.randn(shape, device=device) + result = to_float6_e3m2(x, no_bit_packing=True) + assert result.shape == shape + + @parametrize("device", _DEVICES) + @parametrize("shape", [(4,), (20, 20)]) + def test_to_float6_e3m2_bit_packing_shape(self, device, shape): + x = torch.randn(shape, device=device) + result = to_float6_e3m2(x) + assert result.shape == shape[:-1] + (shape[-1] // 4 * 3,) + + @parametrize("device", _DEVICES) + @parametrize("dtype", _DTYPES) + @parametrize("no_bit_packing", [False, True]) + def test_to_float6_e3m2_compile(self, device, dtype, no_bit_packing): + x = torch.randn(20, 20, device=device, dtype=dtype) + expected = to_float6_e3m2(x, no_bit_packing=no_bit_packing) + + to_float6_e3m2_compiled = torch.compile(to_float6_e3m2) + actual = to_float6_e3m2_compiled(x, no_bit_packing=no_bit_packing) + torch.testing.assert_close(actual, expected) + + @parametrize("device", _DEVICES) + @parametrize( + "input_output", + [ + (0b000000, 0.0), + (0b001100, 1.0), + (0b011111, 28.0), # max + (0b000001, 0.0625), # min + (0b001110, 1.5), + (0b000011, 0.1875), # subnormal + ], + ) + def test_from_float6_e3m2_no_bit_packing_correctness(self, device, input_output): + input, output = input_output + input = torch.tensor(input, device=device, dtype=torch.uint8) + assert from_float6_e3m2(input, no_bit_packing=True).item() == output + + @parametrize("device", _DEVICES) + def test_from_float6_e3m2_bit_packing_correctness(self, device): + x = torch.randint(256, (128, 128 // 4 * 3), device=device, dtype=torch.uint8) + actual = from_float6_e3m2(x) + + bits0, bits1, bits2 = x.unflatten(-1, (-1, 3)).unbind(-1) + x_unpacked0 = bits0 >> 2 + x_unpacked1 = ((bits0 & 0x3) << 4) | (bits1 >> 4) + x_unpacked2 = ((bits1 & 0xF) << 2) | (bits2 >> 6) + x_unpacked3 = bits2 & 0x3F + + x_unpacked = torch.stack([x_unpacked0, x_unpacked1, x_unpacked2, x_unpacked3], dim=-1).flatten(-2) + expected = from_float6_e3m2(x_unpacked, no_bit_packing=True) + torch.testing.assert_close(actual, expected) + + @parametrize("device", _DEVICES) + @parametrize("no_bit_packing", [False, True]) + def test_from_float6_e3m2_compile(self, device, no_bit_packing): + x = torch.randint(256, size=(20, 15), device=device, dtype=torch.uint8) + expected = from_float6_e3m2(x, no_bit_packing=no_bit_packing) + + from_float6_e3m2_compiled = torch.compile(from_float6_e3m2) + actual = from_float6_e3m2_compiled(x, no_bit_packing=no_bit_packing) + torch.testing.assert_close(actual, expected) + + +instantiate_parametrized_tests(TestFp6) + + +if __name__ == "__main__": + run_tests() diff --git a/test/test_ops.py b/test/test_ops.py index 6ce6a4afb..4e463b4e2 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -50,24 +50,21 @@ def test_prepack_fp6_weight(self): opcheck(torch.ops.torchao.prepack_fp6_weight, (fp6_weight,), test_utils=test_utils) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - def test_fp16_to_fp6(self): + def test_fp16_to_fp6_original(self): OC = 256 IC = 256 - - # in this fp6, we use 3 bits for exponent and 2 bits for mantissa - # also, we don't have nan/inf - fp6_absmax = 28.0 # 2 ** (0b111 - 0b011) * (1 + 0.5 + 0.25), where E=111, M=11 - fp6_absmin = 0.0625 # 2 ** (-0b010) * 0.25, where E=000, M=01 (subnormal number) fp16_weight = torch.randn((OC, IC), dtype=torch.float16) - fp16_weight.clip_(-fp6_absmax, fp6_absmax) - fp16_weight[fp16_weight.abs() < fp6_absmin] = 0 + + # the original FP16->FP6 kernel checks for overflow/underflow + fp16_weight.clip_(-28.0, 28.0) + fp16_weight[fp16_weight.abs() < 0.0625] = 0.0 # smoke test - torchao.ops.fp16_to_fp6(fp16_weight) + torchao.ops.fp16_to_fp6_original(fp16_weight) # comprehensive testing test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] - opcheck(torch.ops.torchao.fp16_to_fp6, (fp16_weight,), test_utils=test_utils) + opcheck(torch.ops.torchao.fp16_to_fp6_original, (fp16_weight,), test_utils=test_utils) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_fp16act_fp6weight_linear(self): @@ -89,19 +86,6 @@ def test_fp16act_fp6weight_linear(self): test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] opcheck(torch.ops.torchao.fp16act_fp6weight_linear, (act_cuda, weight_cuda, scale_cuda, splitK), test_utils=test_utils) - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - def test_fp6_weight_dequant(self): - OC = 256 - IC = 256 - fp6_weight, fp16_scale, _ = self._create_fp6_inputs(0, OC, IC) - - # smoke test - torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale) - - # comprehensive testing - test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] - opcheck(torch.ops.torchao.fp6_weight_dequant, (fp6_weight, fp16_scale), test_utils=test_utils) - # adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py @parameterized.expand([(1, 2048, 4096, 5), (2, 8192, 8192, 6)]) @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") @@ -115,8 +99,8 @@ def test_fp6_matmul_correctness(self, BS, OC, IC, splitK): results_fp6 = torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK) - fp16_weight = torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale).cuda() - results_fp16 = act_cuda @ fp16_weight.T + fp16_weight = torchao.dtypes.from_float6_e3m2(fp6_weight.view(torch.uint8), dtype=torch.float16) * fp16_scale[:, None] + results_fp16 = act_cuda @ fp16_weight.cuda().T error = (results_fp6 - results_fp16).abs() relative_error = error / results_fp16.abs() diff --git a/torchao/__init__.py b/torchao/__init__.py index c982e09a0..c8f04c1d9 100644 --- a/torchao/__init__.py +++ b/torchao/__init__.py @@ -1,9 +1,3 @@ -from torchao.quantization import ( - apply_weight_only_int8_quant, - apply_dynamic_quant, - autoquant, -) -from . import dtypes import torch _IS_FBCODE = ( hasattr(torch._utils_internal, "IS_FBSOURCE") and @@ -14,6 +8,13 @@ from . import _C from . import ops +from torchao.quantization import ( + apply_weight_only_int8_quant, + apply_dynamic_quant, + autoquant, +) +from . import dtypes + __all__ = [ "dtypes", "apply_dynamic_quant", diff --git a/torchao/csrc/cuda/fp6_llm/weight_quant.cu b/torchao/csrc/cuda/fp6_llm/weight_quant.cu index d29f70be0..b519cbfb0 100644 --- a/torchao/csrc/cuda/fp6_llm/weight_quant.cu +++ b/torchao/csrc/cuda/fp6_llm/weight_quant.cu @@ -13,7 +13,6 @@ // limitations under the License. // // This file is adapted from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_quant.h -// and https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_dequant.h #include #include @@ -120,41 +119,6 @@ void weight_prepacking_fp16_to_fp6(uint16_t* weight_16bit, } } -void DeQuantMatrix_FP6_To_FP16(half* A_16bit_h, unsigned char* A_6bit_h, size_t M, size_t K, half* scale) { - assert(M%64==0); // Currently, M must be a multiple of 64. - assert(K%64==0); // Currently, K must be a multiple of 64. - size_t TotalSizeInByte = M*K*6/8; - // - half* OutPTR = A_16bit_h; - for(size_t i=0; i>2)&0x1f); - unsigned char B2 = (A_6bit_h[i*3+0]<<6) | ((A_6bit_h[i*3+1]>>2)&0xfc); - B2 = (B2&0x80) | ((B2>>2)&0x1f); - unsigned char B3 = (A_6bit_h[i*3+1]<<4) | ((A_6bit_h[i*3+2]>>4)&0xfc); - B3 = (B3&0x80) | ((B3>>2)&0x1f); - unsigned char B4 = A_6bit_h[i*3+2]<<2; - B4 = (B4&0x80) | ((B4>>2)&0x1f); - half FP1, FP2, FP3, FP4; - unsigned char *PTR1, *PTR2, *PTR3, *PTR4; - PTR1 = reinterpret_cast(&FP1); - PTR2 = reinterpret_cast(&FP2); - PTR3 = reinterpret_cast(&FP3); - PTR4 = reinterpret_cast(&FP4); - PTR1[0] = 0; PTR1[1] = B1; // small endian for X86 CPU - PTR2[0] = 0; PTR2[1] = B2; - PTR3[0] = 0; PTR3[1] = B3; - PTR4[0] = 0; PTR4[1] = B4; - OutPTR[0] = __float2half_rn ( __half2float(FP1) * 4096.0f * __half2float(scale[(4*i)/K]) ); - OutPTR[1] = __float2half_rn ( __half2float(FP2) * 4096.0f * __half2float(scale[(4*i)/K]) ); - OutPTR[2] = __float2half_rn ( __half2float(FP3) * 4096.0f * __half2float(scale[(4*i)/K]) ); - OutPTR[3] = __float2half_rn ( __half2float(FP4) * 4096.0f * __half2float(scale[(4*i)/K]) ); - // - OutPTR +=4; - } -} - - #include #include #include @@ -162,7 +126,7 @@ void DeQuantMatrix_FP6_To_FP16(half* A_16bit_h, unsigned char* A_6bit_h, size_t namespace torchao { // https://github.com/microsoft/DeepSpeed/blob/0fc19b6a320cf8aa0a5f6c2b1fa310bae9a70d94/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels.cpp#L194 -at::Tensor fp16_to_fp6_cpu(at::Tensor fp16_tensor) +at::Tensor fp16_to_fp6_original_cpu(at::Tensor fp16_tensor) { TORCH_CHECK(fp16_tensor.dim() == 2, "weight must be 2-dimensional"); TORCH_CHECK(fp16_tensor.scalar_type() == torch::kFloat16, "weight must be FP16"); @@ -183,37 +147,8 @@ at::Tensor fp16_to_fp6_cpu(at::Tensor fp16_tensor) return packed_fp6_tensor; } -/* - * Dequant a FP6 matrix to a equivalent FP16 matrix using CPUs. - * A useful tool to construct input matrices for the FP16 GEMM baseline. - * [Input] - * fp6_tensor: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights. - * fp16_scale: half tensor of shape [OC]; // for row-wise quantization. - * [Output] - * fp16_tensor: half tensor of shape [OC, IC]. - */ -at::Tensor weight_matrix_dequant_cpu(at::Tensor fp6_tensor, at::Tensor fp16_scale) -{ - int OC = fp6_tensor.size(0); - TORCH_CHECK(fp6_tensor.size(1) % 3 == 0); - int IC = fp6_tensor.size(1) / 3 * 16; - TORCH_CHECK(fp16_scale.size(0) == OC); - // - auto fp6_tensor_ptr = reinterpret_cast(fp6_tensor.data_ptr()); - auto fp16_scale_ptr = reinterpret_cast(fp16_scale.data_ptr()); - // - auto options = at::TensorOptions().dtype(at::kHalf).device(fp16_scale.device()); - at::Tensor fp16_tensor = at::empty({OC, IC}, options); - auto fp16_tensor_ptr = reinterpret_cast(fp16_tensor.data_ptr()); - // - DeQuantMatrix_FP6_To_FP16(fp16_tensor_ptr, fp6_tensor_ptr, OC, IC, fp16_scale_ptr); - // - return fp16_tensor; -} - TORCH_LIBRARY_IMPL(torchao, CPU, m) { - m.impl("torchao::fp16_to_fp6", &fp16_to_fp6_cpu); - m.impl("torchao::fp6_weight_dequant", &weight_matrix_dequant_cpu); + m.impl("torchao::fp16_to_fp6_original", &fp16_to_fp6_original_cpu); } } diff --git a/torchao/csrc/fp6_llm/float6_e3m2.cpp b/torchao/csrc/fp6_llm/float6_e3m2.cpp new file mode 100644 index 000000000..16d71f51d --- /dev/null +++ b/torchao/csrc/fp6_llm/float6_e3m2.cpp @@ -0,0 +1,319 @@ +#include +#include +#include + +#include +#include +#include + + +class float6_e3m2_nan_inf : public std::invalid_argument { +public: + float6_e3m2_nan_inf() : std::invalid_argument("Encounter +/-inf or NaN, which is not representable in float6_e3m2.") { } +}; + +class float6_e3m2_overflow : public std::invalid_argument { +public: + float6_e3m2_overflow() : std::invalid_argument("float6_e3m2 overflow. float6_e3m2 cannot represent +/-inf. Make sure input < 30.0") { } +}; + +// we need to do this because C++17 does not allow using struct as template non-type parameter +// use the upper 16 bits for num exponent, lower 16 bits for num mantissa +static constexpr uint32_t encode_fp_spec(uint32_t n_exp, uint32_t n_man) { return (n_exp << 16u) | n_man; } +static constexpr uint32_t FP32_SPEC = encode_fp_spec(8u, 23u); +static constexpr uint32_t FP16_SPEC = encode_fp_spec(5u, 10u); +static constexpr uint32_t BF16_SPEC = encode_fp_spec(8u, 7u); + +// NOTE: only works for len < 32 +static constexpr uint32_t ones_mask(uint32_t len) { return (1u << len) - 1u; } + +// inspired by __internal_float2half() and float2half() from "cuda_fp16.hpp" +template +static uint8_t to_float6_e3m2_bits(T bits_) { + constexpr uint32_t N_EXP = FP_SPEC >> 16u; + constexpr uint32_t N_MAN = FP_SPEC & ones_mask(16u); + constexpr uint32_t N_EXP_MAN = N_EXP + N_MAN; + constexpr uint32_t EXP_BIAS_DIFF = ones_mask(N_EXP - 1u) - 3u; + + // sanity checks. will be removed in template instantiation. + // minimum 1 bit above FP6 (3 exponent bits and 2 mantissa bits) to avoid edge cases. + static_assert(N_EXP >= 4, "Number of exponent bits must be >= 4."); + static_assert(N_MAN >= 3, "Number of mantissa bits must be >= 3."); + + uint32_t bits = bits_; // bit extension + uint32_t sign = bits >> N_EXP_MAN << 5u; + bits &= ones_mask(N_EXP_MAN); // clear sign bit + uint32_t result, remainder; + + // all exponent bits are 1s + if (bits >= (ones_mask(N_EXP) << N_MAN)) throw float6_e3m2_nan_inf(); + + // max FP6 (28) + half of least significand (2) = 30 (assume N_MAN >= 3) + if (bits >= (((EXP_BIAS_DIFF + 7u) << N_MAN) | (0x7u << (N_MAN - 3u)))) throw float6_e3m2_overflow(); + + // FP6 normal number (E>=001) + if (bits >= ((EXP_BIAS_DIFF + 1u) << N_MAN)) { + remainder = bits << (32u - (N_MAN - 2u)); // shift the truncated bits to most significant position + bits -= (EXP_BIAS_DIFF << N_MAN); // update exponent + result = sign | (bits >> (N_MAN - 2u)); + } + // FP6 subnormal number (more than half of min FP6 subnormal = 0.0625 * 0.5) + else if (bits > ((EXP_BIAS_DIFF - 2u) << N_MAN)) { + uint32_t exp = bits >> N_MAN; + uint32_t man = bits & ones_mask(N_MAN); + + // to make subnormal FP6 from normal FP16 + // step 1: add implicit 1 to mantissa + man |= (1u << N_MAN); + + // step 2: shift mantissa right so that exponent value is equal to + // exponent value of FP6 subnormal, which is -2 (equivalent to E=001) + uint32_t shift = EXP_BIAS_DIFF + 1u - exp; + remainder = man << (32u - (N_MAN - 2u + shift)); // shift the truncated bits to most significant position + result = sign | (man >> (shift + (N_MAN - 2u))); // implicit E=000 + } + // FP6 underflow. E=000, M=00 + else { + remainder = 0u; + result = sign; + } + + // round to nearest even + if ((remainder > 0x8000'0000u) || ((remainder == 0x8000'0000u) && (result & 0x1u))) { + result += 1; + } + return result; +} + +// assume the lower 6 bits contain the data. +template +static T from_float6_e3m2_bits(uint8_t a) { + constexpr uint32_t N_EXP = FP_SPEC >> 16u; + constexpr uint32_t N_MAN = FP_SPEC & ones_mask(16u); + constexpr uint32_t N_EXP_MAN = N_EXP + N_MAN; + constexpr uint32_t EXP_BIAS_DIFF = ones_mask(N_EXP - 1u) - 3u; + + uint32_t bits = a; // bit extension + uint32_t sign = bits >> 5u; + uint32_t exp = (bits >> 2u) & 0x7u; + uint32_t man = bits & 0x3u; + + if (exp > 0u) { // FP6 normal numbers + exp += EXP_BIAS_DIFF; + } else if (man > 0u) { // FP6 denormal numbers + uint32_t shift = (man >= 0b10u) ? 1u : 2u; + man = (man << shift) & 0x3u; // shift and remove explicit 1 + exp = 1u + EXP_BIAS_DIFF - shift; + } + // don't need to handle zero, since E=000 and M=00 + + uint32_t result = (sign << N_EXP_MAN) | (exp << N_MAN) | (man << (N_MAN - 2u)); + return static_cast(result); +} + +namespace torchao { + +template void to_float6_e3m2_unpacked_cpu_impl(const T *bits_ptr, uint8_t *fp6_ptr, int n) { + // exception within OpenMP parallel region must be caught. + // set a flag when exception occurs, then re-raise it. + bool found_nan_inf = false; + bool found_overflow = false; + +#pragma omp parallel for + for (int i = 0; i < n; i++) { + try { fp6_ptr[i] = to_float6_e3m2_bits(bits_ptr[i]); } + catch (float6_e3m2_nan_inf const &) { found_nan_inf = true; } + catch (float6_e3m2_overflow const &) { found_overflow = true; } + } + + if (found_nan_inf) throw float6_e3m2_nan_inf(); + if (found_overflow) throw float6_e3m2_overflow(); +} + +// this is useful for debugging +at::Tensor to_float6_e3m2_unpacked_cpu(at::Tensor fp_tensor) { + TORCH_CHECK(fp_tensor.is_contiguous()); + TORCH_CHECK(fp_tensor.is_cpu()); + + at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp_tensor.device()); + at::Tensor fp6_tensor = at::empty(fp_tensor.sizes(), options); + uint8_t *fp6_ptr = fp6_tensor.data_ptr(); + + int n = fp_tensor.numel(); + auto dtype = fp_tensor.dtype(); + + if (dtype == torch::kFloat32) { + const uint32_t *fp32_ptr = reinterpret_cast(fp_tensor.data_ptr()); + to_float6_e3m2_unpacked_cpu_impl(fp32_ptr, fp6_ptr, n); + + } else if (dtype == torch::kFloat16) { + const uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); + to_float6_e3m2_unpacked_cpu_impl(fp16_ptr, fp6_ptr, n); + + } else if (dtype == torch::kBFloat16) { + const uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); + to_float6_e3m2_unpacked_cpu_impl(bf16_ptr, fp6_ptr, n); + + } else { + throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); + } + + return fp6_tensor; +} + +template void to_float6_e3m2_packed_cpu_impl(const T *bits_ptr, uint8_t *fp6_ptr, int n) { + // exception within OpenMP parallel region must be caught. + // set a flag when exception occurs, then re-raise it. + bool found_nan_inf = false; + bool found_overflow = false; + +#pragma omp parallel for + for (int i = 0; i < n / 4; i++) { + try { + uint8_t val0 = to_float6_e3m2_bits(bits_ptr[i * 4]); + uint8_t val1 = to_float6_e3m2_bits(bits_ptr[i * 4 + 1]); + uint8_t val2 = to_float6_e3m2_bits(bits_ptr[i * 4 + 2]); + uint8_t val3 = to_float6_e3m2_bits(bits_ptr[i * 4 + 3]); + + fp6_ptr[i * 3] = (val0 << 2) | (val1 >> 4); // 0000 0011 + fp6_ptr[i * 3 + 1] = (val1 << 4) | (val2 >> 2); // 1111 2222 + fp6_ptr[i * 3 + 2] = (val2 << 6) | (val3); // 2233 3333 + } + catch (float6_e3m2_nan_inf const &) { found_nan_inf = true; } + catch (float6_e3m2_overflow const &) { found_overflow = true; } + } + + if (found_nan_inf) throw float6_e3m2_nan_inf(); + if (found_overflow) throw float6_e3m2_overflow(); +} + +at::Tensor to_float6_e3m2_packed_cpu(at::Tensor fp_tensor) { + TORCH_CHECK(fp_tensor.is_contiguous()); + TORCH_CHECK(fp_tensor.is_cpu()); + TORCH_CHECK(fp_tensor.ndimension() == 2); + + int M = fp_tensor.size(0); + int N = fp_tensor.size(1); + TORCH_CHECK(N % 4 == 0, "Last dimension must be a multiple of 4, receives ", N); + + at::TensorOptions options = at::TensorOptions().dtype(torch::kUInt8).device(fp_tensor.device()); + at::Tensor fp6_tensor = at::empty({M, N * 3 / 4}, options); + uint8_t *fp6_ptr = fp6_tensor.data_ptr(); + + int n = fp_tensor.numel(); + auto dtype = fp_tensor.dtype(); + + if (dtype == torch::kFloat32) { + const uint32_t *fp32_ptr = reinterpret_cast(fp_tensor.data_ptr()); + to_float6_e3m2_packed_cpu_impl(fp32_ptr, fp6_ptr, n); + + } else if (dtype == torch::kFloat16) { + const uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); + to_float6_e3m2_packed_cpu_impl(fp16_ptr, fp6_ptr, n); + + } else if (dtype == torch::kBFloat16) { + const uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); + to_float6_e3m2_packed_cpu_impl(bf16_ptr, fp6_ptr, n); + + } else { + throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); + } + + return fp6_tensor; +} + +template +void from_float6_e3m2_unpacked_cpu_impl(const uint8_t *fp6_ptr, T *fp_ptr, int n) { +#pragma omp parallel for + for (int i = 0; i < n; i++) + fp_ptr[i] = from_float6_e3m2_bits(fp6_ptr[i]); +} + +at::Tensor from_float6_e3m2_unpacked_cpu(at::Tensor fp6_tensor, c10::ScalarType dtype) { + TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8); + TORCH_CHECK(fp6_tensor.is_contiguous()); + TORCH_CHECK(fp6_tensor.is_cpu()); + + at::TensorOptions options = at::TensorOptions().dtype(dtype).device(fp6_tensor.device()); + at::Tensor fp_tensor = at::empty(fp6_tensor.sizes(), options); + + const uint8_t *fp6_ptr = fp6_tensor.data_ptr(); + int n = fp6_tensor.numel(); + + if (dtype == torch::kFloat32) { + uint32_t *fp32_ptr = reinterpret_cast(fp_tensor.data_ptr()); + from_float6_e3m2_unpacked_cpu_impl(fp6_ptr, fp32_ptr, n); + + } else if (dtype == torch::kFloat16) { + uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); + from_float6_e3m2_unpacked_cpu_impl(fp6_ptr, fp16_ptr, n); + + } else if (dtype == torch::kBFloat16) { + uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); + from_float6_e3m2_unpacked_cpu_impl(fp6_ptr, bf16_ptr, n); + + } else { + throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); + } + + return fp_tensor; +} + +template +void from_float6_e3m2_packed_cpu_impl(const uint8_t *fp6_ptr, T *fp_ptr, int n) { +#pragma omp parallel for + for (int i = 0; i < n / 3; i++) { + uint8_t bits0 = fp6_ptr[i * 3]; // 0000 0011 + uint8_t bits1 = fp6_ptr[i * 3 + 1]; // 1111 2222 + uint8_t bits2 = fp6_ptr[i * 3 + 2]; // 2233 3333 + + fp_ptr[i * 4] = from_float6_e3m2_bits(bits0 >> 2); + fp_ptr[i * 4 + 1] = from_float6_e3m2_bits(((bits0 & 0x3u) << 4) | (bits1 >> 4)); + fp_ptr[i * 4 + 2] = from_float6_e3m2_bits(((bits1 & 0xFu) << 2) | (bits2 >> 6)); + fp_ptr[i * 4 + 3] = from_float6_e3m2_bits(bits2 & 0x3Fu); + } +} + +at::Tensor from_float6_e3m2_packed_cpu(at::Tensor fp6_tensor, c10::ScalarType dtype) { + TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8); + TORCH_CHECK(fp6_tensor.is_contiguous()); + TORCH_CHECK(fp6_tensor.is_cpu()); + TORCH_CHECK(fp6_tensor.ndimension() == 2); + + int M = fp6_tensor.size(0); + int N = fp6_tensor.size(1); + TORCH_CHECK(N % 3 == 0, "Last dimension must be a multiple of 3, receives ", N); + + at::TensorOptions options = at::TensorOptions().dtype(dtype).device(fp6_tensor.device()); + at::Tensor fp_tensor = at::empty({M, N / 3 * 4}, options); + + const uint8_t *fp6_ptr = fp6_tensor.data_ptr(); + int n = fp6_tensor.numel(); + + if (dtype == torch::kFloat32) { + uint32_t *fp32_ptr = reinterpret_cast(fp_tensor.data_ptr()); + from_float6_e3m2_packed_cpu_impl(fp6_ptr, fp32_ptr, n); + + } else if (dtype == torch::kFloat16) { + uint16_t *fp16_ptr = reinterpret_cast(fp_tensor.data_ptr()); + from_float6_e3m2_packed_cpu_impl(fp6_ptr, fp16_ptr, n); + + } else if (dtype == torch::kBFloat16) { + uint16_t *bf16_ptr = reinterpret_cast(fp_tensor.data_ptr()); + from_float6_e3m2_packed_cpu_impl(fp6_ptr, bf16_ptr, n); + + } else { + throw std::invalid_argument("Only FP32, FP16, and BF16 inputs are accepted."); + } + + return fp_tensor; +} + +TORCH_LIBRARY_IMPL(torchao, CPU, m) { + m.impl("torchao::to_float6_e3m2_unpacked_cpu", &to_float6_e3m2_unpacked_cpu); + m.impl("torchao::to_float6_e3m2_packed_cpu", &to_float6_e3m2_packed_cpu); + m.impl("torchao::from_float6_e3m2_unpacked_cpu", &from_float6_e3m2_unpacked_cpu); + m.impl("torchao::from_float6_e3m2_packed_cpu", &from_float6_e3m2_packed_cpu); +} + +} diff --git a/torchao/csrc/fp6_llm/fp6_llm.cpp b/torchao/csrc/fp6_llm/fp6_llm.cpp index 794c79df1..5239593bb 100644 --- a/torchao/csrc/fp6_llm/fp6_llm.cpp +++ b/torchao/csrc/fp6_llm/fp6_llm.cpp @@ -6,6 +6,10 @@ TORCH_LIBRARY_FRAGMENT(torchao, m) { m.impl_abstract_pystub("torchao.ops"); m.def("fp16act_fp6weight_linear(Tensor _in_feats, Tensor _weights, Tensor _scales, int splitK) -> Tensor"); m.def("prepack_fp6_weight(Tensor fp6_tensor) -> Tensor"); - m.def("fp16_to_fp6(Tensor fp16_tensor) -> Tensor"); - m.def("fp6_weight_dequant(Tensor fp6_tensor, Tensor fp16_scale) -> Tensor"); + m.def("fp16_to_fp6_original(Tensor fp16_tensor) -> Tensor"); + + m.def("to_float6_e3m2_unpacked_cpu(Tensor tensor) -> Tensor"); + m.def("to_float6_e3m2_packed_cpu(Tensor tensor) -> Tensor"); + m.def("from_float6_e3m2_unpacked_cpu(Tensor tensor, ScalarType dtype) -> Tensor"); + m.def("from_float6_e3m2_packed_cpu(Tensor tensor, ScalarType dtype) -> Tensor"); } diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index dccd22f3d..d12a6da56 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,6 +1,7 @@ from .nf4tensor import NF4Tensor, to_nf4 from .uint4 import UInt4Tensor from .aqt import AffineQuantizedTensor, to_aq +from .float6_e3m2 import to_float6_e3m2, from_float6_e3m2 __all__ = [ "NF4Tensor", @@ -8,4 +9,6 @@ "UInt4Tensor" "AffineQuantizedTensor", "to_aq", + "to_float6_e3m2", + "from_float6_e3m2", ] diff --git a/torchao/dtypes/float6_e3m2.py b/torchao/dtypes/float6_e3m2.py new file mode 100644 index 000000000..0c27838d0 --- /dev/null +++ b/torchao/dtypes/float6_e3m2.py @@ -0,0 +1,178 @@ +import torch +from torch import Tensor +from torch.utils._triton import has_triton +from torchao.ops import to_float6_e3m2_packed_cpu, to_float6_e3m2_unpacked_cpu, from_float6_e3m2_packed_cpu, from_float6_e3m2_unpacked_cpu + + +# some useful constants +FLOAT6_E3M2_MAX = 28.0 +FLOAT6_E3M2_SMALLEST_SUBNORMAL = 0.0625 + + +if has_triton(): + import triton + from triton import language as tl + + # see _to_float6_e3m2_pt() for explanation + @triton.jit + def _triton_float32_to_float6_e3m2(x: tl.tensor): + x = x.to(tl.float32) + x = x * 2.0 ** (-127 + 3) + bits = x.to(tl.int32, bitcast=True) + + sign = ((bits >> 31) & 0x1) << 5 + exp_and_man = (bits >> 21) & 0x1F + result = sign | exp_and_man + + remainder = bits & 0x1F_FFFF + do_round_up = (remainder > 0x10_0000) | ((remainder == 0x10_0000) & ((result & 1) == 1)) + result = tl.where(do_round_up, result + 1, result) + return result.to(tl.uint8) + + @triton.jit + def _to_float6_e3m2_triton_kernel(in_ptr, out_ptr, n, BLOCK_SIZE: tl.constexpr): + offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n + + # strided memory read. there will be uncoalesced memory access + val0 = _triton_float32_to_float6_e3m2(tl.load(in_ptr + offsets * 4, mask)) + val1 = _triton_float32_to_float6_e3m2(tl.load(in_ptr + offsets * 4 + 1, mask)) + val2 = _triton_float32_to_float6_e3m2(tl.load(in_ptr + offsets * 4 + 2, mask)) + val3 = _triton_float32_to_float6_e3m2(tl.load(in_ptr + offsets * 4 + 3, mask)) + + # bit packing + bits0 = (val0 << 2) | (val1 >> 4) # 0000 0011 + bits1 = (val1 << 4) | (val2 >> 2) # 1111 2222 + bits2 = (val2 << 6) | (val3); # 2233 3333 + + # strided memory write. there will be uncoalesced memory access + tl.store(out_ptr + offsets * 3, bits0, mask) + tl.store(out_ptr + offsets * 3 + 1, bits1, mask) + tl.store(out_ptr + offsets * 3 + 2, bits2, mask) + + def _to_float6_e3m2_triton(tensor: Tensor) -> Tensor: + out_shape = tensor.shape[:-1] + (tensor.shape[-1] // 4 * 3,) + output = torch.empty(out_shape, device=tensor.device, dtype=torch.uint8) + + n = tensor.numel() + grid_size = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"] * 4),) + _to_float6_e3m2_triton_kernel[grid_size](tensor, output, n, BLOCK_SIZE=256) + + return output + +else: + _to_float6_e3m2_triton = None + + +# NOTE: This implementation requires FP32 denormal numbers to be handled correctly. +# On CPU, denormal numbers might be flushed to zero for performance gain (FTZ and DAZ flags). +def _to_float6_e3m2_pt(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: + tensor = tensor.float() + + # correct exponent bias. this also handles subnormal numbers correctly + tensor = tensor * 2.0 ** (-127 + 3) + bits = tensor.view(torch.int32) + + sign = ((bits >> 31) & 0x1) << 5 + exp_and_man = (bits >> 21) & 0x1F + result = sign | exp_and_man + + # round to nearest even + remainder = bits & 0x1F_FFFF # truncated mantissa bits + do_round_up = (remainder > 0x10_0000) | ((remainder == 0x10_0000) & ((result & 1) == 1)) + result = torch.where(do_round_up, result + 1, result) + result = result.to(torch.uint8) + + if no_bit_packing: + return result + + # bit packing + val0, val1, val2, val3 = result.unflatten(-1, (-1, 4)).unbind(-1) + bits0 = (val0 << 2) | (val1 >> 4) # 0000 0011 + bits1 = (val1 << 4) | (val2 >> 2) # 1111 2222 + bits2 = (val2 << 6) | (val3); # 2233 3333 + return torch.stack([bits0, bits1, bits2], dim=-1).flatten(-2) + + +def to_float6_e3m2(tensor: Tensor, no_bit_packing: bool = False) -> Tensor: + """Convert input tensor to FP6. This particular FP6 format has 3 exponent bits and 2 mantissa + bits. By default, bit packing is performed: every 4 FP6 values are packed as 3 uint8 values + (4 x 6 bits = 3 x 8 bits). + + Args: + tensor: Input tensor. The last dimension must be divisible by 4 (unless ``no_bit_packing=False``) + no_bit_packing: Whether to not perform bit packing. Setting this to ``True`` can be useful for + observing the bit patterns and debugging. + + Returns: + :class:`torch.Tensor`: FP6 tensor, stored as uint8 data. If ``no_bit_packing=False``, the last + dimension of output tensor is 3/4 of that of input tensor. + + Note: + This FP6 format does not represent +/-inf and NaN. Thus, make sure that input tensor does + not have +/-inf or NaN values, and no values with magnitude >= 30 (largest number in FP6 is 28. + All numbers >= 28 and < 30 will be rounded down to 28, while >= 30 will overflow). + + See also :func:`from_float6_e3m2` + """ + if not no_bit_packing: + assert tensor.shape[-1] % 4 == 0, "Last dim must be divisible by 4" + + if tensor.is_cpu: + if no_bit_packing: + return to_float6_e3m2_unpacked_cpu(tensor) + + *leading_dims, last_dim = tensor.shape + return to_float6_e3m2_packed_cpu(tensor.view(-1, last_dim)).view(*leading_dims, -1) + + # torch.compile() cannot generate fused bit-packing triton kernel, + # thus we write custom triton kernel for this specific case. + if tensor.is_cuda and not no_bit_packing and _to_float6_e3m2_triton is not None: + return _to_float6_e3m2_triton(tensor) + + else: + return _to_float6_e3m2_pt(tensor, no_bit_packing=no_bit_packing) + + +# NOTE: This implementation requires FP32 denormal numbers to be handled correctly. +# On CPU, denormal numbers might be flushed to zero for performance gain (FTZ and DAZ flags). +def _pt_float6_e3m2_to_float32(tensor: Tensor) -> Tensor: + bits = tensor.to(torch.int32) # bit extension + sign = bits >> 5 << 31 + exp_and_man = (bits & 0x1F) << 21 + results = sign | exp_and_man + + results = results.view(torch.float32) + return results * 2.0 ** (127 - 3) # exponent bias correction + + +def from_float6_e3m2(tensor: Tensor, no_bit_packing: bool = False, dtype: torch.dtype = torch.float32) -> Tensor: + """Convert an FP6 tensor (created by :func:`to_float6_e3m2`) to FP32. + + Args: + tensor: FP6 tensor, stored as uint8 data. If ``no_bit_packing=False``, the last dimension must + be divisible by 3. + no_bit_packing: whether the input does not have bit packing. + dtype: returned dtype. + + Returns: + :class:`torch.Tensor`: FP32 tensor. If ``no_bit_packing=False``, the last dimension of output + tensor is 4/3 of that of input tensor. + """ + assert tensor.dtype == torch.uint8 + if no_bit_packing: + if tensor.is_cpu: + return from_float6_e3m2_unpacked_cpu(tensor, dtype) + + return _pt_float6_e3m2_to_float32(tensor).to(dtype) + + assert tensor.shape[-1] % 3 == 0, "Last dim must be divisible by 3" + if tensor.is_cpu: + return from_float6_e3m2_packed_cpu(tensor, dtype) + + bits0, bits1, bits2 = tensor.unflatten(-1, (-1, 3)).unbind(-1) + val0 = _pt_float6_e3m2_to_float32(bits0 >> 2).to(dtype) + val1 = _pt_float6_e3m2_to_float32(((bits0 & 0x3) << 4) | (bits1 >> 4)).to(dtype) + val2 = _pt_float6_e3m2_to_float32(((bits1 & 0xF) << 2) | (bits2 >> 6)).to(dtype) + val3 = _pt_float6_e3m2_to_float32(bits2 & 0x3F).to(dtype) + return torch.stack([val0, val1, val2, val3], dim=-1).flatten(-2) diff --git a/torchao/ops.py b/torchao/ops.py index 05a166839..7fce2de22 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -2,6 +2,7 @@ from torch import Tensor from torchao.quantization.utils import TORCH_VERSION_AFTER_2_4 + def register_custom_op(name): def decorator(func): if TORCH_VERSION_AFTER_2_4: @@ -11,7 +12,6 @@ def decorator(func): return decorator - def prepack_fp6_weight(fp6_weight: Tensor) -> Tensor: """ Pack FP6 tensor in a layout for use with FP6-LLM. See https://arxiv.org/abs/2401.14112 for more details. @@ -32,14 +32,20 @@ def _(fp6_weight): return torch.empty_like(fp6_weight) -def fp16_to_fp6(fp16_tensor: Tensor) -> Tensor: +def fp16_to_fp6_original(fp16_tensor: Tensor) -> Tensor: """ - Pack FP16 tensor (containing only FP6 values) into FP6 tensor. + Pack FP16 tensor to FP6 tensor. qtorch is required to use this function. """ - return torch.ops.torchao.fp16_to_fp6.default(fp16_tensor) + try: + from qtorch.quant import float_quantize + except ImportError as e: + raise RuntimeError("Please install qtorch to use this function") from e + + fp16_tensor = float_quantize(fp16_tensor.float(), 3, 2, rounding="nearest").half() + return torch.ops.torchao.fp16_to_fp6_original.default(fp16_tensor) -@register_custom_op("torchao::fp16_to_fp6") +@register_custom_op("torchao::fp16_to_fp6_original") def _(fp16_tensor): torch._check(fp16_tensor.dim() == 2, lambda: f"weight should be a 2d tensor, got {fp16_tensor.dim()}D") torch._check(fp16_tensor.dtype is torch.float16, lambda: f"weight must be FP16, got {fp16_tensor.dtype}") @@ -81,18 +87,17 @@ def _(_in_feats, _weights, _scales, splitK = 1): return _in_feats.new_empty((BS, OC)) -def fp6_weight_dequant(fp6_tensor: Tensor, fp16_scale: Tensor) -> Tensor: - return torch.ops.torchao.fp6_weight_dequant.default(fp6_tensor, fp16_scale) +def to_float6_e3m2_unpacked_cpu(tensor: Tensor) -> Tensor: + return torch.ops.torchao.to_float6_e3m2_unpacked_cpu.default(tensor) + + +def to_float6_e3m2_packed_cpu(tensor: Tensor) -> Tensor: + return torch.ops.torchao.to_float6_e3m2_packed_cpu.default(tensor) -@register_custom_op("torchao::fp6_weight_dequant") -def _(fp6_tensor, fp16_scale): - torch._check(fp6_tensor.dim() == 2, lambda: f"weight should be a 2d tensor, got {fp6_tensor.dim()}D") - torch._check(fp6_tensor.dtype is torch.int32, lambda: f"weight must be INT32, got {fp6_tensor.dtype}") - torch._check(fp16_scale.dim() == 1, lambda: f"scale should be a 2d tensor, got {fp16_scale.dim()}D") - torch._check(fp16_scale.dtype is torch.float16, lambda: f"scale must be FP16, got {fp16_scale.dtype}") +def from_float6_e3m2_unpacked_cpu(tensor: Tensor, dtype: torch.dtype) -> Tensor: + return torch.ops.torchao.from_float6_e3m2_unpacked_cpu.default(tensor, dtype) - OC, _IC = fp6_tensor.shape - torch._check(OC == fp16_scale.shape[0], lambda: "Dimensions mismatched") - return fp16_scale.new_empty((OC, _IC * 16 // 3)) +def from_float6_e3m2_packed_cpu(tensor: Tensor, dtype: torch.dtype) -> Tensor: + return torch.ops.torchao.from_float6_e3m2_packed_cpu.default(tensor, dtype)