Skip to content

Commit

Permalink
add psnr ssim pytorch version
Browse files Browse the repository at this point in the history
  • Loading branch information
xinntao committed May 1, 2022
1 parent 786c5fd commit 01a75fb
Show file tree
Hide file tree
Showing 8 changed files with 361 additions and 223 deletions.
3 changes: 1 addition & 2 deletions basicsr/data/paired_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
from basicsr.data.transforms import augment, paired_random_crop
from basicsr.utils import FileClient, imfrombytes, img2tensor
from basicsr.utils.matlab_functions import rgb2ycbcr
from basicsr.utils import FileClient, imfrombytes, img2tensor, rgb2ycbcr
from basicsr.utils.registry import DATASET_REGISTRY


Expand Down
3 changes: 1 addition & 2 deletions basicsr/data/single_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from torchvision.transforms.functional import normalize

from basicsr.data.data_util import paths_from_lmdb
from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir
from basicsr.utils.matlab_functions import rgb2ycbcr
from basicsr.utils import FileClient, imfrombytes, img2tensor, rgb2ycbcr, scandir
from basicsr.utils.registry import DATASET_REGISTRY


Expand Down
2 changes: 1 addition & 1 deletion basicsr/metrics/metric_util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from basicsr.utils.matlab_functions import bgr2ycbcr
from basicsr.utils import bgr2ycbcr


def reorder_image(img, input_order='HWC'):
Expand Down
177 changes: 142 additions & 35 deletions basicsr/metrics/psnr_ssim.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import cv2
import numpy as np
import torch
import torch.nn.functional as F

from basicsr.metrics.metric_util import reorder_image, to_y_channel
from basicsr.utils.color_util import rgb2ycbcr_pt
from basicsr.utils.registry import METRIC_REGISTRY


Expand All @@ -14,23 +17,19 @@ def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=Fal
Args:
img (ndarray): Images with range [0, 255].
img2 (ndarray): Images with range [0, 255].
crop_border (int): Cropped pixels in each edge of an image. These
pixels are not involved in the PSNR calculation.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
Default: 'HWC'.
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
input_order (str): Whether the input order is 'HWC' or 'CHW'. Default: 'HWC'.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
Returns:
float: psnr result.
float: PSNR result.
"""

assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
if input_order not in ['HWC', 'CHW']:
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"')
img = reorder_image(img, input_order=input_order)
img2 = reorder_image(img2, input_order=input_order)
img = img.astype(np.float64)
img2 = img2.astype(np.float64)

if crop_border != 0:
img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
Expand All @@ -40,44 +39,48 @@ def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=Fal
img = to_y_channel(img)
img2 = to_y_channel(img2)

img = img.astype(np.float64)
img2 = img2.astype(np.float64)

mse = np.mean((img - img2)**2)
if mse == 0:
return float('inf')
return 20. * np.log10(255. / np.sqrt(mse))
return 10. * np.log10(255. * 255. / mse)


def _ssim(img, img2):
"""Calculate SSIM (structural similarity) for one channel images.
@METRIC_REGISTRY.register()
def calculate_psnr_pt(img, img2, crop_border, test_y_channel=False, **kwargs):
"""Calculate PSNR (Peak Signal-to-Noise Ratio) (PyTorch version).
It is called by func:`calculate_ssim`.
Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
Args:
img (ndarray): Images with range [0, 255] with order 'HWC'.
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
Returns:
float: ssim result.
float: PSNR result.
"""

c1 = (0.01 * 255)**2
c2 = (0.03 * 255)**2
assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')

img = img.astype(np.float64)
img2 = img2.astype(np.float64)
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
if crop_border != 0:
img = img[:, :, crop_border:-crop_border, crop_border:-crop_border]
img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border]

mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5]
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
if test_y_channel:
img = rgb2ycbcr_pt(img, y_only=True)
img2 = rgb2ycbcr_pt(img2, y_only=True)

ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2))
return ssim_map.mean()
img = img.to(torch.float64)
img2 = img2.to(torch.float64)

mse = torch.mean((img - img2)**2, dim=[1, 2, 3])
if mse == 0:
return float('inf')
return 10. * torch.log10(1. / mse)


