Skip to content

Commit

Permalink
Update test_roc_auc.py (#2657)
Browse files Browse the repository at this point in the history
* Update test_roc_auc.py

Modify `_test_distrib_integration_binary_input` and `_test_distrib_binary_and_multilabel_inputs`

* Update test_roc_auc.py

* Update test_roc_auc.py

Change random seed to prevent generating same elements in list

Co-authored-by: vfdev <vfdev.5@gmail.com>
  • Loading branch information
puhuk and vfdev-5 committed Aug 22, 2022
1 parent 9b79ed0 commit 54e2406
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions tests/ignite/contrib/metrics/test_roc_auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,14 +170,11 @@ def get_test_cases():
def _test_distrib_binary_and_multilabel_inputs(device):

rank = idist.get_rank()
torch.manual_seed(12)

def _test(y_pred, y, batch_size, metric_device):
metric_device = torch.device(metric_device)
roc_auc = ROC_AUC(device=metric_device)

torch.manual_seed(10 + rank)

roc_auc.reset()
if batch_size > 1:
n_iters = y.shape[0] // batch_size + 1
Expand Down Expand Up @@ -215,7 +212,8 @@ def get_test_cases():
]
return test_cases

for _ in range(5):
for i in range(5):
torch.manual_seed(12 + rank + i)
test_cases = get_test_cases()
for y_pred, y, batch_size in test_cases:
_test(y_pred, y, batch_size, "cpu")
Expand All @@ -226,11 +224,9 @@ def get_test_cases():
def _test_distrib_integration_binary_input(device):

rank = idist.get_rank()
torch.manual_seed(12)
n_iters = 80
s = 16
batch_size = 16
n_classes = 2
offset = n_iters * s

def _test(y_preds, y_true, n_epochs, metric_device, update_fn):
metric_device = torch.device(metric_device)
Expand All @@ -243,6 +239,9 @@ def _test(y_preds, y_true, n_epochs, metric_device, update_fn):
data = list(range(n_iters))
engine.run(data=data, max_epochs=n_epochs)

y_preds = idist.all_gather(y_preds)
y_true = idist.all_gather(y_true)

assert "roc_auc" in engine.state.metrics

res = engine.state.metrics["roc_auc"]
Expand All @@ -252,23 +251,23 @@ def _test(y_preds, y_true, n_epochs, metric_device, update_fn):

def get_tests(is_N):
if is_N:
y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(),)).to(device)
y_preds = torch.rand(offset * idist.get_world_size()).to(device)
y_true = torch.randint(0, n_classes, size=(n_iters * batch_size,)).to(device)
y_preds = torch.rand(n_iters * batch_size).to(device)

def update_fn(engine, i):
return (
y_preds[i * s + rank * offset : (i + 1) * s + rank * offset],
y_true[i * s + rank * offset : (i + 1) * s + rank * offset],
y_preds[i * batch_size : (i + 1) * batch_size],
y_true[i * batch_size : (i + 1) * batch_size],
)

else:
y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(), 10)).to(device)
y_preds = torch.rand(offset * idist.get_world_size(), 10).to(device)
y_true = torch.randint(0, n_classes, size=(n_iters * batch_size, 10)).to(device)
y_preds = torch.rand(n_iters * batch_size, 10).to(device)

def update_fn(engine, i):
return (
y_preds[i * s + rank * offset : (i + 1) * s + rank * offset, :],
y_true[i * s + rank * offset : (i + 1) * s + rank * offset, :],
y_preds[i * batch_size : (i + 1) * batch_size],
y_true[i * batch_size : (i + 1) * batch_size],
)

return y_preds, y_true, update_fn
Expand All @@ -277,7 +276,8 @@ def update_fn(engine, i):
if device.type != "xla":
metric_devices.append(idist.device())
for metric_device in metric_devices:
for _ in range(2):
for i in range(2):
torch.manual_seed(12 + rank + i)
# Binary input data of shape (N,)
y_preds, y_true, update_fn = get_tests(is_N=True)
_test(y_preds, y_true, n_epochs=1, metric_device=metric_device, update_fn=update_fn)
Expand Down

0 comments on commit 54e2406

Please sign in to comment.