Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restructure results/checkpoint + New features: (analysis tools, large dataset support, 3D viz and other more) #79

Merged
merged 179 commits into from
Aug 8, 2023

Conversation

calebweinreb
Copy link
Contributor

Summary

This PR introduces the following changes/features, which are explained in more detail below.

  • New logic for syllable indexing
  • Altered format for results and checkpoint files
  • New notebook and widgets for statistical analysis of syllables (thanks to @versey-sherry!)
  • Support for modeling large datasets (multiple GPUs + partial serialization)
  • Interactive visualizations for 3D data
  • Minor features additions (NWB support, syllable similarity plot)

New logic for syllable indexing

Until now, the "extract_results" of keypoint-MoSeq saved saved syllable sequences in their original indexing (as they were represented during modeling) along with a "reindexed" version in which syllables were re-labeled by frequency (so syllable "0" was the most frequent, and so on). But this approach had a fatal flaw: when a fitted model was applied to new data, the syllable frequencies could be different, which would lead to a slightly different re-labeling, so that e.g. syllable "0" would refer to one state in a subset of recordings and a different state in another subset.

To prevent this issue, we now reindex syllable directly inside the model object. That way, if the model is used later to generate syllable for new data, the resulting labels will always be consistent. See #72 for details. Concretely, this means that

  1. The standard modeling pipeline now includes a new step after model fitting but before extracting results:
kpms.reindex_syllables_in_checkpoint(project_dir, model_name);
  1. The results files no longer include separate "syllables" and "syllables_reindexed" fields (see below for more details).

New format for results and checkpoint files

This PR introduces a new format for the results.h5 and checkpoint.p files saved during modeling. This is a breaking change, meaning that results/checkpoints generated with a previous version of the code will no longer work. Below we explain the changes and provide code for converting to the new format.

How the formats have changed

From a user perspective, the main change is that the results.h5 no longer contains separate syllables and syllables_reindexed

For the results.h5 files, we have removed some fields and renamed others. Previously the format was

    results.h5
    ├──session_name1
    │  ├──estimated_coordinates  # denoised coordinates
    │  ├──syllables_reindexed    # syllables reindexed by frequency
    │  ├──syllables              # non-reindexed syllables labels (z)
    │  ├──latent_state           # inferred low-dim pose state (x)
    │  ├──centroid               # inferred centroid (v)
    │  └──heading                # inferred heading (h)
    ⋮

Now the format is

    results.h5
    ├──recording_name1
    │  ├──syllable      # syllable labels (z)
    │  ├──latent_state  # inferred low-dim pose state (x)
    │  ├──centroid      # inferred centroid (v)
    │  └──heading       # inferred heading (h)
    ⋮

The checkpoint.p files have changed more substantively. They are now saved as hdf5 files (rather than joblib) and their internal organization has changed.

Converting to the new format

The following code converts results and checkpoint files to the new format. Given a project directory and model name, a new project directory is generated with the updated files. As part of the reformatting, syllables are reindexed inside the model (see previous section) and a list of the resulting syllable name-changes is printed.

Make sure you are using the most up-to-date version of keypoint_moseq before running.

import keypoint_moseq as kpms
import numpy as np
import os, shutil
import joblib

def update_checkpoint_format(checkpoint):
    model = {k:checkpoint[k] for k in ['seed','noise_prior','params','states','hypparams']}
    model_snapshots = {str(checkpoint['iteration']): model}

    for i,hist in checkpoint['history'].items():
        model_snapshots[str(i)] = {
            'noise_prior': checkpoint['noise_prior'],
            'hypparams': checkpoint['hypparams'],
            'states': hist['states'],
            'params': hist['params'],
            'seed': hist['seed']
        }

    data = {'Y': checkpoint['Y'], 'conf':checkpoint['conf'], 'mask':checkpoint['mask']}
    keys = [l[0] for l in checkpoint['labels']]
    bounds = np.array([l[1:] for l in checkpoint['labels']])
    new_checkpoint = {'data':data, 'metadata':(keys, bounds), 'model_snapshots':model_snapshots}
    return new_checkpoint


def update_results_format(results, index=None):
    for k,v in results.items():
        if 'estimated_coordinates' in v:
            v['est_coords'] = v['estimated_coordinates']
            del v['estimated_coordinates']
       
        if 'syllables' in v:
            if index is None:
                v['syllable'] = v['syllables']
            else:
                v['syllable'] = np.argsort(index)[v['syllables']]
            del v['syllables']
           
        if 'syllables_reindexed' in v:
            del v['syllables_reindexed']  
    return results
  • Setup new project and model directories
old_project_dir = 'path/to/old/project_dir'
new_project_dir = 'path/to/new/project_dir'
model_name = 'name_of_model'

os.makedirs(new_project_dir)
os.makedirs(os.path.join(new_project_dir, model_name))

for filename in ['pcs-xy.pdf', 'pca_scree.pdf', 'config.yml', 'pca.p']:
    src_path = os.path.join(old_project_dir, filename)
    if os.path.exists(src_path):
        shutil.copy(src_path, new_project_dir)
  • Convert saved checkpoint to new format
