diff --git a/tests/ignite/contrib/metrics/test_roc_auc.py b/tests/ignite/contrib/metrics/test_roc_auc.py index 9daf384939b1..215e7b01fddb 100644 --- a/tests/ignite/contrib/metrics/test_roc_auc.py +++ b/tests/ignite/contrib/metrics/test_roc_auc.py @@ -170,14 +170,11 @@ def get_test_cases(): def _test_distrib_binary_and_multilabel_inputs(device): rank = idist.get_rank() - torch.manual_seed(12) def _test(y_pred, y, batch_size, metric_device): metric_device = torch.device(metric_device) roc_auc = ROC_AUC(device=metric_device) - torch.manual_seed(10 + rank) - roc_auc.reset() if batch_size > 1: n_iters = y.shape[0] // batch_size + 1 @@ -215,7 +212,8 @@ def get_test_cases(): ] return test_cases - for _ in range(5): + for i in range(5): + torch.manual_seed(12 + rank + i) test_cases = get_test_cases() for y_pred, y, batch_size in test_cases: _test(y_pred, y, batch_size, "cpu") @@ -226,11 +224,9 @@ def get_test_cases(): def _test_distrib_integration_binary_input(device): rank = idist.get_rank() - torch.manual_seed(12) n_iters = 80 - s = 16 + batch_size = 16 n_classes = 2 - offset = n_iters * s def _test(y_preds, y_true, n_epochs, metric_device, update_fn): metric_device = torch.device(metric_device) @@ -243,6 +239,9 @@ def _test(y_preds, y_true, n_epochs, metric_device, update_fn): data = list(range(n_iters)) engine.run(data=data, max_epochs=n_epochs) + y_preds = idist.all_gather(y_preds) + y_true = idist.all_gather(y_true) + assert "roc_auc" in engine.state.metrics res = engine.state.metrics["roc_auc"] @@ -252,23 +251,23 @@ def _test(y_preds, y_true, n_epochs, metric_device, update_fn): def get_tests(is_N): if is_N: - y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(),)).to(device) - y_preds = torch.rand(offset * idist.get_world_size()).to(device) + y_true = torch.randint(0, n_classes, size=(n_iters * batch_size,)).to(device) + y_preds = torch.rand(n_iters * batch_size).to(device) def update_fn(engine, i): return ( - y_preds[i * s + rank * offset : (i + 1) * s + rank * offset], - y_true[i * s + rank * offset : (i + 1) * s + rank * offset], + y_preds[i * batch_size : (i + 1) * batch_size], + y_true[i * batch_size : (i + 1) * batch_size], ) else: - y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(), 10)).to(device) - y_preds = torch.rand(offset * idist.get_world_size(), 10).to(device) + y_true = torch.randint(0, n_classes, size=(n_iters * batch_size, 10)).to(device) + y_preds = torch.rand(n_iters * batch_size, 10).to(device) def update_fn(engine, i): return ( - y_preds[i * s + rank * offset : (i + 1) * s + rank * offset, :], - y_true[i * s + rank * offset : (i + 1) * s + rank * offset, :], + y_preds[i * batch_size : (i + 1) * batch_size], + y_true[i * batch_size : (i + 1) * batch_size], ) return y_preds, y_true, update_fn @@ -277,7 +276,8 @@ def update_fn(engine, i): if device.type != "xla": metric_devices.append(idist.device()) for metric_device in metric_devices: - for _ in range(2): + for i in range(2): + torch.manual_seed(12 + rank + i) # Binary input data of shape (N,) y_preds, y_true, update_fn = get_tests(is_N=True) _test(y_preds, y_true, n_epochs=1, metric_device=metric_device, update_fn=update_fn)