-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix split calculation and allow for not embedding #120
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -622,6 +622,7 @@ def make_training_splits( | |||||||||||||||||||||||||||||||||||||
n_test: int | float | None = None, | ||||||||||||||||||||||||||||||||||||||
save_dir: str | Path | None = None, | ||||||||||||||||||||||||||||||||||||||
seed: int | None = None, | ||||||||||||||||||||||||||||||||||||||
embed: bool = True, | ||||||||||||||||||||||||||||||||||||||
) -> tuple[Labels, Labels] | tuple[Labels, Labels, Labels]: | ||||||||||||||||||||||||||||||||||||||
"""Make splits for training with embedded images. | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
|
@@ -635,6 +636,10 @@ def make_training_splits( | |||||||||||||||||||||||||||||||||||||
test split will not be saved. | ||||||||||||||||||||||||||||||||||||||
save_dir: If specified, save splits to SLP files with embedded images. | ||||||||||||||||||||||||||||||||||||||
seed: Optional integer seed to use for reproducibility. | ||||||||||||||||||||||||||||||||||||||
embed: If `True` (the default), embed user labeled frame images in the saved | ||||||||||||||||||||||||||||||||||||||
files, which is useful for portability but can be slow for large | ||||||||||||||||||||||||||||||||||||||
projects. If `False`, labels are saved with references to the source | ||||||||||||||||||||||||||||||||||||||
videos files. | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
Returns: | ||||||||||||||||||||||||||||||||||||||
A tuple of `labels_train, labels_val` or | ||||||||||||||||||||||||||||||||||||||
|
@@ -652,6 +657,12 @@ def make_training_splits( | |||||||||||||||||||||||||||||||||||||
- `{save_dir}/val.pkg.slp` | ||||||||||||||||||||||||||||||||||||||
- `{save_dir}/test.pkg.slp` (if `n_test` is specified) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
If `embed` is `False`, the files will be saved without embedded images to: | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
- `{save_dir}/train.slp` | ||||||||||||||||||||||||||||||||||||||
- `{save_dir}/val.slp` | ||||||||||||||||||||||||||||||||||||||
- `{save_dir}/test.slp` (if `n_test` is specified) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
See also: `Labels.split` | ||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||
# Clean up labels. | ||||||||||||||||||||||||||||||||||||||
|
@@ -660,16 +671,23 @@ def make_training_splits( | |||||||||||||||||||||||||||||||||||||
labels.suggestions = [] | ||||||||||||||||||||||||||||||||||||||
labels.clean() | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# Make splits. | ||||||||||||||||||||||||||||||||||||||
# Make train split. | ||||||||||||||||||||||||||||||||||||||
labels_train, labels_rest = labels.split(n_train, seed=seed) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# Make test split. | ||||||||||||||||||||||||||||||||||||||
if n_test is not None: | ||||||||||||||||||||||||||||||||||||||
if n_test < 1: | ||||||||||||||||||||||||||||||||||||||
n_test = (n_test * len(labels)) / len(labels_rest) | ||||||||||||||||||||||||||||||||||||||
labels_test, labels_rest = labels_rest.split(n=n_test, seed=seed) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
# Make val split. | ||||||||||||||||||||||||||||||||||||||
if n_val is not None: | ||||||||||||||||||||||||||||||||||||||
if n_val < 1: | ||||||||||||||||||||||||||||||||||||||
n_val = (n_val * len(labels)) / len(labels_rest) | ||||||||||||||||||||||||||||||||||||||
labels_val, _ = labels_rest.split(n=n_val, seed=seed) | ||||||||||||||||||||||||||||||||||||||
if isinstance(n_val, float) and n_val == 1.0: | ||||||||||||||||||||||||||||||||||||||
labels_val = labels_rest | ||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||
labels_val, _ = labels_rest.split(n=n_val, seed=seed) | ||||||||||||||||||||||||||||||||||||||
Comment on lines
+683
to
+690
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adjust the calculation of Similarly, the calculation of Update the calculation as follows to base if n_val < 1:
- n_val = (n_val * len(labels)) / len(labels_rest)
+ n_val = n_val * len(labels)
n_val = int(n_val)
if n_val >= len(labels_rest):
labels_val = labels_rest
else:
labels_val, _ = labels_rest.split(n=n_val, seed=seed) This ensures that
|
||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||
labels_val = labels_rest | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
|
@@ -678,9 +696,14 @@ def make_training_splits( | |||||||||||||||||||||||||||||||||||||
save_dir = Path(save_dir) | ||||||||||||||||||||||||||||||||||||||
save_dir.mkdir(exist_ok=True, parents=True) | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
labels_train.save(save_dir / "train.pkg.slp", embed="user") | ||||||||||||||||||||||||||||||||||||||
labels_val.save(save_dir / "val.pkg.slp", embed="user") | ||||||||||||||||||||||||||||||||||||||
labels_test.save(save_dir / "test.pkg.slp", embed="user") | ||||||||||||||||||||||||||||||||||||||
if embed: | ||||||||||||||||||||||||||||||||||||||
labels_train.save(save_dir / "train.pkg.slp", embed="user") | ||||||||||||||||||||||||||||||||||||||
labels_val.save(save_dir / "val.pkg.slp", embed="user") | ||||||||||||||||||||||||||||||||||||||
labels_test.save(save_dir / "test.pkg.slp", embed="user") | ||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||
labels_train.save(save_dir / "train.slp", embed=False) | ||||||||||||||||||||||||||||||||||||||
labels_val.save(save_dir / "val.slp", embed=False) | ||||||||||||||||||||||||||||||||||||||
labels_test.save(save_dir / "test.slp", embed=False) | ||||||||||||||||||||||||||||||||||||||
Comment on lines
+699
to
+706
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ensure If Modify the saving logic to include a condition for if save_dir is not None:
save_dir = Path(save_dir)
save_dir.mkdir(exist_ok=True, parents=True)
if embed:
labels_train.save(save_dir / "train.pkg.slp", embed="user")
labels_val.save(save_dir / "val.pkg.slp", embed="user")
+ if n_test is not None:
labels_test.save(save_dir / "test.pkg.slp", embed="user")
else:
labels_train.save(save_dir / "train.slp", embed=False)
labels_val.save(save_dir / "val.slp", embed=False)
+ if n_test is not None:
labels_test.save(save_dir / "test.slp", embed=False) This ensures that 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
if n_test is None: | ||||||||||||||||||||||||||||||||||||||
return labels_train, labels_val | ||||||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -537,7 +537,7 @@ def test_split(slp_real_data, tmp_path): | |||||
) | ||||||
|
||||||
|
||||||
def test_make_training_splits(slp_real_data, tmp_path): | ||||||
def test_make_training_splits(slp_real_data): | ||||||
labels = load_slp(slp_real_data) | ||||||
assert len(labels.user_labeled_frames) == 5 | ||||||
|
||||||
|
@@ -568,6 +568,11 @@ def test_make_training_splits(slp_real_data, tmp_path): | |||||
assert len(val) == 1 | ||||||
assert len(test) == 1 | ||||||
|
||||||
train, val, test = labels.make_training_splits(n_train=0.4, n_val=0.4, n_test=0.2) | ||||||
assert len(train) == 2 | ||||||
assert len(val) == 2 | ||||||
assert len(test) == 1 | ||||||
|
||||||
Comment on lines
+571
to
+575
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Potential rounding issues with split proportions Using floating-point numbers for Consider using integer values for |
||||||
|
||||||
def test_make_training_splits_save(slp_real_data, tmp_path): | ||||||
labels = load_slp(slp_real_data) | ||||||
|
@@ -587,3 +592,38 @@ def test_make_training_splits_save(slp_real_data, tmp_path): | |||||
assert train_.provenance["source_labels"] == slp_real_data | ||||||
assert val_.provenance["source_labels"] == slp_real_data | ||||||
assert test_.provenance["source_labels"] == slp_real_data | ||||||
|
||||||
|
||||||
@pytest.mark.parametrize("embed", [True, False]) | ||||||
def test_make_training_splits_save(slp_real_data, tmp_path, embed): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Redefinition of function The function Apply the following diff to rename the second function and avoid the redefinition: @pytest.mark.parametrize("embed", [True, False])
-def test_make_training_splits_save(slp_real_data, tmp_path, embed):
+def test_make_training_splits_save_with_embed(slp_real_data, tmp_path, embed): 📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff
|
||||||
labels = load_slp(slp_real_data) | ||||||
|
||||||
train, val, test = labels.make_training_splits( | ||||||
0.6, 0.2, 0.2, save_dir=tmp_path, embed=embed | ||||||
) | ||||||
|
||||||
if embed: | ||||||
train_, val_, test_ = ( | ||||||
load_slp(tmp_path / "train.pkg.slp"), | ||||||
load_slp(tmp_path / "val.pkg.slp"), | ||||||
load_slp(tmp_path / "test.pkg.slp"), | ||||||
) | ||||||
else: | ||||||
train_, val_, test_ = ( | ||||||
load_slp(tmp_path / "train.slp"), | ||||||
load_slp(tmp_path / "val.slp"), | ||||||
load_slp(tmp_path / "test.slp"), | ||||||
) | ||||||
|
||||||
assert len(train_) == len(train) | ||||||
assert len(val_) == len(val) | ||||||
assert len(test_) == len(test) | ||||||
|
||||||
if embed: | ||||||
assert train_.provenance["source_labels"] == slp_real_data | ||||||
assert val_.provenance["source_labels"] == slp_real_data | ||||||
assert test_.provenance["source_labels"] == slp_real_data | ||||||
else: | ||||||
assert train_.video.filename == labels.video.filename | ||||||
assert val_.video.filename == labels.video.filename | ||||||
assert test_.video.filename == labels.video.filename |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct the calculation of
n_test
when it's a fraction.The current calculation of
n_test
whenn_test < 1
may not produce the intended number of test samples. Dividing bylen(labels_rest)
can lead to incorrect scaling, especially after the training split has been removed.Consider revising the calculation to correctly scale
n_test
based on the total number of labels:This adjustment ensures that
n_test
reflects the correct fraction of the total dataset.