Skip to content

Commit 261c638

Browse files
committed
Expand complex extension with complex support for sycl::marray
1 parent d38206c commit 261c638

File tree

1 file changed

+247
-12
lines changed

1 file changed

+247
-12
lines changed

sycl/doc/extensions/experimental/sycl_ext_oneapi_complex.asciidoc

+247-12
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,20 @@ specification.*
5454
While {dpcpp} has support for `std::complex` in device code, it limits the
5555
complex interface and operations to the existing C++ standard. This proposal
5656
defines a SYCL complex extension based on but independent of the `std::complex`
57-
interface. This framework would allow for further development of complex math
58-
within oneAPI. Possible areas for deviation with `std::complex` include adding
59-
complex support for `marray` and `vec` and overloading mathematical
60-
functions to handle the element-wise operations.
57+
interface.
58+
59+
The proposed framework not only encompasses complex support for traditional use
60+
cases but also accommodate for advanced mathematical features and data
61+
structures.
62+
63+
Specifically, we propose to incorporate complex support for `sycl::marray`.
64+
This addition will empower developers to store complex numbers seamlessly
65+
within arrays, opening up new possibilities for data manipulation and
66+
computation.
67+
68+
Furthermore, this extension involves overloading existing mathematical
69+
functions to facilitate scalar operation on complex numbers as well as
70+
element-wise operations on complex marrays.
6171

6272
== Specification
6373

