Skip to content

Commit

Permalink
added support for uniquebodyparts in DLC
Browse files Browse the repository at this point in the history
  • Loading branch information
calebweinreb committed Feb 3, 2024
1 parent bf462f5 commit cf3ec56
Showing 1 changed file with 57 additions and 81 deletions.
138 changes: 57 additions & 81 deletions keypoint_moseq/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ def _build_yaml(sections, comments):
return "\n".join(text_blocks)


def _get_path(
project_dir, model_name, path, filename, pathname_for_error_msg="path"
):
def _get_path(project_dir, model_name, path, filename, pathname_for_error_msg="path"):
if path is None:
assert project_dir is not None and model_name is not None, fill(
f"`model_name` and `project_dir` are required if `{pathname_for_error_msg}` is None."
Expand Down Expand Up @@ -258,16 +256,10 @@ def load_config(project_dir, check_if_valid=True, build_indexes=True):

if build_indexes:
config["anterior_idxs"] = jnp.array(
[
config["use_bodyparts"].index(bp)
for bp in config["anterior_bodyparts"]
]
[config["use_bodyparts"].index(bp) for bp in config["anterior_bodyparts"]]
)
config["posterior_idxs"] = jnp.array(
[
config["use_bodyparts"].index(bp)
for bp in config["posterior_bodyparts"]
]
[config["use_bodyparts"].index(bp) for bp in config["posterior_bodyparts"]]
)

if not "skeleton" in config or config["skeleton"] is None:
Expand Down Expand Up @@ -297,9 +289,7 @@ def update_config(project_dir, **kwargs):
>>> print(load_config(project_dir)['trans_hypparams']['kappa'])
100
"""
config = load_config(
project_dir, check_if_valid=False, build_indexes=False
)
config = load_config(project_dir, check_if_valid=False, build_indexes=False)
config.update(kwargs)
generate_config(project_dir, **config)

Expand Down Expand Up @@ -364,14 +354,9 @@ def setup_project(
f"{deeplabcut_config} does not exists or is not a"
" valid yaml file"
)
if (
"multianimalproject" in dlc_config
and dlc_config["multianimalproject"]
):
if "multianimalproject" in dlc_config and dlc_config["multianimalproject"]:
dlc_options["bodyparts"] = dlc_config["multianimalbodyparts"]
dlc_options["use_bodyparts"] = dlc_config[
"multianimalbodyparts"
]
dlc_options["use_bodyparts"] = dlc_config["multianimalbodyparts"]
else:
dlc_options["bodyparts"] = dlc_config["bodyparts"]
dlc_options["use_bodyparts"] = dlc_config["bodyparts"]
Expand All @@ -392,15 +377,12 @@ def setup_project(
)
skeleton = slp_file.skeletons[0]
node_names = skeleton.node_names
edge_names = [
[e.source.name, e.destination.name] for e in skeleton.edges
]
edge_names = [[e.source.name, e.destination.name] for e in skeleton.edges]
else:
with h5py.File(sleap_file, "r") as f:
node_names = [n.decode("utf-8") for n in f["node_names"]]
edge_names = [
[n.decode("utf-8") for n in edge]
for edge in f["edge_names"]
[n.decode("utf-8") for n in edge] for edge in f["edge_names"]
]
sleap_options["bodyparts"] = node_names
sleap_options["use_bodyparts"] = node_names
Expand Down Expand Up @@ -457,15 +439,11 @@ def load_pca(project_dir, pca_path=None):
"""
if pca_path is None:
pca_path = os.path.join(project_dir, "pca.p")
assert os.path.exists(pca_path), fill(
f"No PCA model found at {pca_path}"
)
assert os.path.exists(pca_path), fill(f"No PCA model found at {pca_path}")
return joblib.load(pca_path)


def load_checkpoint(
project_dir=None, model_name=None, path=None, iteration=None
):
def load_checkpoint(project_dir=None, model_name=None, path=None, iteration=None):
"""Load data and model snapshot from a saved checkpoint.
The checkpoint path can be specified directly via `path` or else it is
Expand Down Expand Up @@ -571,9 +549,7 @@ def reindex_syllables_in_checkpoint(
num_states = f[f"model_snapshots/{last_iter}/params/pi"].shape[0]
z = f[f"model_snapshots/{last_iter}/states/z"][()]
mask = f["data/mask"][()]
index = np.argsort(get_frequencies(z, mask, num_states, runlength))[
::-1
]
index = np.argsort(get_frequencies(z, mask, num_states, runlength))[::-1]

def _reindex(model):
model["params"]["betas"] = model["params"]["betas"][index]
Expand Down Expand Up @@ -651,7 +627,7 @@ def extract_results(

# extract syllables; repeat first syllable an extra `nlags` times
nlags = states["x"].shape[1] - states["z"].shape[1]
z = np.pad(states["z"], ((0,0),(nlags, 0)), mode="edge")
z = np.pad(states["z"], ((0, 0), (nlags, 0)), mode="edge")
syllables = unbatch(z, *metadata)

# extract latent state, centroid, and heading
Expand Down Expand Up @@ -724,9 +700,7 @@ def save_results_as_csv(
If a path separator ("/" or "\") is present in the recording name, it
will be replaced with `path_sep` when saving the csv file.
"""
save_dir = _get_path(
project_dir, model_name, save_dir, "results", "save_dir"
)
save_dir = _get_path(project_dir, model_name, save_dir, "results", "save_dir")

if not os.path.exists(save_dir):
os.makedirs(save_dir)
Expand All @@ -749,15 +723,10 @@ def save_results_as_csv(

if "latent_state" in results[key].keys():
latent_dim = results[key]["latent_state"].shape[1]
column_names.append(
[f"latent_state {i}" for i in range(latent_dim)]
)
column_names.append([f"latent_state {i}" for i in range(latent_dim)])
data.append(results[key]["latent_state"])

dfs = [
pd.DataFrame(arr, columns=cols)
for arr, cols in zip(data, column_names)
]
dfs = [pd.DataFrame(arr, columns=cols) for arr, cols in zip(data, column_names)]
df = pd.concat(dfs, axis=1)

for col in df.select_dtypes(include=[np.floating]).columns:
Expand Down Expand Up @@ -790,6 +759,7 @@ def load_keypoints(
path_sep="-",
path_in_name=False,
remove_extension=True,
exclude_individuals=["single"],
):
"""
Load keypoint tracking results from one or more files. Several file
Expand Down Expand Up @@ -886,6 +856,10 @@ def load_keypoints(
Whether to remove the file extension when naming the tracking results
from each file.
exclude_individuals: list of str, default=["single"]
List of individuals to exclude from the results. This is only used for
multi-animal tracking with deeplabcut.
Returns
-------
coordinates: dict
Expand Down Expand Up @@ -933,21 +907,22 @@ def load_keypoints(
"facemap": _facemap_loader,
}[format]

filepaths = list_files_with_exts(
filepath_pattern, extensions, recursive=recursive
)
if format == "deeplabcut":
additional_args = {"exclude_individuals": exclude_individuals}
else:
additional_args = {}

filepaths = list_files_with_exts(filepath_pattern, extensions, recursive=recursive)
assert len(filepaths) > 0, fill(
f"No files with extensions {extensions} found for {filepath_pattern}"
)

coordinates, confidences, bodyparts = {}, {}, None
for filepath in tqdm.tqdm(filepaths, desc=f"Loading keypoints", ncols=72):
try:
name = _name_from_path(
filepath, path_in_name, path_sep, remove_extension
)
name = _name_from_path(filepath, path_in_name, path_sep, remove_extension)
new_coordinates, new_confidences, bodyparts = loader(
filepath, name
filepath, name, **additional_args
)

if set(new_coordinates.keys()) & set(coordinates.keys()):
Expand All @@ -967,15 +942,13 @@ def load_keypoints(
coordinates.update(new_coordinates)
confidences.update(new_confidences)

assert len(coordinates) > 0, fill(
f"No valid results found for {filepath_pattern}"
)
assert len(coordinates) > 0, fill(f"No valid results found for {filepath_pattern}")

check_nan_proportions(coordinates, bodyparts)
return coordinates, confidences, bodyparts


def _deeplabcut_loader(filepath, name):
def _deeplabcut_loader(filepath, name, exclude_individuals=["single"]):
"""Load tracking results from deeplabcut csv or hdf5 files."""
ext = os.path.splitext(filepath)[1]
if ext == ".h5":
Expand All @@ -990,14 +963,30 @@ def _deeplabcut_loader(filepath, name):
df = pd.read_csv(filepath, header=header, index_col=0)

coordinates, confidences = {}, {}
bodyparts = df.columns.get_level_values("bodyparts").unique().tolist()
if "individuals" in df.columns.names:
ind_bodyparts = {}
for ind in df.columns.get_level_values("individuals").unique():
ind_df = df.xs(ind, axis=1, level="individuals")
arr = ind_df.to_numpy().reshape(len(ind_df), -1, 3)
coordinates[f"{name}_{ind}"] = arr[:, :, :-1]
confidences[f"{name}_{ind}"] = arr[:, :, -1]
if ind in exclude_individuals:
print(
f'Excluding individual: "{ind}". Set `exclude_individuals=[]` to include.'
)
else:
ind_df = df.xs(ind, axis=1, level="individuals")
bps = ind_df.columns.get_level_values("bodyparts").unique().tolist()
ind_bodyparts[ind] = bps

arr = ind_df.to_numpy().reshape(len(ind_df), -1, 3)
coordinates[f"{name}_{ind}"] = arr[:, :, :-1]
confidences[f"{name}_{ind}"] = arr[:, :, -1]

bodyparts = set(ind_bodyparts[list(ind_bodyparts.keys())[0]])
assert all([set(bps) == bodyparts for bps in ind_bodyparts.values()]), (
f"Bodyparts are not consistent across individuals. The following bodyparts "
f"were found for each individual: {ind_bodyparts}. Use `exclude_individuals`"
"to exclude specific individuals."
)
else:
bodyparts = df.columns.get_level_values("bodyparts").unique().tolist()
arr = df.to_numpy().reshape(len(df), -1, 3)
coordinates[name] = arr[:, :, :-1]
confidences[name] = arr[:, :, -1]
Expand Down Expand Up @@ -1030,12 +1019,8 @@ def _sleap_loader(filepath, name):
coordinates = {name: coords[0].T}
confidences = {name: confs[0].T}
else:
coordinates = {
f"{name}_track{i}": coords[i].T for i in range(coords.shape[0])
}
confidences = {
f"{name}_track{i}": confs[i].T for i in range(coords.shape[0])
}
coordinates = {f"{name}_track{i}": coords[i].T for i in range(coords.shape[0])}
confidences = {f"{name}_track{i}": confs[i].T for i in range(coords.shape[0])}
return coordinates, confidences, bodyparts


Expand All @@ -1050,10 +1035,7 @@ def _anipose_loader(filepath, name):
df = pd.read_csv(filepath)
coordinates = {
name: np.stack(
[
df[[f"{bp}_x", f"{bp}_y", f"{bp}_z"]].to_numpy()
for bp in bodyparts
],
[df[[f"{bp}_x", f"{bp}_y", f"{bp}_z"]].to_numpy() for bp in bodyparts],
axis=1,
)
}
Expand All @@ -1075,8 +1057,7 @@ def _sleap_anipose_loader(filepath, name):
confidences = {name: confs[:, 0]}
else:
coordinates = {
f"{name}_track{i}": coords[:, i]
for i in range(coords.shape[1])
f"{name}_track{i}": coords[:, i] for i in range(coords.shape[1])
}
confidences = {
f"{name}_track{i}": confs[:, i] for i in range(coords.shape[1])
Expand All @@ -1088,9 +1069,7 @@ def _load_nwb_pose_obj(io, filepath):
"""Grab PoseEstimation object from an opened .nwb file."""
all_objs = io.read().all_children()
pose_objs = [o for o in all_objs if isinstance(o, PoseEstimation)]
assert len(pose_objs) > 0, fill(
f"No PoseEstimation objects found in {filepath}"
)
assert len(pose_objs) > 0, fill(f"No PoseEstimation objects found in {filepath}")
assert len(pose_objs) == 1, fill(
f"Found multiple PoseEstimation objects in {filepath}. "
"This is not currently supported. Please open a github "
Expand All @@ -1109,10 +1088,7 @@ def _nwb_loader(filepath, name):
[pose_obj.pose_estimation_series[bp].data[()] for bp in bodyparts],
axis=1,
)
if (
"confidence"
in pose_obj.pose_estimation_series[bodyparts[0]].fields
):
if "confidence" in pose_obj.pose_estimation_series[bodyparts[0]].fields:
confs = np.stack(
[
pose_obj.pose_estimation_series[bp].confidence[()]
Expand Down

0 comments on commit cf3ec56

Please sign in to comment.