Skip to content

Commit

Permalink
Merge pull request #409 from ubermag/fft_notebook
Browse files Browse the repository at this point in the history
fft notebook
  • Loading branch information
samjrholt authored Aug 22, 2023
2 parents aa43d3d + ca18d08 commit 86b4390
Show file tree
Hide file tree
Showing 4 changed files with 2,893 additions and 10 deletions.
5 changes: 3 additions & 2 deletions discretisedfield/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3618,6 +3618,7 @@ def fftn(self, **kwargs):
"""
mesh = self.mesh.fftn()

# Use scipy as faster than numpy
axes = range(self.mesh.region.ndim)
ft = spfft.fftshift(
spfft.fftn(self.array, axes=axes, **kwargs),
Expand Down Expand Up @@ -3754,7 +3755,7 @@ def rfftn(self, **kwargs):
axes = range(self.mesh.region.ndim)
ft = spfft.fftshift(
spfft.rfftn(self.array, axes=axes, **kwargs),
axes=axes,
axes=axes[:-1],
)

return self._fftn(mesh=mesh, array=ft, ifftn=False)
Expand Down Expand Up @@ -3818,7 +3819,7 @@ def irfftn(self, shape=None, **kwargs):

axes = range(self.mesh.region.ndim)
ft = spfft.irfftn(
spfft.ifftshift(self.array, axes=axes),
spfft.ifftshift(self.array, axes=axes[:-1]),
axes=axes,
s=shape,
**kwargs,
Expand Down
1 change: 1 addition & 0 deletions discretisedfield/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -2247,6 +2247,7 @@ def fftn(self, rfft=False):
freqs = spfft.fftfreq(self.n[i], self.cell[i])
# Shift the region boundaries to get the correct coordinates of
# mesh cells.
# This effectively does the same as using fftshift
dfreq = (freqs[1] - freqs[0]) / 2
p1.append(min(freqs) - dfreq)
p2.append(max(freqs) + dfreq)
Expand Down
26 changes: 18 additions & 8 deletions discretisedfield/tests/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import matplotlib.pyplot as plt
import numpy as np
import pytest
import scipy.fft as spfft
import xarray as xr

import discretisedfield as df
Expand Down Expand Up @@ -2896,17 +2897,11 @@ def _init_random(p):
f = df.Field(mesh, nvdim=3, value=_init_random, norm=1)

# 3d fft
assert f.allclose(f.fftn().ifftn().real)
assert df.Field(mesh, nvdim=3).allclose(f.fftn().ifftn().imag)

assert f.allclose(f.rfftn().irfftn())

# 2d fft
for i in ["x", "y", "z"]:
plane = f.sel(i)
assert plane.allclose(plane.fftn().ifftn().real)
assert df.Field(mesh, nvdim=3).sel(i).allclose(plane.fftn().ifftn().imag)

assert plane.allclose(plane.rfftn().irfftn())

# Fourier slice theoreme
Expand Down Expand Up @@ -2939,11 +2934,26 @@ def _init_random(p):

f = df.Field(mesh, nvdim=1, value=np.random.rand(*mesh.n, 1), norm=1)

assert f.allclose(f.fftn().ifftn().real)
assert df.Field(mesh, nvdim=1).allclose(f.fftn().ifftn().imag)
assert f.allclose(f.rfftn().irfftn(shape=f.mesh.n))

# test 1d rfft
assert f.allclose(f.rfftn().irfftn(shape=f.mesh.n))

# test rfft no shift last dim
a = np.zeros((5, 5))
a[2, 3] = 1

p1 = (0, 0)
p2 = (10, 10)
cell = (2.0, 2.0)
mesh = df.Mesh(p1=p1, p2=p2, cell=cell)
f = df.Field(mesh, nvdim=1, value=a)

field_ft = f.rfftn()
ft = spfft.fftshift(spfft.rfftn(a), axes=[0])

assert np.array_equal(field_ft.array[..., 0], ft)


def test_mpl_scalar(test_field):
# No axes
Expand Down
Loading

0 comments on commit 86b4390

Please sign in to comment.