Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[metrics] Update SSIM #4566

Merged
merged 30 commits into from
Nov 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
51870e3
[metrics] Update SSIM
Nov 7, 2020
e582d33
Merge remote-tracking branch 'origin/master' into metrics/ssim
Nov 7, 2020
6e122df
Merge remote-tracking branch 'origin/master' into metrics/ssim
Nov 9, 2020
a800ab1
Merge remote-tracking branch 'origin/master' into metrics/ssim
Nov 9, 2020
12ede93
Merge remote-tracking branch 'origin/master' into metrics/ssim
Nov 9, 2020
f5fbd5f
[metrics] Update SSIM
Nov 9, 2020
5b3f2f0
Merge remote-tracking branch 'origin/master' into metrics/ssim
Nov 9, 2020
a06ffa2
[metrics] Update SSIM
Nov 9, 2020
89ca5ce
[metrics] Update SSIM
Nov 10, 2020
41c0917
Merge remote-tracking branch 'origin/master' into metrics/ssim
Nov 10, 2020
2acfae9
[metrics] update ssim
Nov 10, 2020
ed2a992
dist_sync_on_step True
Nov 10, 2020
f4b2f65
Merge remote-tracking branch 'origin/master' into metrics/ssim
Nov 10, 2020
72d7518
[metrics] update ssim
Nov 10, 2020
70c366c
Merge branch 'master' into metrics/ssim
Nov 12, 2020
2b4cbbb
Merge branch 'master' into metrics/ssim
Nov 12, 2020
3ea91a4
Merge branch 'master' into metrics/ssim
Nov 14, 2020
c89bfb8
Merge branch 'master' into metrics/ssim
tchaton Nov 16, 2020
4fd0e2e
Update tests/metrics/regression/test_ssim.py
Nov 16, 2020
543162d
Merge branch 'master' into metrics/ssim
tchaton Nov 16, 2020
5b8dcc8
Merge branch 'master' into metrics/ssim
SeanNaren Nov 16, 2020
af29e9b
Update pytorch_lightning/metrics/functional/ssim.py
Nov 17, 2020
d1f5eff
Merge branch 'master' into metrics/ssim
Nov 17, 2020
350faca
ddp=True
Nov 17, 2020
a653689
Update test_ssim.py
Nov 17, 2020
3b560fd
Merge branch 'master' into metrics/ssim
Nov 18, 2020
7f2721f
Merge branch 'master' into metrics/ssim
Nov 18, 2020
b9289c1
Merge branch 'master' into metrics/ssim
SkafteNicki Nov 18, 2020
65a361e
Merge branch 'master' into metrics/ssim
SeanNaren Nov 18, 2020
9394d81
Merge branch 'master' into metrics/ssim
Nov 19, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions pytorch_lightning/metrics/functional/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@
from torch.nn import functional as F


def _gaussian_kernel(channel, kernel_size, sigma, device):
def _gaussian(kernel_size, sigma, device):
gauss = torch.arange(
start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=torch.float32, device=device
)
gauss = torch.exp(-gauss.pow(2) / (2 * pow(sigma, 2)))
return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size)
def _gaussian(kernel_size: int, sigma: int, dtype: torch.dtype, device: torch.device):
dist = torch.arange(start=(1 - kernel_size) / 2, end=(1 + kernel_size) / 2, step=1, dtype=dtype, device=device)
gauss = torch.exp(-torch.pow(dist / sigma, 2) / 2)
return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size)


gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], device)
gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], device)
kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y)
def _gaussian_kernel(channel: int, kernel_size: Sequence[int], sigma: Sequence[float],
dtype: torch.dtype, device: torch.device):
gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], dtype, device)
gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], dtype, device)
kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size)

return kernel.expand(channel, 1, kernel_size[0], kernel_size[1])

Expand Down Expand Up @@ -82,9 +82,15 @@ def _ssim_compute(
device = preds.device

channel = preds.size(1)
kernel = _gaussian_kernel(channel, kernel_size, sigma, device)
dtype = preds.dtype
kernel = _gaussian_kernel(channel, kernel_size, sigma, dtype, device)
pad_w = (kernel_size[0] - 1) // 2
pad_h = (kernel_size[1] - 1) // 2

preds = F.pad(preds, (pad_w, pad_w, pad_h, pad_h), mode='reflect')
target = F.pad(target, (pad_w, pad_w, pad_h, pad_h), mode='reflect')

input_list = torch.cat([preds, target, preds * preds, target * target, preds * target]) # (5 * B, C, H, W)
input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W)
outputs = F.conv2d(input_list, kernel, groups=channel)
output_list = [outputs[x * preds.size(0): (x + 1) * preds.size(0)] for x in range(len(outputs))]

Expand All @@ -100,6 +106,7 @@ def _ssim_compute(
lower = sigma_pred_sq + sigma_target_sq + c2

ssim_idx = ((2 * mu_pred_target + c1) * upper) / ((mu_pred_sq + mu_target_sq + c1) * lower)
ssim_idx = ssim_idx[..., pad_h:-pad_h, pad_w:-pad_w]

return reduce(ssim_idx, reduction)

Expand Down
17 changes: 9 additions & 8 deletions tests/metrics/regression/test_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@


_inputs = []
for size, channel, coef, multichannel in [
(16, 1, 0.9, False),
(32, 3, 0.8, True),
(48, 4, 0.7, True),
(64, 5, 0.6, True),
for size, channel, coef, multichannel, dtype in [
(12, 3, 0.9, True, torch.float),
(13, 1, 0.8, False, torch.float32),
(14, 1, 0.7, False, torch.double),
(15, 3, 0.6, True, torch.float64),
]:
preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size)
preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype)
_inputs.append(
Input(
preds=preds,
Expand All @@ -41,7 +41,8 @@ def _sk_metric(preds, target, data_range, multichannel):
sk_target = sk_target[:, :, :, 0]

return structural_similarity(
sk_target, sk_preds, data_range=data_range, multichannel=multichannel, gaussian_weights=True, win_size=11
sk_target, sk_preds, data_range=data_range, multichannel=multichannel,
gaussian_weights=True, win_size=11, sigma=1.5, use_sample_covariance=False
)


Expand All @@ -50,7 +51,7 @@ def _sk_metric(preds, target, data_range, multichannel):
[(i.preds, i.target, i.multichannel) for i in _inputs],
)
class TestSSIM(MetricTester):
atol = 1e-3 # TODO: ideally tests should pass with lower tolerance
atol = 6e-5

# TODO: for some reason this test hangs with ddp=True
# @pytest.mark.parametrize("ddp", [True, False])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do tests pass with ddp=True?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test hangs with ddp=True

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ydcjeff I just tried locally and the test does not hang for me when ddp=True.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I tried before, the test hangs on drone, will try again.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay test still hangs on drone, so lets just disable ddp. Really not sure why it is not working.

Expand Down
2 changes: 1 addition & 1 deletion tests/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def setup_ddp(rank, world_size):
os.environ["MASTER_ADDR"] = 'localhost'
os.environ['MASTER_PORT'] = '8088'

if torch.distributed.is_available() and sys.platform not in ['win32', 'cygwin']:
if torch.distributed.is_available() and sys.platform not in ('win32', 'cygwin'):
torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size)


Expand Down