Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reconstruct test_psnr.py #2679

Merged
merged 8 commits into from
Sep 5, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 117 additions & 98 deletions tests/ignite/metrics/test_psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from skimage.metrics import peak_signal_noise_ratio as ski_psnr

import ignite.distributed as idist
from ignite.engine import Engine
from ignite.exceptions import NotComputableError
from ignite.metrics import PSNR
from ignite.utils import manual_seed
Expand Down Expand Up @@ -85,26 +86,27 @@ def _test(
data_range,
metric_device,
n_iters,
s,
offset,
rank,
batch_size,
atol,
output_transform=lambda x: x,
compute_y_channel=False,
):
from ignite.engine import Engine

def update(engine, i):
return (
y_pred[i * s + offset * rank : (i + 1) * s + offset * rank],
y[i * s + offset * rank : (i + 1) * s + offset * rank],
y_pred[i * batch_size : (i + 1) * batch_size],
y[i * batch_size : (i + 1) * batch_size],
)

engine = Engine(update)
PSNR(data_range=data_range, output_transform=output_transform, device=metric_device).attach(engine, "psnr")

psnr = PSNR(data_range=data_range, output_transform=output_transform, device=metric_device)
psnr.attach(engine, "psnr")
data = list(range(n_iters))

engine.run(data=data, max_epochs=1)

y = idist.all_gather(y)
y_pred = idist.all_gather(y_pred)

result = engine.state.metrics["psnr"]
assert result > 0.0
assert "psnr" in engine.state.metrics
Expand All @@ -123,100 +125,96 @@ def update(engine, i):
assert np.allclose(result, np_psnr / np_y.shape[0], atol=atol)


def _test_distrib_integration(device, atol=1e-8):
def _test_distrib_input_float(device, atol=1e-8):
def get_test_cases():

y_pred = torch.rand(n_iters * batch_size, 2, 2, device=device)
y = y_pred * 0.65

return y_pred, y

rank = idist.get_rank()
n_iters = 100
s = 10
offset = n_iters * s
batch_size = 10

# test for float
manual_seed(42)
y_pred = torch.rand(offset * idist.get_world_size(), 3, 28, 28, device=device)
y = y_pred * 0.65
data_range = (y.max() - y.min()).cpu().item()
_test(y_pred, y, data_range, "cpu", n_iters, s, offset, rank, atol=atol)
rank = idist.get_rank()
for i in range(3):
# check multiple random inputs as random exact occurencies are rare
torch.manual_seed(42 + rank + i)
y_pred, y = get_test_cases()
_test(y_pred, y, 1, "cpu", n_iters, batch_size, atol=atol)
if device.type != "xla":
_test(y_pred, y, 1, idist.device(), n_iters, batch_size, atol=atol)

# test for YCbCr
manual_seed(42)
y_pred = torch.randint(16, 236, (offset * idist.get_world_size(), 1, 12, 12), dtype=torch.uint8, device=device)
cbcr_pred = torch.randint(16, 241, (offset * idist.get_world_size(), 2, 12, 12), dtype=torch.uint8, device=device)
y = torch.randint(16, 236, (offset * idist.get_world_size(), 1, 12, 12), dtype=torch.uint8, device=device)
cbcr = torch.randint(16, 241, (offset * idist.get_world_size(), 2, 12, 12), dtype=torch.uint8, device=device)

y_pred, y = torch.cat((y_pred, cbcr_pred), dim=1), torch.cat((y, cbcr), dim=1)
data_range = (y[:, 0, ...].max() - y[:, 0, ...].min()).cpu().item()
_test(
y_pred=y_pred,
y=y,
data_range=data_range,
metric_device="cpu",
n_iters=n_iters,
s=s,
offset=offset,
rank=rank,
atol=atol,
output_transform=lambda x: (x[0][:, 0, ...], x[1][:, 0, ...]),
compute_y_channel=True,
)

# test for uint8
manual_seed(42)
y_pred = torch.randint(0, 256, (offset * idist.get_world_size(), 3, 16, 16), device=device, dtype=torch.uint8)
y = (y_pred * 0.65).to(torch.uint8)
data_range = (y.max() - y.min()).cpu().item()
_test(y_pred, y, data_range, "cpu", n_iters, s, offset, rank, atol=atol)
def _test_distrib_multilabel_input_YCbCr(device, atol=1e-8):
def get_test_cases():

