Skip to content

Commit

Permalink
lint test, fix bug in test where TrackingDataset.prepare isn't called
Browse files Browse the repository at this point in the history
  • Loading branch information
aaprasad committed Aug 30, 2024
1 parent 0677ccd commit f8fdcec
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ def test_train_test_split(two_flies):

splits = (0.7, 0.2, 0.1)
tracking_ds = TrackingDataset(train_ds=train_ds, splits=splits)
tracking_ds.setup("fit")

assert len(tracking_ds.train_ds) == int(ds_length * splits[0])
assert len(tracking_ds.val_ds) == int(ds_length * splits[1])
Expand All @@ -544,6 +545,8 @@ def test_train_test_split(two_flies):
assert tracking_ds.test_ds.dataset.augmentations is None

tracking_ds = TrackingDataset(train_ds=train_ds, splits=splits[:-1])
tracking_ds.setup("fit")

assert len(tracking_ds.train_ds) == int(ds_length * splits[0])
assert len(tracking_ds.val_ds) == int(ds_length * splits[1])
assert tracking_ds.test_ds is None
Expand All @@ -558,12 +561,14 @@ def test_train_test_split(two_flies):
ds_length = len(val_ds)

splits = (0.5, 0.5)

tracking_ds = TrackingDataset(train_ds=train_ds, val_ds=val_ds, splits=splits)
tracking_ds.setup("fit")

assert len(tracking_ds.train_ds) == len(train_ds)
assert len(tracking_ds.val_ds) == int(ds_length * splits[0])
assert len(tracking_ds.test_ds) == int(ds_length * splits[1])

assert tracking_ds.train_ds.augmentations is not None
assert tracking_ds.val_ds.dataset.augmentations is None
assert tracking_ds.test_ds.dataset.augmentations is None
assert tracking_ds.test_ds.dataset.augmentations is None

0 comments on commit f8fdcec

Please sign in to comment.