Skip to content
Merged
139 changes: 100 additions & 39 deletions paddle/phi/common/float16.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,22 +118,76 @@ struct PADDLE_ALIGN(2) float16 {
#elif defined(__F16C__) and defined(__PADDLE_x86__)
x = _cvtss_sh(val, 0);

#else
#elif defined(PADDLE_WITH_ARM)
// Conversion routine adapted from
// http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
Bits v, s;
v.f = val;
uint32_t sign = v.si & sigN;
uint32_t sign = v.si & (int32_t)sigN;
v.si ^= sign;
sign >>= shiftSign; // logical shift
s.si = mulN;
s.si = 0x52000000;
s.si = s.f * v.f; // correct subnormals
v.si ^= (s.si ^ v.si) & -(minN > v.si);
v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN));
v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN));
v.si ^= (s.si ^ v.si) & -((int32_t)minN > v.si);
v.si ^= ((int32_t)infN ^ v.si) &
-(((int32_t)infN > v.si) & (v.si > (int32_t)maxN));
v.si ^= ((int32_t)nanN ^ v.si) &
-(((int32_t)nanN > v.si) & (v.si > (int32_t)infN));
v.ui >>= shift; // logical shift
v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC);
v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC);
v.si ^= ((v.si - (int32_t)maxD) ^ v.si) & -(v.si > (int32_t)maxC);
v.si ^= ((v.si - (int32_t)minD) ^ v.si) & -(v.si > (int32_t)subC);
x = v.ui | sign;

#else
Bits v;
v.f = val;

// 1. Extract sign bit and clear from value
const uint32_t sign = (v.ui & sigN) >> shiftSign;
v.ui &= ~sigN;

// 2. Handle special values: infinity and NaN
const uint32_t inf_cond =
(infN >= v.ui) && (v.ui >= minINF) ? 0xFFFFFFFF : 0;
const uint32_t nan_cond = (nanN > v.ui) && (v.ui > infN) ? 0xFFFFFFFF : 0;
v.ui ^= (infN ^ v.ui) & inf_cond;
v.ui ^= (nanN ^ v.ui) & nan_cond;

const bool is_subnormal = (v.ui < minN);
if (is_subnormal) {
// 3. Handle subnormal numbers
// 3.1 Extract FP32 exponent and mantissa
const uint32_t exp = (v.ui >> 23) & exp_mask;
const uint32_t mantissa = (v.ui & mantissa_mask) | implicit_bit;
// 3.2 Compute required shift
const uint32_t shift_amount = exp_bias_diff - exp;
// 3.3 64-bit mantissa
uint64_t normalized_mantissa = static_cast<uint64_t>(mantissa)
<< precision_shift;
normalized_mantissa >>= shift_amount;
// 3.4 Round to nearest even
// https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even
const uint32_t lsb = (normalized_mantissa >> mantissa_shift) & 0x1;
normalized_mantissa += rounding_bias + lsb;
v.ui = static_cast<uint32_t>(normalized_mantissa >> mantissa_shift);
} else {
// 4. Handle normal numbers
// Round to nearest even
const uint32_t lsb =
(v.ui >> shift) & 0x1; // Least significant retained bit
const uint32_t rounding =
(v.ui < infN) ? (0xFFF + lsb) : 0; // Round with overflow protection
v.ui += rounding;
// inf and nan
const uint32_t max_cond = (v.ui >= infN) ? 0xFFFFFFFF : 0;
// Align bits
v.ui >>= shift;
// Exponent adjustment for overflow
v.ui ^= ((v.ui - maxD) ^ v.ui) & max_cond;
// Exponent adjustment for normal numbers
v.ui ^= ((v.ui - minD) ^ v.ui);
}
// Combine sign and value bits
x = v.ui | sign;

#endif
Expand Down Expand Up @@ -258,18 +312,18 @@ struct PADDLE_ALIGN(2) float16 {
// http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
Bits v;
v.ui = this->x;
int32_t sign = v.si & sigC;
v.si ^= sign;
uint32_t sign = v.ui & sigC;
v.ui ^= sign;
sign <<= shiftSign;
v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
v.ui ^= ((v.ui + minD) ^ v.ui) & -(int32_t)(v.ui > subC);
v.ui ^= ((v.ui + maxD) ^ v.ui) & -(int32_t)(v.ui > maxC);
Bits s;
s.si = mulC;
s.ui = mulC;
s.f *= v.si;
int32_t mask = -(norC > v.si);
v.si <<= shift;
v.si ^= (s.si ^ v.si) & mask;
v.si |= sign;
int32_t mask = -(int32_t)(norC > v.ui);
v.ui <<= shift;
v.ui ^= (s.ui ^ v.ui) & mask;
v.ui |= sign;
return v.f;

#endif
Expand Down Expand Up @@ -320,28 +374,35 @@ struct PADDLE_ALIGN(2) float16 {
uint32_t ui;
};

static const int shift = 13;
static const int shiftSign = 16;

static const int32_t infN = 0x7F800000;
static const int32_t maxN = 0x477FE000; // max flt16 as flt32
static const int32_t minN = 0x38800000; // min flt16 normal as flt32
static const int32_t sigN = 0x80000000; // sign bit

static constexpr int32_t infC = infN >> shift;
static constexpr int32_t nanN = (infC + 1)
<< shift; // minimum flt16 nan as float32
static constexpr int32_t maxC = maxN >> shift;
static constexpr int32_t minC = minN >> shift;
static constexpr int32_t sigC = sigN >> shiftSign;

static const int32_t mulN = 0x52000000; // (1 << 23) / minN
static const int32_t mulC = 0x33800000; // minN / (1 << (23 - shift))
static const int32_t subC = 0x003FF; // max flt32 subnormal downshifted
static const int32_t norC = 0x00400; // min flt32 normal downshifted

static constexpr int32_t maxD = infC - maxC - 1;
static constexpr int32_t minD = minC - subC - 1;
static constexpr int shift = 13;
static constexpr int shiftSign = 16;

static constexpr uint32_t infN = 0x7F800000;
static constexpr uint32_t maxN = 0x477FE000; // max flt16 as flt32
static constexpr uint32_t minINF = 0x47800000; // min flt16 inf as flt32
static constexpr uint32_t minN = 0x38800000; // min flt16 normal as flt32
static constexpr uint32_t sigN = 0x80000000; // sign bit

static constexpr uint32_t infC = infN >> shift;
static constexpr uint32_t nanN = (infC + 1)
<< shift; // minimum flt16 nan as float32
static constexpr uint32_t maxC = maxN >> shift;
static constexpr uint32_t minC = minN >> shift;
static constexpr uint32_t sigC = sigN >> shiftSign;

static constexpr uint32_t mulC = 0x33800000; // minN / (1 << (23 - shift))
static constexpr uint32_t subC = 0x003FF; // max flt32 subnormal downshifted
static constexpr uint32_t norC = 0x00400; // min flt32 normal downshifted
static constexpr uint32_t maxD = infC - maxC - 1;
static constexpr uint32_t minD = minC - subC - 1;

static constexpr uint32_t exp_mask = 0xFF;
static constexpr uint32_t mantissa_mask = 0x7FFFFF;
static constexpr uint32_t implicit_bit = 0x800000;
static constexpr uint32_t exp_bias_diff = 113; // 127 - 14
static constexpr uint64_t precision_shift = 40;
static constexpr uint64_t rounding_bias = 0xFFFFFFFFFFFFF;
static constexpr int mantissa_shift = 53;
};

// Arithmetic operators on GPU
Expand Down