Skip to content

Commit

Permalink
Improve primitives for FP6 quant (pytorch#248)
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst authored May 25, 2024
1 parent a7bc592 commit 4ca3985
Show file tree
Hide file tree
Showing 12 changed files with 679 additions and 117 deletions.
3 changes: 3 additions & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@ pandas

# Custom CUDA Extensions
ninja

# for FP6-LLM (can be removed once we remove fp16_to_fp6_original())
qtorch
2 changes: 2 additions & 0 deletions docs/source/api_ref_dtypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ torchao.dtypes

to_nf4
UInt4Tensor
to_float6_e3m2
from_float6_e3m2

..
_NF4Tensor - add after fixing torchao/dtypes/nf4tensor.py:docstring
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,12 @@ def get_extensions():
use_cuda = torch.cuda.is_available() and CUDA_HOME is not None
extension = CUDAExtension if use_cuda else CppExtension

extra_link_args = []
extra_link_args = ["-fopenmp"]
extra_compile_args = {
"cxx": [
"-O3" if not debug_mode else "-O0",
"-fdiagnostics-color=always",
"-fopenmp",
],
"nvcc": [
"-O3" if not debug_mode else "-O0",
Expand Down
127 changes: 127 additions & 0 deletions test/dtypes/test_float6_e3m2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import torch
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)
from torchao.dtypes.float6_e3m2 import to_float6_e3m2, from_float6_e3m2


_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])


class TestFp6(TestCase):

@parametrize("device", _DEVICES)
@parametrize("dtype", _DTYPES)
@parametrize(
"input_output",
[
(0.0, 0b000000), # exact values
(1.0, 0b001100), # normal numbers
(1.25, 0b001101),
(28.0, 0b011111), # max
(0.1875, 0b000011), # subnormal number
(0.0625, 0b000001), # min
(29.0, 0b011111), # normal round down
(26.0, 0b011110), # normal round to nearest even
(0.1251, 0b000010), # subnormal round down
(0.0314, 0b000001), # subnormal round up
(0.03, 0b000000), # underflow
],
)
def test_to_float6_e3m2_no_bit_packing_correctness(self, device, dtype, input_output):
input, output = input_output
input = torch.tensor(input, device=device, dtype=dtype)
assert to_float6_e3m2(input, no_bit_packing=True).item() == output

@parametrize("device", _DEVICES)
@parametrize("dtype", _DTYPES)
def test_to_float6_e3m2_bit_packing_correctness(self, device, dtype):
x = torch.randn(128, 128, device=device, dtype=dtype)
results_unpacked = to_float6_e3m2(x, no_bit_packing=True)
results_packed = to_float6_e3m2(x)

val0, val1, val2, val3 = results_unpacked.unflatten(-1, (-1, 4)).unbind(-1)
bits0 = (val0 << 2) | (val1 >> 4) # 0000 0011
bits1 = (val1 << 4) | (val2 >> 2) # 1111 2222
bits2 = (val2 << 6) | (val3); # 2233 3333

expected_packed = torch.stack([bits0, bits1, bits2], dim=-1).flatten(-2)
assert (results_packed == expected_packed).all()

@parametrize("device", _DEVICES)
@parametrize("shape", [(), (0,), (10,), (20, 20)])
def test_to_float6_e3m2_no_bit_packing_shape(self, device, shape):
x = torch.randn(shape, device=device)
result = to_float6_e3m2(x, no_bit_packing=True)
assert result.shape == shape

