Skip to content

Commit

Permalink
Update canberra and fractional absolute error (#2659)
Browse files Browse the repository at this point in the history
Update canberra and fractional absolute error

Co-authored-by: vfdev <vfdev.5@gmail.com>
  • Loading branch information
puhuk and vfdev-5 committed Aug 22, 2022
1 parent e2c95ec commit 9b79ed0
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 23 deletions.
23 changes: 12 additions & 11 deletions tests/ignite/contrib/metrics/regression/test_canberra_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def _test_distrib_compute(device):
def _test(metric_device):
metric_device = torch.device(metric_device)
m = CanberraMetric(device=metric_device)
torch.manual_seed(10 + rank)

y_pred = torch.randint(0, 10, size=(10,), device=device).float()
y = torch.randint(0, 10, size=(10,), device=device).float()
Expand All @@ -125,7 +124,8 @@ def _test(metric_device):
res = m.compute()
assert canberra.pairwise([np_y_pred, np_y])[0][1] == pytest.approx(res)

for _ in range(3):
for i in range(3):
torch.manual_seed(10 + rank + i)
_test("cpu")
if device.type != "xla":
_test(idist.device())
Expand All @@ -134,23 +134,20 @@ def _test(metric_device):
def _test_distrib_integration(device):

rank = idist.get_rank()
torch.manual_seed(12)
canberra = DistanceMetric.get_metric("canberra")

def _test(n_epochs, metric_device):
metric_device = torch.device(metric_device)
n_iters = 80
s = 16
n_classes = 2
batch_size = 16

offset = n_iters * s
y_true = torch.rand(size=(offset * idist.get_world_size(),)).to(device)
y_preds = torch.rand(size=(offset * idist.get_world_size(),)).to(device)
y_true = torch.rand(size=(n_iters * batch_size,)).to(device)
y_preds = torch.rand(size=(n_iters * batch_size,)).to(device)

def update(engine, i):
return (
y_preds[i * s + rank * offset : (i + 1) * s + rank * offset],
y_true[i * s + rank * offset : (i + 1) * s + rank * offset],
y_preds[i * batch_size : (i + 1) * batch_size],
y_true[i * batch_size : (i + 1) * batch_size],
)

engine = Engine(update)
Expand All @@ -161,6 +158,9 @@ def update(engine, i):
data = list(range(n_iters))
engine.run(data=data, max_epochs=n_epochs)

y_preds = idist.all_gather(y_preds)
y_true = idist.all_gather(y_true)

assert "cm" in engine.state.metrics

res = engine.state.metrics["cm"]
Expand All @@ -176,7 +176,8 @@ def update(engine, i):
if device.type != "xla":
metric_devices.append(idist.device())
for metric_device in metric_devices:
for _ in range(2):
for i in range(2):
torch.manual_seed(12 + rank + i)
_test(n_epochs=1, metric_device=metric_device)
_test(n_epochs=2, metric_device=metric_device)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,10 @@ def get_test_cases():
def _test_distrib_compute(device):

rank = idist.get_rank()
torch.manual_seed(12)

def _test(metric_device):
metric_device = torch.device(metric_device)
m = FractionalAbsoluteError(device=metric_device)
torch.manual_seed(10 + rank)

y_pred = torch.rand(size=(100,), device=device)
y = torch.rand(size=(100,), device=device)
Expand All @@ -129,7 +127,8 @@ def _test(metric_device):

assert m.compute() == pytest.approx(np_ans)

for _ in range(3):
for i in range(3):
torch.manual_seed(10 + rank + i)
_test("cpu")
if device.type != "xla":
_test(idist.device())
Expand All @@ -138,22 +137,19 @@ def _test(metric_device):
def _test_distrib_integration(device):

rank = idist.get_rank()
torch.manual_seed(12)

def _test(n_epochs, metric_device):
metric_device = torch.device(metric_device)
n_iters = 80
s = 16
n_classes = 2
batch_size = 16

offset = n_iters * s
y_true = torch.rand(size=(offset * idist.get_world_size(),)).to(device)
y_preds = torch.rand(size=(offset * idist.get_world_size(),)).to(device)
y_true = torch.rand(size=(n_iters * batch_size,)).to(device)
y_preds = torch.rand(size=(n_iters * batch_size,)).to(device)

def update(engine, i):
return (
y_preds[i * s + rank * offset : (i + 1) * s + rank * offset],
y_true[i * s + rank * offset : (i + 1) * s + rank * offset],
y_preds[i * batch_size : (i + 1) * batch_size],
y_true[i * batch_size : (i + 1) * batch_size],
)

engine = Engine(update)
Expand All @@ -164,6 +160,9 @@ def update(engine, i):
data = list(range(n_iters))
engine.run(data=data, max_epochs=n_epochs)

y_preds = idist.all_gather(y_preds)
y_true = idist.all_gather(y_true)

assert "fae" in engine.state.metrics

res = engine.state.metrics["fae"]
Expand All @@ -183,7 +182,8 @@ def update(engine, i):
if device.type != "xla":
metric_devices.append(idist.device())
for metric_device in metric_devices:
for _ in range(2):
for i in range(2):
torch.manual_seed(12 + rank + i)
_test(n_epochs=1, metric_device=metric_device)
_test(n_epochs=2, metric_device=metric_device)

Expand Down

0 comments on commit 9b79ed0

Please sign in to comment.