Skip to content

Commit

Permalink
Fix deterministic group_shuffle_split (#1839)
Browse files Browse the repository at this point in the history
* order sets

* suggestion

* add unit test

* fix

* updated test

* fix

* indices from file

* test util update

* path

* no file

* no file

* comment

* i cannot spell
  • Loading branch information
nilsleh authored Feb 28, 2024
1 parent e04e1a5 commit 54b4ddc
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 15 deletions.
34 changes: 21 additions & 13 deletions tests/datamodules/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,11 @@ def test_dataset_split() -> None:


def test_group_shuffle_split() -> None:
alphabet = np.array(list("abcdefghijklmnopqrstuvwxyz"))
groups = np.random.randint(0, 26, size=(1000))
train_indices = [0, 2, 5, 6, 7, 8, 9, 10, 11, 13, 14]
test_indices = [1, 3, 4, 12]
np.random.seed(0)
alphabet = np.array(list("abc"))
groups = np.random.randint(0, 3, size=(15))
groups = alphabet[groups]

with pytest.raises(ValueError, match="You must specify `train_size` *"):
Expand All @@ -43,16 +46,21 @@ def test_group_shuffle_split() -> None:
match=re.escape("`train_size` and `test_size` must be in the range (0,1)."),
):
group_shuffle_split(groups, train_size=-0.2, test_size=1.2)
with pytest.raises(ValueError, match="26 groups were found, however the current *"):
with pytest.raises(ValueError, match="3 groups were found, however the current *"):
group_shuffle_split(groups, train_size=None, test_size=0.999)

train_indices, test_indices = group_shuffle_split(
groups, train_size=None, test_size=0.2
)
assert len(set(train_indices) & set(test_indices)) == 0
assert len(set(groups[train_indices])) == 21
train_indices, test_indices = group_shuffle_split(
groups, train_size=0.8, test_size=None
)
assert len(set(train_indices) & set(test_indices)) == 0
assert len(set(groups[train_indices])) == 21
test_cases = [(None, 0.2, 42), (0.8, None, 42)]

for train_size, test_size, random_state in test_cases:
train_indices1, test_indices1 = group_shuffle_split(
groups,
train_size=train_size,
test_size=test_size,
random_state=random_state,
)
# Check that the results are the same as expected
assert np.array_equal(train_indices, train_indices1)
assert np.array_equal(test_indices, test_indices1)

assert len(set(train_indices1) & set(test_indices1)) == 0
assert len(set(groups[train_indices1])) == 2
4 changes: 2 additions & 2 deletions torchgeo/datamodules/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def group_shuffle_split(
if train_size <= 0 or train_size >= 1 or test_size <= 0 or test_size >= 1:
raise ValueError("`train_size` and `test_size` must be in the range (0,1).")

group_vals = set(groups)
group_vals = sorted(set(groups))
n_groups = len(group_vals)
n_test_groups = round(n_groups * test_size)
n_train_groups = n_groups - n_test_groups
Expand All @@ -198,7 +198,7 @@ def group_shuffle_split(

generator = np.random.default_rng(seed=random_state)
train_group_vals = set(
generator.choice(list(group_vals), size=n_train_groups, replace=False)
generator.choice(group_vals, size=n_train_groups, replace=False)
)

train_idxs = []
Expand Down

0 comments on commit 54b4ddc

Please sign in to comment.