Skip to content

Commit

Permalink
Remove magic numbers from fbgemm/Types.h (pytorch#1629)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1629

Replaces magic numbers with constexpr variables

Reviewed By: sryap

Differential Revision: D43776442

fbshipit-source-id: 5cef7566816f8730f5daa08948ee3260367787aa
  • Loading branch information
r-barnes authored and facebook-github-bot committed Mar 18, 2023
1 parent 64833b5 commit 54eeae2
Showing 1 changed file with 114 additions and 75 deletions.
189 changes: 114 additions & 75 deletions include/fbgemm/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,160 +15,199 @@ namespace fbgemm {
using float16 = std::uint16_t;
using bfloat16 = std::uint16_t;

// The IEEE754 standard species a binary16 as having the following format:
// SEEEEEMMMMMMMMMM
// 0432109876543210
// That is:
// * 1 sign bit
// * 5 exponent bits
// * 10 mantissa/significand bits (an 11th bit is implicit)
constexpr uint32_t f16_num_bits = 16;
constexpr uint32_t f16_num_exponent_bits = 5;
constexpr uint32_t f16_num_mantissa_bits = 10;
constexpr uint32_t f16_num_non_sign_bits =
f16_num_exponent_bits + f16_num_mantissa_bits;
constexpr uint32_t f16_exponent_mask = 0x1F; // 5 bits
constexpr uint32_t f16_sign_bit = 1u
<< (f16_num_exponent_bits + f16_num_mantissa_bits);
constexpr uint32_t f16_exponent_bits = f16_exponent_mask
<< f16_num_mantissa_bits;
constexpr uint32_t f16_mantissa_mask = 0x3FF; // 10 bits
constexpr uint32_t f16_exponent_bias = 15;
constexpr uint32_t f16_nan = 0x7FFF;

// The IEEE754 standard specifies a binary32 as having:
// SEEEEEEEEMMMMMMMMMMMMMMMMMMMMMMM
// That is:
// * 1 sign bit
// * 8 exponent bits
// * 23 mantissa/significand bits (a 24th bit is implicit)
constexpr uint32_t f32_num_exponent_bits = 8;
constexpr uint32_t f32_num_mantissa_bits = 23;
constexpr uint32_t f32_exponent_mask = 0xFF; // 8 bits
constexpr uint32_t f32_mantissa_mask = 0x7FFFFF; // 23 bits
constexpr uint32_t f32_exponent_bias = 127;
constexpr uint32_t f32_all_non_sign_mask = 0x7FFFFFFF; // 31 bits
constexpr uint32_t f32_most_significant_bit = 1u << 22; // Turn on 23rd bit
constexpr uint32_t f32_num_non_sign_bits =
f32_num_exponent_bits + f32_num_mantissa_bits;

// Round to nearest even
static inline float16 cpu_float2half_rn(float f) {
float16 ret;

static_assert(
sizeof(unsigned int) == sizeof(float),
"Programming error sizeof(unsigned int) != sizeof(float)");
sizeof(uint32_t) == sizeof(float),
"Programming error sizeof(uint32_t) != sizeof(float)");

unsigned* xp = reinterpret_cast<unsigned int*>(&f);
unsigned x = *xp;
unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1;
unsigned sign, exponent, mantissa;
uint32_t* xp = reinterpret_cast<uint32_t*>(&f);
uint32_t x = *xp;
uint32_t u = (x & f32_all_non_sign_mask);

// Get rid of +NaN/-NaN case first.
if (u > 0x7f800000) {
ret = 0x7fffU;
return ret;
return static_cast<float16>(f16_nan);
}

sign = ((x >> 16) & 0x8000);
uint32_t sign = ((x >> f16_num_bits) & f16_sign_bit);

// Get rid of +Inf/-Inf, +0/-0.
if (u > 0x477fefff) {
ret = static_cast<float16>(sign | 0x7c00U);
return ret;
return static_cast<float16>(sign | f16_exponent_bits);
}
if (u < 0x33000001) {
ret = static_cast<float16>(sign | 0x0000);
return ret;
return static_cast<float16>(sign | 0x0000);
}

exponent = ((u >> 23) & 0xff);
mantissa = (u & 0x7fffff);
uint32_t exponent = ((u >> f32_num_mantissa_bits) & f32_exponent_mask);
uint32_t mantissa = (u & f32_mantissa_mask);

if (exponent > 0x70) {
shift = 13;
exponent -= 0x70;
uint32_t shift;
if (exponent > f32_exponent_bias - f16_exponent_bias) {
shift = f32_num_mantissa_bits - f16_num_mantissa_bits;
exponent -= f32_exponent_bias - f16_exponent_bias;
} else {
shift = 0x7e - exponent;
shift = (f32_exponent_bias - 1) - exponent;
exponent = 0;
mantissa |= 0x800000;
mantissa |=
(1u
<< f32_num_mantissa_bits); // Bump the least significant exponent bit
}
lsb = (1 << shift);
lsb_s1 = (lsb >> 1);
lsb_m1 = (lsb - 1);
const uint32_t lsb = (1u << shift);
const uint32_t lsb_s1 = (lsb >> 1);
const uint32_t lsb_m1 = (lsb - 1);

// Round to nearest even.
remainder = (mantissa & lsb_m1);
const uint32_t remainder = (mantissa & lsb_m1);
mantissa >>= shift;
if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) {
++mantissa;
if (!(mantissa & 0x3ff)) {
if (!(mantissa & f16_mantissa_mask)) {
++exponent;
mantissa = 0;
}
}

ret = static_cast<float16>(sign | (exponent << 10) | mantissa);

return ret;
return static_cast<float16>(
sign | (exponent << f16_num_mantissa_bits) | mantissa);
}

// Round to zero
static inline float16 cpu_float2half_rz(float f) {
float16 ret;

static_assert(
sizeof(unsigned int) == sizeof(float),
"Programming error sizeof(unsigned int) != sizeof(float)");
sizeof(uint32_t) == sizeof(float),
"Programming error sizeof(uint32_t) != sizeof(float)");

unsigned* xp = reinterpret_cast<unsigned int*>(&f);
unsigned x = *xp;
unsigned u = (x & 0x7fffffff);
unsigned shift, sign, exponent, mantissa;
const uint32_t* xp = reinterpret_cast<uint32_t*>(&f);
const uint32_t x = *xp;
const uint32_t u = (x & f32_all_non_sign_mask);

// Get rid of +NaN/-NaN case first.
if (u > 0x7f800000) {
ret = static_cast<float16>(0x7fffU);
return ret;
return static_cast<float16>(f16_nan);
}

sign = ((x >> 16) & 0x8000);
uint32_t sign = ((x >> f16_num_bits) & f16_sign_bit);

// Get rid of +Inf/-Inf, +0/-0.
if (u > 0x477fefff) {
ret = static_cast<float16>(sign | 0x7c00U);
return ret;
return static_cast<float16>(sign | f16_exponent_bits);
}
if (u < 0x33000001) {
ret = static_cast<float16>(sign | 0x0000);
return ret;
return static_cast<float16>(sign | 0x0000);
}

exponent = ((u >> 23) & 0xff);
mantissa = (u & 0x7fffff);
uint32_t exponent = ((u >> f32_num_mantissa_bits) & f32_exponent_mask);
uint32_t mantissa = (u & f32_mantissa_mask);

if (exponent > 0x70) {
shift = 13;
exponent -= 0x70;
uint32_t shift;
if (exponent > f32_exponent_bias - f16_exponent_bias) {
shift = f32_num_mantissa_bits - f16_num_mantissa_bits;
exponent -= f32_exponent_bias - f16_exponent_bias;
} else {
shift = 0x7e - exponent;
shift = (f32_exponent_bias - 1) - exponent;
exponent = 0;
mantissa |= 0x800000;
mantissa |=
(1u
<< f32_num_mantissa_bits); // Bump the least significant exponent bit
}

// Round to zero.
mantissa >>= shift;

ret = static_cast<float16>(sign | (exponent << 10) | mantissa);

return ret;
return static_cast<float16>(
sign | (exponent << f16_num_mantissa_bits) | mantissa);
}

static inline float cpu_half2float(float16 h) {
unsigned sign = ((h >> 15) & 1);
unsigned exponent = ((h >> 10) & 0x1f);
unsigned mantissa = ((h & 0x3ff) << 13);

if (exponent == 0x1f) { /* NaN or Inf */
mantissa = (mantissa ? (sign = 0, 0x7fffff) : 0);
exponent = 0xff;
} else if (!exponent) { /* Denorm or Zero */
// Converts a 16-bit unsigned integer representation of a IEEE754 half-precision
// float into an IEEE754 32-bit single-precision float
static inline float cpu_half2float(const float16 h) {
// Get sign and exponent alone by themselves
uint32_t sign_bit = (h >> f16_num_non_sign_bits) & 1;
uint32_t exponent = (h >> f16_num_mantissa_bits) & f16_exponent_mask;
// Shift mantissa so that it fills the most significant bits of a float32
uint32_t mantissa = (h & f16_mantissa_mask)
<< (f32_num_mantissa_bits - f16_num_mantissa_bits);

if (exponent == f16_exponent_mask) { // NaN or Inf
if (mantissa) {
unsigned int msb;
exponent = 0x71;
mantissa = f32_mantissa_mask;
sign_bit = 0;
}
exponent = f32_exponent_mask;
} else if (!exponent) { // Denorm or Zero
if (mantissa) {
uint32_t msb;
exponent = f32_exponent_bias - f16_exponent_bias + 1;
do {
msb = (mantissa & 0x400000);
mantissa <<= 1; /* normalize */
msb = mantissa & f32_most_significant_bit;
mantissa <<= 1; // normalize
--exponent;
} while (!msb);
mantissa &= 0x7fffff; /* 1.mantissa is implicit */
mantissa &= f32_mantissa_mask; // 1.mantissa is implicit
}
} else {
exponent += 0x70;
exponent += f32_exponent_bias - f16_exponent_bias;
}

unsigned i = ((sign << 31) | (exponent << 23) | mantissa);
const uint32_t i = (sign_bit << f32_num_non_sign_bits) |
(exponent << f32_num_mantissa_bits) | mantissa;

float ret;
memcpy(&ret, &i, sizeof(i));
std::memcpy(&ret, &i, sizeof(float));
return ret;
}

static inline float cpu_bf162float(bfloat16 src) {
float ret;
uint32_t val_fp32 =
static_cast<uint32_t>(reinterpret_cast<const uint16_t*>(&src)[0]) << 16;
memcpy(&ret, &val_fp32, sizeof(ret));
memcpy(&ret, &val_fp32, sizeof(float));
return ret;
}

static inline bfloat16 cpu_float2bfloat16(float src) {
uint32_t temp;
memcpy(&temp, &src, sizeof(temp));
return (temp + (1 << 15)) >> 16;
memcpy(&temp, &src, sizeof(uint32_t));
return (temp + (1u << 15)) >> 16;
}

} // namespace fbgemm

0 comments on commit 54eeae2

Please sign in to comment.