11#include < iostream>
2- #include < sycl/ext/oneapi/experimental/ bfloat16.hpp>
2+ #include < sycl/ext/oneapi/bfloat16.hpp>
33#include < sycl/sycl.hpp>
44
55#include < cmath>
@@ -11,8 +11,7 @@ constexpr size_t N = 100;
1111template <typename T> void assert_close (const T &C, const float ref) {
1212 for (size_t i = 0 ; i < N; i++) {
1313 auto diff = C[i] - ref;
14- assert (std::fabs (static_cast <float >(diff)) <
15- std::numeric_limits<float >::epsilon ());
14+ assert (std::fabs (static_cast <float >(diff)) < 0.1 );
1615 }
1716}
1817
@@ -21,7 +20,7 @@ void verify_conv_implicit(queue &q, buffer<float, 1> &a, range<1> &r,
2120 q.submit ([&](handler &cgh) {
2221 auto A = a.get_access <access::mode::read_write>(cgh);
2322 cgh.parallel_for <class calc_conv >(r, [=](id<1 > index) {
24- sycl::ext::oneapi::experimental:: bfloat16 AVal{A[index]};
23+ sycl::ext::oneapi::bfloat16 AVal{A[index]};
2524 A[index] = AVal;
2625 });
2726 });
@@ -34,9 +33,8 @@ void verify_conv_explicit(queue &q, buffer<float, 1> &a, range<1> &r,
3433 q.submit ([&](handler &cgh) {
3534 auto A = a.get_access <access::mode::read_write>(cgh);
3635 cgh.parallel_for <class calc_conv_impl >(r, [=](id<1 > index) {
37- uint16_t AVal =
38- sycl::ext::oneapi::experimental::bfloat16::from_float (A[index]);
39- A[index] = sycl::ext::oneapi::experimental::bfloat16::to_float (AVal);
36+ sycl::ext::oneapi::bfloat16 AVal = A[index];
37+ A[index] = float (AVal);
4038 });
4139 });
4240
@@ -52,9 +50,9 @@ void verify_add(queue &q, buffer<float, 1> &a, buffer<float, 1> &b, range<1> &r,
5250 auto B = b.get_access <access::mode::read>(cgh);
5351 auto C = c.get_access <access::mode::write>(cgh);
5452 cgh.parallel_for <class calc_add_expl >(r, [=](id<1 > index) {
55- sycl::ext::oneapi::experimental:: bfloat16 AVal{A[index]};
56- sycl::ext::oneapi::experimental:: bfloat16 BVal{B[index]};
57- sycl::ext::oneapi::experimental:: bfloat16 CVal = AVal + BVal;
53+ sycl::ext::oneapi::bfloat16 AVal{A[index]};
54+ sycl::ext::oneapi::bfloat16 BVal{B[index]};
55+ sycl::ext::oneapi::bfloat16 CVal = AVal + BVal;
5856 C[index] = CVal;
5957 });
6058 });
@@ -71,9 +69,9 @@ void verify_sub(queue &q, buffer<float, 1> &a, buffer<float, 1> &b, range<1> &r,
7169 auto B = b.get_access <access::mode::read>(cgh);
7270 auto C = c.get_access <access::mode::write>(cgh);
7371 cgh.parallel_for <class calc_sub >(r, [=](id<1 > index) {
74- sycl::ext::oneapi::experimental:: bfloat16 AVal{A[index]};
75- sycl::ext::oneapi::experimental:: bfloat16 BVal{B[index]};
76- sycl::ext::oneapi::experimental:: bfloat16 CVal = AVal - BVal;
72+ sycl::ext::oneapi::bfloat16 AVal{A[index]};
73+ sycl::ext::oneapi::bfloat16 BVal{B[index]};
74+ sycl::ext::oneapi::bfloat16 CVal = AVal - BVal;
7775 C[index] = CVal;
7876 });
7977 });
@@ -88,8 +86,8 @@ void verify_minus(queue &q, buffer<float, 1> &a, range<1> &r, const float ref) {
8886 auto A = a.get_access <access::mode::read>(cgh);
8987 auto C = c.get_access <access::mode::write>(cgh);
9088 cgh.parallel_for <class calc_minus >(r, [=](id<1 > index) {
91- sycl::ext::oneapi::experimental:: bfloat16 AVal{A[index]};
92- sycl::ext::oneapi::experimental:: bfloat16 CVal = -AVal;
89+ sycl::ext::oneapi::bfloat16 AVal{A[index]};
90+ sycl::ext::oneapi::bfloat16 CVal = -AVal;
9391 C[index] = CVal;
9492 });
9593 });
@@ -106,9 +104,9 @@ void verify_mul(queue &q, buffer<float, 1> &a, buffer<float, 1> &b, range<1> &r,
106104 auto B = b.get_access <access::mode::read>(cgh);
107105 auto C = c.get_access <access::mode::write>(cgh);
108106 cgh.parallel_for <class calc_mul >(r, [=](id<1 > index) {
109- sycl::ext::oneapi::experimental:: bfloat16 AVal{A[index]};
110- sycl::ext::oneapi::experimental:: bfloat16 BVal{B[index]};
111- sycl::ext::oneapi::experimental:: bfloat16 CVal = AVal * BVal;
107+ sycl::ext::oneapi::bfloat16 AVal{A[index]};
108+ sycl::ext::oneapi::bfloat16 BVal{B[index]};
109+ sycl::ext::oneapi::bfloat16 CVal = AVal * BVal;
112110 C[index] = CVal;
113111 });
114112 });
@@ -125,9 +123,9 @@ void verify_div(queue &q, buffer<float, 1> &a, buffer<float, 1> &b, range<1> &r,
125123 auto B = b.get_access <access::mode::read>(cgh);
126124 auto C = c.get_access <access::mode::write>(cgh);
127125 cgh.parallel_for <class calc_div >(r, [=](id<1 > index) {
128- sycl::ext::oneapi::experimental:: bfloat16 AVal{A[index]};
129- sycl::ext::oneapi::experimental:: bfloat16 BVal{B[index]};
130- sycl::ext::oneapi::experimental:: bfloat16 CVal = AVal / BVal;
126+ sycl::ext::oneapi::bfloat16 AVal{A[index]};
127+ sycl::ext::oneapi::bfloat16 BVal{B[index]};
128+ sycl::ext::oneapi::bfloat16 CVal = AVal / BVal;
131129 C[index] = CVal;
132130 });
133131 });
@@ -144,19 +142,18 @@ void verify_logic(queue &q, buffer<float, 1> &a, buffer<float, 1> &b,
144142 auto B = b.get_access <access::mode::read>(cgh);
145143 auto C = c.get_access <access::mode::write>(cgh);
146144 cgh.parallel_for <class logic >(r, [=](id<1 > index) {
147- sycl::ext::oneapi::experimental:: bfloat16 AVal{A[index]};
148- sycl::ext::oneapi::experimental:: bfloat16 BVal{B[index]};
145+ sycl::ext::oneapi::bfloat16 AVal{A[index]};
146+ sycl::ext::oneapi::bfloat16 BVal{B[index]};
149147 if (AVal) {
150148 if (AVal > BVal || AVal >= BVal || AVal < BVal || AVal <= BVal ||
151149 !BVal) {
152- sycl::ext::oneapi::experimental::bfloat16 CVal =
153- AVal != BVal ? AVal : BVal;
150+ sycl::ext::oneapi::bfloat16 CVal = AVal != BVal ? AVal : BVal;
154151 CVal--;
155152 CVal++;
156153 if (AVal == BVal) {
157154 CVal -= AVal;
158- CVal *= 3.0 ;
159- CVal /= 2.0 ;
155+ CVal *= 3 .0f ;
156+ CVal /= 2 .0f ;
160157 } else
161158 CVal += BVal;
162159 C[index] = CVal;
@@ -179,9 +176,9 @@ int run_tests() {
179176 return 0 ;
180177 }
181178
182- std::vector<float > vec_a (N, 5.0 );
183- std::vector<float > vec_b (N, 2.0 );
184- std::vector<float > vec_b_neg (N, -2.0 );
179+ std::vector<float > vec_a (N, 5 .0f );
180+ std::vector<float > vec_b (N, 2 .0f );
181+ std::vector<float > vec_b_neg (N, -2 .0f );
185182
186183 range<1 > r (N);
187184 buffer<float , 1 > a{vec_a.data (), r};
@@ -190,19 +187,32 @@ int run_tests() {
190187
191188 queue q{dev};
192189
193- verify_conv_implicit (q, a, r, 5.0 );
194- verify_conv_explicit (q, a, r, 5.0 );
195- verify_add (q, a, b, r, 7.0 );
196- verify_sub (q, a, b, r, 3.0 );
197- verify_mul (q, a, b, r, 10.0 );
198- verify_div (q, a, b, r, 2.5 );
199- verify_logic (q, a, b, r, 7.0 );
200- verify_add (q, a, b_neg, r, 3.0 );
201- verify_sub (q, a, b_neg, r, 7.0 );
202- verify_minus (q, a, r, -5.0 );
203- verify_mul (q, a, b_neg, r, -10.0 );
204- verify_div (q, a, b_neg, r, -2.5 );
205- verify_logic (q, a, b_neg, r, 3.0 );
190+ verify_conv_implicit (q, a, r, 5 .0f );
191+ std::cout << " PASS verify_conv_implicit\n " ;
192+ verify_conv_explicit (q, a, r, 5 .0f );
193+ std::cout << " PASS verify_conv_explicit\n " ;
194+ verify_add (q, a, b, r, 7 .0f );
195+ std::cout << " PASS verify_add\n " ;
196+ verify_sub (q, a, b, r, 3 .0f );
197+ std::cout << " PASS verify_sub\n " ;
198+ verify_mul (q, a, b, r, 10 .0f );
199+ std::cout << " PASS verify_mul\n " ;
200+ verify_div (q, a, b, r, 2 .5f );
201+ std::cout << " PASS verify_div\n " ;
202+ verify_logic (q, a, b, r, 7 .0f );
203+ std::cout << " PASS verify_logic\n " ;
204+ verify_add (q, a, b_neg, r, 3 .0f );
205+ std::cout << " PASS verify_add\n " ;
206+ verify_sub (q, a, b_neg, r, 7 .0f );
207+ std::cout << " PASS verify_sub\n " ;
208+ verify_minus (q, a, r, -5 .0f );
209+ std::cout << " PASS verify_minus\n " ;
210+ verify_mul (q, a, b_neg, r, -10 .0f );
211+ std::cout << " PASS verify_mul\n " ;
212+ verify_div (q, a, b_neg, r, -2 .5f );
213+ std::cout << " PASS verify_div\n " ;
214+ verify_logic (q, a, b_neg, r, 3 .0f );
215+ std::cout << " PASS verify_logic\n " ;
206216
207217 return 0 ;
208218}
0 commit comments