Skip to content

Commit

Permalink
revisions
Browse files Browse the repository at this point in the history
revisions
  • Loading branch information
vtsai881 committed Dec 14, 2023
1 parent a9cb496 commit 6d93f17
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 27 deletions.
4 changes: 2 additions & 2 deletions sleap/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,13 +511,13 @@ def add_submenu_choices(menu, title, options, key):
export_csv_menu,
"export_csv_current_all_frames",
"Current Video (all frames)...",
self.commands.exportCSVFile(all_frames=True),
lambda: self.commands.exportCSVFile(all_frames=True),
)
add_menu_item(
export_csv_menu,
"export_csv_all_all_frames",
"All Videos (all frames)...",
lambda: self.commands.exportCSVFile(all_frames=True),
lambda: self.commands.exportCSVFile(all_videos=True, all_frames=True),
)

add_menu_item(fileMenu, "export_nwb", "Export NWB...", self.commands.exportNWB)
Expand Down
34 changes: 17 additions & 17 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,23 +1143,23 @@ def do_action(cls, context: CommandContext, params: dict):
else:
adaptor = SleapAnalysisAdaptor

if params['all_frames']:
adaptor.write(
filename=output_path,
all_frames=params.get('all_frames', False),
source_object=context.labels,
source_path=context.state["filename"],
video=video,
)
else:
adaptor.write(
filename=output_path,
all_frames=False,
source_object=context.labels,
source_path=context.state["filename"],
video=video,
)
if 'all_frames' in params and params['all_frames']:
adaptor.write(
filename=output_path,
all_frames=True,
source_object=context.labels,
source_path=context.state["filename"],
video=video,
)
else:
adaptor.write(
filename=output_path,
all_frames=False,
source_object=context.labels,
source_path=context.state["filename"],
video=video,
)

@staticmethod
def ask(context: CommandContext, params: dict) -> bool:
def ask_for_filename(default_name: str, csv: bool) -> str:
Expand Down
5 changes: 2 additions & 3 deletions sleap/info/write_tracking_h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def write_occupancy_file(
print(f"Saved as {output_path}")


def write_csv_file(output_path, data_dict):
def write_csv_file(output_path, data_dict, all_frames):

"""Write CSV file with data from given dictionary.
Expand Down Expand Up @@ -348,7 +348,6 @@ def write_csv_file(output_path, data_dict):
tracks.append(detection)

tracks = pd.DataFrame(tracks)
all_frames = globals().get('all_frames', False)

if all_frames:
tracks = tracks.set_index('frame_idx')
Expand Down Expand Up @@ -443,7 +442,7 @@ def main(
)

if csv:
write_csv_file(output_path, data_dict)
write_csv_file(output_path, data_dict, all_frames=all_frames)
else:
write_occupancy_file(output_path, data_dict, transpose=True)

Expand Down
9 changes: 4 additions & 5 deletions sleap/io/format/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def write(
filename: str,
source_object: Labels,
source_path: str = None,
all_frames: bool= False,
all_frames: bool = False,
video: Video = None,
):
"""Writes csv file for :py:class:`Labels` `source_object`.
Expand All @@ -54,15 +54,14 @@ def write(
filename: The filename for the output file.
source_object: The :py:class:`Labels` from which to get data from.
source_path: Path for the labels object
all_frames: A boolean flag to determine whether to include all frames
or only those with tracking data in the export.
video: The :py:class:`Video` from which toget data from. If no `video` is
all_frames: A boolean flag to determine whether to include all frames or
only those with tracking data in the export.
video: The :py:class:`Video` from which to get data from. If no `video` is
specified, then the first video in `source_object` videos list will be
used. If there are no :py:class:`Labeled Frame`s in the `video`, then no
analysis file will be written.
"""
from sleap.info.write_tracking_h5 import main as write_analysis

write_analysis(
labels=source_object,
output_path=filename,
Expand Down
60 changes: 60 additions & 0 deletions tests/gui/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,66 @@ def assert_videos_written(num_videos: int, labels_path: str = None):
ExportAnalysisFile.do_action(context=context, params=params)
assert_videos_written(num_videos=2, labels_path=context.state["filename"])

# Test with all_videos True and all_frames True
params = {"all_videos": True, "all_frames": True, "csv": csv}
okay = ExportAnalysisFile_ask(context=context, params=params)
assert okay == True
ExportAnalysisFile.do_action(context=context, params=params)
assert_videos_written(num_videos=2, labels_path=context.state["filename"])

# Test with all_videos False and all_frames True
params = {"all_videos": False, "all_frames": True, "csv": csv}
okay = ExportAnalysisFile_ask(context=context, params=params)
assert okay == True
ExportAnalysisFile.do_action(context=context, params=params)
assert_videos_written(num_videos=1, labels_path=context.state["filename"])

# Test with all_videos False and all_frames False
params = {"all_videos": False, "all_frames": False, "csv": csv}
okay = ExportAnalysisFile_ask(context=context, params=params)
assert okay == True
ExportAnalysisFile.do_action(context=context, params=params)
assert_videos_written(num_videos=1, labels_path=context.state["filename"])

# Add labels path and test with all_videos True and all_frames True (single video)
context.state["filename"] = str(tmpdir.with_name("path.to.labels"))
params = {"all_videos": True, "all_frames": True, "csv": csv}
okay = ExportAnalysisFile_ask(context=context, params=params)
assert okay == True
ExportAnalysisFile.do_action(context=context, params=params)
assert_videos_written(num_videos=2, labels_path=context.state["filename"])

# Add a video (no labels) and test with all_videos True and all_frames True
labels.add_video(small_robot_mp4_vid)

params = {"all_videos": True, "all_frames": True, "csv": csv}
okay = ExportAnalysisFile_ask(context=context, params=params)
assert okay == True
ExportAnalysisFile.do_action(context=context, params=params)
assert_videos_written(num_videos=2, labels_path=context.state["filename"])

# Test with videos with the same filename
(tmpdir / "session1").mkdir()
(tmpdir / "session2").mkdir()
shutil.copy(
centered_pair_predictions.video.backend.filename,
tmpdir / "session1" / "video.mp4",
)
shutil.copy(small_robot_mp4_vid.backend.filename, tmpdir / "session2" / "video.mp4")
labels.videos[0].backend.filename = str(tmpdir / "session1" / "video.mp4")
labels.videos[1].backend.filename = str(tmpdir / "session2" / "video.mp4")
params = {"all_videos": True, "csv": csv}
okay = ExportAnalysisFile_ask(context=context, params=params)
assert okay == True
ExportAnalysisFile.do_action(context=context, params=params)
assert_videos_written(num_videos=2, labels_path=context.state["filename"])

# Remove all videos and test
all_videos = list(labels.videos)
for video in all_videos:
labels.remove_video(labels.videos[-1])

params = {"all_videos": True, "all_frames": True, "csv": csv}
# Test with videos with the same filename
(tmpdir / "session1").mkdir()
(tmpdir / "session2").mkdir()
Expand Down

0 comments on commit 6d93f17

Please sign in to comment.