From 1385de5d497e93e9cb88bc68d95a49453fa912fc Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Sat, 28 Sep 2024 18:57:08 -0700 Subject: [PATCH] Fix split calculation and allow for not embedding --- sleap_io/io/slp.py | 2 +- sleap_io/model/labels.py | 33 +++++++++++++++++++++++++----- tests/model/test_labels.py | 42 +++++++++++++++++++++++++++++++++++++- 3 files changed, 70 insertions(+), 7 deletions(-) diff --git a/sleap_io/io/slp.py b/sleap_io/io/slp.py index 777a3f98..5ae23025 100644 --- a/sleap_io/io/slp.py +++ b/sleap_io/io/slp.py @@ -1112,7 +1112,7 @@ def write_labels( """ if Path(labels_path).exists(): Path(labels_path).unlink() - if embed is not None: + if embed: embed_videos(labels_path, labels, embed) write_videos(labels_path, labels.videos, restore_source=(embed == "source")) write_tracks(labels_path, labels.tracks) diff --git a/sleap_io/model/labels.py b/sleap_io/model/labels.py index 67dea5cf..bf6a6795 100644 --- a/sleap_io/model/labels.py +++ b/sleap_io/model/labels.py @@ -622,6 +622,7 @@ def make_training_splits( n_test: int | float | None = None, save_dir: str | Path | None = None, seed: int | None = None, + embed: bool = True, ) -> tuple[Labels, Labels] | tuple[Labels, Labels, Labels]: """Make splits for training with embedded images. @@ -635,6 +636,10 @@ def make_training_splits( test split will not be saved. save_dir: If specified, save splits to SLP files with embedded images. seed: Optional integer seed to use for reproducibility. + embed: If `True` (the default), embed user labeled frame images in the saved + files, which is useful for portability but can be slow for large + projects. If `False`, labels are saved with references to the source + videos files. Returns: A tuple of `labels_train, labels_val` or @@ -652,6 +657,12 @@ def make_training_splits( - `{save_dir}/val.pkg.slp` - `{save_dir}/test.pkg.slp` (if `n_test` is specified) + If `embed` is `False`, the files will be saved without embedded images to: + + - `{save_dir}/train.slp` + - `{save_dir}/val.slp` + - `{save_dir}/test.slp` (if `n_test` is specified) + See also: `Labels.split` """ # Clean up labels. @@ -660,16 +671,23 @@ def make_training_splits( labels.suggestions = [] labels.clean() - # Make splits. + # Make train split. labels_train, labels_rest = labels.split(n_train, seed=seed) + + # Make test split. if n_test is not None: if n_test < 1: n_test = (n_test * len(labels)) / len(labels_rest) labels_test, labels_rest = labels_rest.split(n=n_test, seed=seed) + + # Make val split. if n_val is not None: if n_val < 1: n_val = (n_val * len(labels)) / len(labels_rest) - labels_val, _ = labels_rest.split(n=n_val, seed=seed) + if isinstance(n_val, float) and n_val == 1.0: + labels_val = labels_rest + else: + labels_val, _ = labels_rest.split(n=n_val, seed=seed) else: labels_val = labels_rest @@ -678,9 +696,14 @@ def make_training_splits( save_dir = Path(save_dir) save_dir.mkdir(exist_ok=True, parents=True) - labels_train.save(save_dir / "train.pkg.slp", embed="user") - labels_val.save(save_dir / "val.pkg.slp", embed="user") - labels_test.save(save_dir / "test.pkg.slp", embed="user") + if embed: + labels_train.save(save_dir / "train.pkg.slp", embed="user") + labels_val.save(save_dir / "val.pkg.slp", embed="user") + labels_test.save(save_dir / "test.pkg.slp", embed="user") + else: + labels_train.save(save_dir / "train.slp", embed=False) + labels_val.save(save_dir / "val.slp", embed=False) + labels_test.save(save_dir / "test.slp", embed=False) if n_test is None: return labels_train, labels_val diff --git a/tests/model/test_labels.py b/tests/model/test_labels.py index 067a1d30..f8403be7 100644 --- a/tests/model/test_labels.py +++ b/tests/model/test_labels.py @@ -537,7 +537,7 @@ def test_split(slp_real_data, tmp_path): ) -def test_make_training_splits(slp_real_data, tmp_path): +def test_make_training_splits(slp_real_data): labels = load_slp(slp_real_data) assert len(labels.user_labeled_frames) == 5 @@ -568,6 +568,11 @@ def test_make_training_splits(slp_real_data, tmp_path): assert len(val) == 1 assert len(test) == 1 + train, val, test = labels.make_training_splits(n_train=0.4, n_val=0.4, n_test=0.2) + assert len(train) == 2 + assert len(val) == 2 + assert len(test) == 1 + def test_make_training_splits_save(slp_real_data, tmp_path): labels = load_slp(slp_real_data) @@ -587,3 +592,38 @@ def test_make_training_splits_save(slp_real_data, tmp_path): assert train_.provenance["source_labels"] == slp_real_data assert val_.provenance["source_labels"] == slp_real_data assert test_.provenance["source_labels"] == slp_real_data + + +@pytest.mark.parametrize("embed", [True, False]) +def test_make_training_splits_save(slp_real_data, tmp_path, embed): + labels = load_slp(slp_real_data) + + train, val, test = labels.make_training_splits( + 0.6, 0.2, 0.2, save_dir=tmp_path, embed=embed + ) + + if embed: + train_, val_, test_ = ( + load_slp(tmp_path / "train.pkg.slp"), + load_slp(tmp_path / "val.pkg.slp"), + load_slp(tmp_path / "test.pkg.slp"), + ) + else: + train_, val_, test_ = ( + load_slp(tmp_path / "train.slp"), + load_slp(tmp_path / "val.slp"), + load_slp(tmp_path / "test.slp"), + ) + + assert len(train_) == len(train) + assert len(val_) == len(val) + assert len(test_) == len(test) + + if embed: + assert train_.provenance["source_labels"] == slp_real_data + assert val_.provenance["source_labels"] == slp_real_data + assert test_.provenance["source_labels"] == slp_real_data + else: + assert train_.video.filename == labels.video.filename + assert val_.video.filename == labels.video.filename + assert test_.video.filename == labels.video.filename