old_checkpoint_path = os.path.join(old_project_dir, model_name, 'checkpoint.h5')
new_checkpoint_path = os.path.join(new_project_dir, model_name, 'checkpoint.h5')

old_checkpoint = joblib.load(os.path.join(old_project_dir, model_name, 'checkpoint.p'))
new_checkpoint = update_checkpoint_format(old_checkpoint)
kpms.save_hdf5(new_checkpoint_path, new_checkpoint)
  • Reindex syllables in the model checkpoint
index = kpms.reindex_syllables_in_checkpoint(new_project_dir, model_name)
for i,j in enumerate(index):
    print(f'Syllable {j} is now labeled {i}')
  • Convert saved results to new format
old_results_path = os.path.join(old_project_dir, model_name, 'results.h5')
new_results_path = os.path.join(new_project_dir, model_name, 'results.h5')

old_results = kpms.load_hdf5(os.path.join(old_project_dir, model_name, 'results.h5'))
new_results = update_results_format(old_results, index)
kpms.save_hdf5(new_results_path, new_results)
  • Regenerate visualizations
config = lambda: kpms.load_config(new_project_dir)
keypoint_data_path = 'path/to/data' # modify as needed
coordinates, confidences, bodyparts = kpms.load_keypoints(keypoint_data_path, 'deeplabcut')
results = kpms.load_results(new_project_dir, model_name)
kpms.save_results_as_csv(results, new_project_dir, model_name)
kpms.generate_trajectory_plots(coordinates, new_results, new_project_dir, model_name, **config())
kpms.generate_grid_movies(new_results, new_project_dir, model_name, coordinates=coordinates, **config())

New analysis tools

This PR introduces a new set of analysis widgets and a tutorial notebook (analysis.ipynb) for using them. These widgets ingest results in the updated format described above. So make sure to run the conversion code before applying the analysis pipeline to an existing project!

Support for large datasets

Currently it is not possible to model large datasets on a GPU without incurring out-of-memory (OOM errors). To address this problem, we have created a framework for mixed serial/parallel computation and added multi-GPU support.

Partial serialization

By default, modeling is parallelized across the full dataset. Here we introduce a new option for mixed parallel/serial computation where the data is split into batches that are processed one at a time. To enable this option, run the following code before fitting the model (if you have already initiated model fitting the kernel must be restarted)

from jax_moseq.utils import set_mixed_map_iters
set_mixed_map_iters(4) # adjust as needed

This will split the data into 4 batches, which should reduce the memory requirements about 4-fold but also result in a 4-fold slow-down. The number of batches can be adjusted as needed.

Multi-GPU support

To use multiple GOUs, run the following code before fitting the model (if you have already initiated model fitting the kernel must be restarted)

from jax_moseq.utils import set_mixed_map_gpus
set_mixed_map_gpus(2)

This will split the computation across two GPUs.

Additional info on implementation

Both of the above options (multi-GPU support and partial serialization) rely on a new utility called mixed_map that we added to the jax_moseq package. Below is a copy of its docstring:

def mixed_map(fun, in_axes=None, out_axes=None):
    """
    Combine jax.pmap, jax.vmap and jax.lax.map for parallelization.

    This function is similar to `jax.vmap`, except that it mixes together
    `jax.pmap`, `jax.vmap` and `jax.lax.map` to prevent OOM errors and allow
    for parallelization across multiple GPUs. The behavior is determined by
    the global variables `_MIXED_MAP_ITERS` and `_MIXED_MAP_GPUS`, which can be
    set using :py:func:`jax_moseq.utils.set_mixed_map_iters` and
    py:func:`jax_moseq.utils.set_mixed_map_gpus` respectively.

    Given an axis size of N to map, the data is padded such that the axis size
    is a multiple of the number of `_MIXED_MAP_ITERS * _MIXED_MAP_GPUS`. The
    data is then processed serially chunks, where the number of chunks is
    determined by `_MIXED_MAP_ITERS`. Each chunk is processed in parallel
    using jax.pmap to distribute across `_MIXED_MAP_GPUS` devices and jax.vmap
    to parallelize within each device.
    """

3D plotting tools

  • In addition to 2D projections of 3D keypoints, plot_pcs and generate_trajectory_plots now produce interactive 3D visualizations. These are rendered in the notebook and can also be viewed offline in a browser using the saved .html files.

  • It is now possible to generate grid movies for 3D keypoints, although they will only show 2D projections of the keypoints and not the underlying video. To generate grid movies from 3D data, include the flag keypoints_only=True and set the desired projection plane with the use_dims argument, e.g.

# generate grid movies in the x/y plane
kpms.generate_grid_movies(
   results, 
   project_dir, 
   name, 
   coordinates=coordinates, 
   keypoints_only=True, 
   use_dims=[0,1], 
   **config())

@bainro
Copy link

bainro commented Feb 3, 2024

Is this the same issue @calebweinreb ? I'm running your lab's 3D dataset, but didn't realize I'd need multiple GPUs :(

image

@calebweinreb
Copy link
Contributor Author

You shouldn't need multiple GPUs. Just use "mixed_map_iters" as described in here https://keypoint-moseq.readthedocs.io/en/latest/FAQs.html#troubleshooting

@bainro
Copy link

bainro commented Feb 3, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants