diff --git a/jax/_src/numpy/fft.py b/jax/_src/numpy/fft.py index 70581aa69b03..3a0d95a4de97 100644 --- a/jax/_src/numpy/fft.py +++ b/jax/_src/numpy/fft.py @@ -750,13 +750,13 @@ def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None, "The d argument of jax.numpy.fft.fftfreq only takes a single value. " "Got d = %s." % list(d)) - k = jnp.zeros(n, dtype=dtype) + k = jnp.zeros(n, dtype=dtype, device=device) if n % 2 == 0: # k[0: n // 2 - 1] = jnp.arange(0, n // 2 - 1) - k = k.at[0: n // 2].set( jnp.arange(0, n // 2, dtype=dtype)) + k = k.at[0: n // 2].set(jnp.arange(0, n // 2, dtype=dtype)) # k[n // 2:] = jnp.arange(-n // 2, -1) - k = k.at[n // 2:].set( jnp.arange(-n // 2, 0, dtype=dtype)) + k = k.at[n // 2:].set(jnp.arange(-n // 2, 0, dtype=dtype)) else: # k[0: (n - 1) // 2] = jnp.arange(0, (n - 1) // 2) @@ -765,11 +765,7 @@ def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None, # k[(n - 1) // 2 + 1:] = jnp.arange(-(n - 1) // 2, -1) k = k.at[(n - 1) // 2 + 1:].set(jnp.arange(-(n - 1) // 2, 0, dtype=dtype)) - result = k / jnp.array(d * n, dtype=dtype) - - if device is not None: - return result.to_device(device) - return result + return k / jnp.array(d * n, dtype=dtype, device=device) def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None, diff --git a/tests/fft_test.py b/tests/fft_test.py index ce7455fdb4f8..05fa96a93fae 100644 --- a/tests/fft_test.py +++ b/tests/fft_test.py @@ -346,13 +346,16 @@ def testFft2Errors(self, inverse, real): dtype=all_dtypes, size=[9, 10, 101, 102], d=[0.1, 2.], + device=[None, -1], ) - def testFftfreq(self, size, d, dtype): + def testFftfreq(self, size, d, dtype, device): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng([size], dtype),) jnp_op = jnp.fft.fftfreq np_op = np.fft.fftfreq - jnp_fn = lambda a: jnp_op(size, d=d) + if device is not None: + device = jax.devices()[device] + jnp_fn = lambda a: jnp_op(size, d=d, device=device) np_fn = lambda a: np_op(size, d=d) # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, @@ -362,6 +365,10 @@ def testFftfreq(self, size, d, dtype): if dtype in inexact_dtypes: tol = 0.15 # TODO(skye): can we be more precise? jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol) + # Test device + if device is not None: + out = jnp_fn(args_maker()) + self.assertEqual(out.devices(), {device}) @jtu.sample_product(n=[[0, 1, 2]]) def testFftfreqErrors(self, n): @@ -384,13 +391,16 @@ def testFftfreqErrors(self, n): dtype=all_dtypes, size=[9, 10, 101, 102], d=[0.1, 2.], + device=[None, -1], ) - def testRfftfreq(self, size, d, dtype): + def testRfftfreq(self, size, d, dtype, device): rng = jtu.rand_default(self.rng()) args_maker = lambda: (rng([size], dtype),) jnp_op = jnp.fft.rfftfreq np_op = np.fft.rfftfreq - jnp_fn = lambda a: jnp_op(size, d=d) + if device is not None: + device = jax.devices()[device] + jnp_fn = lambda a: jnp_op(size, d=d, device=device) np_fn = lambda a: np_op(size, d=d) # Numpy promotes to complex128 aggressively. self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=False, @@ -400,6 +410,10 @@ def testRfftfreq(self, size, d, dtype): if dtype in inexact_dtypes: tol = 0.15 # TODO(skye): can we be more precise? jtu.check_grads(jnp_fn, args_maker(), order=2, atol=tol, rtol=tol) + # Test device + if device is not None: + out = jnp_fn(args_maker()) + self.assertEqual(out.devices(), {device}) @jtu.sample_product(n=[[0, 1, 2]]) def testRfftfreqErrors(self, n):