diff --git a/src/spyglass/common/__init__.py b/src/spyglass/common/__init__.py index 270acdf3a..ae144398b 100644 --- a/src/spyglass/common/__init__.py +++ b/src/spyglass/common/__init__.py @@ -7,6 +7,7 @@ StateScriptFile, VideoFile, convert_epoch_interval_name_to_position_interval_name, + get_position_interval_epoch, ) from spyglass.common.common_device import ( CameraDevice, @@ -40,10 +41,7 @@ intervals_by_length, ) from spyglass.common.common_lab import Institution, Lab, LabMember, LabTeam -from spyglass.common.common_nwbfile import ( - AnalysisNwbfile, - Nwbfile, -) +from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile from spyglass.common.common_position import ( IntervalLinearizationSelection, IntervalLinearizedPosition, diff --git a/src/spyglass/common/common_behav.py b/src/spyglass/common/common_behav.py index 53f8b7dcd..49b439278 100644 --- a/src/spyglass/common/common_behav.py +++ b/src/spyglass/common/common_behav.py @@ -686,6 +686,30 @@ def get_interval_list_name_from_epoch(nwb_file_name: str, epoch: int) -> str: return interval_names[0] +def get_position_interval_epoch( + nwb_file_name: str, position_interval_name: str +) -> int: + """Return the epoch number for a given position interval name.""" + # look up the epoch + key = dict( + nwb_file_name=nwb_file_name, + position_interval_name=position_interval_name, + ) + query = PositionIntervalMap * TaskEpoch & key + if query: + return query.fetch1("epoch") + # if no match, make sure all epoch interval names are mapped + for epoch_key in (TaskEpoch() & key).fetch( + "nwb_file_name", "interval_list_name", as_dict=True + ): + convert_epoch_interval_name_to_position_interval_name(epoch_key) + # try again + query = PositionIntervalMap * TaskEpoch & key + if query: + return query.fetch1("epoch") + return None + + def populate_position_interval_map_session(nwb_file_name: str): """Populate PositionIntervalMap for all epochs in a given NWB file.""" # 1. remove redundancy in interval names diff --git a/src/spyglass/common/common_position.py b/src/spyglass/common/common_position.py index 0a47a4d19..5216540de 100644 --- a/src/spyglass/common/common_position.py +++ b/src/spyglass/common/common_position.py @@ -1,6 +1,5 @@ import bottleneck import datajoint as dj -import matplotlib.pyplot as plt import numpy as np import pandas as pd import pynwb @@ -15,7 +14,11 @@ from position_tools.core import gaussian_smooth from tqdm import tqdm_notebook as tqdm -from spyglass.common.common_behav import RawPosition, VideoFile +from spyglass.common.common_behav import ( + RawPosition, + VideoFile, + get_position_interval_epoch, +) from spyglass.common.common_interval import IntervalList # noqa F401 from spyglass.common.common_nwbfile import AnalysisNwbfile from spyglass.settings import raw_dir, test_mode, video_dir @@ -553,13 +556,8 @@ def make(self, key): ).fetch1_dataframe() logger.info("Loading video data...") - epoch = ( - int( - key["interval_list_name"] - .replace("pos ", "") - .replace(" valid times", "") - ) - + 1 + epoch = get_position_interval_epoch( + key["nwb_file_name"], key["interval_list_name"] ) video_info = (VideoFile() & {**nwb_dict, "epoch": epoch}).fetch1() io = pynwb.NWBHDF5IO(raw_dir + "/" + video_info["nwb_file_name"], "r") diff --git a/src/spyglass/position/v1/position_trodes_position.py b/src/spyglass/position/v1/position_trodes_position.py index af047dc20..efaa611b0 100644 --- a/src/spyglass/position/v1/position_trodes_position.py +++ b/src/spyglass/position/v1/position_trodes_position.py @@ -7,7 +7,7 @@ from datajoint.utils import to_camel_case from pandas import DataFrame -from spyglass.common import PositionIntervalMap, TaskEpoch +from spyglass.common import get_position_interval_epoch from spyglass.common.common_behav import RawPosition from spyglass.common.common_nwbfile import AnalysisNwbfile from spyglass.common.common_position import IntervalPositionInfo, _fix_col_names @@ -255,11 +255,13 @@ def fetch_pose_dataframe(self): raise NotImplementedError("No pose data for TrodesPosV1") def fetch_video_path(self): - nwb_file_name = self.fetch1("nwb_file_name") - epoch = ( - TaskEpoch() * PositionIntervalMap() & self.fetch1("KEY") - ).fetch("epoch")[0] - return get_video_path(nwb_file_name) + nwb_file_name, interval_list_name = self.fetch1( + "nwb_file_name", "interval_list_name" + ) + epoch = get_position_interval_epoch(nwb_file_name, interval_list_name) + return get_video_info({"nwb_file_name": nwb_file_name, "epoch": epoch})[ + 0 + ] @schema @@ -298,14 +300,9 @@ def make(self, key): pos_df = (TrodesPosV1() & key).fetch1_dataframe() logger.info("Loading video data...") - epoch = ( - int( - key["interval_list_name"] - .replace("pos ", "") - .replace(" valid times", "") - ) - + 1 - ) # TODO: Fix this hack + epoch = get_position_interval_epoch( + key["nwb_file_name"], key["interval_list_name"] + ) ( video_path,