# test with NHW shape
manual_seed(42)
y_pred = torch.rand(offset * idist.get_world_size(), 28, 28, device=device)
y = y_pred * 0.8
data_range = (y.max() - y.min()).cpu().item()
_test(y_pred, y, data_range, "cpu", n_iters, s, offset, rank, atol=atol)
y_pred = torch.randint(16, 236, (n_iters * batch_size, 1, 12, 12), dtype=torch.uint8, device=device)
cbcr_pred = torch.randint(16, 241, (n_iters * batch_size, 2, 12, 12), dtype=torch.uint8, device=device)
y = torch.randint(16, 236, (n_iters * batch_size, 1, 12, 12), dtype=torch.uint8, device=device)
cbcr = torch.randint(16, 241, (n_iters * batch_size, 2, 12, 12), dtype=torch.uint8, device=device)

if torch.device(device).type != "xla":
manual_seed(42)
y_pred = torch.rand(offset * idist.get_world_size(), 3, 28, 28, device=device)
y = y_pred * 0.65
data_range = (y.max() - y.min()).cpu().item()
_test(y_pred, y, data_range, idist.device(), n_iters, s, offset, rank, atol=atol)

# test for YCbCr
manual_seed(42)
y_pred = torch.randint(16, 236, (offset * idist.get_world_size(), 1, 12, 12), dtype=torch.uint8, device=device)
cbcr_pred = torch.randint(
16, 241, (offset * idist.get_world_size(), 2, 12, 12), dtype=torch.uint8, device=device
)
y = torch.randint(16, 236, (offset * idist.get_world_size(), 1, 12, 12), dtype=torch.uint8, device=device)
cbcr = torch.randint(16, 241, (offset * idist.get_world_size(), 2, 12, 12), dtype=torch.uint8, device=device)
y_pred, y = torch.cat((y_pred, cbcr_pred), dim=1), torch.cat((y, cbcr), dim=1)
data_range = (y[:, 0, ...].max() - y[:, 0, ...].min()).cpu().item()
_test(
y_pred=y_pred,
y=y,
data_range=data_range,
metric_device=idist.device(),
n_iters=n_iters,
s=s,
offset=offset,
rank=rank,
atol=atol,
output_transform=lambda x: (x[0][:, 0, ...], x[1][:, 0, ...]),
compute_y_channel=True,
)

manual_seed(42)
y_pred = torch.randint(0, 256, (offset * idist.get_world_size(), 3, 16, 16), device=device, dtype=torch.uint8)
return y_pred, y

n_iters = 100
batch_size = 10

def out_fn(x):
return x[0][:, 0, ...], x[1][:, 0, ...]

rank = idist.get_rank()
for i in range(3):
# check multiple random inputs as random exact occurencies are rare
torch.manual_seed(42 + rank + i)
y_pred, y = get_test_cases()
_test(y_pred, y, 220, "cpu", n_iters, batch_size, atol, output_transform=out_fn, compute_y_channel=True)
if device.type != "xla":
dev = idist.device()
_test(y_pred, y, 220, dev, n_iters, batch_size, atol, output_transform=out_fn, compute_y_channel=True)


def _test_distrib_multilabel_input_uint8(device, atol=1e-8):
def get_test_cases():

y_pred = torch.randint(0, 256, (n_iters * batch_size, 3, 16, 16), device=device, dtype=torch.uint8)
y = (y_pred * 0.65).to(torch.uint8)
data_range = (y.max() - y.min()).cpu().item()
_test(y_pred, y, data_range, idist.device(), n_iters, s, offset, rank, atol=atol)

# test with NHW shape
manual_seed(42)
y_pred = torch.rand(offset * idist.get_world_size(), 28, 28, device=device)
return y_pred, y

n_iters = 100
batch_size = 10

