Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfixes and diagnostic plotting for model results with just a single syllable #24

Merged
merged 3 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions keypoint_moseq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
plot_pcs,
plot_scree,
plot_progress,
plot_syllable_frequencies,
plot_duration_distribution,
generate_crowd_movies,
generate_grid_movies,
generate_trajectory_plots,
Expand Down
202 changes: 187 additions & 15 deletions keypoint_moseq/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,12 +202,152 @@ def plot_pcs(pca, *, use_bodyparts, skeleton, keypoint_colormap='autumn',
plt.savefig(os.path.join(project_dir,f'pcs-{name}.pdf'))
plt.show()


def plot_syllable_frequencies(results=None, path=None, project_dir=None,
name=None, use_reindexed=True, minlength=10,
min_frequency=0.005):
"""
Plot a histogram showing the frequency of each syllable.

Caller must provide a results dictionary, a path to a results .h5,
or a project directory and model name, in which case the results are
loaded from ``{project_dir}/{name}/results.h5``.

Parameters
----------
results : dict, default=None
Dictionary containing modeling results for a dataset (see
:py:func:`keypoint_moseq.fitting.apply_model`)

name: str, default=None
Name of the model. Required to load results if ``results`` is
None and ``path`` is None.

project_dir: str, default=None
Project directory. Required to load results if ``results`` is
None and ``path`` is None.

path: str, default=None
Path to a results file. If None, results will be loaded from
``{project_dir}/{name}/results.h5``.

use_reindexed: bool, default=True
Whether to use label syllables by their frequency rank (True) or
or their original label (False). When reindexing, "0" represents
the most frequent syllable).

minlength: int, default=10
Minimum x-axis length of the histogram.

min_frequency: float, default=0.005
Minimum frequency of syllables to include in the histogram.

Returns
-------
fig : matplotlib.figure.Figure
Figure containing the histogram.

ax : matplotlib.axes.Axes
Axes containing the histogram.
"""
if results is None:
results = load_results(path=path, name=name, project_dir=project_dir)

syllable_key = 'syllables' if not use_reindexed else 'syllables_reindexed'
syllables = {k:res[syllable_key] for k,res in results.items()}
frequencies = get_frequencies(syllables)
frequencies = frequencies[frequencies>min_frequency]
xmax = max(minlength, len(frequencies))

fig, ax = plt.subplots()
ax.bar(range(len(frequencies)),frequencies,width=1)
ax.set_ylabel('probability')
ax.set_xlabel('syllable rank')
ax.set_xlim(-1,xmax+1)
ax.set_title('Frequency distribution')
ax.set_yticks([])
return fig, ax


def plot_duration_distribution(results=None, path=None, project_dir=None,
name=None, use_reindexed=True, lim=None,
num_bins=30, fps=None):
"""
Plot a histogram showing the frequency of each syllable.

Caller must provide a results dictionary, a path to a results .h5,
or a project directory and model name, in which case the results are
loaded from ``{project_dir}/{name}/results.h5``.

Parameters
----------
results : dict, default=None
Dictionary containing modeling results for a dataset (see
:py:func:`keypoint_moseq.fitting.apply_model`)

name: str, default=None
Name of the model. Required to load results if ``results`` is
None and ``path`` is None.

project_dir: str, default=None
Project directory. Required to load results if ``results`` is
None and ``path`` is None.

path: str, default=None
Path to a results file. If None, results will be loaded from
``{project_dir}/{name}/results.h5``.

lim: tuple, default=None
x-axis limits as a pair of ints (in units of frames). If None,
the limits are set to (0, 95th-percentile).

num_bins: int, default=30
Number of bins in the histogram.

fps: int, default=None
Frames per second. Used to convert x-axis from frames to seconds.

Returns
-------
fig : matplotlib.figure.Figure
Figure containing the histogram.

ax : matplotlib.axes.Axes
Axes containing the histogram.
"""
if results is None:
results = load_results(path=path, name=name, project_dir=project_dir)

syllable_key = 'syllables' if not use_reindexed else 'syllables_reindexed'
syllables = {k:res[syllable_key] for k,res in results.items()}
durations = get_durations(syllables)

if lim is None:
lim = int(np.percentile(durations, 95))
binsize = max(int(np.floor(lim/num_bins)),1)

if fps is not None:
durations = durations/fps
binsize = binsize/fps
lim = lim/fps
xlabel = 'syllable duration (s)'
else:
xlabel = 'syllable duration (frames)'

fig, ax = plt.subplots()
ax.hist(durations, range=(0,lim), bins=(int(lim/binsize)), density=True)
ax.set_xlim([0,lim])
ax.set_xlabel(xlabel)
ax.set_ylabel('probability')
ax.set_title('Duration distribution')
ax.set_yticks([])
return fig, ax


def plot_progress(model, data, history, iteration, path=None,
project_dir=None, name=None, savefig=True,
fig_size=None, seq_length=600, min_frequency=.001,
**kwargs):
min_histogram_length=10, **kwargs):
"""
Plot the progress of the model during fitting.

Expand Down Expand Up @@ -253,9 +393,20 @@ def plot_progress(model, data, history, iteration, path=None,
Minimum frequency for including a state in the frequency
distribution plot.

min_histogram_length : int, default=10
Minimum x-axis length of the frequency distribution plot.

project_dir : str, default=None
name : str, default=None
path : str, default=None

Returns
-------
fig : matplotlib.figure.Figure
Figure containing the plots.

axs : list of matplotlib.axes.Axes
Axes containing the plots.
"""
z = np.array(model['states']['z'])
mask = np.array(data['mask'])
Expand All @@ -275,15 +426,16 @@ def plot_progress(model, data, history, iteration, path=None,
if fig_size is None: fig_size=(4,2.5)

frequencies = np.sort(frequencies[frequencies>min_frequency])[::-1]
xmax = max(len(frequencies),min_histogram_length)
axs[0].bar(range(len(frequencies)),frequencies,width=1)
axs[0].set_ylabel('probability')
axs[0].set_xlabel('syllable rank')
axs[0].set_xlim([-1,xmax+1])
axs[0].set_title('Frequency distribution')
axs[0].set_yticks([])

lim = int(np.percentile(durations, 95))
binsize = max(int(np.floor(lim/30)),1)
lim = lim-(lim%binsize)
axs[1].hist(durations, range=(1,lim), bins=(int(lim/binsize)), density=True)
axs[1].set_xlim([1,lim])
axs[1].set_xlabel('syllable duration (frames)')
Expand Down Expand Up @@ -328,6 +480,8 @@ def plot_progress(model, data, history, iteration, path=None,
plt.savefig(path)
plt.show()

return fig,axs


def write_video_clip(frames, path, fps=30, quality=7):
"""
Expand Down Expand Up @@ -562,10 +716,9 @@ def generate_grid_movies(
grid movie for that syllable.

use_reindexed: bool, default=True
Determines the naming of syllables (``results["syllables"]`` if
False, or ``results["syllables_reindexed"]`` if True). The
reindexed naming corresponds to the rank order of syllable
frequency (e.g. "0" for the most frequent syllable).
Whether to use label syllables by their frequency rank (True) or
or their original label (False). When reindexing, "0" represents
the most frequent syllable).

sampling_options: dict, default={}
Dictionary of options for sampling syllable instances (see
Expand Down Expand Up @@ -651,6 +804,13 @@ def generate_grid_movies(
syllable_instances = get_syllable_instances(
syllables, pre=pre, post=post, min_duration=min_duration,
min_frequency=min_frequency, min_instances=rows*cols)

if len(syllable_instances) == 0:
warnings.warn(fill(
'No syllables with sufficient instances to make a grid movie. '
'This usually occurs when all frames have the same syllable label '
'(use `plot_syllable_frequencies` to check if this is the case)'))
return

sampled_instances = sample_instances(
syllable_instances, rows*cols, coordinates=coordinates,
Expand Down Expand Up @@ -977,10 +1137,9 @@ def generate_trajectory_plots(
trajectory average.

use_reindexed: bool, default=True
Determines the naming of syllables (``results["syllables"]`` if
False, or ``results["syllables_reindexed"]`` if True). The
reindexed naming corresponds to the rank order of syllable
frequency (e.g. "0" for the most frequent syllable).
Whether to use label syllables by their frequency rank (True) or
or their original label (False). When reindexing, "0" represents
the most frequent syllable).

bodyparts: list of str, default=None
List of bodypart names in ``coordinates``.
Expand Down Expand Up @@ -1060,6 +1219,13 @@ def generate_trajectory_plots(
syllables, pre=pre, post=post, min_duration=min_duration,
min_frequency=min_frequency, min_instances=num_samples)

if len(syllable_instances) == 0:
warnings.warn(fill(
'No syllables with sufficient instances to make a trajectory plot. '
'This usually occurs when all frames have the same syllable label '
'(use `plot_syllable_frequencies` to check if this is the case)'))
return

sampling_options['n_neighbors'] = num_samples
sampled_instances = sample_instances(
syllable_instances, num_samples, coordinates=coordinates,
Expand Down Expand Up @@ -1473,11 +1639,10 @@ def generate_crowd_movies(
crowd movie for that syllable.

use_reindexed: bool, default=True
Determines the naming of syllables (``results["syllables"]`` if
False, or ``results["syllables_reindexed"]`` if True). The
reindexed naming corresponds to the rank order of syllable
frequency (e.g. "0" for the most frequent syllable).

Whether to use label syllables by their frequency rank (True) or
or their original label (False). When reindexing, "0" represents
the most frequent syllable).

bodyparts: list of str, default=None
List of bodypart names in ``coordinates``.

Expand Down Expand Up @@ -1553,6 +1718,13 @@ def generate_crowd_movies(
syllable_instances = get_syllable_instances(
syllables, pre=pre, post=post, min_duration=min_duration,
min_frequency=min_frequency, min_instances=num_instances)

if len(syllable_instances) == 0:
warnings.warn(fill(
'No syllables with sufficient instances to make a crowd movie. '
'This usually occurs when all frames have the same syllable label '
'(use `plot_syllable_frequencies` to check if this is the case)'))
return

sampled_instances = sample_instances(
syllable_instances, num_instances, coordinates=coordinates,
Expand Down