Skip to content

Commit

Permalink
FIX: Fix spherical harmonics
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Jul 12, 2015
1 parent 13dd371 commit 15a4f17
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 39 deletions.
92 changes: 69 additions & 23 deletions mne/preprocessing/maxwell.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,56 @@ def maxwell_filter(raw, origin=(0, 0, 40), int_order=8, ext_order=3,
return raw_sss


def _sph_harm(order, degree, az, pol):
"""Evaluate point in specified multipolar moment. [1]_ Equation 4.
When using, pay close attention to inputs. Spherical harmonic notation for
order/degree, and theta/phi are both reversed in original SSS work compared
to many other sources. See mathworld.wolfram.com/SphericalHarmonic.html for
more discussion.
Note that scipy has ``scipy.special.sph_harm``, but that function is
too slow on old versions (< 0.15) and has a weird bug on newer versions.
At some point we should track it down and open a bug report...
Parameters
----------
order : int
Order of spherical harmonic. (Usually) corresponds to 'm'
degree : int
Degree of spherical harmonic. (Usually) corresponds to 'l'
az : float
Azimuthal (longitudinal) spherical coordinate [0, 2*pi]. 0 is aligned
with x-axis.
pol : float
Polar (or colatitudinal) spherical coordinate [0, pi]. 0 is aligned
with z-axis.
Returns
-------
base : complex float
The spherical harmonic value at the specified azimuth and polar angles
"""
from scipy.special import lpmv

# Error checks
if np.abs(order) > degree:
raise ValueError('Absolute value of expansion coefficient must be <= '
'degree')
# Ensure that polar and azimuth angles are arrays
az = np.asarray(az)
pol = np.asarray(pol)
if (az < -2 * np.pi).any() or (az > 2 * np.pi).any():
raise ValueError('Azimuth coords must lie in [-2*pi, 2*pi]')
if(pol < 0).any() or (pol > np.pi).any():
raise ValueError('Polar coords must lie in [0, pi]')

base = np.sqrt((2 * degree + 1) / (4 * np.pi) * factorial(degree - order) /
factorial(degree + order)) * \
lpmv(order, degree, np.cos(pol)) * np.exp(1j * order * az)
return base


def _sss_basis(origin, coils, int_order, ext_order):
"""Compute SSS basis for given conditions.
Expand Down Expand Up @@ -240,20 +290,18 @@ def _grad_in_components(degree, order, rad, az, pol):
Gradient of the spherical harmonic and vector specified in rectangular
coordinates
"""
from scipy.special import sph_harm
# Compute gradients for all spherical coordinates (Eq. 6)
g_rad = -(degree + 1) / rad ** (degree + 2) * sph_harm(order, degree,
az, pol)
g_rad = (-(degree + 1) / rad ** (degree + 2) *
_sph_harm(order, degree, az, pol))

g_az = 1 / (rad ** (degree + 2) * np.sin(pol)) * 1j * order * \
sph_harm(order, degree, az, pol)
g_az = (1 / (rad ** (degree + 2) * np.sin(pol)) * 1j * order *
_sph_harm(order, degree, az, pol))

g_pol = 1 / rad ** (degree + 2) * np.sqrt((2 * degree + 1) *
factorial(degree - order) /
(4 * np.pi *
factorial(degree + order))) * \
-np.sin(pol) * _alegendre_deriv(degree, order, np.cos(pol)) * \
np.exp(1j * order * az)
g_pol = (1 / rad ** (degree + 2) *
np.sqrt((2 * degree + 1) * factorial(degree - order) /
(4 * np.pi * factorial(degree + order))) *
-np.sin(pol) * _alegendre_deriv(degree, order, np.cos(pol)) *
np.exp(1j * order * az))

# Get real component of vectors, convert to cartesian coords, and return
real_grads = _get_real_grad(np.c_[g_rad, g_az, g_pol], order)
Expand Down Expand Up @@ -287,18 +335,16 @@ def _grad_out_components(degree, order, rad, az, pol):
coordinates
"""
# Compute gradients for all spherical coordinates (Eq. 7)
from scipy.special import sph_harm
g_rad = degree * rad ** (degree - 1) * sph_harm(order, degree, az, pol)
g_rad = degree * rad ** (degree - 1) * _sph_harm(order, degree, az, pol)

g_az = rad ** (degree - 1) / np.sin(pol) * 1j * order * \
sph_harm(order, degree, az, pol)
g_az = (rad ** (degree - 1) / np.sin(pol) * 1j * order *
_sph_harm(order, degree, az, pol))

g_pol = rad ** (degree - 1) * np.sqrt((2 * degree + 1) *
factorial(degree - order) /
(4 * np.pi *
factorial(degree + order))) * \
-np.sin(pol) * _alegendre_deriv(degree, order, np.cos(pol)) * \
np.exp(1j * order * az)
g_pol = (rad ** (degree - 1) *
np.sqrt((2 * degree + 1) * factorial(degree - order) /
(4 * np.pi * factorial(degree + order))) *
-np.sin(pol) * _alegendre_deriv(degree, order, np.cos(pol)) *
np.exp(1j * order * az))

# Get real component of vectors, convert to cartesian coords, and return
real_grads = _get_real_grad(np.c_[g_rad, g_az, g_pol], order)
Expand Down Expand Up @@ -432,7 +478,7 @@ def _make_coils(info, accurate=True, elekta_defs=False):
else:
accuracy = FIFF.FWD_COIL_ACCURACY_NORMAL
meg_info = None
megcoils = []
megcoils = list()

# MEG channels
picks = pick_types(info, meg=True, eeg=False, ref_meg=False,
Expand All @@ -448,7 +494,7 @@ def _make_coils(info, accurate=True, elekta_defs=False):

# Create coil descriptions with transformation to head or MRI frame
if elekta_defs:
elekta_coil_defs = op.join(op.split(__file__)[0], '..', 'data',
elekta_coil_defs = op.join(op.dirname(__file__), '..', 'data',
'coil_def_Elekta.dat')
templates = _read_coil_defs(elekta_coil_defs)

Expand Down
53 changes: 37 additions & 16 deletions mne/preprocessing/tests/test_maxwell.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,17 @@
from mne.preprocessing import maxwell
from mne.datasets import testing
from mne.io import Raw, proc_history
from mne.utils import slow_test, _TempDir
from mne.utils import _TempDir, run_tests_if_main

warnings.simplefilter('always') # Always throw warnings

data_path = op.join(testing.data_path(download=False))
raw_fname = op.join(data_path, 'SSS', 'test_move_anon_raw.fif')
sss_std_fname = op.join(data_path, 'SSS',
'test_move_anon_raw_simp_stdOrigin_sss.fif')
sss_nonstd_fname = op.join(data_path, 'SSS',
'test_move_anon_raw_simp_nonStdOrigin_sss.fif')


@testing.requires_testing_data
def test_maxwell_filter():
Expand All @@ -23,24 +31,14 @@ def test_maxwell_filter():
# TODO: Future tests integrate with mne/io/tests/test_proc_history

# Load testing data (raw, SSS std origin, SSS non-standard origin)
data_path = op.join(testing.data_path(download=False))

file_name = 'test_move_anon'

raw_fname = op.join(data_path, 'SSS', file_name + '_raw.fif')
sss_std_fname = op.join(data_path, 'SSS', file_name +
'_raw_simp_stdOrigin_sss.fif')
sss_nonStd_fname = op.join(data_path, 'SSS', file_name +
'_raw_simp_nonStdOrigin_sss.fif')

with warnings.catch_warnings(record=True): # maxshield
raw = Raw(raw_fname, preload=False, proj=False,
allow_maxshield=True).crop(0., 1., False)
raw.preload_data()
with warnings.catch_warnings(record=True): # maxshield, naming
sss_std = Raw(sss_std_fname, preload=True, proj=False,
allow_maxshield=True)
sss_nonStd = Raw(sss_nonStd_fname, preload=True, proj=False,
sss_nonStd = Raw(sss_nonstd_fname, preload=True, proj=False,
allow_maxshield=True)
raw_err = Raw(raw_fname, preload=False, proj=True,
allow_maxshield=True).crop(0., 0.1, False)
Expand Down Expand Up @@ -101,6 +99,29 @@ def test_maxwell_filter():
error_rms = np.sqrt(np.mean(error ** 2, axis=1))
assert_true(np.mean(bench_rms / error_rms) > 1000, 'SNR < 1000')

# Check against SSS functions from proc_history
sss_info = raw_sss.info['proc_history'][0]['max_info']
assert_equal(maxwell.get_num_moments(int_order, 0),
proc_history._get_sss_rank(sss_info))


@testing.requires_testing_data
def test_maxwell_filter_additional():
"""Test processing of Maxwell filtered data"""

# TODO: Future tests integrate with mne/io/tests/test_proc_history

# Load testing data (raw, SSS std origin, SSS non-standard origin)
data_path = op.join(testing.data_path(download=False))

file_name = 'test_move_anon'

raw_fname = op.join(data_path, 'SSS', file_name + '_raw.fif')
with warnings.catch_warnings(record=True): # maxshield
raw = Raw(raw_fname, preload=False, proj=False,
allow_maxshield=True).crop(0., 1., False)
raw_sss = maxwell.maxwell_filter(raw)

# Test io on processed data
tempdir = _TempDir()
test_outname = op.join(tempdir, 'test_raw_sss.fif')
Expand All @@ -111,7 +132,7 @@ def test_maxwell_filter():
assert_allclose(raw_sss_loaded._data[:, :], raw_sss._data[:, :],
rtol=1e-6, atol=1e-20)

# Check against SSS functions from proc_history
sss_info = raw_sss.info['proc_history'][0]['max_info']
assert_equal(maxwell.get_num_moments(int_order, 0),
proc_history._get_sss_rank(sss_info))
# Test covariance calculation XXX add this


run_tests_if_main()

0 comments on commit 15a4f17

Please sign in to comment.