@@ -54,10 +54,20 @@ specification.*
54
54
While {dpcpp} has support for `std::complex` in device code, it limits the
55
55
complex interface and operations to the existing C++ standard. This proposal
56
56
defines a SYCL complex extension based on but independent of the `std::complex`
57
- interface. This framework would allow for further development of complex math
58
- within oneAPI. Possible areas for deviation with `std::complex` include adding
59
- complex support for `marray` and `vec` and overloading mathematical
60
- functions to handle the element-wise operations.
57
+ interface.
58
+
59
+ The proposed framework not only encompasses complex support for traditional use
60
+ cases but also accommodate for advanced mathematical features and data
61
+ structures.
62
+
63
+ Specifically, we propose to incorporate complex support for `sycl::marray`.
64
+ This addition will empower developers to store complex numbers seamlessly
65
+ within arrays, opening up new possibilities for data manipulation and
66
+ computation.
67
+
68
+ Furthermore, this extension involves overloading existing mathematical
69
+ functions to facilitate scalar operation on complex numbers as well as
70
+ element-wise operations on complex marrays.
61
71
62
72
== Specification
63
73
@@ -211,17 +221,125 @@ namespace sycl::ext::oneapi::experimental {
211
221
} // namespace sycl::ext::oneapi::experimental
212
222
```
213
223
214
- === Mathematical operations
224
+ === Marray Complex Class Specialization
225
+
226
+ This proposal also introduces the specialization of the marray class to
227
+ support SYCL complex. The marray class undergoes slight modification for this
228
+ specialization, primarily involving the removal of operators that are
229
+ inapplicable. No new functions or operators are introduced to the marray class.
230
+
231
+ The marray complex specialization maintains the principles of trivial
232
+ copyability (as seen in the Complex class description), with the
233
+ `is_device_copyable` type trait resolving to `std::true_type`.
234
+
235
+ The marray definition used within this proposal assumes that any operator the
236
+ `sycl::marray` class defines is only implemented if the marray's value type
237
+ also implements the operator.
238
+
239
+ For instance,
240
+ `sycl::marray<sycl::ext::oneapi::experimental::complex<T>, NumElements>` does
241
+ not implement the modulus operator since
242
+ `sycl::ext::oneapi::experimental::complex<T>` does not support it.
243
+
244
+ ```C++
245
+ namespace sycl {
246
+
247
+ // Specialization of exiting `marray` class for `sycl::ext::oneapi::experimental::complex`
248
+ template <typename T, std::size_t NumElements>
249
+ class marray<sycl::ext::oneapi::experimental::complex<T>, NumElements> {
250
+ public:
251
+
252
+ /* ... */
253
+
254
+ friend marray operator %(const marray &lhs, const marray &rhs) = delete;
255
+ friend marray operator %(const marray &lhs, const value_type &rhs) = delete;
256
+ friend marray operator %(const value_type &lhs, const marray &rhs) = delete;
257
+
258
+ friend marray &operator %=(marray &lhs, const marray &rhs) = delete;
259
+ friend marray &operator %=(marray &lhs, const value_type &rhs) = delete;
260
+ friend marray &operator %=(value_type &lhs, const marray &rhs) = delete;
261
+
262
+ friend marray operator ++(marray &lhs, int) = delete;
263
+ friend marray &operator ++(marray & rhs) = delete;
264
+
265
+ friend marray operator --(marray &lhs, int) = delete;
266
+ friend marray &operator --(marray & rhs) = delete;
267
+
268
+ friend marray operator &(const marray &lhs, const marray &rhs) = delete;
269
+ friend marray operator &(const marray &lhs, const value_type &rhs) = delete;
270
+
271
+ friend marray operator |(const marray &lhs, const marray &rhs) = delete;
272
+ friend marray operator |(const marray &lhs, const value_type &rhs) = delete;
273
+
274
+ friend marray operator ^(const marray &lhs, const marray &rhs) = delete;
275
+ friend marray operator ^(const marray &lhs, const value_type &rhs) = delete;
276
+
277
+ friend marray &operator &=(marray & lhs, const marray & rhs) = delete;
278
+ friend marray &operator &=(marray & lhs, const value_type & rhs) = delete;
279
+ friend marray &operator &=(value_type & lhs, const marray & rhs) = delete;
280
+
281
+ friend marray &operator |=(marray & lhs, const marray & rhs) = delete;
282
+ friend marray &operator |=(marray & lhs, const value_type & rhs) = delete;
283
+ friend marray &operator |=(value_type & lhs, const marray & rhs) = delete;
284
+
285
+ friend marray &operator ^=(marray & lhs, const marray & rhs) = delete;
286
+ friend marray &operator ^=(marray & lhs, const value_type & rhs) = delete;
287
+ friend marray &operator ^=(value_type & lhs, const marray & rhs) = delete;
288
+
289
+ friend marray<bool, NumElements> operator <<(const marray & lhs, const marray & rhs) = delete;
290
+ friend marray<bool, NumElements> operator <<(const marray & lhs, const value_type & rhs) = delete;
291
+ friend marray<bool, NumElements> operator <<(const value_type & lhs, const marray & rhs) = delete;
292
+
293
+ friend marray<bool, NumElements> operator >>(const marray & lhs, const marray & rhs) = delete;
294
+ friend marray<bool, NumElements> operator >>(const marray & lhs, const value_type & rhs) = delete;
295
+ friend marray<bool, NumElements> operator >>(const value_type & lhs, const marray & rhs) = delete;
296
+
297
+ friend marray &operator <<=(marray & lhs, const marray & rhs) = delete;
298
+ friend marray &operator <<=(marray & lhs, const value_type & rhs) = delete;
299
+
300
+ friend marray &operator >>=(marray & lhs, const marray & rhs) = delete;
301
+ friend marray &operator >>=(marray & lhs, const value_type & rhs) = delete;
215
302
216
- This proposal adds to the `sycl::ext::oneapi::experimental` namespace, math
217
- functions accepting the complex types `complex<sycl::half>`, `complex<float>`,
218
- `complex<double>` as well as the scalar types `sycl::half`, `float` and `double`
219
- for the SYCL math functions, `abs`, `acos`, `asin`, `atan`, `acosh`, `asinh`,
220
- `atanh`, `arg`, `conj`, `cos`, `cosh`, `exp`, `log`, `log10`, `norm`, `polar`,
221
- `pow`, `proj`, `sin`, `sinh`, `sqrt`, `tan`, and `tanh`.
303
+ friend marray<bool, NumElements> operator <(const marray & lhs, const marray & rhs) = delete;
304
+ friend marray<bool, NumElements> operator <(const marray & lhs, const value_type & rhs) = delete;
305
+ friend marray<bool, NumElements> operator <(const value_type & lhs, const marray & rhs) = delete;
306
+
307
+ friend marray<bool, NumElements> operator >(const marray & lhs, const marray & rhs) = delete;
308
+ friend marray<bool, NumElements> operator >(const marray & lhs, const value_type & rhs) = delete;
309
+ friend marray<bool, NumElements> operator >(const value_type & lhs, const marray & rhs) = delete;
310
+
311
+ friend marray<bool, NumElements> operator <=(const marray & lhs, const marray & rhs) = delete;
312
+ friend marray<bool, NumElements> operator <=(const marray & lhs, const value_type & rhs) = delete;
313
+ friend marray<bool, NumElements> operator <=(const value_type & lhs, const marray & rhs) = delete;
314
+
315
+ friend marray<bool, NumElements> operator >=(const marray & lhs, const marray & rhs) = delete;
316
+ friend marray<bool, NumElements> operator >=(const marray & lhs, const value_type & rhs) = delete;
317
+ friend marray<bool, NumElements> operator >=(const value_type & lhs, const marray & rhs) = delete;
318
+
319
+ friend marray operator ~(const marray &v) = delete;
320
+
321
+ friend marray<bool, NumElements> operator !(const marray &v) = delete;
322
+ };
323
+
324
+ } // namespace sycl
325
+ ```
326
+
327
+ === Scalar Mathematical operations
328
+
329
+ This proposal extends the `sycl::ext::oneapi::experimental` namespace math
330
+ functions to accept `complex<sycl::half>`, `complex<float>`, `complex<double>`
331
+ as well as the scalar types `sycl::half`, `float` and `double` for a range of
332
+ SYCL math functions.
333
+
334
+ Specifically, it adds support for `abs`, `acos`, `asin`, `atan`, `acosh`,
335
+ `asinh`, `atanh`, `arg`, `conj`, `cos`, `cosh`, `exp`, `log`, `log10`, `norm`,
336
+ `polar`, `pow`, `proj`, `sin`, `sinh`, `sqrt`, `tan`, and `tanh`.
337
+
338
+ Additionally, this extension introduces support for the `real` and `imag` free
339
+ functions, which the real and imaginary component, respectively.
222
340
223
341
These functions are available in both host and device code, and each math
224
- function should follow the C++ standard for handling NaN's and Inf values.
342
+ function should follow the C++ standard for handling ` NaN` and ` Inf` values.
225
343
226
344
Note: In the case of the `pow` function, additional overloads have been added
227
345
to ensure that for their first argument `base` and second argument `exponent`:
@@ -319,6 +437,123 @@ namespace sycl::ext::oneapi::experimental {
319
437
} // namespace sycl::ext::oneapi::experimental
320
438
```
321
439
440
+ === Element-Wise Mathematical operations
441
+
442
+ In harmony with the complex scalar operations, this proposal extends
443
+ furthermore the `sycl::ext::oneapi::experimental`` namespace math functions
444
+ to accept `sycl::marray<complex<T>>` for a range of SYCL math functions.
445
+
446
+ Specifically, it adds support for `abs`, `acos`, `asin`, `atan`, `acosh`,
447
+ `asinh`, `atanh`, `arg`, `conj`, `cos`, `cosh`, `exp`, `log`, `log10`, `norm`,
448
+ `polar`, `pow`, `proj`, `sin`, `sinh`, `sqrt`, `tan`, and `tanh`.
449
+
450
+ Additionally, this extension introduces support for the `real` and `imag` free
451
+ functions, which return marrays of scalar values representing the real and
452
+ imaginary components, respectively.
453
+
454
+ In scenarios where mathematical functions involve both marray and scalar
455
+ parameters, two sets of overloads are introduced marray-scalar and
456
+ scalar-marray.
457
+
458
+ These mathematical operations are designed to execute element-wise across the
459
+ marray, ensuring that each operation is applied to every element within the
460
+ marray.
461
+
462
+ Moreover, this proposal includes overloads for mathematical functions between
463
+ marrays and scalar inputs. In these cases, the operations are executed across
464
+ the entire marray, with the scalar value held constant.
465
+
466
+ For consistency, these functions are available in both host and device code,
467
+ and each math function should follow the C++ standard for handling `NaN` and
468
+ `Inf` values.
469
+
470
+ ```C++
471
+ namespace sycl/ext/oneapi/experimental {
472
+
473
+ /// VALUES:
474
+ /// Returns an marray of real components from the marray x.
475
+ template <typename T, std::size_t NumElements>
476
+ sycl::marray<T, NumElements> real(const marray<complex<T>, NumElements> &x);
477
+ /// Returns an marray of imaginary components from the marray x.
478
+ template <typename T, std::size_t NumElements>
479
+ sycl::marray<T, NumElements> imag(const marray<complex<T>, NumElements> &x);
480
+
481
+ /// Compute the magnitude for each complex number in marray x.
482
+ template <typename T, std::size_t NumElements> marray<T, NumElements> abs(const marray<complex<T>, NumElements> &x);
483
+ /// Compute phase angle in radians for each complex number in marray x.
484
+ template <typename T, std::size_t NumElements> marray<T, NumElements> arg(const marray<complex<T>, NumElements> &x);
485
+ /// Compute the squared magnitude for each complex number in marray x.
486
+ template <typename T, std::size_t NumElements> marray<T, NumElements> norm(const marray<complex<T>, NumElements> &x);
487
+ /// Compute the conjugate for each complex number in marray x.
488
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> conj(const marray<complex<T>, NumElements> &x);
489
+ /// Compute the projection for each complex number in marray x.
490
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> proj(const marray<complex<T>, NumElements> &x);
491
+ /// Compute the projection for each real number in marray x.
492
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> proj(const marray<T, NumElements> &x);
493
+ /// Construct an marray, elementwise, of complex numbers from each polar coordinate in marray rho and scalar theta.
494
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> polar(const marray<T, NumElements> &rho, T theta = 0);
495
+ /// Construct an marray, elementwise, of complex numbers from each polar coordinate in marray rho and marray theta.
496
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> polar(const marray<T, NumElements> &rho, const marray<T, NumElements> &theta);
497
+ /// Construct an marray, elementwise, of complex numbers from each polar coordinate in scalar rho and marray theta.
498
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> polar(T rho, const marray<T, NumElements> &theta);
499
+
500
+ /// TRANSCENDENTALS:
501
+ /// Compute the natural log for each complex number in marray x.
502
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> log(const marray<complex<T>, NumElements> &x);
503
+ /// Compute the base-10 log for each complex number in marray x.
504
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> log10(const marray<complex<T>, NumElements> &x);
505
+ /// Compute the square root for each complex number in marray x.
506
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> sqrt(const marray<complex<T>, NumElements> &x);
507
+ /// Compute the base-e exponent for each complex number in marray x.
508
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> exp(const marray<complex<T>, NumElements> &x);
509
+
510
+ /// Raise each complex element in x to the power of the corresponding decimal element in y.
511
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> pow(const marray<complex<T>, NumElements> &x, const marray<T, NumElements> &y);
512
+ /// Raise each complex element in x to the power of the decimal number y.
513
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> pow(const marray<complex<T>, NumElements> &x, T y);
514
+ /// Raise complex number x to the power of each decimal element in y.
515
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> pow(const marray<complex<T>, NumElements> &x, const marray<T, NumElements> &y);
516
+ /// Raise each complex element in x to the power of the corresponding complex element in y.
517
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> pow(const marray<complex<T>, NumElements> &x, const marray<complex<T>, NumElements> &y);
518
+ /// Raise each complex element in x to the power of the complex number y.
519
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> pow(const marray<complex<T>, NumElements> &x, const marray<complex<T>, NumElements> &y);
520
+ /// Raise complex number x to the power of each complex element in y.
521
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> pow(const marray<complex<T>, NumElements> &x, const marray<complex<T>, NumElements> &y);
522
+ /// Raise each decimal element in x to the power of the corresponding complex element in y.
523
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> pow(const marray<T, NumElements> &x, const marray<complex<T>, NumElements> &y);
524
+ /// Raise each decimal element in x to the power of the complex number y.
525
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> pow(const marray<T, NumElements> &x, const marray<complex<T>, NumElements> &y);
526
+ /// Raise decimal number x to the power of each complex element in y.
527
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> pow(T x, const marray<complex<T>, NumElements> &y);
528
+
529
+ /// Compute the inverse hyperbolic sine for each complex number in marray x.
530
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> asinh(const marray<complex<T>, NumElements> &x);
531
+ /// Compute the inverse hyperbolic cosine for each complex number in marray x.
532
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> acosh(const marray<complex<T>, NumElements> &x);
533
+ /// Compute the inverse hyperbolic tangent for each complex number in marray x.
534
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> atanh(const marray<complex<T>, NumElements> &x);
535
+ /// Compute the hyperbolic sine for each complex number in marray x.
536
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> sinh(const marray<complex<T>, NumElements> &x);
537
+ /// Compute the hyperbolic cosine for each complex number in marray x.
538
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> cosh(const marray<complex<T>, NumElements> &x);
539
+ /// Compute the hyperbolic tangent for each complex number in marray x.
540
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> tanh(const marray<complex<T>, NumElements> &x);
541
+ /// Compute the inverse sine for each complex number in marray x.
542
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> asin(const marray<complex<T>, NumElements> &x);
543
+ /// Compute the inverse cosine for each complex number in marray x.
544
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> acos(const marray<complex<T>, NumElements> &x);
545
+ /// Compute the inverse tangent for each complex number in marray x.
546
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> atan(const marray<complex<T>, NumElements> &x);
547
+ /// Compute the sine for each complex number in marray x.
548
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> sin(const marray<complex<T>, NumElements> &x);
549
+ /// Compute the cosine for each complex number in marray x.
550
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> cos(const marray<complex<T>, NumElements> &x);
551
+ /// Compute the tangent for each complex number in marray x.
552
+ template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> tan(const marray<complex<T>, NumElements> &x);
553
+
554
+ } // namespace sycl::ext::oneapi::experimental
555
+ ```
556
+
322
557
== Implementation notes
323
558
324
559
The complex mathematical operations can all be defined using SYCL built-ins.
0 commit comments