Skip to content

Commit f6030f4

Browse files
committed
fp32tofp16,all test pass!
1 parent ea6bcac commit f6030f4

File tree

1 file changed

+73
-54
lines changed

1 file changed

+73
-54
lines changed

paddle/phi/common/float16.h

Lines changed: 73 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -119,40 +119,53 @@ struct PADDLE_ALIGN(2) float16 {
119119
x = _cvtss_sh(val, 0);
120120

121121
#else
122-
// Conversion routine adapted from
123-
// http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
124-
Bits v, s;
122+
Bits v;
125123
v.f = val;
126-
// Extract sign bit and clear from value
127-
uint32_t sign = v.si & sigN;
128-
v.si ^= sign;
129-
sign >>= shiftSign;
130-
131-
// Handle subnormals: normalize using multiplication
132-
const uint32_t subnormal_mask = -(minN > v.si);
133-
s.si = mulN;
134-
s.si = s.f * v.f; // Extract the fraction of the subnormal number through
135-
// multiplication and conversion from float to int
136-
v.si ^= (s.si ^ v.si) & subnormal_mask;
137-
138-
// Handle special values: infinity and NaN
139-
v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN));
140-
v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN));
141-
142-
// Rounding: round to nearest, ties to even
143-
// https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even
144-
const uint32_t lsb =
145-
(v.ui >> shift) & 0x1; // Least significant retained bit
146-
v.ui += (0xFFF + lsb) & -(v.ui < infN); // Round with overflow protection
147-
148-
v.ui >>= shift; // logical shift
149-
150-
// Exponent adjustment for overflow (max values)
151-
v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC);
152-
// Exponent adjustment for normal numbers
153-
const uint32_t normal_mask = ~subnormal_mask;
154-
v.si ^= ((v.si - minD) ^ v.si) & normal_mask;
155124

125+
// 1. Extract sign bit and clear from value
126+
const uint32_t sign = (v.ui & sigN) >> shiftSign;
127+
v.ui &= ~sigN;
128+
129+
// 2. Handle special values: infinity and NaN
130+
const int32_t inf_cond = -((infN >= v.si) & (v.si >= minINF));
131+
const int32_t nan_cond = -((nanN > v.si) & (v.si > infN));
132+
v.si ^= (infN ^ v.si) & inf_cond;
133+
v.si ^= (nanN ^ v.si) & nan_cond;
134+
135+
const bool is_subnormal = (v.ui < minN);
136+
if (is_subnormal) {
137+
// 3. Handle subnormal numbers
138+
// 3.1 Extract FP32 exponent and mantissa
139+
const uint32_t exp = (v.ui >> 23) & exp_mask;
140+
const uint32_t mantissa = (v.ui & mantissa_mask) | implicit_bit;
141+
// 3.2 Compute required shift
142+
const uint32_t shift_amount = exp_bias_diff - exp;
143+
// 3.3 64-bit mantissa
144+
uint64_t normalized_mantissa = static_cast<uint64_t>(mantissa)
145+
<< precision_shift;
146+
normalized_mantissa >>= shift_amount;
147+
// 3.4 Round to nearest even
148+
// https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even
149+
const uint32_t lsb = (normalized_mantissa >> mantissa_shift) & 0x1;
150+
normalized_mantissa += rounding_bias + lsb;
151+
v.ui = static_cast<uint32_t>(normalized_mantissa >> mantissa_shift);
152+
} else {
153+
// 4. Handle normal numbers
154+
// Round to nearest even
155+
const uint32_t lsb =
156+
(v.ui >> shift) & 0x1; // Least significant retained bit
157+
const uint32_t rounding =
158+
(0xFFF + lsb) & -(v.ui < infN); // Round with overflow protection
159+
v.ui += rounding;
160+
// inf and nan
161+
const int32_t max_cond = -(v.ui >= infN);
162+
// Align bits
163+
v.ui >>= shift;
164+
// Exponent adjustment for overflow
165+
v.si ^= ((v.si - maxD) ^ v.si) & max_cond;
166+
// Exponent adjustment for normal numbers
167+
v.si ^= ((v.si - minD) ^ v.si);
168+
}
156169
// Combine sign and value bits
157170
x = v.ui | sign;
158171

@@ -340,28 +353,34 @@ struct PADDLE_ALIGN(2) float16 {
340353
uint32_t ui;
341354
};
342355

343-
static const int shift = 13;
344-
static const int shiftSign = 16;
345-
346-
static const int32_t infN = 0x7F800000;
347-
static const int32_t maxN = 0x477FE000; // max flt16 as flt32
348-
static const int32_t minN = 0x38800000; // min flt16 normal as flt32
349-
static const int32_t sigN = 0x80000000; // sign bit
350-
351-
static constexpr int32_t infC = infN >> shift;
352-
static constexpr int32_t nanN = (infC + 1)
353-
<< shift; // minimum flt16 nan as float32
354-
static constexpr int32_t maxC = maxN >> shift;
355-
static constexpr int32_t minC = minN >> shift;
356-
static constexpr int32_t sigC = sigN >> shiftSign;
357-
358-
static const int32_t mulN = 0x52000000; // (1 << 23) / minN
359-
static const int32_t mulC = 0x33800000; // minN / (1 << (23 - shift))
360-
static const int32_t subC = 0x003FF; // max flt32 subnormal downshifted
361-
static const int32_t norC = 0x00400; // min flt32 normal downshifted
362-
363-
static constexpr int32_t maxD = infC - maxC - 1;
364-
static constexpr int32_t minD = minC - subC - 1;
356+
static constexpr int shift = 13;
357+
static constexpr int shiftSign = 16;
358+
359+
static constexpr uint32_t infN = 0x7F800000;
360+
static constexpr uint32_t maxN = 0x477FE000; // max flt16 as flt32
361+
static constexpr uint32_t minINF = 0x47800000; // min flt16 inf as flt32
362+
static constexpr uint32_t minN = 0x38800000; // min flt16 normal as flt32
363+
static constexpr uint32_t sigN = 0x80000000; // sign bit
364+
365+
static constexpr uint32_t infC = infN >> shift;
366+
static constexpr uint32_t nanN = (infC + 1)
367+
<< shift; // minimum flt16 nan as float32
368+
static constexpr uint32_t maxC = maxN >> shift;
369+
static constexpr uint32_t minC = minN >> shift;
370+
static constexpr uint32_t sigC = sigN >> shiftSign;
371+
372+
static constexpr uint32_t subC = 0x003FF; // max flt32 subnormal downshifted
373+
static constexpr uint32_t norC = 0x00400; // min flt32 normal downshifted
374+
static constexpr uint32_t maxD = infC - maxC - 1;
375+
static constexpr uint32_t minD = minC - subC - 1;
376+
377+
static constexpr uint32_t exp_mask = 0xFF;
378+
static constexpr uint32_t mantissa_mask = 0x7FFFFF;
379+
static constexpr uint32_t implicit_bit = 0x800000;
380+
static constexpr uint32_t exp_bias_diff = 113; // 127 - 14
381+
static constexpr uint64_t precision_shift = 40;
382+
static constexpr uint64_t rounding_bias = 0xFFFFFFFFFFFFF;
383+
static constexpr int mantissa_shift = 53;
365384
};
366385

367386
// Arithmetic operators on GPU

0 commit comments

Comments
 (0)