rank = idist.get_rank()
for i in range(3):
# check multiple random inputs as random exact occurencies are rare
torch.manual_seed(42 + rank + i)
y_pred, y = get_test_cases()
_test(y_pred, y, 100, "cpu", n_iters, batch_size, atol)
if device.type != "xla":
_test(y_pred, y, 100, idist.device(), n_iters, batch_size, atol)


def _test_distrib_multilabel_input_NHW(device, atol=1e-8):
def get_test_cases():

y_pred = torch.rand(n_iters * batch_size, 28, 28, device=device)
y = y_pred * 0.8
data_range = (y.max() - y.min()).cpu().item()
_test(y_pred, y, data_range, idist.device(), n_iters, s, offset, rank, atol=atol)

return y_pred, y

n_iters = 100
batch_size = 10

rank = idist.get_rank()
for i in range(3):
# check multiple random inputs as random exact occurencies are rare
torch.manual_seed(42 + rank + i)
y_pred, y = get_test_cases()
_test(y_pred, y, 10, "cpu", n_iters, batch_size, atol)
if device.type != "xla":
_test(y_pred, y, 10, idist.device(), n_iters, batch_size, atol)


def _test_distrib_accumulator_device(device):
Expand All @@ -242,7 +240,10 @@ def _test_distrib_accumulator_device(device):
def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo):

device = idist.device()
_test_distrib_integration(device)
_test_distrib_input_float(device)
_test_distrib_multilabel_input_YCbCr(device)
_test_distrib_multilabel_input_uint8(device)
_test_distrib_multilabel_input_NHW(device)
_test_distrib_accumulator_device(device)


Expand All @@ -252,7 +253,10 @@ def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo):
def test_distrib_nccl_gpu(distributed_context_single_node_nccl):

device = idist.device()
_test_distrib_integration(device)
_test_distrib_input_float(device)
_test_distrib_multilabel_input_YCbCr(device)
_test_distrib_multilabel_input_uint8(device)
_test_distrib_multilabel_input_NHW(device)
_test_distrib_accumulator_device(device)


Expand All @@ -262,7 +266,10 @@ def test_distrib_nccl_gpu(distributed_context_single_node_nccl):
def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo):

device = idist.device()
_test_distrib_integration(device)
_test_distrib_input_float(device)
_test_distrib_multilabel_input_YCbCr(device)
_test_distrib_multilabel_input_uint8(device)
_test_distrib_multilabel_input_NHW(device)
_test_distrib_accumulator_device(device)


Expand All @@ -272,7 +279,10 @@ def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo):
def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl):

device = idist.device()
_test_distrib_integration(device)
_test_distrib_input_float(device)
_test_distrib_multilabel_input_YCbCr(device)
_test_distrib_multilabel_input_uint8(device)
_test_distrib_multilabel_input_NHW(device)
_test_distrib_accumulator_device(device)


Expand All @@ -282,13 +292,19 @@ def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl):
def test_distrib_single_device_xla():

device = idist.device()
_test_distrib_integration(device)
_test_distrib_input_float(device)
_test_distrib_multilabel_input_YCbCr(device)
_test_distrib_multilabel_input_uint8(device)
_test_distrib_multilabel_input_NHW(device)
_test_distrib_accumulator_device(device)


def _test_distrib_xla_nprocs(index):
device = idist.device()
_test_distrib_integration(device)
_test_distrib_input_float(device)
_test_distrib_multilabel_input_YCbCr(device)
_test_distrib_multilabel_input_uint8(device)
_test_distrib_multilabel_input_NHW(device)
_test_distrib_accumulator_device(device)


Expand All @@ -307,5 +323,8 @@ def test_distrib_hvd(gloo_hvd_executor):
device = "cpu" if not torch.cuda.is_available() else "cuda"
nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count()

gloo_hvd_executor(_test_distrib_integration, (device,), np=nproc, do_init=True)
gloo_hvd_executor(_test_distrib_input_float, (device,), np=nproc, do_init=True)
gloo_hvd_executor(_test_distrib_multilabel_input_YCbCr, (device,), np=nproc, do_init=True)
gloo_hvd_executor(_test_distrib_multilabel_input_uint8, (device,), np=nproc, do_init=True)
gloo_hvd_executor(_test_distrib_multilabel_input_NHW, (device,), np=nproc, do_init=True)
gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True)