Skip to content

Commit

Permalink
test export_csv method added to Labels class
Browse files Browse the repository at this point in the history
  • Loading branch information
eberrigan committed Apr 8, 2024
1 parent 821d114 commit 3eeb9fd
Showing 1 changed file with 46 additions and 0 deletions.
46 changes: 46 additions & 0 deletions tests/io/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import pandas as pd
import pytest
import numpy as np
from pathlib import Path, PurePath

import sleap
from sleap.info.write_tracking_h5 import get_nodes_as_np_strings
from sleap.skeleton import Skeleton
from sleap.instance import Instance, Point, LabeledFrame, PredictedInstance, Track
from sleap.io.video import Video, MediaVideo
Expand Down Expand Up @@ -1559,3 +1561,47 @@ def test_export_nwb(centered_pair_predictions: Labels, tmpdir):
# Read from NWB file
read_labels = NDXPoseAdaptor.read(NDXPoseAdaptor, filehandle.FileHandle(filename))
assert_read_labels_match(centered_pair_predictions, read_labels)


@pytest.mark.parametrize("labels_fixture_name", [
"centered_pair_labels",
"centered_pair_predictions",
"min_labels",
"min_labels_slp",
"min_labels_robot"
])
def test_export_csv(labels_fixture_name, tmpdir, request):
# Retrieve Labels fixture by name
labels_fixture = request.getfixturevalue(labels_fixture_name)

# Generate the filename for the CSV file
csv_filename = Path(tmpdir) / (labels_fixture_name + "_export.csv")

# Export to CSV file
labels_fixture.export_csv(str(csv_filename))

# Assert that the CSV file was created
assert csv_filename.is_file(), f"CSV file '{csv_filename}' was not created"


def test_exported_csv(tmpdir, min_labels_slp, minimal_instance_predictions_csv_path):
# Construct the filename for the CSV file
filename_csv = str(tmpdir + "\\analysis.csv")
labels = min_labels_slp
# Export to CSV file
labels.export_csv(filename_csv)
# Read the CSV file
labels_csv = pd.read_csv(filename_csv)

# Read the csv file fixture
csv_predictions = pd.read_csv(minimal_instance_predictions_csv_path)

assert labels_csv.equals(csv_predictions)

# check number of cols
assert len(labels_csv.columns) - 3 == len(get_nodes_as_np_strings(labels)) * 3





0 comments on commit 3eeb9fd

Please sign in to comment.