@@ -119,40 +119,53 @@ struct PADDLE_ALIGN(2) float16 {
119119 x = _cvtss_sh (val, 0 );
120120
121121#else
122- // Conversion routine adapted from
123- // http://stackoverflow.com/questions/1659440/32-bit-to-16-bit-floating-point-conversion
124- Bits v, s;
122+ Bits v;
125123 v.f = val;
126- // Extract sign bit and clear from value
127- uint32_t sign = v.si & sigN;
128- v.si ^= sign;
129- sign >>= shiftSign;
130-
131- // Handle subnormals: normalize using multiplication
132- const uint32_t subnormal_mask = -(minN > v.si );
133- s.si = mulN;
134- s.si = s.f * v.f ; // Extract the fraction of the subnormal number through
135- // multiplication and conversion from float to int
136- v.si ^= (s.si ^ v.si ) & subnormal_mask;
137-
138- // Handle special values: infinity and NaN
139- v.si ^= (infN ^ v.si ) & -((infN > v.si ) & (v.si > maxN));
140- v.si ^= (nanN ^ v.si ) & -((nanN > v.si ) & (v.si > infN));
141-
142- // Rounding: round to nearest, ties to even
143- // https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even
144- const uint32_t lsb =
145- (v.ui >> shift) & 0x1 ; // Least significant retained bit
146- v.ui += (0xFFF + lsb) & -(v.ui < infN); // Round with overflow protection
147-
148- v.ui >>= shift; // logical shift
149-
150- // Exponent adjustment for overflow (max values)
151- v.si ^= ((v.si - maxD) ^ v.si ) & -(v.si > maxC);
152- // Exponent adjustment for normal numbers
153- const uint32_t normal_mask = ~subnormal_mask;
154- v.si ^= ((v.si - minD) ^ v.si ) & normal_mask;
155124
125+ // 1. Extract sign bit and clear from value
126+ const uint32_t sign = (v.ui & sigN) >> shiftSign;
127+ v.ui &= ~sigN;
128+
129+ // 2. Handle special values: infinity and NaN
130+ const int32_t inf_cond = -((infN >= v.si ) & (v.si >= minINF));
131+ const int32_t nan_cond = -((nanN > v.si ) & (v.si > infN));
132+ v.si ^= (infN ^ v.si ) & inf_cond;
133+ v.si ^= (nanN ^ v.si ) & nan_cond;
134+
135+ const bool is_subnormal = (v.ui < minN);
136+ if (is_subnormal) {
137+ // 3. Handle subnormal numbers
138+ // 3.1 Extract FP32 exponent and mantissa
139+ const uint32_t exp = (v.ui >> 23 ) & exp_mask;
140+ const uint32_t mantissa = (v.ui & mantissa_mask) | implicit_bit;
141+ // 3.2 Compute required shift
142+ const uint32_t shift_amount = exp_bias_diff - exp;
143+ // 3.3 64-bit mantissa
144+ uint64_t normalized_mantissa = static_cast <uint64_t >(mantissa)
145+ << precision_shift;
146+ normalized_mantissa >>= shift_amount;
147+ // 3.4 Round to nearest even
148+ // https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even
149+ const uint32_t lsb = (normalized_mantissa >> mantissa_shift) & 0x1 ;
150+ normalized_mantissa += rounding_bias + lsb;
151+ v.ui = static_cast <uint32_t >(normalized_mantissa >> mantissa_shift);
152+ } else {
153+ // 4. Handle normal numbers
154+ // Round to nearest even
155+ const uint32_t lsb =
156+ (v.ui >> shift) & 0x1 ; // Least significant retained bit
157+ const uint32_t rounding =
158+ (0xFFF + lsb) & -(v.ui < infN); // Round with overflow protection
159+ v.ui += rounding;
160+ // inf and nan
161+ const int32_t max_cond = -(v.ui >= infN);
162+ // Align bits
163+ v.ui >>= shift;
164+ // Exponent adjustment for overflow
165+ v.si ^= ((v.si - maxD) ^ v.si ) & max_cond;
166+ // Exponent adjustment for normal numbers
167+ v.si ^= ((v.si - minD) ^ v.si );
168+ }
156169 // Combine sign and value bits
157170 x = v.ui | sign;
158171
@@ -340,28 +353,34 @@ struct PADDLE_ALIGN(2) float16 {
340353 uint32_t ui;
341354 };
342355
343- static const int shift = 13 ;
344- static const int shiftSign = 16 ;
345-
346- static const int32_t infN = 0x7F800000 ;
347- static const int32_t maxN = 0x477FE000 ; // max flt16 as flt32
348- static const int32_t minN = 0x38800000 ; // min flt16 normal as flt32
349- static const int32_t sigN = 0x80000000 ; // sign bit
350-
351- static constexpr int32_t infC = infN >> shift;
352- static constexpr int32_t nanN = (infC + 1 )
353- << shift; // minimum flt16 nan as float32
354- static constexpr int32_t maxC = maxN >> shift;
355- static constexpr int32_t minC = minN >> shift;
356- static constexpr int32_t sigC = sigN >> shiftSign;
357-
358- static const int32_t mulN = 0x52000000 ; // (1 << 23) / minN
359- static const int32_t mulC = 0x33800000 ; // minN / (1 << (23 - shift))
360- static const int32_t subC = 0x003FF ; // max flt32 subnormal downshifted
361- static const int32_t norC = 0x00400 ; // min flt32 normal downshifted
362-
363- static constexpr int32_t maxD = infC - maxC - 1 ;
364- static constexpr int32_t minD = minC - subC - 1 ;
356+ static constexpr int shift = 13 ;
357+ static constexpr int shiftSign = 16 ;
358+
359+ static constexpr uint32_t infN = 0x7F800000 ;
360+ static constexpr uint32_t maxN = 0x477FE000 ; // max flt16 as flt32
361+ static constexpr uint32_t minINF = 0x47800000 ; // min flt16 inf as flt32
362+ static constexpr uint32_t minN = 0x38800000 ; // min flt16 normal as flt32
363+ static constexpr uint32_t sigN = 0x80000000 ; // sign bit
364+
365+ static constexpr uint32_t infC = infN >> shift;
366+ static constexpr uint32_t nanN = (infC + 1 )
367+ << shift; // minimum flt16 nan as float32
368+ static constexpr uint32_t maxC = maxN >> shift;
369+ static constexpr uint32_t minC = minN >> shift;
370+ static constexpr uint32_t sigC = sigN >> shiftSign;
371+
372+ static constexpr uint32_t subC = 0x003FF ; // max flt32 subnormal downshifted
373+ static constexpr uint32_t norC = 0x00400 ; // min flt32 normal downshifted
374+ static constexpr uint32_t maxD = infC - maxC - 1 ;
375+ static constexpr uint32_t minD = minC - subC - 1 ;
376+
377+ static constexpr uint32_t exp_mask = 0xFF ;
378+ static constexpr uint32_t mantissa_mask = 0x7FFFFF ;
379+ static constexpr uint32_t implicit_bit = 0x800000 ;
380+ static constexpr uint32_t exp_bias_diff = 113 ; // 127 - 14
381+ static constexpr uint64_t precision_shift = 40 ;
382+ static constexpr uint64_t rounding_bias = 0xFFFFFFFFFFFFF ;
383+ static constexpr int mantissa_shift = 53 ;
365384};
366385
367386// Arithmetic operators on GPU
0 commit comments