From 3eeb9fdb35a72a7d22fd75900eba2c40dc7455b6 Mon Sep 17 00:00:00 2001 From: Elizabeth Berrigan Date: Mon, 8 Apr 2024 15:20:44 -0700 Subject: [PATCH] test `export_csv` method added to `Labels` class --- tests/io/test_dataset.py | 46 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index 5592ae437..be5b71730 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -1,9 +1,11 @@ import os +import pandas as pd import pytest import numpy as np from pathlib import Path, PurePath import sleap +from sleap.info.write_tracking_h5 import get_nodes_as_np_strings from sleap.skeleton import Skeleton from sleap.instance import Instance, Point, LabeledFrame, PredictedInstance, Track from sleap.io.video import Video, MediaVideo @@ -1559,3 +1561,47 @@ def test_export_nwb(centered_pair_predictions: Labels, tmpdir): # Read from NWB file read_labels = NDXPoseAdaptor.read(NDXPoseAdaptor, filehandle.FileHandle(filename)) assert_read_labels_match(centered_pair_predictions, read_labels) + + +@pytest.mark.parametrize("labels_fixture_name", [ + "centered_pair_labels", + "centered_pair_predictions", + "min_labels", + "min_labels_slp", + "min_labels_robot" +]) +def test_export_csv(labels_fixture_name, tmpdir, request): + # Retrieve Labels fixture by name + labels_fixture = request.getfixturevalue(labels_fixture_name) + + # Generate the filename for the CSV file + csv_filename = Path(tmpdir) / (labels_fixture_name + "_export.csv") + + # Export to CSV file + labels_fixture.export_csv(str(csv_filename)) + + # Assert that the CSV file was created + assert csv_filename.is_file(), f"CSV file '{csv_filename}' was not created" + + +def test_exported_csv(tmpdir, min_labels_slp, minimal_instance_predictions_csv_path): + # Construct the filename for the CSV file + filename_csv = str(tmpdir + "\\analysis.csv") + labels = min_labels_slp + # Export to CSV file + labels.export_csv(filename_csv) + # Read the CSV file + labels_csv = pd.read_csv(filename_csv) + + # Read the csv file fixture + csv_predictions = pd.read_csv(minimal_instance_predictions_csv_path) + + assert labels_csv.equals(csv_predictions) + + # check number of cols + assert len(labels_csv.columns) - 3 == len(get_nodes_as_np_strings(labels)) * 3 + + + + +