Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plotting wrappers: Head Trajectory #394

Merged
merged 37 commits into from
Feb 12, 2025
Merged
Changes from 1 commit
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
444e3bf
add plot module with trajectory function
stellaprins Jan 24, 2025
e8ba011
Merge branch 'main' into sp/282-plot-wrappers
stellaprins Jan 27, 2025
e0369b6
update example
stellaprins Jan 28, 2025
dfe6be0
add plot kwargs and defaults
stellaprins Jan 29, 2025
4b67a7c
Merge branch 'main' into sp/282-plot-wrappers
stellaprins Jan 29, 2025
a865a3f
add tests for trajectory plotting
stellaprins Jan 30, 2025
7c3922f
parametrize test_trajectory
stellaprins Jan 31, 2025
4ceecc9
add docstring to trajectory, remove test code
stellaprins Jan 31, 2025
15dccd0
improve code coverage, make colorbar alpha resistant
stellaprins Jan 31, 2025
c290579
Add Niko's plot trajectory suggestion, fix tests
stellaprins Feb 4, 2025
03b7a65
update example
stellaprins Feb 4, 2025
cc506f9
Merge branch 'main' into sp/282-plot-wrappers
stellaprins Feb 4, 2025
02e1146
adjust example, make lines causing '\d' SyntaxWarning raw
stellaprins Feb 4, 2025
0087615
Merge branch 'sp/282-plot-wrappers' of https://github.com/neuroinform…
stellaprins Feb 4, 2025
8f282aa
change trajectory inputs, replace keypoint and individual with select…
stellaprins Feb 5, 2025
2a4a399
change trajectory inputs, replace keypoint and individual with select…
stellaprins Feb 5, 2025
e5a006a
Merge branch 'sp/282-plot-wrappers' of https://github.com/neuroinform…
stellaprins Feb 5, 2025
e906722
use selection dict for individuals and keypoints
stellaprins Feb 5, 2025
29ef1dd
update trajectory examples, allow user to set marker colour
stellaprins Feb 5, 2025
69a0a6e
fix examples, fix colorbar
stellaprins Feb 6, 2025
7701aa3
improve code coverage
stellaprins Feb 6, 2025
deca284
remove image_path from plot input, adjust example
stellaprins Feb 6, 2025
15d13ac
fix test after removing image_path input
stellaprins Feb 7, 2025
5f40c30
add more tests
stellaprins Feb 7, 2025
bf03285
test trajectory without individuals and/or keypoints dimension
stellaprins Feb 7, 2025
7501c69
fix logic test_trajectory_dropped_dim
stellaprins Feb 7, 2025
c63213f
add tests, update examples, change selection to individuals and keypo…
stellaprins Feb 7, 2025
5f9e49a
change trajectory input in examples as well
stellaprins Feb 7, 2025
2130e6e
fix input load and explore poses example
stellaprins Feb 8, 2025
d641a3f
change folder structure
stellaprins Feb 10, 2025
10b8b34
add init file
stellaprins Feb 10, 2025
be4ae20
fix test
stellaprins Feb 10, 2025
e22108f
deal with drop deprecation (-> drop_vars)
stellaprins Feb 10, 2025
16878f3
process review suggestions, adjust drop dimension test
stellaprins Feb 11, 2025
2d4998d
reorder plot_trajectory inputs
stellaprins Feb 12, 2025
d10f1f4
fix typo example
stellaprins Feb 12, 2025
cd2bc99
Merge branch 'main' into sp/282-plot-wrappers
stellaprins Feb 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add tests for trajectory plotting
stellaprins committed Jan 30, 2025

Verified

