Skip to content

Commit

Permalink
Follow-up to jax-ml#22736
Browse files Browse the repository at this point in the history
On adding  device kwarg to jnp.fft.fftfreq and jnp.fft.rfftfreq
  • Loading branch information
vfdev-5 committed Jul 30, 2024
1 parent cc21245 commit bb1fb3b
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 12 deletions.
12 changes: 4 additions & 8 deletions jax/_src/numpy/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,13 +750,13 @@ def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None,
"The d argument of jax.numpy.fft.fftfreq only takes a single value. "
"Got d = %s." % list(d))

k = jnp.zeros(n, dtype=dtype)
k = jnp.zeros(n, dtype=dtype, device=device)
if n % 2 == 0:
# k[0: n // 2 - 1] = jnp.arange(0, n // 2 - 1)
k = k.at[0: n // 2].set( jnp.arange(0, n // 2, dtype=dtype))
k = k.at[0: n // 2].set(jnp.arange(0, n // 2, dtype=dtype))

# k[n // 2:] = jnp.arange(-n // 2, -1)
k = k.at[n // 2:].set( jnp.arange(-n // 2, 0, dtype=dtype))
k = k.at[n // 2:].set(jnp.arange(-n // 2, 0, dtype=dtype))

else:
# k[0: (n - 1) // 2] = jnp.arange(0, (n - 1) // 2)
Expand All @@ -765,11 +765,7 @@ def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None,
# k[(n - 1) // 2 + 1:] = jnp.arange(-(n - 1) // 2, -1)
k = k.at[(n - 1) // 2 + 1:].set(jnp.arange(-(n - 1) // 2, 0, dtype=dtype))

result = k / jnp.array(d * n, dtype=dtype)

if device is not None:
return result.to_device(device)
return result
return k / jnp.array(d * n, dtype=dtype, device=device)


def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None,
Expand Down
22 changes: 18 additions & 4 deletions tests/fft_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,13 +346,16 @@ def testFft2Errors(self, inverse, real):
dtype=all_dtypes,
size=[9, 10, 101, 102],
d=[0.1, 2.],
device=[None, -1],
)
def testFftfreq(self, size, d, dtype):
def testFftfreq(self, size, d, dtype, device):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng([size], dtype),)
jnp_op = jnp.fft.fftfreq
np_op = np.fft.fftfreq
jnp_fn = lambda a: jnp_op(size, d=d)
if device is not None:
device = jax.devices()[device]
jnp_fn = lambda a: jnp_op(size, d=d, device=device)
np_fn = lambda a: np_op(size, d=d)
# Numpy promotes to complex128 aggressively.
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
Expand All @@ -362,6 +365,10 @@ def testFftfreq(self, size, d, dtype):
if dtype in inexact_dtypes:
tol = 0.15 # TODO(skye): can we be more precise?
jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)
# Test device
if device is not None:
out = jnp_fn(args_maker())
self.assertEqual(out.devices(), {device})

@jtu.sample_product(n=[[0, 1, 2]])
def testFftfreqErrors(self, n):
Expand All @@ -384,13 +391,16 @@ def testFftfreqErrors(self, n):
dtype=all_dtypes,
size=[9, 10, 101, 102],
d=[0.1, 2.],
device=[None, -1],
)
def testRfftfreq(self, size, d, dtype):
def testRfftfreq(self, size, d, dtype, device):
rng = jtu.rand_default(self.rng())
args_maker = lambda: (rng([size], dtype),)
jnp_op = jnp.fft.rfftfreq
np_op = np.fft.rfftfreq
jnp_fn = lambda a: jnp_op(size, d=d)
if device is not None:
device = jax.devices()[device]
jnp_fn = lambda a: jnp_op(size, d=d, device=device)
np_fn = lambda a: np_op(size, d=d)
# Numpy promotes to complex128 aggressively.
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False,
Expand All @@ -400,6 +410,10 @@ def testRfftfreq(self, size, d, dtype):
if dtype in inexact_dtypes:
tol = 0.15 # TODO(skye): can we be more precise?
jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol)
# Test device
if device is not None:
out = jnp_fn(args_maker())
self.assertEqual(out.devices(), {device})

@jtu.sample_product(n=[[0, 1, 2]])
def testRfftfreqErrors(self, n):
Expand Down

0 comments on commit bb1fb3b

Please sign in to comment.