@@ -79,9 +79,20 @@ template <typename T> T __mul_hi(T a, T b) {
7979 return (mul >> (sizeof (T) * 8 ));
8080}
8181
82- // T is minimum of 64 bits- long or longlong
83- template <typename T> inline T __long_mul_hi (T a, T b) {
84- int halfsize = (sizeof (T) * 8 ) / 2 ;
82+ // A helper function for mul_hi built-in for long
83+ template <typename T> inline T __get_high_half (T a0b0, T a0b1, T a1b0, T a1b1) {
84+ constexpr int halfsize = (sizeof (T) * 8 ) / 2 ;
85+ // To get the upper 64 bits:
86+ // 64 bits from a1b1, upper 32 bits from [a1b0 + (a0b1 + a0b0>>32 (carry bit
87+ // in 33rd bit))] with carry bit on 64th bit - use of hadd. Add the a1b1 to
88+ // the above 32 bit result.
89+ return a1b1 + (__hadd (a1b0, (a0b1 + (a0b0 >> halfsize))) >> (halfsize - 1 ));
90+ }
91+
92+ // A helper function for mul_hi built-in for long
93+ template <typename T>
94+ inline void __get_half_products (T a, T b, T &a0b0, T &a0b1, T &a1b0, T &a1b1) {
95+ constexpr int halfsize = (sizeof (T) * 8 ) / 2 ;
8596 T a1 = a >> halfsize;
8697 T a0 = (a << halfsize) >> halfsize;
8798 T b1 = b >> halfsize;
@@ -90,26 +101,53 @@ template <typename T> inline T __long_mul_hi(T a, T b) {
90101 // a1b1 - for bits - [64-128)
91102 // a1b0 a0b1 for bits - [32-96)
92103 // a0b0 for bits - [0-64)
93- T a1b1 = a1 * b1;
94- T a0b1 = a0 * b1;
95- T a1b0 = a1 * b0;
96- T a0b0 = a0 * b0;
104+ a1b1 = a1 * b1;
105+ a0b1 = a0 * b1;
106+ a1b0 = a1 * b0;
107+ a0b0 = a0 * b0;
108+ }
109+
110+ // T is minimum of 64 bits- long or longlong
111+ template <typename T> inline T __u_long_mul_hi (T a, T b) {
112+ T a0b0, a0b1, a1b0, a1b1;
113+ __get_half_products (a, b, a0b0, a0b1, a1b0, a1b1);
114+ T result = __get_high_half (a0b0, a0b1, a1b0, a1b1);
115+ return result;
116+ }
117+
118+ template <typename T> inline T __s_long_mul_hi (T a, T b) {
119+ using UT = typename std::make_unsigned<T>::type;
120+ UT absA = std::abs (a);
121+ UT absB = std::abs (b);
122+
123+ UT a0b0, a0b1, a1b0, a1b1;
124+ __get_half_products (absA, absB, a0b0, a0b1, a1b0, a1b1);
125+ T result = __get_high_half (a0b0, a0b1, a1b0, a1b1);
126+
127+ bool isResultNegative = (a < 0 ) != (b < 0 );
128+ if (isResultNegative) {
129+ result = ~result;
130+
131+ // Find the low half to see if we need to carry
132+ constexpr int halfsize = (sizeof (T) * 8 ) / 2 ;
133+ UT low = a0b0 + ((a0b1 + a1b0) << halfsize);
134+ if (low == 0 )
135+ ++result;
136+ }
97137
98- // To get the upper 64 bits:
99- // 64 bits from a1b1, upper 32 bits from [a1b0 + (a0b1 + a0b0>>32 (carry bit
100- // in 33rd bit))] with carry bit on 64th bit - use of hadd. Add the a1b1 to
101- // the above 32 bit result.
102- T result =
103- a1b1 + (__hadd (a1b0, (a0b1 + (a0b0 >> halfsize))) >> (halfsize - 1 ));
104138 return result;
105139}
106140
107141template <typename T> inline T __mad_hi (T a, T b, T c) {
108142 return __mul_hi (a, b) + c;
109143}
110144
111- template <typename T> inline T __long_mad_hi (T a, T b, T c) {
112- return __long_mul_hi (a, b) + c;
145+ template <typename T> inline T __u_long_mad_hi (T a, T b, T c) {
146+ return __u_long_mul_hi (a, b) + c;
147+ }
148+
149+ template <typename T> inline T __s_long_mad_hi (T a, T b, T c) {
150+ return __s_long_mul_hi (a, b) + c;
113151}
114152
115153template <typename T> inline T __s_mad_sat (T a, T b, T c) {
@@ -123,7 +161,7 @@ template <typename T> inline T __s_mad_sat(T a, T b, T c) {
123161
124162template <typename T> inline T __s_long_mad_sat (T a, T b, T c) {
125163 bool neg_prod = (a < 0 ) ^ (b < 0 );
126- T mulhi = __long_mul_hi (a, b);
164+ T mulhi = __s_long_mul_hi (a, b);
127165
128166 // check mul_hi. If it is any value != 0.
129167 // if prod is +ve, any value in mulhi means we need to saturate.
@@ -145,7 +183,7 @@ template <typename T> inline T __u_mad_sat(T a, T b, T c) {
145183}
146184
147185template <typename T> inline T __u_long_mad_sat (T a, T b, T c) {
148- T mulhi = __long_mul_hi (a, b);
186+ T mulhi = __u_long_mul_hi (a, b);
149187 // check mul_hi. If it is any value != 0.
150188 if (mulhi != 0 )
151189 return d::max_v<T>();
@@ -421,7 +459,7 @@ cl_char s_mul_hi(cl_char a, cl_char b) { return __mul_hi(a, b); }
421459cl_short s_mul_hi (cl_short a, cl_short b) { return __mul_hi (a, b); }
422460cl_int s_mul_hi (cl_int a, cl_int b) { return __mul_hi (a, b); }
423461cl_long s_mul_hi (s::cl_long x, s::cl_long y) __NOEXC {
424- return __long_mul_hi (x, y);
462+ return __s_long_mul_hi (x, y);
425463}
426464MAKE_1V_2V (s_mul_hi, s::cl_char, s::cl_char, s::cl_char)
427465MAKE_1V_2V (s_mul_hi, s::cl_short, s::cl_short, s::cl_short)
@@ -433,7 +471,7 @@ cl_uchar u_mul_hi(cl_uchar a, cl_uchar b) { return __mul_hi(a, b); }
433471cl_ushort u_mul_hi (cl_ushort a, cl_ushort b) { return __mul_hi (a, b); }
434472cl_uint u_mul_hi (cl_uint a, cl_uint b) { return __mul_hi (a, b); }
435473cl_ulong u_mul_hi (s::cl_ulong x, s::cl_ulong y) __NOEXC {
436- return __long_mul_hi (x, y);
474+ return __u_long_mul_hi (x, y);
437475}
438476MAKE_1V_2V (u_mul_hi, s::cl_uchar, s::cl_uchar, s::cl_uchar)
439477MAKE_1V_2V (u_mul_hi, s::cl_ushort, s::cl_ushort, s::cl_ushort)
@@ -452,7 +490,7 @@ cl_int s_mad_hi(s::cl_int x, s::cl_int minval, s::cl_int maxval) __NOEXC {
452490 return __mad_hi (x, minval, maxval);
453491}
454492cl_long s_mad_hi (s::cl_long x, s::cl_long minval, s::cl_long maxval) __NOEXC {
455- return __long_mad_hi (x, minval, maxval);
493+ return __s_long_mad_hi (x, minval, maxval);
456494}
457495MAKE_1V_2V_3V (s_mad_hi, s::cl_char, s::cl_char, s::cl_char, s::cl_char)
458496MAKE_1V_2V_3V (s_mad_hi, s::cl_short, s::cl_short, s::cl_short, s::cl_short)
@@ -473,7 +511,7 @@ cl_uint u_mad_hi(s::cl_uint x, s::cl_uint minval, s::cl_uint maxval) __NOEXC {
473511}
474512cl_ulong u_mad_hi (s::cl_ulong x, s::cl_ulong minval,
475513 s::cl_ulong maxval) __NOEXC {
476- return __long_mad_hi (x, minval, maxval);
514+ return __u_long_mad_hi (x, minval, maxval);
477515}
478516MAKE_1V_2V_3V (u_mad_hi, s::cl_uchar, s::cl_uchar, s::cl_uchar, s::cl_uchar)
479517MAKE_1V_2V_3V (u_mad_hi, s::cl_ushort, s::cl_ushort, s::cl_ushort, s::cl_ushort)
0 commit comments