Skip to content

Commit ea6bcac

Browse files
committed
[fp32tofp16] fix subnormal round to normal error
1 parent 1bdd369 commit ea6bcac

File tree

1 file changed

+23
-10
lines changed

1 file changed

+23
-10
lines changed

paddle/phi/common/float16.h

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -123,24 +123,37 @@ struct PADDLE_ALIGN(2) float16 {
123123
// http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
124124
Bits v, s;
125125
v.f = val;
126+
// Extract sign bit and clear from value
126127
uint32_t sign = v.si & sigN;
127128
v.si ^= sign;
128-
sign >>= shiftSign; // logical shift
129+
sign >>= shiftSign;
130+
131+
// Handle subnormals: normalize using multiplication
132+
const uint32_t subnormal_mask = -(minN > v.si);
129133
s.si = mulN;
130-
s.si = s.f * v.f; // correct subnormals
131-
v.si ^= (s.si ^ v.si) & -(minN > v.si);
134+
s.si = s.f * v.f; // Extract the fraction of the subnormal number through
135+
// multiplication and conversion from float to int
136+
v.si ^= (s.si ^ v.si) & subnormal_mask;
137+
138+
// Handle special values: infinity and NaN
132139
v.si ^= (infN ^ v.si) & -((infN > v.si) & (v.si > maxN));
133140
v.si ^= (nanN ^ v.si) & -((nanN > v.si) & (v.si > infN));
134-
// Rounding conditions (round to nearest, ties to even).
141+
142+
// Rounding: round to nearest, ties to even
135143
// https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even
136-
if (v.ui < infN) { // Skip special values (infinity and NaN)
137-
// Lowest significant bit of the retained part
138-
const uint32_t lsb = (v.ui >> shift) & 0x1;
139-
v.ui = (v.ui + 0xFFF + lsb); // rounding up
140-
}
144+
const uint32_t lsb =
145+
(v.ui >> shift) & 0x1; // Least significant retained bit
146+
v.ui += (0xFFF + lsb) & -(v.ui < infN); // Round with overflow protection
147+
141148
v.ui >>= shift; // logical shift
149+
150+
// Exponent adjustment for overflow (max values)
142151
v.si ^= ((v.si - maxD) ^ v.si) & -(v.si > maxC);
143-
v.si ^= ((v.si - minD) ^ v.si) & -(v.si > subC);
152+
// Exponent adjustment for normal numbers
153+
const uint32_t normal_mask = ~subnormal_mask;
154+
v.si ^= ((v.si - minD) ^ v.si) & normal_mask;
155+
156+
// Combine sign and value bits
144157
x = v.ui | sign;
145158

146159
#endif

0 commit comments

Comments
 (0)