Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3D extension for SSIM #818

Merged
merged 125 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from 100 commits
Commits
Show all changes
125 commits
Select commit Hold shift + click to select a range
848142f
3D ssim, first try
Jan 30, 2022
6b726b4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 31, 2022
f802930
changelog
SkafteNicki Jan 31, 2022
2eee2ea
fix doctest
SkafteNicki Jan 31, 2022
a8e9d02
Update torchmetrics/functional/image/ssim.py
weningerleon Jan 31, 2022
207c5bf
Update torchmetrics/functional/image/ssim.py
weningerleon Jan 31, 2022
fc2d6ed
Update torchmetrics/functional/image/ssim.py
weningerleon Jan 31, 2022
b11817b
Update torchmetrics/functional/image/ssim.py
weningerleon Jan 31, 2022
30345d7
update ssim 3d
Feb 1, 2022
26e6799
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 1, 2022
967209d
adding 3d ssim tests
Feb 1, 2022
002c932
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 1, 2022
75d93fb
Merge branch 'PyTorchLightning:master' into master
weningerleon Feb 1, 2022
0a552a0
Merge branch 'master' into master
Borda Feb 5, 2022
95b78c9
Merge branch 'master' into weningerleon/master
Borda Feb 8, 2022
bacbecb
update
Borda Feb 8, 2022
68011c8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 8, 2022
df1b9e7
Merge branch 'master' into master
weningerleon Feb 8, 2022
8199e78
Merge branch 'master' into master
Borda Feb 8, 2022
50be3f3
fixed formatting errors
Feb 9, 2022
b12a922
bug fix ssim
Feb 9, 2022
ad285e7
_ssim_update
Feb 9, 2022
665c92c
Merge branch 'master' into master
justusschock Feb 9, 2022
57ae26e
Apply suggestions from code review
SkafteNicki Feb 10, 2022
769a7f4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 10, 2022
084fc4f
docs
SkafteNicki Feb 10, 2022
bc66a4c
Merge branch 'master' of https://github.com/weningerleon/metrics into…
SkafteNicki Feb 10, 2022
faf9900
Merge branch 'master' into master
Borda Feb 10, 2022
98d1490
Merge branch 'master' of https://github.com/weningerleon/metrics into…
SkafteNicki Feb 10, 2022
bf0fcb2
Merge branch 'master' into master
mergify[bot] Feb 10, 2022
5fe1410
Merge branch 'master' into master
Borda Feb 10, 2022
378a07e
Merge branch 'master' into master
mergify[bot] Feb 10, 2022
ca6b90f
Merge branch 'master' into master
mergify[bot] Feb 11, 2022
3ae6db8
Merge branch 'master' into master
mergify[bot] Feb 11, 2022
2d40262
update kernel_size default
Feb 11, 2022
d38d8a2
ms-ssim in 3d
Feb 11, 2022
c0098a3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 11, 2022
a7766ee
formatting
Feb 11, 2022
9d2dc0d
merge
Feb 11, 2022
e33877a
Merge branch 'master' into master
mergify[bot] Feb 14, 2022
b73b0cd
Merge branch 'master' into master
mergify[bot] Feb 14, 2022
248eb95
Use our own 3D reflection padding
stancld Feb 16, 2022
9f2c13c
pytorch implementation depending on version, user warning if deprecated
Feb 17, 2022
2cc5873
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2022
f45c908
update torch version checking
Feb 17, 2022
fb38f21
Merge branch 'master' of github.com:weningerleon/metrics
Feb 17, 2022
d72b5b2
Apply suggestions from code review
stancld Feb 17, 2022
3c57d36
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2022
ed68ec2
Use smaller batch size due to OOM
stancld Feb 17, 2022
7d256f6
Fix test according to weningerleon's suggestsion + apply a small batch
stancld Feb 17, 2022
afe6de8
Clean reference sk metric
stancld Feb 17, 2022
08c0f6f
updates and bug fixes
Feb 17, 2022
20f5336
merge
Feb 17, 2022
6e8476d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2022
8309a78
docs
Feb 17, 2022
709d4e7
adapt ssim
Feb 17, 2022
a2be8bc
Merge branch 'master' of github.com:weningerleon/metrics
Feb 17, 2022
b37a512
adapt ssim
Feb 17, 2022
24314df
Merge branch 'master' into master
mergify[bot] Feb 18, 2022
986048e
fix tests
Feb 18, 2022
28aded0
old atol ssim
Feb 18, 2022
c6d9f83
Merge branch 'master' of github.com:weningerleon/metrics
Feb 18, 2022
98fef1b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 18, 2022
5b22686
formatting
Feb 18, 2022
7618c8b
merge
Feb 18, 2022
11c73b6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 18, 2022
86bd470
fix ms ssim
Feb 18, 2022
8185fbd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 18, 2022
60f0719
docs
Feb 18, 2022
b4ffccc
num_batches
Feb 18, 2022
7dd1578
doctest
Feb 18, 2022
7e974ca
torch tensor
Feb 18, 2022
f7fdf88
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 18, 2022
bc3a781
Merge branch 'master' into master
mergify[bot] Feb 21, 2022
d88d02f
changelog
Feb 21, 2022
47adccf
changelog
Feb 21, 2022
ab09084
Merge branch 'master' into master
weningerleon Feb 21, 2022
40368c7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 21, 2022
15c0510
Merge branch 'master' into master
mergify[bot] Feb 21, 2022
b33e4e9
add dict
Feb 21, 2022
2082feb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 21, 2022
5a8a503
Merge branch 'master' into master
weningerleon Feb 23, 2022
69f8d56
typing
Feb 23, 2022
42d0aa1
merging
Feb 23, 2022
e5fc229
Merge branch 'master' into master
mergify[bot] Feb 23, 2022
0dc5a84
Merge branch 'master' into master
mergify[bot] Feb 23, 2022
1fafb1a
Merge branch 'master' into master
mergify[bot] Feb 24, 2022
dc2e79a
Update tests/image/test_ssim.py
weningerleon Feb 24, 2022
9f46a36
Update tests/image/test_ssim.py
weningerleon Feb 24, 2022
fcc2870
Update torchmetrics/functional/image/helper.py
weningerleon Feb 24, 2022
afbb379
Update torchmetrics/image/ssim.py
weningerleon Feb 24, 2022
a2ac937
removed kernel size parametrization
Feb 24, 2022
b93e36b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 24, 2022
371572f
add Raise to docstring
Feb 24, 2022
af9c82d
Merge branch 'master' of github.com:weningerleon/metrics
Feb 24, 2022
98675a1
Update torchmetrics/image/ssim.py
weningerleon Feb 24, 2022
8785563
Update torchmetrics/functional/image/ssim.py
weningerleon Feb 24, 2022
0923588
Merge branch 'master' into master
mergify[bot] Feb 25, 2022
0126a82
Apply suggestions from code review
Borda Feb 25, 2022
28363af
Merge branch 'master' into master
mergify[bot] Feb 28, 2022
fc897b6
Merge branch 'master' into master
mergify[bot] Mar 1, 2022
d35c88f
Merge branch 'master' into master
mergify[bot] Mar 1, 2022
cb960af
add reduce
Mar 2, 2022
1eb5dfa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 2, 2022
299ca82
Merge branch 'master' into master
mergify[bot] Mar 3, 2022
7f583dc
re fix tests
Mar 3, 2022
cb6332d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2022
84dccfc
reduce, other settings
Mar 3, 2022
74c3ebf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 3, 2022
9205d1c
Merge branch 'master' into master
mergify[bot] Mar 3, 2022
f732cd0
Merge branch 'master' into master
mergify[bot] Mar 7, 2022
d47be94
Merge branch 'master' into master
mergify[bot] Mar 11, 2022
20a2877
Merge branch 'master' into master
Borda Mar 20, 2022
4eed069
Merge branch 'master' into master
mergify[bot] Mar 21, 2022
e27b0ce
Merge branch 'master' into master
mergify[bot] Mar 21, 2022
6841417
missing docstrings
SkafteNicki Mar 21, 2022
5b455be
fix doctest
SkafteNicki Mar 21, 2022
cd0dbad
fix doc test
SkafteNicki Mar 21, 2022
6bcbee2
Merge branch 'master' into master
SkafteNicki Mar 21, 2022
daf5ae3
Merge branch 'master' into master
mergify[bot] Mar 22, 2022
7a8576b
fix memory issues
SkafteNicki Mar 22, 2022
c93b2c1
device placement
SkafteNicki Mar 22, 2022
1af87a8
lower memory
SkafteNicki Mar 22, 2022
eba0214
Merge branch 'master' into weningerleon/master
Borda Mar 24, 2022
b057212
Merge branch 'master' into master
mergify[bot] Mar 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added new image metric `UniversalImageQualityIndex` ([#824](https://github.com/PyTorchLightning/metrics/pull/824))


- Added support for 3D image and uniform kernel in `StructuralSimilarityIndexMeasure` ([#818](https://github.com/PyTorchLightning/metrics/pull/818))


- Added smart update of `MetricCollection` ([#709](https://github.com/PyTorchLightning/metrics/pull/709))


Expand Down
23 changes: 12 additions & 11 deletions tests/image/test_ms_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
BATCH_SIZE = 1

_inputs = []
for size, coef in [(128, 0.9), (128, 0.7)]:
for size, coef in [(182, 0.9), (182, 0.7)]:
preds = torch.rand(NUM_BATCHES, BATCH_SIZE, 1, size, size)
_inputs.append(
Input(
Expand All @@ -40,40 +40,41 @@


def pytorch_ms_ssim(preds, target, data_range, kernel_size):
return ms_ssim(preds, target, data_range=data_range, win_size=kernel_size)
return ms_ssim(preds, target, data_range=data_range, win_size=kernel_size, size_average=False)


@pytest.mark.parametrize(
"preds, target",
[(i.preds, i.target) for i in _inputs],
)
@pytest.mark.parametrize("kernel_size", [5, 7])
class TestMultiScaleStructuralSimilarityIndexMeasure(MetricTester):
atol = 6e-3

# in the pytorch-msssim package, sigma is hardcoded to 1.5. We can thus only test this value, which corresponds
# to a kernel size of 11

@pytest.mark.parametrize("ddp", [False, True])
@pytest.mark.parametrize("dist_sync_on_step", [False, True])
def test_ms_ssim(self, preds, target, kernel_size, ddp, dist_sync_on_step):
def test_ms_ssim(self, preds, target, ddp, dist_sync_on_step):
self.run_class_metric_test(
ddp,
preds,
target,
MultiScaleStructuralSimilarityIndexMeasure,
partial(pytorch_ms_ssim, data_range=1.0, kernel_size=kernel_size),
metric_args={"data_range": 1.0, "kernel_size": (kernel_size, kernel_size)},
partial(pytorch_ms_ssim, data_range=1.0, kernel_size=11),
dist_sync_on_step=dist_sync_on_step,
)

def test_ms_ssim_functional(self, preds, target, kernel_size):
def test_ms_ssim_functional(self, preds, target):
self.run_functional_metric_test(
preds,
target,
multiscale_structural_similarity_index_measure,
partial(pytorch_ms_ssim, data_range=1.0, kernel_size=kernel_size),
metric_args={"data_range": 1.0, "kernel_size": (kernel_size, kernel_size)},
partial(pytorch_ms_ssim, data_range=1.0, kernel_size=11),
metric_args={"data_range": 1.0, "kernel_size": 11},
)

def test_ms_ssim_differentiability(self, preds, target, kernel_size):
def test_ms_ssim_differentiability(self, preds, target):
# We need to minimize this example to make the test tractable
single_beta = (1.0,)
_preds = preds[:, :, :, :16, :16]
Expand All @@ -86,7 +87,7 @@ def test_ms_ssim_differentiability(self, preds, target, kernel_size):
metric_module=MultiScaleStructuralSimilarityIndexMeasure,
metric_args={
"data_range": 1.0,
"kernel_size": (kernel_size, kernel_size),
"kernel_size": 11,
"betas": single_beta,
},
)
141 changes: 97 additions & 44 deletions tests/image/test_ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,84 +25,122 @@

seed_all(42)

Input = namedtuple("Input", ["preds", "target", "multichannel"])
Input = namedtuple("Input", ["preds", "target"])

_inputs = []
for size, channel, coef, multichannel, dtype in [
(12, 3, 0.9, True, torch.float),
(13, 1, 0.8, False, torch.float32),
(14, 1, 0.7, False, torch.double),
(15, 3, 0.6, True, torch.float64),
for size, channel, coef, dtype in [
(12, 3, 0.9, torch.float),
(13, 1, 0.8, torch.float32),
(14, 1, 0.7, torch.double),
(13, 3, 0.6, torch.float32),
]:
preds = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype)
preds2d = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, dtype=dtype)
_inputs.append(
Input(
preds=preds,
target=preds * coef,
multichannel=multichannel,
preds=preds2d,
target=preds2d * coef,
)
)
preds3d = torch.rand(NUM_BATCHES, BATCH_SIZE, channel, size, size, size, dtype=dtype)
_inputs.append(
Input(
preds=preds3d,
target=preds3d * coef,
)
)


def _sk_ssim(preds, target, data_range, multichannel, kernel_size):
c, h, w = preds.shape[-3:]
sk_preds = preds.view(-1, c, h, w).permute(0, 2, 3, 1).numpy()
sk_target = target.view(-1, c, h, w).permute(0, 2, 3, 1).numpy()
if not multichannel:
sk_preds = sk_preds[:, :, :, 0]
sk_target = sk_target[:, :, :, 0]

return structural_similarity(
sk_target,
sk_preds,
data_range=data_range,
multichannel=multichannel,
gaussian_weights=True,
win_size=kernel_size,
sigma=1.5,
use_sample_covariance=False,
)
def _sk_ssim(preds, target, data_range, sigma, kernel_size=None, return_ssim_image=False, gaussian_weights=True):
if len(preds.shape) == 4:
c, h, w = preds.shape[-3:]
sk_preds = preds.view(-1, c, h, w).permute(0, 2, 3, 1).numpy()
sk_target = target.view(-1, c, h, w).permute(0, 2, 3, 1).numpy()
elif len(preds.shape) == 5:
c, d, h, w = preds.shape[-4:]
sk_preds = preds.view(-1, c, d, h, w).permute(0, 2, 3, 4, 1).numpy()
sk_target = target.view(-1, c, d, h, w).permute(0, 2, 3, 4, 1).numpy()

results = torch.zeros(sk_preds.shape[0], dtype=target.dtype)
if not return_ssim_image:
for i in range(sk_preds.shape[0]):
res = structural_similarity(
sk_target[i],
sk_preds[i],
data_range=data_range,
multichannel=True,
gaussian_weights=gaussian_weights,
win_size=kernel_size,
sigma=sigma,
use_sample_covariance=False,
full=return_ssim_image,
)
results[i] = torch.from_numpy(res).type(preds.dtype)
return results
else:
fullimages = torch.zeros(target.shape, dtype=target.dtype)
for i in range(sk_preds.shape[0]):
res, fullimage = structural_similarity(
sk_target[i],
sk_preds[i],
data_range=data_range,
multichannel=True,
gaussian_weights=gaussian_weights,
win_size=kernel_size,
sigma=sigma,
use_sample_covariance=False,
full=return_ssim_image,
)
results[i] = torch.from_numpy(res).type(preds.dtype)
fullimage = torch.from_numpy(fullimage).type(preds.dtype)
if len(preds.shape) == 4:
fullimages[i] = fullimage.permute(2, 0, 1)
elif len(preds.shape) == 5:
fullimages[i] = fullimage.permute(3, 0, 1, 2)
return results, fullimages


@pytest.mark.parametrize(
"preds, target, multichannel",
[(i.preds, i.target, i.multichannel) for i in _inputs],
"preds, target",
[(i.preds, i.target) for i in _inputs],
)
@pytest.mark.parametrize("kernel_size", [5, 11])
@pytest.mark.parametrize("sigma", [1.5, 0.5])
class TestSSIM(MetricTester):
stancld marked this conversation as resolved.
Show resolved Hide resolved
atol = 6e-3

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_ssim(self, preds, target, multichannel, kernel_size, ddp, dist_sync_on_step):
def test_ssim(self, preds, target, sigma, ddp, dist_sync_on_step):
self.run_class_metric_test(
ddp,
preds,
target,
StructuralSimilarityIndexMeasure,
partial(_sk_ssim, data_range=1.0, multichannel=multichannel, kernel_size=kernel_size),
metric_args={"data_range": 1.0, "kernel_size": (kernel_size, kernel_size)},
partial(_sk_ssim, data_range=1.0, sigma=sigma, kernel_size=None),
metric_args={
"data_range": 1.0,
"sigma": sigma,
},
dist_sync_on_step=dist_sync_on_step,
)

def test_ssim_functional(self, preds, target, multichannel, kernel_size):
def test_ssim_functional(self, preds, target, sigma):
self.run_functional_metric_test(
preds,
target,
structural_similarity_index_measure,
partial(_sk_ssim, data_range=1.0, multichannel=multichannel, kernel_size=kernel_size),
metric_args={"data_range": 1.0, "kernel_size": (kernel_size, kernel_size)},
partial(_sk_ssim, data_range=1.0, sigma=sigma, kernel_size=None),
metric_args={"data_range": 1.0, "sigma": sigma},
)

# SSIM half + cpu does not work due to missing support in torch.log
@pytest.mark.xfail(reason="SSIM metric does not support cpu + half precision")
def test_ssim_half_cpu(self, preds, target, multichannel, kernel_size):
def test_ssim_half_cpu(self, preds, target, sigma):
self.run_precision_test_cpu(
preds, target, StructuralSimilarityIndexMeasure, structural_similarity_index_measure, {"data_range": 1.0}
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
def test_ssim_half_gpu(self, preds, target, multichannel, kernel_size):
def test_ssim_half_gpu(self, preds, target, sigma):
self.run_precision_test_gpu(
preds, target, StructuralSimilarityIndexMeasure, structural_similarity_index_measure, {"data_range": 1.0}
)
Expand All @@ -111,8 +149,8 @@ def test_ssim_half_gpu(self, preds, target, multichannel, kernel_size):
@pytest.mark.parametrize(
["pred", "target", "kernel", "sigma"],
[
([1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # len(shape)
([1, 1, 16, 16], [1, 1, 16, 16], [11, 11], [1.5]), # len(kernel), len(sigma)
([1, 16, 16], [1, 16, 16], [11, 11], [1.5, 1.5]), # len(shape)
([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5, 1.5]), # len(kernel), len(sigma)
([1, 1, 16, 16], [1, 1, 16, 16], [11], [1.5]), # len(kernel), len(sigma)
([1, 1, 16, 16], [1, 1, 16, 16], [11, 0], [1.5, 1.5]), # invalid kernel input
Expand All @@ -123,15 +161,15 @@ def test_ssim_half_gpu(self, preds, target, multichannel, kernel_size):
],
)
def test_ssim_invalid_inputs(pred, target, kernel, sigma):
pred_t = torch.rand(pred)
pred_t = torch.rand(pred, dtype=torch.float32)
target_t = torch.rand(target, dtype=torch.float64)
with pytest.raises(TypeError):
structural_similarity_index_measure(pred_t, target_t)

pred = torch.rand(pred)
target = torch.rand(target)
with pytest.raises(ValueError):
structural_similarity_index_measure(pred, target, kernel, sigma)
structural_similarity_index_measure(pred, target, kernel_size=kernel, sigma=sigma)


def test_ssim_unequal_kernel_size():
Expand Down Expand Up @@ -167,5 +205,20 @@ def test_ssim_unequal_kernel_size():
]
)
# kernel order matters
assert structural_similarity_index_measure(preds, target, kernel_size=(3, 5)) == torch.tensor(0.10814697)
assert structural_similarity_index_measure(preds, target, kernel_size=(5, 3)) != torch.tensor(0.10814697)
assert torch.isclose(
structural_similarity_index_measure(preds, target, gaussian_kernel=True, sigma=(0.25, 0.5)),
torch.tensor(0.08869550),
)
assert not torch.isclose(
structural_similarity_index_measure(preds, target, gaussian_kernel=True, sigma=(0.5, 0.25)),
torch.tensor(0.08869550),
)

assert torch.isclose(
structural_similarity_index_measure(preds, target, gaussian_kernel=False, kernel_size=(3, 5)),
torch.tensor(0.05131844),
)
assert not torch.isclose(
structural_similarity_index_measure(preds, target, gaussian_kernel=False, kernel_size=(5, 3)),
torch.tensor(0.05131844),
)
59 changes: 55 additions & 4 deletions torchmetrics/functional/image/helper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import Sequence
from typing import Sequence, Union

import torch
import torch.nn.functional as F
from torch import Tensor

from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_1_10


def _gaussian(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> Tensor:
"""Computes 1D gaussian kernel.
Expand All @@ -22,8 +26,12 @@ def _gaussian(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.
return (gauss / gauss.sum()).unsqueeze(dim=0) # (1, kernel_size)


def _gaussian_kernel(
channel: int, kernel_size: Sequence[int], sigma: Sequence[float], dtype: torch.dtype, device: torch.device
def _gaussian_kernel_2d(
channel: int,
kernel_size: Sequence[int],
sigma: Sequence[float],
dtype: torch.dtype,
device: Union[torch.device, str],
) -> Tensor:
"""Computes 2D gaussian kernel.

Expand All @@ -35,7 +43,7 @@ def _gaussian_kernel(
device: device of the output tensor

Example:
>>> _gaussian_kernel(1, (5,5), (1,1), torch.float, "cpu")
>>> _gaussian_kernel_2d(1, (5,5), (1,1), torch.float, "cpu")
tensor([[[[0.0030, 0.0133, 0.0219, 0.0133, 0.0030],
[0.0133, 0.0596, 0.0983, 0.0596, 0.0133],
[0.0219, 0.0983, 0.1621, 0.0983, 0.0219],
Expand All @@ -48,3 +56,46 @@ def _gaussian_kernel(
kernel = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size)

return kernel.expand(channel, 1, kernel_size[0], kernel_size[1])


def _gaussian_kernel_3d(
channel: int, kernel_size: Sequence[int], sigma: Sequence[float], dtype: torch.dtype, device: torch.device
) -> Tensor:
"""Computes 3D gaussian kernel.

Args:
channel: number of channels in the image
kernel_size: size of the gaussian kernel as a tuple (h, w, d)
sigma: Standard deviation of the gaussian kernel
dtype: data type of the output tensor
device: device of the output tensor
"""

gaussian_kernel_x = _gaussian(kernel_size[0], sigma[0], dtype, device)
gaussian_kernel_y = _gaussian(kernel_size[1], sigma[1], dtype, device)
gaussian_kernel_z = _gaussian(kernel_size[2], sigma[2], dtype, device)
kernel_xy = torch.matmul(gaussian_kernel_x.t(), gaussian_kernel_y) # (kernel_size, 1) * (1, kernel_size)
kernel = torch.mul(
kernel_xy.unsqueeze(-1).repeat(1, 1, kernel_size[2]),
gaussian_kernel_z.expand(kernel_size[0], kernel_size[1], kernel_size[2]),
)
return kernel.expand(channel, 1, kernel_size[0], kernel_size[1], kernel_size[2])


def _single_dimension_pad(inputs: Tensor, dim: int, pad: int) -> Tensor:
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
_max = inputs.shape[dim] - 2
x = torch.index_select(inputs, dim, torch.arange(pad, 0, -1))
y = torch.index_select(inputs, dim, torch.arange(_max, _max - pad, -1))
return torch.cat((x, inputs, y), dim)


def _reflection_pad_3d(inputs: Tensor, pad_h: int, pad_w: int, pad_d: int) -> Tensor:
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
if _TORCH_GREATER_EQUAL_1_10:
inputs = F.pad(inputs, (pad_h, pad_h, pad_w, pad_w, pad_d, pad_d), mode="reflect")
else:
rank_zero_warn(
"An older version of pyTorch is used. For optimal speed, please upgrade to at least pyTorch 1.10."
)
for dim, pad in enumerate([pad_h, pad_w, pad_d]):
inputs = _single_dimension_pad(inputs, dim + 2, pad)
return inputs
Loading