Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[array API] add device argument to fftfreq & rfftfreq #22736

Merged
merged 1 commit into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion jax/_src/basearray.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,8 @@ class Array(abc.ABC):
@property
def traceback(self) -> Traceback: ...
def unsafe_buffer_pointer(self) -> int: ...
def to_device(self, device: Device | Sharding, *, stream: int | Any | None) -> Array: ...
def to_device(self, device: Device | Sharding, *,
stream: int | Any | None = ...) -> Array: ...


StaticScalar = Union[
Expand Down
71 changes: 56 additions & 15 deletions jax/_src/numpy/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
from jax._src.numpy.util import check_arraylike, implements, promote_dtypes_inexact
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import ufuncs, reductions
from jax._src.typing import Array, ArrayLike
from jax._src.sharding import Sharding
from jax._src.typing import Array, ArrayLike, DTypeLike

Shape = Sequence[int]

Expand Down Expand Up @@ -716,12 +717,28 @@ def irfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1),
norm=norm)


@implements(np.fft.fftfreq, extra_params="""
dtype : Optional
The dtype of the returned frequencies. If not specified, JAX's default
floating point dtype will be used.
""")
def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array:
def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None,
device: xla_client.Device | Sharding | None = None) -> Array:
"""Return sample frequencies for the discrete Fourier transform.

JAX implementation of :func:`numpy.fft.fftfreq`. Returns frequencies appropriate
for use with the outputs of :func:`~jax.numpy.fft` and :func:`~jax.numpy.ifft`.

Args:
n: length of the FFT window
d: optional scalar sample spacing (default: 1.0)
dtype: optional dtype of returned frequencies. If not specified, JAX's default
floating point dtype will be used.
device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
to which the created array will be committed.

Returns:
Array of sample frequencies, length ``n``.

See also:
- :func:`jax.numpy.fft.rfftfreq`: frequencies for use with :func:`~jax.numpy.rfft`
and :func:`~jax.numpy.irfft`.
"""
dtype = dtype or dtypes.canonicalize_dtype(jnp.float_)
if isinstance(n, (list, tuple)):
raise ValueError(
Expand All @@ -748,15 +765,35 @@ def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array:
# k[(n - 1) // 2 + 1:] = jnp.arange(-(n - 1) // 2, -1)
k = k.at[(n - 1) // 2 + 1:].set(jnp.arange(-(n - 1) // 2, 0, dtype=dtype))

return k / jnp.array(d * n, dtype=dtype)
result = k / jnp.array(d * n, dtype=dtype)

if device is not None:
return result.to_device(device)
return result


@implements(np.fft.rfftfreq, extra_params="""
dtype : Optional
The dtype of the returned frequencies. If not specified, JAX's default
floating point dtype will be used.
""")
def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array:
def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None,
device: xla_client.Device | Sharding | None = None) -> Array:
"""Return sample frequencies for the discrete Fourier transform.

JAX implementation of :func:`numpy.fft.fftfreq`. Returns frequencies appropriate
for use with the outputs of :func:`~jax.numpy.rfft` and :func:`~jax.numpy.irfft`.

Args:
n: length of the FFT window
d: optional scalar sample spacing (default: 1.0)
dtype: optional dtype of returned frequencies. If not specified, JAX's default
floating point dtype will be used.
device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding`
to which the created array will be committed.

Returns:
Array of sample frequencies, length ``n // 2 + 1``.

See also:
- :func:`jax.numpy.fft.rfftfreq`: frequencies for use with :func:`~jax.numpy.fft`
and :func:`~jax.numpy.ifft`.
"""
dtype = dtype or dtypes.canonicalize_dtype(jnp.float_)
if isinstance(n, (list, tuple)):
raise ValueError(
Expand All @@ -774,7 +811,11 @@ def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype=None) -> Array:
else:
k = jnp.arange(0, (n - 1) // 2 + 1, dtype=dtype)

return k / jnp.array(d * n, dtype=dtype)
result = k / jnp.array(d * n, dtype=dtype)

if device is not None:
return result.to_device(device)
return result


@implements(np.fft.fftshift)
Expand Down
25 changes: 0 additions & 25 deletions jax/experimental/array_api/_fft_functions.py

This file was deleted.

7 changes: 2 additions & 5 deletions jax/experimental/array_api/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from jax.numpy.fft import (
fft as fft,
fftfreq as fftfreq,
fftn as fftn,
fftshift as fftshift,
hfft as hfft,
Expand All @@ -24,10 +25,6 @@
irfft as irfft,
irfftn as irfftn,
rfft as rfft,
rfftn as rfftn,
)

from jax.experimental.array_api._fft_functions import (
fftfreq as fftfreq,
rfftfreq as rfftfreq,
rfftn as rfftn,
)