@@ -211,17 +221,125 @@ namespace sycl::ext::oneapi::experimental {
211221
} // namespace sycl::ext::oneapi::experimental
212222
```
213223

214-
=== Mathematical operations
224+
=== Marray Complex Class Specialization
225+
226+
This proposal also introduces the specialization of the marray class to
227+
support SYCL complex. The marray class undergoes slight modification for this
228+
specialization, primarily involving the removal of operators that are
229+
inapplicable. No new functions or operators are introduced to the marray class.
230+
231+
The marray complex specialization maintains the principles of trivial
232+
copyability (as seen in the Complex class description), with the
233+
`is_device_copyable` type trait resolving to `std::true_type`.
234+
235+
The marray definition used within this proposal assumes that any operator the
236+
`sycl::marray` class defines is only implemented if the marray's value type
237+
also implements the operator.
238+
239+
For instance,
240+
`sycl::marray<sycl::ext::oneapi::experimental::complex<T>, NumElements>` does
241+
not implement the modulus operator since
242+
`sycl::ext::oneapi::experimental::complex<T>` does not support it.
243+
244+
```C++
245+
namespace sycl {
246+
247+
// Specialization of exiting `marray` class for `sycl::ext::oneapi::experimental::complex`
248+
template <typename T, std::size_t NumElements>
249+
class marray<sycl::ext::oneapi::experimental::complex<T>, NumElements> {
250+
public:
251+
252+
/* ... */
253+
254+
friend marray operator %(const marray &lhs, const marray &rhs) = delete;
255+
friend marray operator %(const marray &lhs, const value_type &rhs) = delete;
256+
friend marray operator %(const value_type &lhs, const marray &rhs) = delete;
257+
258+
friend marray &operator %=(marray &lhs, const marray &rhs) = delete;
259+
friend marray &operator %=(marray &lhs, const value_type &rhs) = delete;
260+
friend marray &operator %=(value_type &lhs, const marray &rhs) = delete;
261+
262+
friend marray operator ++(marray &lhs, int) = delete;
263+
friend marray &operator ++(marray & rhs) = delete;
264+
265+
friend marray operator --(marray &lhs, int) = delete;
266+
friend marray &operator --(marray & rhs) = delete;
267+
268+
friend marray operator &(const marray &lhs, const marray &rhs) = delete;
269+
friend marray operator &(const marray &lhs, const value_type &rhs) = delete;
270+
271+
friend marray operator |(const marray &lhs, const marray &rhs) = delete;
272+
friend marray operator |(const marray &lhs, const value_type &rhs) = delete;
273+
274+
friend marray operator ^(const marray &lhs, const marray &rhs) = delete;
275+
friend marray operator ^(const marray &lhs, const value_type &rhs) = delete;
276+
277+
friend marray &operator &=(marray & lhs, const marray & rhs) = delete;
278+
friend marray &operator &=(marray & lhs, const value_type & rhs) = delete;
279+
friend marray &operator &=(value_type & lhs, const marray & rhs) = delete;
280+
281+
friend marray &operator |=(marray & lhs, const marray & rhs) = delete;
282+
friend marray &operator |=(marray & lhs, const value_type & rhs) = delete;
283+
friend marray &operator |=(value_type & lhs, const marray & rhs) = delete;
284+
285+
friend marray &operator ^=(marray & lhs, const marray & rhs) = delete;
286+
friend marray &operator ^=(marray & lhs, const value_type & rhs) = delete;
287+
friend marray &operator ^=(value_type & lhs, const marray & rhs) = delete;
288+
289+
friend marray<bool, NumElements> operator <<(const marray & lhs, const marray & rhs) = delete;
290+
friend marray<bool, NumElements> operator <<(const marray & lhs, const value_type & rhs) = delete;
291+
friend marray<bool, NumElements> operator <<(const value_type & lhs, const marray & rhs) = delete;
292+
293+
friend marray<bool, NumElements> operator >>(const marray & lhs, const marray & rhs) = delete;
294+
friend marray<bool, NumElements> operator >>(const marray & lhs, const value_type & rhs) = delete;
295+
friend marray<bool, NumElements> operator >>(const value_type & lhs, const marray & rhs) = delete;
296+
297+
friend marray &operator <<=(marray & lhs, const marray & rhs) = delete;
298+
friend marray &operator <<=(marray & lhs, const value_type & rhs) = delete;
299+
300+
friend marray &operator >>=(marray & lhs, const marray & rhs) = delete;
301+
friend marray &operator >>=(marray & lhs, const value_type & rhs) = delete;
215302

216-
This proposal adds to the `sycl::ext::oneapi::experimental` namespace, math
217-
functions accepting the complex types `complex<sycl::half>`, `complex<float>`,
218-
`complex<double>` as well as the scalar types `sycl::half`, `float` and `double`
219-
for the SYCL math functions, `abs`, `acos`, `asin`, `atan`, `acosh`, `asinh`,
220-
`atanh`, `arg`, `conj`, `cos`, `cosh`, `exp`, `log`, `log10`, `norm`, `polar`,
221-
`pow`, `proj`, `sin`, `sinh`, `sqrt`, `tan`, and `tanh`.
303+
friend marray<bool, NumElements> operator <(const marray & lhs, const marray & rhs) = delete;
304+
friend marray<bool, NumElements> operator <(const marray & lhs, const value_type & rhs) = delete;
305+
friend marray<bool, NumElements> operator <(const value_type & lhs, const marray & rhs) = delete;
306+
307+
friend marray<bool, NumElements> operator >(const marray & lhs, const marray & rhs) = delete;
308+
friend marray<bool, NumElements> operator >(const marray & lhs, const value_type & rhs) = delete;
309+
friend marray<bool, NumElements> operator >(const value_type & lhs, const marray & rhs) = delete;
310+
311+
friend marray<bool, NumElements> operator <=(const marray & lhs, const marray & rhs) = delete;
312+
friend marray<bool, NumElements> operator <=(const marray & lhs, const value_type & rhs) = delete;
313+
friend marray<bool, NumElements> operator <=(const value_type & lhs, const marray & rhs) = delete;
314+
315+
friend marray<bool, NumElements> operator >=(const marray & lhs, const marray & rhs) = delete;
316+
friend marray<bool, NumElements> operator >=(const marray & lhs, const value_type & rhs) = delete;
317+
friend marray<bool, NumElements> operator >=(const value_type & lhs, const marray & rhs) = delete;
318+
319+
friend marray operator ~(const marray &v) = delete;
320+
321+
friend marray<bool, NumElements> operator !(const marray &v) = delete;
322+
};
323+
324+
} // namespace sycl
325+
```
326+
327+
=== Scalar Mathematical operations
328+
329+
This proposal extends the `sycl::ext::oneapi::experimental` namespace math
330+
functions to accept `complex<sycl::half>`, `complex<float>`, `complex<double>`
331+
as well as the scalar types `sycl::half`, `float` and `double` for a range of
332+
SYCL math functions.
333+
334+
Specifically, it adds support for `abs`, `acos`, `asin`, `atan`, `acosh`,
335+
`asinh`, `atanh`, `arg`, `conj`, `cos`, `cosh`, `exp`, `log`, `log10`, `norm`,
336+
`polar`, `pow`, `proj`, `sin`, `sinh`, `sqrt`, `tan`, and `tanh`.
337+
338+
Additionally, this extension introduces support for the `real` and `imag` free
339+
functions, which the real and imaginary component, respectively.
222340

223341
These functions are available in both host and device code, and each math
224-
function should follow the C++ standard for handling NaN's and Inf values.
342+
function should follow the C++ standard for handling `NaN` and `Inf` values.
225343

226344
Note: In the case of the `pow` function, additional overloads have been added
227345
to ensure that for their first argument `base` and second argument `exponent`:
@@ -319,6 +437,123 @@ namespace sycl::ext::oneapi::experimental {
319437
} // namespace sycl::ext::oneapi::experimental
320438
```
321439

