Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix split calculation and allow for not embedding #120

Merged
merged 1 commit into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sleap_io/io/slp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1112,7 +1112,7 @@ def write_labels(
"""
if Path(labels_path).exists():
Path(labels_path).unlink()
if embed is not None:
if embed:
embed_videos(labels_path, labels, embed)
write_videos(labels_path, labels.videos, restore_source=(embed == "source"))
write_tracks(labels_path, labels.tracks)
Expand Down
33 changes: 28 additions & 5 deletions sleap_io/model/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,7 @@ def make_training_splits(
n_test: int | float | None = None,
save_dir: str | Path | None = None,
seed: int | None = None,
embed: bool = True,
) -> tuple[Labels, Labels] | tuple[Labels, Labels, Labels]:
"""Make splits for training with embedded images.

Expand All @@ -635,6 +636,10 @@ def make_training_splits(
test split will not be saved.
save_dir: If specified, save splits to SLP files with embedded images.
seed: Optional integer seed to use for reproducibility.
embed: If `True` (the default), embed user labeled frame images in the saved
files, which is useful for portability but can be slow for large
projects. If `False`, labels are saved with references to the source
videos files.

Returns:
A tuple of `labels_train, labels_val` or
Expand All @@ -652,6 +657,12 @@ def make_training_splits(
- `{save_dir}/val.pkg.slp`
- `{save_dir}/test.pkg.slp` (if `n_test` is specified)

If `embed` is `False`, the files will be saved without embedded images to:

- `{save_dir}/train.slp`
- `{save_dir}/val.slp`
- `{save_dir}/test.slp` (if `n_test` is specified)

See also: `Labels.split`
"""
# Clean up labels.
Expand All @@ -660,16 +671,23 @@ def make_training_splits(
labels.suggestions = []
labels.clean()

# Make splits.
# Make train split.
labels_train, labels_rest = labels.split(n_train, seed=seed)

# Make test split.
if n_test is not None:
if n_test < 1:
n_test = (n_test * len(labels)) / len(labels_rest)
labels_test, labels_rest = labels_rest.split(n=n_test, seed=seed)

Comment on lines +677 to +682
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Correct the calculation of n_test when it's a fraction.

The current calculation of n_test when n_test < 1 may not produce the intended number of test samples. Dividing by len(labels_rest) can lead to incorrect scaling, especially after the training split has been removed.

Consider revising the calculation to correctly scale n_test based on the total number of labels:

if n_test < 1:
-    n_test = (n_test * len(labels)) / len(labels_rest)
+    n_test = n_test * len(labels)
n_test = int(n_test)
labels_test, labels_rest = labels_rest.split(n=n_test, seed=seed)

This adjustment ensures that n_test reflects the correct fraction of the total dataset.

Committable suggestion was skipped due to low confidence.

# Make val split.
if n_val is not None:
if n_val < 1:
n_val = (n_val * len(labels)) / len(labels_rest)
labels_val, _ = labels_rest.split(n=n_val, seed=seed)
if isinstance(n_val, float) and n_val == 1.0:
labels_val = labels_rest
else:
labels_val, _ = labels_rest.split(n=n_val, seed=seed)
Comment on lines +683 to +690
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Adjust the calculation of n_val when it's a fraction.

Similarly, the calculation of n_val may not yield the correct number of validation samples. Using len(labels_rest) in the scaling can cause inconsistencies, especially if n_test has altered the size of labels_rest.

Update the calculation as follows to base n_val on the total number of labels:

if n_val < 1:
-    n_val = (n_val * len(labels)) / len(labels_rest)
+    n_val = n_val * len(labels)
n_val = int(n_val)
if n_val >= len(labels_rest):
    labels_val = labels_rest
else:
    labels_val, _ = labels_rest.split(n=n_val, seed=seed)

This ensures that n_val represents the intended fraction of the entire dataset.

Committable suggestion was skipped due to low confidence.

else:
labels_val = labels_rest

Expand All @@ -678,9 +696,14 @@ def make_training_splits(
save_dir = Path(save_dir)
save_dir.mkdir(exist_ok=True, parents=True)

labels_train.save(save_dir / "train.pkg.slp", embed="user")
labels_val.save(save_dir / "val.pkg.slp", embed="user")
labels_test.save(save_dir / "test.pkg.slp", embed="user")
if embed:
labels_train.save(save_dir / "train.pkg.slp", embed="user")
labels_val.save(save_dir / "val.pkg.slp", embed="user")
labels_test.save(save_dir / "test.pkg.slp", embed="user")
else:
labels_train.save(save_dir / "train.slp", embed=False)
labels_val.save(save_dir / "val.slp", embed=False)
labels_test.save(save_dir / "test.slp", embed=False)
Comment on lines +699 to +706
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure labels_test is defined before attempting to save.

If n_test is None, labels_test may not be defined, leading to a NameError when trying to save it. The code should check whether labels_test exists before saving.

Modify the saving logic to include a condition for labels_test:

if save_dir is not None:
    save_dir = Path(save_dir)
    save_dir.mkdir(exist_ok=True, parents=True)

    if embed:
        labels_train.save(save_dir / "train.pkg.slp", embed="user")
        labels_val.save(save_dir / "val.pkg.slp", embed="user")
+       if n_test is not None:
            labels_test.save(save_dir / "test.pkg.slp", embed="user")
    else:
        labels_train.save(save_dir / "train.slp", embed=False)
        labels_val.save(save_dir / "val.slp", embed=False)
+       if n_test is not None:
            labels_test.save(save_dir / "test.slp", embed=False)

This ensures that labels_test.save() is only called when labels_test is defined.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if embed:
labels_train.save(save_dir / "train.pkg.slp", embed="user")
labels_val.save(save_dir / "val.pkg.slp", embed="user")
labels_test.save(save_dir / "test.pkg.slp", embed="user")
else:
labels_train.save(save_dir / "train.slp", embed=False)
labels_val.save(save_dir / "val.slp", embed=False)
labels_test.save(save_dir / "test.slp", embed=False)
if embed:
labels_train.save(save_dir / "train.pkg.slp", embed="user")
labels_val.save(save_dir / "val.pkg.slp", embed="user")
if n_test is not None:
labels_test.save(save_dir / "test.pkg.slp", embed="user")
else:
labels_train.save(save_dir / "train.slp", embed=False)
labels_val.save(save_dir / "val.slp", embed=False)
if n_test is not None:
labels_test.save(save_dir / "test.slp", embed=False)


if n_test is None:
return labels_train, labels_val
Expand Down
42 changes: 41 additions & 1 deletion tests/model/test_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def test_split(slp_real_data, tmp_path):
)


def test_make_training_splits(slp_real_data, tmp_path):
def test_make_training_splits(slp_real_data):
labels = load_slp(slp_real_data)
assert len(labels.user_labeled_frames) == 5

Expand Down Expand Up @@ -568,6 +568,11 @@ def test_make_training_splits(slp_real_data, tmp_path):
assert len(val) == 1
assert len(test) == 1

train, val, test = labels.make_training_splits(n_train=0.4, n_val=0.4, n_test=0.2)
assert len(train) == 2
assert len(val) == 2
assert len(test) == 1

Comment on lines +571 to +575
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Potential rounding issues with split proportions

Using floating-point numbers for n_train=0.4, n_val=0.4, and n_test=0.2 may lead to unexpected split sizes due to floating-point precision errors. Ensure that the sum of these proportions equals the total number of frames in the dataset.

Consider using integer values for n_train, n_val, and n_test or adjusting the split logic to handle floating-point precision.


def test_make_training_splits_save(slp_real_data, tmp_path):
labels = load_slp(slp_real_data)
Expand All @@ -587,3 +592,38 @@ def test_make_training_splits_save(slp_real_data, tmp_path):
assert train_.provenance["source_labels"] == slp_real_data
assert val_.provenance["source_labels"] == slp_real_data
assert test_.provenance["source_labels"] == slp_real_data


@pytest.mark.parametrize("embed", [True, False])
def test_make_training_splits_save(slp_real_data, tmp_path, embed):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Redefinition of function test_make_training_splits_save

The function test_make_training_splits_save is redefined at line 598, which duplicates the previous definition and will cause a NameError.

Apply the following diff to rename the second function and avoid the redefinition:

 @pytest.mark.parametrize("embed", [True, False])
-def test_make_training_splits_save(slp_real_data, tmp_path, embed):
+def test_make_training_splits_save_with_embed(slp_real_data, tmp_path, embed):
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def test_make_training_splits_save(slp_real_data, tmp_path, embed):
def test_make_training_splits_save_with_embed(slp_real_data, tmp_path, embed):
🧰 Tools
🪛 Ruff

598-598: Redefinition of unused test_make_training_splits_save from line 577

(F811)

labels = load_slp(slp_real_data)

train, val, test = labels.make_training_splits(
0.6, 0.2, 0.2, save_dir=tmp_path, embed=embed
)

if embed:
train_, val_, test_ = (
load_slp(tmp_path / "train.pkg.slp"),
load_slp(tmp_path / "val.pkg.slp"),
load_slp(tmp_path / "test.pkg.slp"),
)
else:
train_, val_, test_ = (
load_slp(tmp_path / "train.slp"),
load_slp(tmp_path / "val.slp"),
load_slp(tmp_path / "test.slp"),
)

assert len(train_) == len(train)
assert len(val_) == len(val)
assert len(test_) == len(test)

if embed:
assert train_.provenance["source_labels"] == slp_real_data
assert val_.provenance["source_labels"] == slp_real_data
assert test_.provenance["source_labels"] == slp_real_data
else:
assert train_.video.filename == labels.video.filename
assert val_.video.filename == labels.video.filename
assert test_.video.filename == labels.video.filename
Loading