Skip to content

Commit

Permalink
add get_position_interval_epoch
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelbray32 committed Dec 20, 2024
1 parent c3c7b84 commit e5b2c1b
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 27 deletions.
6 changes: 2 additions & 4 deletions src/spyglass/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
StateScriptFile,
VideoFile,
convert_epoch_interval_name_to_position_interval_name,
get_position_interval_epoch,
)
from spyglass.common.common_device import (
CameraDevice,
Expand Down Expand Up @@ -40,10 +41,7 @@
intervals_by_length,
)
from spyglass.common.common_lab import Institution, Lab, LabMember, LabTeam
from spyglass.common.common_nwbfile import (
AnalysisNwbfile,
Nwbfile,
)
from spyglass.common.common_nwbfile import AnalysisNwbfile, Nwbfile
from spyglass.common.common_position import (
IntervalLinearizationSelection,
IntervalLinearizedPosition,
Expand Down
24 changes: 24 additions & 0 deletions src/spyglass/common/common_behav.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,30 @@ def get_interval_list_name_from_epoch(nwb_file_name: str, epoch: int) -> str:
return interval_names[0]


def get_position_interval_epoch(
nwb_file_name: str, position_interval_name: str
) -> int:
"""Return the epoch number for a given position interval name."""
# look up the epoch
key = dict(
nwb_file_name=nwb_file_name,
position_interval_name=position_interval_name,
)
query = PositionIntervalMap * TaskEpoch & key
if query:
return query.fetch1("epoch")
# if no match, make sure all epoch interval names are mapped
for epoch_key in (TaskEpoch() & key).fetch(
"nwb_file_name", "interval_list_name", as_dict=True
):
convert_epoch_interval_name_to_position_interval_name(epoch_key)
# try again
query = PositionIntervalMap * TaskEpoch & key
if query:
return query.fetch1("epoch")
return None


def populate_position_interval_map_session(nwb_file_name: str):
"""Populate PositionIntervalMap for all epochs in a given NWB file."""
# 1. remove redundancy in interval names
Expand Down
16 changes: 7 additions & 9 deletions src/spyglass/common/common_position.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import bottleneck
import datajoint as dj
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pynwb
Expand All @@ -15,7 +14,11 @@
from position_tools.core import gaussian_smooth
from tqdm import tqdm_notebook as tqdm

from spyglass.common.common_behav import RawPosition, VideoFile
from spyglass.common.common_behav import (
RawPosition,
VideoFile,
get_position_interval_epoch,
)
from spyglass.common.common_interval import IntervalList # noqa F401
from spyglass.common.common_nwbfile import AnalysisNwbfile
from spyglass.settings import raw_dir, test_mode, video_dir
Expand Down Expand Up @@ -553,13 +556,8 @@ def make(self, key):
).fetch1_dataframe()

logger.info("Loading video data...")
epoch = (
int(
key["interval_list_name"]
.replace("pos ", "")
.replace(" valid times", "")
)
+ 1
epoch = get_position_interval_epoch(
key["nwb_file_name"], key["interval_list_name"]
)
video_info = (VideoFile() & {**nwb_dict, "epoch": epoch}).fetch1()
io = pynwb.NWBHDF5IO(raw_dir + "/" + video_info["nwb_file_name"], "r")
Expand Down
25 changes: 11 additions & 14 deletions src/spyglass/position/v1/position_trodes_position.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from datajoint.utils import to_camel_case
from pandas import DataFrame

from spyglass.common import PositionIntervalMap, TaskEpoch
from spyglass.common import get_position_interval_epoch
from spyglass.common.common_behav import RawPosition
from spyglass.common.common_nwbfile import AnalysisNwbfile
from spyglass.common.common_position import IntervalPositionInfo, _fix_col_names
Expand Down Expand Up @@ -255,11 +255,13 @@ def fetch_pose_dataframe(self):
raise NotImplementedError("No pose data for TrodesPosV1")

def fetch_video_path(self):
nwb_file_name = self.fetch1("nwb_file_name")
epoch = (
TaskEpoch() * PositionIntervalMap() & self.fetch1("KEY")
).fetch("epoch")[0]
return get_video_path(nwb_file_name)
nwb_file_name, interval_list_name = self.fetch1(
"nwb_file_name", "interval_list_name"
)
epoch = get_position_interval_epoch(nwb_file_name, interval_list_name)
return get_video_info({"nwb_file_name": nwb_file_name, "epoch": epoch})[
0
]


@schema
Expand Down Expand Up @@ -298,14 +300,9 @@ def make(self, key):
pos_df = (TrodesPosV1() & key).fetch1_dataframe()

logger.info("Loading video data...")
epoch = (
int(
key["interval_list_name"]
.replace("pos ", "")
.replace(" valid times", "")
)
+ 1
) # TODO: Fix this hack
epoch = get_position_interval_epoch(
key["nwb_file_name"], key["interval_list_name"]
)

(
video_path,
Expand Down

0 comments on commit e5b2c1b

Please sign in to comment.