Skip to content

Commit

Permalink
move moseq dir definition to spyglass config
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelbray32 committed Dec 27, 2024
1 parent f92320b commit d663e0e
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 15 deletions.
10 changes: 2 additions & 8 deletions src/spyglass/behavior/moseq.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from spyglass.common import AnalysisNwbfile
from spyglass.position.position_merge import PositionOutput
from spyglass.settings import moseq_project_dir, moseq_video_dir
from spyglass.utils import SpyglassMixin

from .core import PoseGroup, format_dataset_for_moseq, results_to_df
Expand Down Expand Up @@ -113,15 +114,8 @@ def make(self, key):
model_params = (MoseqModelParams & key).fetch1("model_params")

# set up the project and config
project_dir = (
"/home/sambray/Documents/moseq_test_proj3" # TODO: make this better
)
video_dir = (
"/home/sambray/Documents/moseq_test_vids3" # TODO: make this better
)
project_dir, video_dir = moseq_project_dir, moseq_video_dir
# make symlinks to the videos in a single directory
os.makedirs(video_dir, exist_ok=True)
# os.makedirs(project_dir, exist_ok=True)
video_paths = (PoseGroup & key).fetch_video_paths()
for video in video_paths:
destination = os.path.join(video_dir, os.path.basename(video))
Expand Down
51 changes: 44 additions & 7 deletions src/spyglass/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ def __init__(self, base_dir: str = None, **kwargs) -> None:
"video": "video",
"output": "output",
},
"moseq": {
"project": "projects",
"video": "video",
},
}
self.dj_defaults = {
"database.host": kwargs.get("database_host", "lmf-db.cin.ucsf.edu"),
Expand Down Expand Up @@ -139,6 +143,7 @@ def load_config(
dj_spyglass = dj_custom.get("spyglass_dirs", {})
dj_kachery = dj_custom.get("kachery_dirs", {})
dj_dlc = dj_custom.get("dlc_dirs", {})
dj_moseq = dj_custom.get("moseq_dirs", {})

self._debug_mode = dj_custom.get("debug_mode", False)
self._test_mode = kwargs.get("test_mode") or dj_custom.get(
Expand Down Expand Up @@ -174,9 +179,20 @@ def load_config(
)
Path(self._dlc_base).mkdir(exist_ok=True)

self._moseq_base = (
dj_moseq.get("base")
or os.environ.get("MOSEQ_BASE_DIR")
or str(Path(resolved_base) / "moseq")
)
Path(self._moseq_base).mkdir(exist_ok=True)

config_dirs = {"SPYGLASS_BASE_DIR": str(resolved_base)}
for prefix, dirs in self.relative_dirs.items():
this_base = self._dlc_base if prefix == "dlc" else resolved_base
this_base = (
self._dlc_base
if prefix == "dlc"
else (self._moseq_base if prefix == "moseq" else resolved_base)
)
for dir, dir_str in dirs.items():
dir_env_fmt = self.dir_to_var(dir=dir, dir_type=prefix)

Expand All @@ -185,12 +201,14 @@ def load_config(
if not self.supplied_base_dir
else None
)

source_config = (
dj_dlc
if prefix == "dlc"
else dj_kachery if prefix == "kachery" else dj_spyglass
)
if prefix == "dlc":
source_config = dj_dlc
elif prefix == "moseq":
source_config = dj_moseq
elif prefix == "kachery":
source_config = dj_kachery
else:
source_config = dj_spyglass
dir_location = (
source_config.get(dir)
or env_loc
Expand Down Expand Up @@ -482,6 +500,11 @@ def _dj_custom(self) -> dict:
"video": self.dlc_video_dir,
"output": self.dlc_output_dir,
},
"moseq_dirs": {
"base": self._moseq_base,
"project": self.moseq_project_dir,
"video": self.moseq_video_dir,
},
"kachery_zone": "franklab.default",
}
}
Expand Down Expand Up @@ -567,6 +590,16 @@ def dlc_output_dir(self) -> str:
"""DLC output directory as a string."""
return self.config.get(self.dir_to_var("output", "dlc"))

@property
def moseq_project_dir(self) -> str:
"""Moseq project directory as a string."""
return self.config.get(self.dir_to_var("project", "moseq"))

@property
def moseq_video_dir(self) -> str:
"""Moseq video directory as a string."""
return self.config.get(self.dir_to_var("video", "moseq"))


sg_config = SpyglassConfig()
sg_config.load_config(on_startup=True)
Expand All @@ -588,6 +621,8 @@ def dlc_output_dir(self) -> str:
dlc_project_dir = None
dlc_video_dir = None
dlc_output_dir = None
moseq_project_dir = None
moseq_video_dir = None
else:
config = sg_config.config
base_dir = sg_config.base_dir
Expand All @@ -605,3 +640,5 @@ def dlc_output_dir(self) -> str:
dlc_project_dir = sg_config.dlc_project_dir
dlc_video_dir = sg_config.dlc_video_dir
dlc_output_dir = sg_config.dlc_output_dir
moseq_project_dir = sg_config.moseq_project_dir
moseq_video_dir = sg_config.moseq_video_dir

0 comments on commit d663e0e

Please sign in to comment.