Skip to content

Commit

Permalink
BUG: Fix bug with get_view
Browse files Browse the repository at this point in the history
  • Loading branch information
larsoner committed Sep 18, 2023
1 parent 966bdcc commit 1275b7f
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 64 deletions.
2 changes: 1 addition & 1 deletion mne/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ def _cart_to_sph(cart):
Array containing points in spherical coordinates (rad, azimuth, polar)
"""
cart = np.atleast_2d(cart)
assert cart.ndim == 2 and cart.shape[1] == 3
assert cart.ndim == 2 and cart.shape[1] == 3, cart.shape
out = np.empty((len(cart), 3))
out[:, 0] = np.sqrt(np.sum(cart * cart, axis=1))
norm = np.where(out[:, 0] > 0, out[:, 0], 1) # protect against / 0
Expand Down
12 changes: 7 additions & 5 deletions mne/viz/_brain/_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ def _setup_canonical_rotation(self):
xfm = _estimate_talxfm_rigid(self._subject, self._subjects_dir)
except Exception:
logger.info(
"Could not estimate rigid Talairach alignment, " "using identity matrix"
"Could not estimate rigid Talairach alignment, using identity matrix"
)
else:
self._rigid[:] = xfm
Expand Down Expand Up @@ -3223,7 +3223,7 @@ def show(self):
_qt_app_exec(self._renderer.figure.store["app"])

@fill_doc
def get_view(self, row=0, col=0):
def get_view(self, row=0, col=0, *, align=True):
"""Get the camera orientation for a given subplot display.
Parameters
Expand All @@ -3232,6 +3232,7 @@ def get_view(self, row=0, col=0):
The row to use, default is the first one.
col : int
The column to check, the default is the first one.
%(align_view)s
Returns
-------
Expand All @@ -3243,10 +3244,11 @@ def get_view(self, row=0, col=0):
"""
row = _ensure_int(row, "row")
col = _ensure_int(col, "col")
rigid = self._rigid if align else None
for h in self._hemis:
for ri, ci, _ in self._iter_views(h):
if (row == ri) and (col == ci):
return self._renderer.get_camera()
return self._renderer.get_camera(rigid=rigid)
return (None,) * 5

@verbose
Expand Down Expand Up @@ -3352,14 +3354,14 @@ def show_view(
param: val for param, val in view_params.items() if val is not None
} # no overwriting with None
view_params = dict(views_dicts[hemi].get(view), **view_params)
xfm = self._rigid if align else None
rigid = self._rigid if align else None
for h in self._hemis:
for ri, ci, _ in self._iter_views(h):
if (row is None or row == ri) and (col is None or col == ci):
self._renderer.set_camera(
**view_params,
reset_camera=False,
rigid=xfm,
rigid=rigid,
update=False,
)
if update:
Expand Down
72 changes: 55 additions & 17 deletions mne/viz/_brain/tests/test_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
)
from mne.channels import make_dig_montage
from mne.minimum_norm import apply_inverse, make_inverse_operator
from mne.source_estimate import _BaseSourceEstimate
from mne.source_space import read_source_spaces, setup_volume_source_space
from mne.datasets import testing
from mne.io import read_info
Expand Down Expand Up @@ -193,7 +194,6 @@ def test_brain_routines(renderer, brain_gc):
@testing.requires_testing_data
def test_brain_init(renderer_pyvistaqt, tmp_path, pixel_ratio, brain_gc):
"""Test initialization of the Brain instance."""
from mne.source_estimate import _BaseSourceEstimate

class FakeSTC(_BaseSourceEstimate):
def __init__(self):
Expand Down Expand Up @@ -546,28 +546,67 @@ def __init__(self):
)
for a, b, p, color in zip(annots, borders, alphas, colors):
brain.add_annotation(str(a), b, p, color=color)
brain.close()


view_args = dict(roll=1, distance=500, focalpoint=(1e-5, 1e-5, 1e-5))
def _assert_view_allclose(brain, roll, distance, azimuth, elevation, focalpoint, align):
__tracebackhide__ = True
r_, d_, a_, e_, f_ = brain.get_view(align=align)
azimuth, a_ = azimuth % 360, a_ % 360
assert_allclose(r_, roll, err_msg="Roll")
assert_allclose(d_, distance, rtol=0.01, err_msg="Distance")
assert_allclose(a_, azimuth, rtol=0.01, atol=1e-6, err_msg="Azimuth")
assert_allclose(e_, elevation, rtol=1e-5, atol=1e-6, err_msg="Elevation")
assert_allclose(f_, focalpoint, err_msg="Focal point")
cam = brain._renderer.figure.plotter.camera
previous_roll = cam.GetRoll()
assert_allclose(cam.GetFocalPoint(), focalpoint, err_msg="Camera focal point")
assert_allclose(cam.GetDistance(), distance, rtol=1e-4, err_msg="Camera distance")
assert_allclose(cam.GetRoll(), roll, atol=1e-5, err_msg="Camera roll")


@pytest.mark.parametrize("align", (True, False))
def test_view_round_trip(renderer_interactive_pyvistaqt, tmp_path, brain_gc, align):
"""Test get_view / set_view round-trip."""
brain = _create_testing_brain(hemi="lh")
img = brain.screenshot()
roll, distance, azimuth, elevation, focalpoint = brain.get_view(align=align)
brain.show_view(
azimuth=azimuth,
elevation=elevation,
focalpoint=focalpoint,
roll=roll,
distance=distance,
align=align,
)
img_1 = brain.screenshot()
assert_allclose(img, img_1)
_assert_view_allclose(brain, roll, distance, azimuth, elevation, focalpoint, align)

# Now with custom values
roll, distance, focalpoint = 1, 500, (1e-5, 1e-5, 1e-5)
view_args = dict(roll=roll, distance=distance, focalpoint=focalpoint, align=align)
brain.show_view(**view_args)
assert_allclose(cam.GetFocalPoint(), view_args["focalpoint"])
assert_allclose(cam.GetDistance(), view_args["distance"])
assert_allclose(cam.GetRoll(), previous_roll + view_args["roll"])
_assert_view_allclose(brain, roll, distance, azimuth, elevation, focalpoint, align)

# test get_view
azimuth, elevation = 180.0, 90.0
view_args.update(azimuth=azimuth, elevation=elevation)
brain.show_view(**view_args)
roll, distance, azimuth, elevation, focalpoint = brain.get_view()
assert_allclose(cam.GetRoll(), roll)
assert_allclose(cam.GetDistance(), distance)
assert_allclose(view_args["azimuth"] % 360, azimuth % 360)
assert_allclose(view_args["elevation"] % 180, elevation % 180)
assert_allclose(view_args["focalpoint"], focalpoint)
del view_args

# image and screenshot
_assert_view_allclose(brain, roll, distance, azimuth, elevation, focalpoint, align)
brain.close()


def test_image_screenshot(
renderer_interactive_pyvistaqt,
tmp_path,
pixel_ratio,
brain_gc,
):
"""Test screenshot and image saving."""
size = (300, 300)
brain = _create_testing_brain(hemi="lh", size=size)
cam = brain._renderer.figure.plotter.camera
azimuth, elevation = 180.0, 90.0
fname = tmp_path / "test.png"
assert not fname.is_file()
brain.save_image(fname)
Expand All @@ -581,8 +620,7 @@ def __init__(self):
brain.show_view(**view_args)
assert_allclose(brain._renderer.figure._azimuth % 360, azimuth % 360)
assert_allclose(brain._renderer.figure._elevation % 180, elevation % 180)
assert_allclose(cam.GetFocalPoint(), fp)
del view_args
assert_allclose(cam.GetFocalPoint(), fp, atol=1e-6)
img = brain.screenshot(mode="rgba")
want_size = np.array([size[0] * pixel_ratio, size[1] * pixel_ratio, 4])
# on macOS sometimes matplotlib is HiDPI and VTK is not...
Expand Down
70 changes: 29 additions & 41 deletions mne/viz/backends/_pyvista.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
_init_mne_qtapp,
)
from ...fixes import _compare_version
from ...transforms import apply_trans
from ...transforms import apply_trans, _cart_to_sph, _sph_to_cart
from ...utils import (
copy_base_doc_to_subclass_doc,
_check_option,
Expand Down Expand Up @@ -820,8 +820,8 @@ def show(self):
def close(self):
_close_3d_figure(figure=self.figure)

def get_camera(self):
return _get_3d_view(self.figure)
def get_camera(self, *, rigid=None):
return _get_3d_view(self.figure, rigid=rigid)

def set_camera(
self,
Expand Down Expand Up @@ -1110,14 +1110,6 @@ def _hide_testing_actor(actor):
actor.SetVisibility(False)


def _deg2rad(deg):
return deg * np.pi / 180.0


def _rad2deg(rad):
return rad * 180.0 / np.pi


def _to_pos(azimuth, elevation):
theta = azimuth * np.pi / 180.0
phi = (90.0 - elevation) * np.pi / 180.0
Expand Down Expand Up @@ -1154,18 +1146,17 @@ def _close_all():


def _get_camera_direction(focalpoint, position):
x, y, z = position - focalpoint
r = np.sqrt(x * x + y * y + z * z)
theta = np.arccos(z / r)
phi = np.arctan2(y, x)
return r, theta, phi


def _get_3d_view(figure):
position = np.array(figure.plotter.camera_position[0])
focalpoint = np.array(figure.plotter.camera_position[1])
_, theta, phi = _get_camera_direction(focalpoint, position)
azimuth, elevation = _rad2deg(phi), _rad2deg(theta)
return tuple(_cart_to_sph(position - focalpoint)[0])


def _get_3d_view(figure, *, rigid=None):
position = np.array(figure.plotter.camera.position, float)
focalpoint = np.array(figure.plotter.camera.focal_point, float)
rigid = np.eye(4) if rigid is None else np.linalg.inv(rigid)
position = apply_trans(rigid, position)
focalpoint = apply_trans(rigid, focalpoint)
_, phi, theta = _get_camera_direction(focalpoint, position)
azimuth, elevation = np.rad2deg(phi), np.rad2deg(theta)
return (
figure.plotter.camera.GetRoll(),
figure.plotter.camera.GetDistance(),
Expand All @@ -1186,8 +1177,9 @@ def _set_3d_view(
rigid=None,
update=True,
):
camera = figure.plotter.camera
rigid = np.eye(4) if rigid is None else rigid
position = np.array(figure.plotter.camera_position[0])
position = np.array(camera.position)
bounds = np.array(figure.plotter.renderer.ComputeVisiblePropBounds())
if reset_camera:
figure.plotter.reset_camera(render=False)
Expand All @@ -1199,19 +1191,18 @@ def _set_3d_view(
_check_option("focalpoint", focalpoint, ("auto",), extra="when a string")
focalpoint = (bounds[1::2] + bounds[::2]) * 0.5
elif focalpoint is None:
focalpoint = np.array(figure.plotter.camera_position[1])
else:
focalpoint = np.asarray(focalpoint)
focalpoint = camera.focal_point
focalpoint = np.array(focalpoint, float)

# work in the transformed space
position = apply_trans(rigid, position)
focalpoint = apply_trans(rigid, focalpoint)
_, theta, phi = _get_camera_direction(focalpoint, position)
_, phi, theta = _get_camera_direction(focalpoint, position)

if azimuth is not None:
phi = _deg2rad(azimuth)
phi = np.deg2rad(azimuth)
if elevation is not None:
theta = _deg2rad(elevation)
theta = np.deg2rad(elevation)

# set the distance
if distance is None:
Expand All @@ -1225,25 +1216,22 @@ def _set_3d_view(
else:
view_up = [0, 1, 0]

position = [
distance * np.cos(phi) * np.sin(theta),
distance * np.sin(phi) * np.sin(theta),
distance * np.cos(theta),
]
position = _sph_to_cart([distance, phi, theta])[0]

figure._azimuth = _rad2deg(phi)
figure._elevation = _rad2deg(theta)
# TODO: We should remove this hack and compute from the camera
figure._azimuth = np.rad2deg(phi)
figure._elevation = np.rad2deg(theta)

# restore to the original frame
rigid = np.linalg.inv(rigid)
position = apply_trans(rigid, position)
focalpoint = apply_trans(rigid, focalpoint)
view_up = apply_trans(rigid, view_up, move=False)
figure.plotter.camera_position = [position, focalpoint, view_up]
# We need to add the requested roll to the roll dictated by the
# transformed view_up
camera.position = position
camera.focal_point = focalpoint
camera.view_up = view_up
if roll is not None:
figure.plotter.camera.SetRoll(figure.plotter.camera.GetRoll() + roll)
camera.SetRoll(roll)

if update:
figure.plotter.update()
Expand Down

0 comments on commit 1275b7f

Please sign in to comment.