Skip to content

Commit

Permalink
add fp6->fp16
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed May 17, 2024
1 parent f61aa37 commit 1640bbf
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions torchao/csrc/cuda/fp6_llm/weight_quant.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,27 @@ __device__ __host__ uint8_t fp16_to_fp6(const __half a) {
return result;
}

// assume the lower 6 bits contain the data
__device__ __host__ __half fp6_to_fp16(const uint8_t a) {
// we shift the bits so that sign, exponent, and mantissa bits are in their
// correct positions in FP16
// FP6: SE EEMM
// FP16: S00E EEMM 0000 0000
uint16_t bits = a;
uint16_t sign = (a << 10u) & 0x8000u;
uint16_t exp_and_man = (a & 0x1Fu) << 8u;
uint16_t result_bits = sign | exp_and_man;

// the result will be off by the difference in exponent bias
// FP6: Ebias = 011 = 2^3
// FP16: Ebias = 01111 = 2^15
// correction = 2^12 = 4096
// we can correct this by direct FP16 multiplication
__half result;
std::memcpy(&result, &result_bits, sizeof(result));
return result * __float2half(4096.0f);
}

/*
* Function to pack 4 fake quantized FP16 value into continuously stored 4 FP6 values.
*/
Expand Down Expand Up @@ -373,11 +394,32 @@ at::Tensor fp16_to_fp6_packed_cuda(at::Tensor fp16_tensor) {
return fp6_tensor;
}

at::Tensor fp6_unpacked_to_fp16_cpu(at::Tensor fp6_tensor) {
TORCH_CHECK(fp6_tensor.dtype() == torch::kUInt8);
TORCH_CHECK(fp6_tensor.is_contiguous());
TORCH_CHECK(fp6_tensor.is_cpu());

at::TensorOptions options = at::TensorOptions().dtype(torch::kFloat16).device(fp6_tensor.device());
at::Tensor fp16_tensor = at::empty(fp6_tensor.sizes(), options);

const uint8_t *fp6_ptr = fp6_tensor.data_ptr<uint8_t>();
__half *fp16_ptr = reinterpret_cast<__half *>(fp16_tensor.data_ptr<at::Half>());
int n = fp6_tensor.numel();

#pragma omp parallel for num_threads(4)
for (int i = 0; i < n; i++) {
fp16_ptr[i] = fp6_to_fp16(fp6_ptr[i]);
}

return fp16_tensor;
}

TORCH_LIBRARY_IMPL(torchao, CPU, m) {
m.impl("torchao::fp16_to_fp6_original", &fp16_to_fp6_original_cpu);
m.impl("torchao::fp6_weight_dequant", &weight_matrix_dequant_cpu);
m.impl("torchao::fp16_to_fp6_unpacked", &fp16_to_fp6_unpacked_cpu);
m.impl("torchao::fp16_to_fp6_packed", &fp16_to_fp6_packed_cpu);
m.impl("torchao::fp6_unpacked_to_fp16", &fp6_unpacked_to_fp16_cpu);
}

TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
Expand Down

0 comments on commit 1640bbf

Please sign in to comment.