@parametrize("device", _DEVICES)
@parametrize("shape", [(4,), (20, 20)])
def test_to_float6_e3m2_bit_packing_shape(self, device, shape):
x = torch.randn(shape, device=device)
result = to_float6_e3m2(x)
assert result.shape == shape[:-1] + (shape[-1] // 4 * 3,)

@parametrize("device", _DEVICES)
@parametrize("dtype", _DTYPES)
@parametrize("no_bit_packing", [False, True])
def test_to_float6_e3m2_compile(self, device, dtype, no_bit_packing):
x = torch.randn(20, 20, device=device, dtype=dtype)
expected = to_float6_e3m2(x, no_bit_packing=no_bit_packing)

to_float6_e3m2_compiled = torch.compile(to_float6_e3m2)
actual = to_float6_e3m2_compiled(x, no_bit_packing=no_bit_packing)
torch.testing.assert_close(actual, expected)

@parametrize("device", _DEVICES)
@parametrize(
"input_output",
[
(0b000000, 0.0),
(0b001100, 1.0),
(0b011111, 28.0), # max
(0b000001, 0.0625), # min
(0b001110, 1.5),
(0b000011, 0.1875), # subnormal
],
)
def test_from_float6_e3m2_no_bit_packing_correctness(self, device, input_output):
input, output = input_output
input = torch.tensor(input, device=device, dtype=torch.uint8)
assert from_float6_e3m2(input, no_bit_packing=True).item() == output

@parametrize("device", _DEVICES)
def test_from_float6_e3m2_bit_packing_correctness(self, device):
x = torch.randint(256, (128, 128 // 4 * 3), device=device, dtype=torch.uint8)
actual = from_float6_e3m2(x)

bits0, bits1, bits2 = x.unflatten(-1, (-1, 3)).unbind(-1)
x_unpacked0 = bits0 >> 2
x_unpacked1 = ((bits0 & 0x3) << 4) | (bits1 >> 4)
x_unpacked2 = ((bits1 & 0xF) << 2) | (bits2 >> 6)
x_unpacked3 = bits2 & 0x3F

x_unpacked = torch.stack([x_unpacked0, x_unpacked1, x_unpacked2, x_unpacked3], dim=-1).flatten(-2)
expected = from_float6_e3m2(x_unpacked, no_bit_packing=True)
torch.testing.assert_close(actual, expected)

@parametrize("device", _DEVICES)
@parametrize("no_bit_packing", [False, True])
def test_from_float6_e3m2_compile(self, device, no_bit_packing):
x = torch.randint(256, size=(20, 15), device=device, dtype=torch.uint8)
expected = from_float6_e3m2(x, no_bit_packing=no_bit_packing)

from_float6_e3m2_compiled = torch.compile(from_float6_e3m2)
actual = from_float6_e3m2_compiled(x, no_bit_packing=no_bit_packing)
torch.testing.assert_close(actual, expected)


instantiate_parametrized_tests(TestFp6)


if __name__ == "__main__":
run_tests()
34 changes: 9 additions & 25 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,21 @@ def test_prepack_fp6_weight(self):
opcheck(torch.ops.torchao.prepack_fp6_weight, (fp6_weight,), test_utils=test_utils)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp16_to_fp6(self):
def test_fp16_to_fp6_original(self):
OC = 256
IC = 256

# in this fp6, we use 3 bits for exponent and 2 bits for mantissa
# also, we don't have nan/inf
fp6_absmax = 28.0 # 2 ** (0b111 - 0b011) * (1 + 0.5 + 0.25), where E=111, M=11
fp6_absmin = 0.0625 # 2 ** (-0b010) * 0.25, where E=000, M=01 (subnormal number)
fp16_weight = torch.randn((OC, IC), dtype=torch.float16)
fp16_weight.clip_(-fp6_absmax, fp6_absmax)
fp16_weight[fp16_weight.abs() < fp6_absmin] = 0

# the original FP16->FP6 kernel checks for overflow/underflow
fp16_weight.clip_(-28.0, 28.0)
fp16_weight[fp16_weight.abs() < 0.0625] = 0.0

# smoke test
torchao.ops.fp16_to_fp6(fp16_weight)
torchao.ops.fp16_to_fp6_original(fp16_weight)

# comprehensive testing
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.fp16_to_fp6, (fp16_weight,), test_utils=test_utils)
opcheck(torch.ops.torchao.fp16_to_fp6_original, (fp16_weight,), test_utils=test_utils)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp16act_fp6weight_linear(self):
Expand All @@ -89,19 +86,6 @@ def test_fp16act_fp6weight_linear(self):
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.fp16act_fp6weight_linear, (act_cuda, weight_cuda, scale_cuda, splitK), test_utils=test_utils)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_fp6_weight_dequant(self):
OC = 256
IC = 256
fp6_weight, fp16_scale, _ = self._create_fp6_inputs(0, OC, IC)

# smoke test
torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale)