This commit was signed with the committer’s verified signature.
sagikazarmark Márk Sági-Kazár
commit a865a3fdb73f44338ed03a5ab560394841dd8d3f
6 changes: 5 additions & 1 deletion movement/plot.py
Original file line number Diff line number Diff line change
@@ -66,7 +66,11 @@ def trajectory(
ax.set_title(title)
else:
ax.set_title(f"{individual} trajectory of {plotting_point_name}")
fig.colorbar(sc, ax=ax, label=f"time ({ds.attrs['time_unit']})")

if ds.attrs.get("time_unit") is not None:
fig.colorbar(sc, ax=ax, label=f"time ({ds.attrs['time_unit']})")
else:
fig.colorbar(sc, ax=ax, label="time steps (frames)")

if frame_path is not None:
frame = plt.imread(frame_path)
112 changes: 112 additions & 0 deletions tests/test_unit/test_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import numpy as np
import pytest
import xarray as xr
from matplotlib import pyplot as plt

from movement.plot import trajectory


@pytest.fixture
def sample_data():
"""Sample data for plot testing.

Data has three keypoints (left, centre, right) for one
individual that moves in a straight line along the y-axis with a
constant x-coordinate.

"""
time_steps = 4
individuals = ["individual_0"]
keypoints = ["left", "centre", "right"]
space = ["x", "y"]
positions = {
"left": {"x": -1, "y": np.arange(time_steps)},
"centre": {"x": 0, "y": np.arange(time_steps)},
"right": {"x": 1, "y": np.arange(time_steps)},
}

time = np.arange(time_steps)
position_data = np.zeros(
(time_steps, len(space), len(keypoints), len(individuals))
)

# Create x and y coordinates arrays
x_coords = np.array([positions[key]["x"] for key in keypoints])
y_coords = np.array([positions[key]["y"] for key in keypoints])

for i, _ in enumerate(keypoints):
position_data[:, 0, i, 0] = x_coords[i] # x-coordinates
position_data[:, 1, i, 0] = y_coords[i] # y-coordinates

confidence_data = np.full(
(time_steps, len(keypoints), len(individuals)), 0.90
)

ds = xr.Dataset(
{
"position": (
["time", "space", "keypoints", "individuals"],
position_data,
),
"confidence": (
["time", "keypoints", "individuals"],
confidence_data,
),
},
coords={
"time": time,
"space": space,
"keypoints": keypoints,
"individuals": individuals,
},
)
return ds


def test_trajectory(sample_data):
"""Test midpoint between left and right keypoints."""
plt.switch_backend("Agg") # to avoid pop-up window
fig_centre = trajectory(sample_data, keypoint="centre")
fig_left_right_midpoint = trajectory(
sample_data, keypoint=["left", "right"]
)

expected_data = np.array([[0, 0], [0, 1], [0, 2], [0, 3]])

# Retrieve data points from figures
ax_centre = fig_centre.axes[0]
centre_data = ax_centre.collections[0].get_offsets().data

ax_left_right = fig_left_right_midpoint.axes[0]
left_right_data = ax_left_right.collections[0].get_offsets().data

np.testing.assert_array_almost_equal(centre_data, left_right_data)
np.testing.assert_array_almost_equal(centre_data, expected_data)
np.testing.assert_array_almost_equal(left_right_data, expected_data)


def test_trajectory_with_frame(sample_data, tmp_path):
"""Test plot trajectory with frame."""
frame_path = tmp_path / "frame.png"
fig, ax = plt.subplots()
ax.imshow(np.zeros((10, 10)))
fig.savefig(frame_path)

fig_centre = trajectory(
sample_data, keypoint="centre", frame_path=frame_path
)
fig_left_right_midpoint = trajectory(
sample_data, keypoint=["left", "right"], frame_path=frame_path
)

# Retrieve data points from figures
ax_centre = fig_centre.axes[0]
centre_data = ax_centre.collections[0].get_offsets().data

ax_left_right = fig_left_right_midpoint.axes[0]
left_right_data = ax_left_right.collections[0].get_offsets().data

expected_data = np.array([[0, 0], [0, 1], [0, 2], [0, 3]])
np.testing.assert_array_almost_equal(centre_data, left_right_data)
np.testing.assert_array_almost_equal(centre_data, expected_data)
np.testing.assert_array_almost_equal(left_right_data, expected_data)