diff --git a/sleap_io/model/labels.py b/sleap_io/model/labels.py index 67dea5c..3328c44 100644 --- a/sleap_io/model/labels.py +++ b/sleap_io/model/labels.py @@ -660,13 +660,19 @@ def make_training_splits( labels.suggestions = [] labels.clean() - # Make splits. + # Make train split. labels_train, labels_rest = labels.split(n_train, seed=seed) + + # Make test split. if n_test is not None: if n_test < 1: n_test = (n_test * len(labels)) / len(labels_rest) labels_test, labels_rest = labels_rest.split(n=n_test, seed=seed) - if n_val is not None: + else: + labels_test = labels_rest + + # Make val split. + if n_val is not None or (isinstance(n_val, float) and n_val == 1.0): if n_val < 1: n_val = (n_val * len(labels)) / len(labels_rest) labels_val, _ = labels_rest.split(n=n_val, seed=seed)