440+
=== Element-Wise Mathematical operations
441+
442+
In harmony with the complex scalar operations, this proposal extends
443+
furthermore the `sycl::ext::oneapi::experimental`` namespace math functions
444+
to accept `sycl::marray<complex<T>>` for a range of SYCL math functions.
445+
446+
Specifically, it adds support for `abs`, `acos`, `asin`, `atan`, `acosh`,
447+
`asinh`, `atanh`, `arg`, `conj`, `cos`, `cosh`, `exp`, `log`, `log10`, `norm`,
448+
`polar`, `pow`, `proj`, `sin`, `sinh`, `sqrt`, `tan`, and `tanh`.
449+
450+
Additionally, this extension introduces support for the `real` and `imag` free
451+
functions, which return marrays of scalar values representing the real and
452+
imaginary components, respectively.
453+
454+
In scenarios where mathematical functions involve both marray and scalar
455+
parameters, two sets of overloads are introduced marray-scalar and
456+
scalar-marray.
457+
458+
These mathematical operations are designed to execute element-wise across the
459+
marray, ensuring that each operation is applied to every element within the
460+
marray.
461+
462+
Moreover, this proposal includes overloads for mathematical functions between
463+
marrays and scalar inputs. In these cases, the operations are executed across
464+
the entire marray, with the scalar value held constant.
465+
466+
For consistency, these functions are available in both host and device code,
467+
and each math function should follow the C++ standard for handling `NaN` and
468+
`Inf` values.
469+
470+
```C++
471+
namespace sycl/ext/oneapi/experimental {
472+
473+
/// VALUES:
474+
/// Returns an marray of real components from the marray x.
475+
template <typename T, std::size_t NumElements>
476+
sycl::marray<T, NumElements> real(const marray<complex<T>, NumElements> &x);
477+
/// Returns an marray of imaginary components from the marray x.
478+
template <typename T, std::size_t NumElements>
479+
sycl::marray<T, NumElements> imag(const marray<complex<T>, NumElements> &x);
480+
481+
/// Compute the magnitude for each complex number in marray x.
482+
template <typename T, std::size_t NumElements> marray<T, NumElements> abs(const marray<complex<T>, NumElements> &x);
483+
/// Compute phase angle in radians for each complex number in marray x.
484+
template <typename T, std::size_t NumElements> marray<T, NumElements> arg(const marray<complex<T>, NumElements> &x);
485+
/// Compute the squared magnitude for each complex number in marray x.
486+
template <typename T, std::size_t NumElements> marray<T, NumElements> norm(const marray<complex<T>, NumElements> &x);
487+
/// Compute the conjugate for each complex number in marray x.
488+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> conj(const marray<complex<T>, NumElements> &x);
489+
/// Compute the projection for each complex number in marray x.
490+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> proj(const marray<complex<T>, NumElements> &x);
491+
/// Compute the projection for each real number in marray x.
492+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> proj(const marray<T, NumElements> &x);
493+
/// Construct an marray, elementwise, of complex numbers from each polar coordinate in marray rho and scalar theta.
494+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> polar(const marray<T, NumElements> &rho, T theta = 0);
495+
/// Construct an marray, elementwise, of complex numbers from each polar coordinate in marray rho and marray theta.
496+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> polar(const marray<T, NumElements> &rho, const marray<T, NumElements> &theta);
497+
/// Construct an marray, elementwise, of complex numbers from each polar coordinate in scalar rho and marray theta.
498+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> polar(T rho, const marray<T, NumElements> &theta);
499+
500+
/// TRANSCENDENTALS:
501+
/// Compute the natural log for each complex number in marray x.
502+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> log(const marray<complex<T>, NumElements> &x);
503+
/// Compute the base-10 log for each complex number in marray x.
504+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> log10(const marray<complex<T>, NumElements> &x);
505+
/// Compute the square root for each complex number in marray x.
506+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> sqrt(const marray<complex<T>, NumElements> &x);
507+
/// Compute the base-e exponent for each complex number in marray x.
508+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> exp(const marray<complex<T>, NumElements> &x);
509+
510+
/// Raise each complex element in x to the power of the corresponding decimal element in y.
511+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> pow(const marray<complex<T>, NumElements> &x, const marray<T, NumElements> &y);
512+
/// Raise each complex element in x to the power of the decimal number y.
513+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> pow(const marray<complex<T>, NumElements> &x, T y);
514+
/// Raise complex number x to the power of each decimal element in y.
515+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> pow(const marray<complex<T>, NumElements> &x, const marray<T, NumElements> &y);
516+
/// Raise each complex element in x to the power of the corresponding complex element in y.
517+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> pow(const marray<complex<T>, NumElements> &x, const marray<complex<T>, NumElements> &y);
518+
/// Raise each complex element in x to the power of the complex number y.
519+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> pow(const marray<complex<T>, NumElements> &x, const marray<complex<T>, NumElements> &y);
520+
/// Raise complex number x to the power of each complex element in y.
521+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> pow(const marray<complex<T>, NumElements> &x, const marray<complex<T>, NumElements> &y);
522+
/// Raise each decimal element in x to the power of the corresponding complex element in y.
523+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> pow(const marray<T, NumElements> &x, const marray<complex<T>, NumElements> &y);
524+
/// Raise each decimal element in x to the power of the complex number y.
525+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> pow(const marray<T, NumElements> &x, const marray<complex<T>, NumElements> &y);
526+
/// Raise decimal number x to the power of each complex element in y.
527+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> pow(T x, const marray<complex<T>, NumElements> &y);
528+
529+
/// Compute the inverse hyperbolic sine for each complex number in marray x.
530+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> asinh(const marray<complex<T>, NumElements> &x);
531+
/// Compute the inverse hyperbolic cosine for each complex number in marray x.
532+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> acosh(const marray<complex<T>, NumElements> &x);
533+
/// Compute the inverse hyperbolic tangent for each complex number in marray x.
534+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> atanh(const marray<complex<T>, NumElements> &x);
535+
/// Compute the hyperbolic sine for each complex number in marray x.
536+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> sinh(const marray<complex<T>, NumElements> &x);
537+
/// Compute the hyperbolic cosine for each complex number in marray x.
538+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> cosh(const marray<complex<T>, NumElements> &x);
539+
/// Compute the hyperbolic tangent for each complex number in marray x.
540+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> tanh(const marray<complex<T>, NumElements> &x);
541+
/// Compute the inverse sine for each complex number in marray x.
542+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> asin(const marray<complex<T>, NumElements> &x);
543+
/// Compute the inverse cosine for each complex number in marray x.
544+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> acos(const marray<complex<T>, NumElements> &x);
545+
/// Compute the inverse tangent for each complex number in marray x.
546+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> atan(const marray<complex<T>, NumElements> &x);
547+
/// Compute the sine for each complex number in marray x.
548+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> sin(const marray<complex<T>, NumElements> &x);
549+
/// Compute the cosine for each complex number in marray x.
550+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> cos(const marray<complex<T>, NumElements> &x);
551+
/// Compute the tangent for each complex number in marray x.
552+
template <typename T, std::size_t NumElements> marray<complex<T>, NumElements> tan(const marray<complex<T>, NumElements> &x);
553+
554+
} // namespace sycl::ext::oneapi::experimental
555+
```
556+
322557
== Implementation notes
323558

324559
The complex mathematical operations can all be defined using SYCL built-ins.

0 commit comments

Comments
 (0)