diff --git a/sleap_nn/config/__init__.py b/sleap_nn/config/__init__.py deleted file mode 100644 index 16afead9..00000000 --- a/sleap_nn/config/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Modules relating to configuring data pipelines.""" diff --git a/sleap_nn/config/data.py b/sleap_nn/config/data.py deleted file mode 100644 index c8f81143..00000000 --- a/sleap_nn/config/data.py +++ /dev/null @@ -1,14 +0,0 @@ -"""This module implements base configurations for data pipelines.""" - -from omegaconf import OmegaConf - -# Base TopDownConfmapsPipeline data config. -base_topdown_data_config = OmegaConf.create( - { - "preprocessing": { - "crop_hw": (160, 160), - "conf_map_gen": {"sigma": 1.5, "output_stride": 2}, - }, - "augmentation_config": {"random_crop": 0.0, "random_crop_hw": (160, 160)}, - } -) diff --git a/tests/data/test_pipelines.py b/tests/data/test_pipelines.py index 023bafb9..feb6da6e 100644 --- a/tests/data/test_pipelines.py +++ b/tests/data/test_pipelines.py @@ -1,6 +1,6 @@ import torch +from omegaconf import OmegaConf -from sleap_nn.config.data import base_topdown_data_config from sleap_nn.data.confidence_maps import ConfidenceMapGenerator from sleap_nn.data.instance_centroids import InstanceCentroidFinder from sleap_nn.data.instance_cropping import InstanceCropper @@ -24,6 +24,16 @@ def test_sleap_dataset(minimal_instance): def test_topdownconfmapspipeline(minimal_instance): + base_topdown_data_config = OmegaConf.create( + { + "preprocessing": { + "crop_hw": (160, 160), + "conf_map_gen": {"sigma": 1.5, "output_stride": 2}, + }, + "augmentation_config": {"random_crop": 0.0, "random_crop_hw": (160, 160)}, + } + ) + pipeline = TopdownConfmapsPipeline(data_config=base_topdown_data_config) datapipe = pipeline.make_base_pipeline( data_provider=LabelsReader, filename=minimal_instance @@ -32,4 +42,4 @@ def test_topdownconfmapspipeline(minimal_instance): sample = next(iter(datapipe)) assert len(sample) == 2 assert sample[0].shape == (1, 160, 160) - assert sample[1].shape == (2, 80, 80) \ No newline at end of file + assert sample[1].shape == (2, 80, 80)