@METRIC_REGISTRY.register()
Expand All @@ -96,23 +99,20 @@ def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=Fal
Args:
img (ndarray): Images with range [0, 255].
img2 (ndarray): Images with range [0, 255].
crop_border (int): Cropped pixels in each edge of an image. These
pixels are not involved in the SSIM calculation.
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
input_order (str): Whether the input order is 'HWC' or 'CHW'.
Default: 'HWC'.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
Returns:
float: ssim result.
float: SSIM result.
"""

assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
if input_order not in ['HWC', 'CHW']:
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"')
img = reorder_image(img, input_order=input_order)
img2 = reorder_image(img2, input_order=input_order)
img = img.astype(np.float64)
img2 = img2.astype(np.float64)

if crop_border != 0:
img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
Expand All @@ -122,7 +122,114 @@ def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=Fal
img = to_y_channel(img)
img2 = to_y_channel(img2)

img = img.astype(np.float64)
img2 = img2.astype(np.float64)

ssims = []
for i in range(img.shape[2]):
ssims.append(_ssim(img[..., i], img2[..., i]))
return np.array(ssims).mean()


@METRIC_REGISTRY.register()
def calculate_ssim_pt(img, img2, crop_border, test_y_channel=False, **kwargs):
"""Calculate SSIM (structural similarity) (PyTorch version).
Ref:
Image quality assessment: From error visibility to structural similarity
The results are the same as that of the official released MATLAB code in
https://ece.uwaterloo.ca/~z70wang/research/ssim/.
For three-channel images, SSIM is calculated for each channel and then
averaged.
Args:
img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
crop_border (int): Cropped pixels in each edge of an image. These pixels are not involved in the calculation.
test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
Returns:
float: SSIM result.
"""

assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')

if crop_border != 0:
img = img[:, :, crop_border:-crop_border, crop_border:-crop_border]
img2 = img2[:, :, crop_border:-crop_border, crop_border:-crop_border]

if test_y_channel:
img = rgb2ycbcr_pt(img, y_only=True)
img2 = rgb2ycbcr_pt(img2, y_only=True)

img = img.to(torch.float64)
img2 = img2.to(torch.float64)

ssim = _ssim_pth(img * 255., img2 * 255.)
return ssim


def _ssim(img, img2):
"""Calculate SSIM (structural similarity) for one channel images.
It is called by func:`calculate_ssim`.
Args:
img (ndarray): Images with range [0, 255] with order 'HWC'.
img2 (ndarray): Images with range [0, 255] with order 'HWC'.
Returns:
float: SSIM result.
"""

c1 = (0.01 * 255)**2
c2 = (0.03 * 255)**2
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())

mu1 = cv2.filter2D(img, -1, window)[5:-5, 5:-5] # valid mode for window size 11
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
mu1_sq = mu1**2
mu2_sq = mu2**2
mu1_mu2 = mu1 * mu2
sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2

ssim_map = ((2 * mu1_mu2 + c1) * (2 * sigma12 + c2)) / ((mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2))
return ssim_map.mean()


def _ssim_pth(img, img2):
"""Calculate SSIM (structural similarity) (PyTorch version).
It is called by func:`calculate_ssim_pt`.
Args:
img (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
img2 (Tensor): Images with range [0, 1], shape (n, 3/1, h, w).
Returns:
float: SSIM result.
"""
c1 = (0.01 * 255)**2
c2 = (0.03 * 255)**2

kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
window = torch.from_numpy(window).view(1, 1, 11, 11).expand(img.size(1), 1, 11, 11).to(img.dtype).to(img.device)

mu1 = F.conv2d(img, window, stride=1, padding=0, groups=img.shape[1]) # valid mode
mu2 = F.conv2d(img2, window, stride=1, padding=0, groups=img2.shape[1]) # valid mode
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv2d(img * img, window, stride=1, padding=0, groups=img.shape[1]) - mu1_sq
sigma2_sq = F.conv2d(img2 * img2, window, stride=1, padding=0, groups=img.shape[1]) - mu2_sq
sigma12 = F.conv2d(img * img2, window, stride=1, padding=0, groups=img.shape[1]) - mu1_mu2

cs_map = (2 * sigma12 + c2) / (sigma1_sq + sigma2_sq + c2)
ssim_map = ((2 * mu1_mu2 + c1) / (mu1_sq + mu2_sq + c1)) * cs_map
return ssim_map.mean([1, 2, 3])
7 changes: 7 additions & 0 deletions basicsr/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb
from .diffjpeg import DiffJPEG
from .file_client import FileClient
from .img_process_util import USMSharp, usm_sharp
Expand All @@ -6,6 +7,12 @@
from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt

__all__ = [
# color_util.py
'bgr2ycbcr',
'rgb2ycbcr',
'rgb2ycbcr_pt',
'ycbcr2bgr',
'ycbcr2rgb',
# file_client.py
'FileClient',
# img_util.py
Expand Down
Loading

0 comments on commit 01a75fb

Please sign in to comment.