Skip to content

Commit

Permalink
🐛 Fix JITing of complex submodule
Browse files Browse the repository at this point in the history
The `complex` function confused the compiler.
  • Loading branch information
francois-rozet committed Sep 11, 2021
1 parent af274e6 commit d098f4d
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 15 deletions.
2 changes: 1 addition & 1 deletion piqa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
specific image quality assessement metric.
"""

__version__ = '1.1.6'
__version__ = '1.1.7'

from .tv import TV
from .psnr import PSNR
Expand Down
2 changes: 1 addition & 1 deletion piqa/fsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def fsim(
s_q = (2 * q_x * q_y + t4) / (q_x ** 2 + q_y ** 2 + t4)

s_iq = s_i * s_q
s_iq = cx.complex(s_iq, torch.zeros_like(s_iq))
s_iq = cx.complx(s_iq, torch.zeros_like(s_iq))
s_iq_lambda = cx.real(cx.pow(s_iq, lmbda))

s_l = s_l * s_iq_lambda
Expand Down
4 changes: 2 additions & 2 deletions piqa/mdsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def mdsi(
cs = cs_num / cs_den

# Gradient-chromaticity similarity
gs = cx.complex(gs, torch.zeros_like(gs))
cs = cx.complex(cs, torch.zeros_like(cs))
gs = cx.complx(gs, torch.zeros_like(gs))
cs = cx.complx(cs, torch.zeros_like(cs))

if combination == 'prod':
gcs = cx.prod(cx.pow(gs, gamma), cx.pow(cs, beta))
Expand Down
10 changes: 5 additions & 5 deletions piqa/utils/complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch


def complex(real: torch.Tensor, imag: torch.Tensor) -> torch.Tensor:
def complx(real: torch.Tensor, imag: torch.Tensor) -> torch.Tensor:
r"""Returns a complex tensor with its real part equal to \(\Re\) and
its imaginary part equal to \(\Im\).
Expand All @@ -20,7 +20,7 @@ def complex(real: torch.Tensor, imag: torch.Tensor) -> torch.Tensor:
Example:
>>> x = torch.tensor([2., 0.7071])
>>> y = torch.tensor([0., 0.7071])
>>> complex(x, y)
>>> complx(x, y)
tensor([[2.0000, 0.0000],
[0.7071, 0.7071]])
"""
Expand Down Expand Up @@ -103,7 +103,7 @@ def turn(x: torch.Tensor) -> torch.Tensor:
[-0.7071, 0.7071]])
"""

return complex(-imag(x), real(x))
return complx(-imag(x), real(x))


def polar(r: torch.Tensor, phi: torch.Tensor) -> torch.Tensor:
Expand All @@ -127,7 +127,7 @@ def polar(r: torch.Tensor, phi: torch.Tensor) -> torch.Tensor:
[0.7071, 0.7071]])
"""

return complex(r * torch.cos(phi), r * torch.sin(phi))
return complx(r * torch.cos(phi), r * torch.sin(phi))


def mod(x: torch.Tensor, squared: bool = False) -> torch.Tensor:
Expand Down Expand Up @@ -200,7 +200,7 @@ def prod(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
x_r, x_i = x[..., 0], x[..., 1]
y_r, y_i = y[..., 0], y[..., 1]

return complex(x_r * y_r - x_i * y_i, x_i * y_r + x_r * y_i)
return complx(x_r * y_r - x_i * y_i, x_i * y_r + x_r * y_i)


def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion piqa/vsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def vsi(
s_c = (2 * mn_x * mn_y + c3) / (mn_x ** 2 + mn_y ** 2 + c3)
s_c = s_c.prod(dim=1)

s_c = cx.complex(s_c, torch.zeros_like(s_c))
s_c = cx.complx(s_c, torch.zeros_like(s_c))
s_c_beta = cx.real(cx.pow(s_c, beta))

s_vs = s_vs * s_c_beta
Expand Down
7 changes: 2 additions & 5 deletions tests/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
'PSNR': (2, {
'sk.psnr-np': sk.peak_signal_noise_ratio,
'piq.psnr': piq.psnr,
'kornia.PSNR': kornia.PSNRLoss(max_val=1.),
'kornia.PSNR': kornia.PSNRLoss(1.),
'piqa.PSNR': piqa.PSNR(),
}),
'SSIM': (2, {
Expand All @@ -55,10 +55,7 @@
gaussian_weights=True,
),
'piq.ssim': lambda x, y: piq.ssim(x, y, downsample=False),
'kornia.SSIM-halfloss': kornia.SSIM(
window_size=11,
reduction='mean',
),
'kornia.SSIM-halfloss': kornia.SSIMLoss(11),
'IQA.SSIM-loss': IQA.SSIM(),
'vainf.SSIM': vainf.SSIM(data_range=1.),
'piqa.SSIM': piqa.SSIM(),
Expand Down

0 comments on commit d098f4d

Please sign in to comment.