# comprehensive testing
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"]
opcheck(torch.ops.torchao.fp6_weight_dequant, (fp6_weight, fp16_scale), test_utils=test_utils)

# adapted from https://github.com/usyd-fsalab/fp6_llm/blob/main/tests/python/kernel_test.py
@parameterized.expand([(1, 2048, 4096, 5), (2, 8192, 8192, 6)])
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
Expand All @@ -115,8 +99,8 @@ def test_fp6_matmul_correctness(self, BS, OC, IC, splitK):

results_fp6 = torchao.ops.fp16act_fp6weight_linear(act_cuda, weight_cuda, scale_cuda, splitK)

fp16_weight = torchao.ops.fp6_weight_dequant(fp6_weight, fp16_scale).cuda()
results_fp16 = act_cuda @ fp16_weight.T
fp16_weight = torchao.dtypes.from_float6_e3m2(fp6_weight.view(torch.uint8), dtype=torch.float16) * fp16_scale[:, None]
results_fp16 = act_cuda @ fp16_weight.cuda().T

error = (results_fp6 - results_fp16).abs()
relative_error = error / results_fp16.abs()
Expand Down
13 changes: 7 additions & 6 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
from torchao.quantization import (
apply_weight_only_int8_quant,
apply_dynamic_quant,
autoquant,
)
from . import dtypes
import torch
_IS_FBCODE = (
hasattr(torch._utils_internal, "IS_FBSOURCE") and
Expand All @@ -14,6 +8,13 @@
from . import _C
from . import ops

from torchao.quantization import (
apply_weight_only_int8_quant,
apply_dynamic_quant,
autoquant,
)
from . import dtypes

__all__ = [
"dtypes",
"apply_dynamic_quant",
Expand Down
69 changes: 2 additions & 67 deletions torchao/csrc/cuda/fp6_llm/weight_quant.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.
//
// This file is adapted from https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_quant.h
// and https://github.com/usyd-fsalab/fp6_llm/blob/ce76774bcfc26b325c1b558abcf1935026d9abbc/fp6_llm/csrc/utils/weight_dequant.h

#include <cuda_fp16.h>
#include <iostream>
Expand Down Expand Up @@ -120,49 +119,14 @@ void weight_prepacking_fp16_to_fp6(uint16_t* weight_16bit,
}
}

void DeQuantMatrix_FP6_To_FP16(half* A_16bit_h, unsigned char* A_6bit_h, size_t M, size_t K, half* scale) {
assert(M%64==0); // Currently, M must be a multiple of 64.
assert(K%64==0); // Currently, K must be a multiple of 64.
size_t TotalSizeInByte = M*K*6/8;
//
half* OutPTR = A_16bit_h;
for(size_t i=0; i<TotalSizeInByte/3; i++) { // 4 FP6 = 3 Bytes for each Loop
unsigned char B1 = A_6bit_h[i*3+0] & 0xfc;
B1 = (B1&0x80) | ((B1>>2)&0x1f);
unsigned char B2 = (A_6bit_h[i*3+0]<<6) | ((A_6bit_h[i*3+1]>>2)&0xfc);
B2 = (B2&0x80) | ((B2>>2)&0x1f);
unsigned char B3 = (A_6bit_h[i*3+1]<<4) | ((A_6bit_h[i*3+2]>>4)&0xfc);
B3 = (B3&0x80) | ((B3>>2)&0x1f);
unsigned char B4 = A_6bit_h[i*3+2]<<2;
B4 = (B4&0x80) | ((B4>>2)&0x1f);
half FP1, FP2, FP3, FP4;
unsigned char *PTR1, *PTR2, *PTR3, *PTR4;
PTR1 = reinterpret_cast<unsigned char*>(&FP1);
PTR2 = reinterpret_cast<unsigned char*>(&FP2);
PTR3 = reinterpret_cast<unsigned char*>(&FP3);
PTR4 = reinterpret_cast<unsigned char*>(&FP4);
PTR1[0] = 0; PTR1[1] = B1; // small endian for X86 CPU
PTR2[0] = 0; PTR2[1] = B2;
PTR3[0] = 0; PTR3[1] = B3;
PTR4[0] = 0; PTR4[1] = B4;
OutPTR[0] = __float2half_rn ( __half2float(FP1) * 4096.0f * __half2float(scale[(4*i)/K]) );
OutPTR[1] = __float2half_rn ( __half2float(FP2) * 4096.0f * __half2float(scale[(4*i)/K]) );
OutPTR[2] = __float2half_rn ( __half2float(FP3) * 4096.0f * __half2float(scale[(4*i)/K]) );
OutPTR[3] = __float2half_rn ( __half2float(FP4) * 4096.0f * __half2float(scale[(4*i)/K]) );
//
OutPTR +=4;
}
}


#include <torch/extension.h>
#include <ATen/ATen.h>
#include <torch/library.h>

namespace torchao {

// https://github.com/microsoft/DeepSpeed/blob/0fc19b6a320cf8aa0a5f6c2b1fa310bae9a70d94/deepspeed/inference/v2/kernels/core_ops/cuda_linear/linear_kernels.cpp#L194
at::Tensor fp16_to_fp6_cpu(at::Tensor fp16_tensor)
at::Tensor fp16_to_fp6_original_cpu(at::Tensor fp16_tensor)
{
TORCH_CHECK(fp16_tensor.dim() == 2, "weight must be 2-dimensional");
TORCH_CHECK(fp16_tensor.scalar_type() == torch::kFloat16, "weight must be FP16");
Expand All @@ -183,37 +147,8 @@ at::Tensor fp16_to_fp6_cpu(at::Tensor fp16_tensor)
return packed_fp6_tensor;
}

/*
* Dequant a FP6 matrix to a equivalent FP16 matrix using CPUs.
* A useful tool to construct input matrices for the FP16 GEMM baseline.
* [Input]
* fp6_tensor: int tensor of shape [OC, IC // 16 * 3]; // 3 INT32 words contains 16 FP6 weights.
* fp16_scale: half tensor of shape [OC]; // for row-wise quantization.
* [Output]
* fp16_tensor: half tensor of shape [OC, IC].
*/
at::Tensor weight_matrix_dequant_cpu(at::Tensor fp6_tensor, at::Tensor fp16_scale)
{
int OC = fp6_tensor.size(0);
TORCH_CHECK(fp6_tensor.size(1) % 3 == 0);
int IC = fp6_tensor.size(1) / 3 * 16;
TORCH_CHECK(fp16_scale.size(0) == OC);
//
auto fp6_tensor_ptr = reinterpret_cast<unsigned char*>(fp6_tensor.data_ptr<int>());
auto fp16_scale_ptr = reinterpret_cast<half*>(fp16_scale.data_ptr<at::Half>());
//
auto options = at::TensorOptions().dtype(at::kHalf).device(fp16_scale.device());
at::Tensor fp16_tensor = at::empty({OC, IC}, options);
auto fp16_tensor_ptr = reinterpret_cast<half*>(fp16_tensor.data_ptr<at::Half>());
//
DeQuantMatrix_FP6_To_FP16(fp16_tensor_ptr, fp6_tensor_ptr, OC, IC, fp16_scale_ptr);
//
return fp16_tensor;
}

TORCH_LIBRARY_IMPL(torchao, CPU, m) {
m.impl("torchao::fp16_to_fp6", &fp16_to_fp6_cpu);
m.impl("torchao::fp6_weight_dequant", &weight_matrix_dequant_cpu);
m.impl("torchao::fp16_to_fp6_original", &fp16_to_fp6_original_cpu);
}

}
Loading

0 comments on commit 4ca3985

Please sign in to comment.