diff --git a/mmedit/evaluation/metrics/__init__.py b/mmedit/evaluation/metrics/__init__.py index d89a216e07..e6a0aa3e01 100644 --- a/mmedit/evaluation/metrics/__init__.py +++ b/mmedit/evaluation/metrics/__init__.py @@ -1,40 +1,45 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .connectivity_error import ConnectivityError from .equivariance import Equivariance from .fid import FrechetInceptionDistance, TransFID +from .gradient_error import GradientError from .inception_score import InceptionScore, TransIS -from .matting import SAD, ConnectivityError, GradientError, MattingMSE +from .mae import MAE +from .matting_mse import MattingMSE from .ms_ssim import MultiScaleStructureSimilarity +from .mse import MSE from .niqe import NIQE, niqe -from .pixel_metrics import MAE, MSE, PSNR, SNR, psnr, snr from .ppl import PerceptualPathLength from .precision_and_recall import PrecisionAndRecall +from .psnr import PSNR, psnr +from .sad import SAD +from .snr import SNR, snr from .ssim import SSIM, ssim from .swd import SlicedWassersteinDistance __all__ = [ - 'ConnectivityError', - 'GradientError', 'MAE', - 'MattingMSE', 'MSE', - 'NIQE', - 'niqe', 'PSNR', 'psnr', - 'SAD', 'SNR', 'snr', 'SSIM', 'ssim', - 'Equivariance', + 'MultiScaleStructureSimilarity', 'FrechetInceptionDistance', + 'TransFID', 'InceptionScore', - 'MultiScaleStructureSimilarity', + 'TransIS', + 'SAD', + 'MattingMSE', + 'ConnectivityError', + 'GradientError', 'PerceptualPathLength', - 'MultiScaleStructureSimilarity', 'PrecisionAndRecall', 'SlicedWassersteinDistance', - 'TransFID', - 'TransIS', + 'NIQE', + 'niqe', + 'Equivariance', ] diff --git a/mmedit/evaluation/metrics/connectivity_error.py b/mmedit/evaluation/metrics/connectivity_error.py new file mode 100644 index 0000000000..507359db90 --- /dev/null +++ b/mmedit/evaluation/metrics/connectivity_error.py @@ -0,0 +1,117 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Evaluation metrics used in Image Matting.""" + +from typing import List, Sequence + +import cv2 +import numpy as np +from mmengine.evaluator import BaseMetric + +from mmedit.registry import METRICS +from .metrics_utils import _fetch_data_and_check, average + + +@METRICS.register_module() +class ConnectivityError(BaseMetric): + """Connectivity error for evaluating alpha matte prediction. + + .. note:: + + Current implementation assume image / alpha / trimap array in numpy + format and with pixel value ranging from 0 to 255. + + .. note:: + + pred_alpha should be masked by trimap before passing + into this metric + + Args: + step (float): Step of threshold when computing intersection between + `alpha` and `pred_alpha`. Default to 0.1 . + norm_const (int): Divide the result to reduce its magnitude. + Default to 1000. + + Default prefix: '' + + Metrics: + - ConnectivityError (float): Connectivity Error + """ + + def __init__( + self, + step=0.1, + norm_constant=1000, + **kwargs, + ) -> None: + self.step = step + self.norm_constant = norm_constant + super().__init__(**kwargs) + + def process(self, data_batch: Sequence[dict], + data_samples: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. The processed + results should be stored in ``self.results``, which will be used to + compute the metrics when all batches have been processed. + + Args: + data_batch (Sequence[dict]): A batch of data from the dataloader. + predictions (Sequence[dict]): A batch of outputs from + the model. + """ + + for data_sample in data_samples: + pred_alpha, gt_alpha, trimap = _fetch_data_and_check(data_sample) + + thresh_steps = np.arange(0, 1 + self.step, self.step) + round_down_map = -np.ones_like(gt_alpha) + for i in range(1, len(thresh_steps)): + gt_alpha_thresh = gt_alpha >= thresh_steps[i] + pred_alpha_thresh = pred_alpha >= thresh_steps[i] + intersection = gt_alpha_thresh & pred_alpha_thresh + intersection = intersection.astype(np.uint8) + + # connected components + _, output, stats, _ = cv2.connectedComponentsWithStats( + intersection, connectivity=4) + # start from 1 in dim 0 to exclude background + size = stats[1:, -1] + + # largest connected component of the intersection + omega = np.zeros_like(gt_alpha) + if len(size) != 0: + max_id = np.argmax(size) + # plus one to include background + omega[output == max_id + 1] = 1 + + mask = (round_down_map == -1) & (omega == 0) + round_down_map[mask] = thresh_steps[i - 1] + round_down_map[round_down_map == -1] = 1 + + gt_alpha_diff = gt_alpha - round_down_map + pred_alpha_diff = pred_alpha - round_down_map + # only calculate difference larger than or equal to 0.15 + gt_alpha_phi = 1 - gt_alpha_diff * (gt_alpha_diff >= 0.15) + pred_alpha_phi = 1 - pred_alpha_diff * (pred_alpha_diff >= 0.15) + + connectivity_error = np.sum( + np.abs(gt_alpha_phi - pred_alpha_phi) * (trimap == 128)) + + # divide by 1000 to reduce the magnitude of the result + connectivity_error /= self.norm_constant + + self.results.append({'conn_err': connectivity_error}) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + + conn_err = average(results, 'conn_err') + + return {'ConnectivityError': conn_err} diff --git a/mmedit/evaluation/metrics/gradient_error.py b/mmedit/evaluation/metrics/gradient_error.py new file mode 100644 index 0000000000..de5a15dccc --- /dev/null +++ b/mmedit/evaluation/metrics/gradient_error.py @@ -0,0 +1,95 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Sequence + +import cv2 +import numpy as np +from mmengine.evaluator import BaseMetric + +from mmedit.registry import METRICS +from ..functional import gauss_gradient +from .metrics_utils import _fetch_data_and_check, average + + +@METRICS.register_module() +class GradientError(BaseMetric): + """Gradient error for evaluating alpha matte prediction. + + .. note:: + + Current implementation assume image / alpha / trimap array in numpy + format and with pixel value ranging from 0 to 255. + + .. note:: + + pred_alpha should be masked by trimap before passing + into this metric + + Args: + sigma (float): Standard deviation of the gaussian kernel. + Defaults to 1.4 . + norm_const (int): Divide the result to reduce its magnitude. + Defaults to 1000 . + + Default prefix: '' + + Metrics: + - GradientError (float): Gradient Error + """ + + def __init__( + self, + sigma=1.4, + norm_constant=1000, + **kwargs, + ) -> None: + self.sigma = sigma + self.norm_constant = norm_constant + super().__init__(**kwargs) + + def process(self, data_batch: Sequence[dict], + data_samples: Sequence[dict]) -> None: + """Process one batch of data samples and predictions. The processed + results should be stored in ``self.results``, which will be used to + compute the metrics when all batches have been processed. + + Args: + data_batch (Sequence[dict]): A batch of data from the dataloader. + predictions (Sequence[dict]): A batch of outputs from + the model. + """ + + for data_sample in data_samples: + pred_alpha, gt_alpha, trimap = _fetch_data_and_check(data_sample) + + gt_alpha_normed = np.zeros_like(gt_alpha) + pred_alpha_normed = np.zeros_like(pred_alpha) + + cv2.normalize(gt_alpha, gt_alpha_normed, 1.0, 0.0, cv2.NORM_MINMAX) + cv2.normalize(pred_alpha, pred_alpha_normed, 1.0, 0.0, + cv2.NORM_MINMAX) + + gt_alpha_grad = gauss_gradient(gt_alpha_normed, self.sigma) + pred_alpha_grad = gauss_gradient(pred_alpha_normed, self.sigma) + # this is the sum over n samples + grad_loss = ((gt_alpha_grad - pred_alpha_grad)**2 * + (trimap == 128)).sum() + + # divide by 1000 to reduce the magnitude of the result + grad_loss /= self.norm_constant + + self.results.append({'grad_err': grad_loss}) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + + grad_err = average(results, 'grad_err') + + return {'GradientError': grad_err} diff --git a/mmedit/evaluation/metrics/mae.py b/mmedit/evaluation/metrics/mae.py new file mode 100644 index 0000000000..0acd972a63 --- /dev/null +++ b/mmedit/evaluation/metrics/mae.py @@ -0,0 +1,60 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Evaluation metrics based on pixels.""" + +import numpy as np + +from mmedit.registry import METRICS +from .base_sample_wise_metric import BaseSampleWiseMetric + + +@METRICS.register_module() +class MAE(BaseSampleWiseMetric): + """Mean Absolute Error metric for image. + + mean(abs(a-b)) + + Args: + + gt_key (str): Key of ground-truth. Default: 'gt_img' + pred_key (str): Key of prediction. Default: 'pred_img' + mask_key (str, optional): Key of mask, if mask_key is None, calculate + all regions. Default: None + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Default: None + + Metrics: + - MAE (float): Mean of Absolute Error + """ + + metric = 'MAE' + + def process_image(self, gt, pred, mask): + """Process an image. + + Args: + gt (Tensor | np.ndarray): GT image. + pred (Tensor | np.ndarray): Pred image. + mask (Tensor | np.ndarray): Mask of evaluation. + Returns: + result (np.ndarray): MAE result. + """ + + gt = gt / 255. + pred = pred / 255. + + diff = gt - pred + diff = abs(diff) + + if self.mask_key is not None: + diff *= mask # broadcast for channel dimension + scale = np.prod(diff.shape) / np.prod(mask.shape) + result = diff.sum() / (mask.sum() * scale + 1e-12) + else: + result = diff.mean() + + return result diff --git a/mmedit/evaluation/metrics/matting.py b/mmedit/evaluation/metrics/matting.py deleted file mode 100644 index ba4c6e30f9..0000000000 --- a/mmedit/evaluation/metrics/matting.py +++ /dev/null @@ -1,396 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -"""Evaluation metrics used in Image Matting.""" - -from typing import List, Sequence - -import cv2 -import numpy as np -from mmengine.evaluator import BaseMetric - -from mmedit.registry import METRICS -from ..functional import gauss_gradient -from .metrics_utils import average - - -def _assert_ndim(input, name, ndim, shape_hint): - if input.ndim != ndim: - raise ValueError( - f'{name} should be of shape {shape_hint}, but got {input.shape}.') - - -def _assert_masked(pred_alpha, trimap): - if (pred_alpha[trimap == 0] != 0).any() or (pred_alpha[trimap == 255] != - 255).any(): - raise ValueError( - 'pred_alpha should be masked by trimap before evaluation') - - -def _fetch_data_and_check(data_samples): - """Fetch and check data from one item of data_batch and predictions. - - Args: - data_batch (dict): One item of data_batch. - predictions (dict): One item of predictions. - - Returns: - pred_alpha (Tensor): Pred_alpha data of predictions. - ori_alpha (Tensor): Ori_alpha data of data_batch. - ori_trimap (Tensor): Ori_trimap data of data_batch. - """ - ori_trimap = data_samples['ori_trimap'][:, :, 0] - ori_alpha = data_samples['ori_alpha'][:, :, 0] - pred_alpha = data_samples['output']['pred_alpha']['data'] # 2D tensor - pred_alpha = pred_alpha.cpu().numpy() - - _assert_ndim(ori_trimap, 'trimap', 2, 'HxW') - _assert_ndim(ori_alpha, 'gt_alpha', 2, 'HxW') - _assert_ndim(pred_alpha, 'pred_alpha', 2, 'HxW') - _assert_masked(pred_alpha, ori_trimap) - - # dtype uint8 -> float64 - pred_alpha = pred_alpha / 255.0 - ori_alpha = ori_alpha / 255.0 - # test shows that using float32 vs float64 differs final results at 1e-4 - # speed are comparable, so we choose float64 for accuracy - - return pred_alpha, ori_alpha, ori_trimap - - -@METRICS.register_module() -class SAD(BaseMetric): - """Sum of Absolute Differences metric for image matting. - - This metric compute per-pixel absolute difference and sum across all - pixels. - i.e. sum(abs(a-b)) / norm_const - - .. note:: - - Current implementation assume image / alpha / trimap array in numpy - format and with pixel value ranging from 0 to 255. - - .. note:: - - pred_alpha should be masked by trimap before passing - into this metric - - Default prefix: '' - - Args: - norm_const (int): Divide the result to reduce its magnitude. - Default to 1000. - - Metrics: - - SAD (float): Sum of Absolute Differences - """ - - default_prefix = '' - - def __init__( - self, - norm_const=1000, - **kwargs, - ) -> None: - self.norm_const = norm_const - super().__init__(**kwargs) - - def process(self, data_batch: Sequence[dict], - data_samples: Sequence[dict]) -> None: - """Process one batch of data and predictions. - - Args: - data_batch (Sequence[Tuple[Any, dict]]): A batch of data - from the dataloader. - predictions (Sequence[dict]): A batch of outputs from - the model. - """ - for data_sample in data_samples: - pred_alpha, gt_alpha, _ = _fetch_data_and_check(data_sample) - - # divide by 1000 to reduce the magnitude of the result - sad_sum = np.abs(pred_alpha - gt_alpha).sum() / self.norm_const - - result = {'sad': sad_sum} - - self.results.append(result) - - def compute_metrics(self, results: List): - """Compute the metrics from processed results. - - Args: - results (dict): The processed results of each batch. - - Returns: - Dict: The computed metrics. The keys are the names of the metrics, - and the values are corresponding results. - """ - - sad = average(results, 'sad') - - return {'SAD': sad} - - -@METRICS.register_module() -class MattingMSE(BaseMetric): - """Mean Squared Error metric for image matting. - - This metric compute per-pixel squared error average across all - pixels. - i.e. mean((a-b)^2) / norm_const - - .. note:: - - Current implementation assume image / alpha / trimap array in numpy - format and with pixel value ranging from 0 to 255. - - .. note:: - - pred_alpha should be masked by trimap before passing - into this metric - - Default prefix: '' - - Args: - norm_const (int): Divide the result to reduce its magnitude. - Default to 1000. - - Metrics: - - MattingMSE (float): Mean of Squared Error - """ - - default_prefix = '' - - def __init__( - self, - norm_const=1000, - **kwargs, - ) -> None: - self.norm_const = norm_const - super().__init__(**kwargs) - - def process(self, data_batch: Sequence[dict], - data_samples: Sequence[dict]) -> None: - """Process one batch of data and predictions. - - Args: - data_batch (Sequence[dict]): A batch of data - from the dataloader. - data_samples (Sequence[dict]): A batch of outputs from - the model. - """ - for data_sample in data_samples: - pred_alpha, gt_alpha, trimap = _fetch_data_and_check(data_sample) - - weight_sum = (trimap == 128).sum() - if weight_sum != 0: - mse_result = ((pred_alpha - gt_alpha)**2).sum() / weight_sum - else: - mse_result = 0 - - self.results.append({'mse': mse_result}) - - def compute_metrics(self, results: List): - """Compute the metrics from processed results. - - Args: - results (dict): The processed results of each batch. - - Returns: - Dict: The computed metrics. The keys are the names of the metrics, - and the values are corresponding results. - """ - - mse = average(results, 'mse') - - return {'MattingMSE': mse} - - -@METRICS.register_module() -class GradientError(BaseMetric): - """Gradient error for evaluating alpha matte prediction. - - .. note:: - - Current implementation assume image / alpha / trimap array in numpy - format and with pixel value ranging from 0 to 255. - - .. note:: - - pred_alpha should be masked by trimap before passing - into this metric - - Args: - sigma (float): Standard deviation of the gaussian kernel. - Defaults to 1.4 . - norm_const (int): Divide the result to reduce its magnitude. - Defaults to 1000 . - - Default prefix: '' - - Metrics: - - GradientError (float): Gradient Error - """ - - def __init__( - self, - sigma=1.4, - norm_constant=1000, - **kwargs, - ) -> None: - self.sigma = sigma - self.norm_constant = norm_constant - super().__init__(**kwargs) - - def process(self, data_batch: Sequence[dict], - data_samples: Sequence[dict]) -> None: - """Process one batch of data samples and predictions. The processed - results should be stored in ``self.results``, which will be used to - compute the metrics when all batches have been processed. - - Args: - data_batch (Sequence[dict]): A batch of data from the dataloader. - predictions (Sequence[dict]): A batch of outputs from - the model. - """ - - for data_sample in data_samples: - pred_alpha, gt_alpha, trimap = _fetch_data_and_check(data_sample) - - gt_alpha_normed = np.zeros_like(gt_alpha) - pred_alpha_normed = np.zeros_like(pred_alpha) - - cv2.normalize(gt_alpha, gt_alpha_normed, 1.0, 0.0, cv2.NORM_MINMAX) - cv2.normalize(pred_alpha, pred_alpha_normed, 1.0, 0.0, - cv2.NORM_MINMAX) - - gt_alpha_grad = gauss_gradient(gt_alpha_normed, self.sigma) - pred_alpha_grad = gauss_gradient(pred_alpha_normed, self.sigma) - # this is the sum over n samples - grad_loss = ((gt_alpha_grad - pred_alpha_grad)**2 * - (trimap == 128)).sum() - - # divide by 1000 to reduce the magnitude of the result - grad_loss /= self.norm_constant - - self.results.append({'grad_err': grad_loss}) - - def compute_metrics(self, results: List): - """Compute the metrics from processed results. - - Args: - results (dict): The processed results of each batch. - - Returns: - Dict: The computed metrics. The keys are the names of the metrics, - and the values are corresponding results. - """ - - grad_err = average(results, 'grad_err') - - return {'GradientError': grad_err} - - -@METRICS.register_module() -class ConnectivityError(BaseMetric): - """Connectivity error for evaluating alpha matte prediction. - - .. note:: - - Current implementation assume image / alpha / trimap array in numpy - format and with pixel value ranging from 0 to 255. - - .. note:: - - pred_alpha should be masked by trimap before passing - into this metric - - Args: - step (float): Step of threshold when computing intersection between - `alpha` and `pred_alpha`. Default to 0.1 . - norm_const (int): Divide the result to reduce its magnitude. - Default to 1000. - - Default prefix: '' - - Metrics: - - ConnectivityError (float): Connectivity Error - """ - - def __init__( - self, - step=0.1, - norm_constant=1000, - **kwargs, - ) -> None: - self.step = step - self.norm_constant = norm_constant - super().__init__(**kwargs) - - def process(self, data_batch: Sequence[dict], - data_samples: Sequence[dict]) -> None: - """Process one batch of data samples and predictions. The processed - results should be stored in ``self.results``, which will be used to - compute the metrics when all batches have been processed. - - Args: - data_batch (Sequence[dict]): A batch of data from the dataloader. - predictions (Sequence[dict]): A batch of outputs from - the model. - """ - - for data_sample in data_samples: - pred_alpha, gt_alpha, trimap = _fetch_data_and_check(data_sample) - - thresh_steps = np.arange(0, 1 + self.step, self.step) - round_down_map = -np.ones_like(gt_alpha) - for i in range(1, len(thresh_steps)): - gt_alpha_thresh = gt_alpha >= thresh_steps[i] - pred_alpha_thresh = pred_alpha >= thresh_steps[i] - intersection = gt_alpha_thresh & pred_alpha_thresh - intersection = intersection.astype(np.uint8) - - # connected components - _, output, stats, _ = cv2.connectedComponentsWithStats( - intersection, connectivity=4) - # start from 1 in dim 0 to exclude background - size = stats[1:, -1] - - # largest connected component of the intersection - omega = np.zeros_like(gt_alpha) - if len(size) != 0: - max_id = np.argmax(size) - # plus one to include background - omega[output == max_id + 1] = 1 - - mask = (round_down_map == -1) & (omega == 0) - round_down_map[mask] = thresh_steps[i - 1] - round_down_map[round_down_map == -1] = 1 - - gt_alpha_diff = gt_alpha - round_down_map - pred_alpha_diff = pred_alpha - round_down_map - # only calculate difference larger than or equal to 0.15 - gt_alpha_phi = 1 - gt_alpha_diff * (gt_alpha_diff >= 0.15) - pred_alpha_phi = 1 - pred_alpha_diff * (pred_alpha_diff >= 0.15) - - connectivity_error = np.sum( - np.abs(gt_alpha_phi - pred_alpha_phi) * (trimap == 128)) - - # divide by 1000 to reduce the magnitude of the result - connectivity_error /= self.norm_constant - - self.results.append({'conn_err': connectivity_error}) - - def compute_metrics(self, results: List): - """Compute the metrics from processed results. - - Args: - results (dict): The processed results of each batch. - - Returns: - Dict: The computed metrics. The keys are the names of the metrics, - and the values are corresponding results. - """ - - conn_err = average(results, 'conn_err') - - return {'ConnectivityError': conn_err} diff --git a/mmedit/evaluation/metrics/matting_mse.py b/mmedit/evaluation/metrics/matting_mse.py new file mode 100644 index 0000000000..d734c01bb2 --- /dev/null +++ b/mmedit/evaluation/metrics/matting_mse.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Sequence + +from mmengine.evaluator import BaseMetric + +from mmedit.registry import METRICS +from .metrics_utils import _fetch_data_and_check, average + + +@METRICS.register_module() +class MattingMSE(BaseMetric): + """Mean Squared Error metric for image matting. + + This metric compute per-pixel squared error average across all + pixels. + i.e. mean((a-b)^2) / norm_const + + .. note:: + + Current implementation assume image / alpha / trimap array in numpy + format and with pixel value ranging from 0 to 255. + + .. note:: + + pred_alpha should be masked by trimap before passing + into this metric + + Default prefix: '' + + Args: + norm_const (int): Divide the result to reduce its magnitude. + Default to 1000. + + Metrics: + - MattingMSE (float): Mean of Squared Error + """ + + default_prefix = '' + + def __init__( + self, + norm_const=1000, + **kwargs, + ) -> None: + self.norm_const = norm_const + super().__init__(**kwargs) + + def process(self, data_batch: Sequence[dict], + data_samples: Sequence[dict]) -> None: + """Process one batch of data and predictions. + + Args: + data_batch (Sequence[dict]): A batch of data + from the dataloader. + data_samples (Sequence[dict]): A batch of outputs from + the model. + """ + for data_sample in data_samples: + pred_alpha, gt_alpha, trimap = _fetch_data_and_check(data_sample) + + weight_sum = (trimap == 128).sum() + if weight_sum != 0: + mse_result = ((pred_alpha - gt_alpha)**2).sum() / weight_sum + else: + mse_result = 0 + + self.results.append({'mse': mse_result}) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + + mse = average(results, 'mse') + + return {'MattingMSE': mse} diff --git a/mmedit/evaluation/metrics/metrics_utils.py b/mmedit/evaluation/metrics/metrics_utils.py index 7519f54111..4ed5602a64 100644 --- a/mmedit/evaluation/metrics/metrics_utils.py +++ b/mmedit/evaluation/metrics/metrics_utils.py @@ -6,6 +6,50 @@ from mmedit.utils import reorder_image +def _assert_ndim(input, name, ndim, shape_hint): + if input.ndim != ndim: + raise ValueError( + f'{name} should be of shape {shape_hint}, but got {input.shape}.') + + +def _assert_masked(pred_alpha, trimap): + if (pred_alpha[trimap == 0] != 0).any() or (pred_alpha[trimap == 255] != + 255).any(): + raise ValueError( + 'pred_alpha should be masked by trimap before evaluation') + + +def _fetch_data_and_check(data_samples): + """Fetch and check data from one item of data_batch and predictions. + + Args: + data_batch (dict): One item of data_batch. + predictions (dict): One item of predictions. + + Returns: + pred_alpha (Tensor): Pred_alpha data of predictions. + ori_alpha (Tensor): Ori_alpha data of data_batch. + ori_trimap (Tensor): Ori_trimap data of data_batch. + """ + ori_trimap = data_samples['ori_trimap'][:, :, 0] + ori_alpha = data_samples['ori_alpha'][:, :, 0] + pred_alpha = data_samples['output']['pred_alpha']['data'] # 2D tensor + pred_alpha = pred_alpha.cpu().numpy() + + _assert_ndim(ori_trimap, 'trimap', 2, 'HxW') + _assert_ndim(ori_alpha, 'gt_alpha', 2, 'HxW') + _assert_ndim(pred_alpha, 'pred_alpha', 2, 'HxW') + _assert_masked(pred_alpha, ori_trimap) + + # dtype uint8 -> float64 + pred_alpha = pred_alpha / 255.0 + ori_alpha = ori_alpha / 255.0 + # test shows that using float32 vs float64 differs final results at 1e-4 + # speed are comparable, so we choose float64 for accuracy + + return pred_alpha, ori_alpha, ori_trimap + + def average(results, key): """Average of key in results(list[dict]). diff --git a/mmedit/evaluation/metrics/mse.py b/mmedit/evaluation/metrics/mse.py new file mode 100644 index 0000000000..863fbb97cf --- /dev/null +++ b/mmedit/evaluation/metrics/mse.py @@ -0,0 +1,57 @@ +# Copyright (c) OpenMMLab. All rights reserved. +"""Evaluation metrics based on pixels.""" + +from mmedit.registry import METRICS +from .base_sample_wise_metric import BaseSampleWiseMetric + + +@METRICS.register_module() +class MSE(BaseSampleWiseMetric): + """Mean Squared Error metric for image. + + mean((a-b)^2) + + Args: + + gt_key (str): Key of ground-truth. Default: 'gt_img' + pred_key (str): Key of prediction. Default: 'pred_img' + mask_key (str, optional): Key of mask, if mask_key is None, calculate + all regions. Default: None + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Default: None + + Metrics: + - MSE (float): Mean of Squared Error + """ + + metric = 'MSE' + + def process_image(self, gt, pred, mask): + """Process an image. + + Args: + gt (Torch | np.ndarray): GT image. + pred (Torch | np.ndarray): Pred image. + mask (Torch | np.ndarray): Mask of evaluation. + Returns: + result (np.ndarray): MSE result. + """ + + gt = gt / 255. + pred = pred / 255. + + diff = gt - pred + diff *= diff + + if self.mask_key is not None: + diff *= mask + result = diff.sum() / mask.sum() + else: + result = diff.mean() + + return result diff --git a/mmedit/evaluation/metrics/pixel_metrics.py b/mmedit/evaluation/metrics/pixel_metrics.py deleted file mode 100644 index 19479bc4a6..0000000000 --- a/mmedit/evaluation/metrics/pixel_metrics.py +++ /dev/null @@ -1,360 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -"""Evaluation metrics based on pixels.""" - -from typing import Optional - -import numpy as np - -from mmedit.registry import METRICS -from .base_sample_wise_metric import BaseSampleWiseMetric -from .metrics_utils import img_transform - - -@METRICS.register_module() -class MAE(BaseSampleWiseMetric): - """Mean Absolute Error metric for image. - - mean(abs(a-b)) - - Args: - - gt_key (str): Key of ground-truth. Default: 'gt_img' - pred_key (str): Key of prediction. Default: 'pred_img' - mask_key (str, optional): Key of mask, if mask_key is None, calculate - all regions. Default: None - collect_device (str): Device name used for collecting results from - different ranks during distributed training. Must be 'cpu' or - 'gpu'. Defaults to 'cpu'. - prefix (str, optional): The prefix that will be added in the metric - names to disambiguate homonymous metrics of different evaluators. - If prefix is not provided in the argument, self.default_prefix - will be used instead. Default: None - - Metrics: - - MAE (float): Mean of Absolute Error - """ - - metric = 'MAE' - - def process_image(self, gt, pred, mask): - """Process an image. - - Args: - gt (Tensor | np.ndarray): GT image. - pred (Tensor | np.ndarray): Pred image. - mask (Tensor | np.ndarray): Mask of evaluation. - Returns: - result (np.ndarray): MAE result. - """ - - gt = gt / 255. - pred = pred / 255. - - diff = gt - pred - diff = abs(diff) - - if self.mask_key is not None: - diff *= mask # broadcast for channel dimension - scale = np.prod(diff.shape) / np.prod(mask.shape) - result = diff.sum() / (mask.sum() * scale + 1e-12) - else: - result = diff.mean() - - return result - - -@METRICS.register_module() -class MSE(BaseSampleWiseMetric): - """Mean Squared Error metric for image. - - mean((a-b)^2) - - Args: - - gt_key (str): Key of ground-truth. Default: 'gt_img' - pred_key (str): Key of prediction. Default: 'pred_img' - mask_key (str, optional): Key of mask, if mask_key is None, calculate - all regions. Default: None - collect_device (str): Device name used for collecting results from - different ranks during distributed training. Must be 'cpu' or - 'gpu'. Defaults to 'cpu'. - prefix (str, optional): The prefix that will be added in the metric - names to disambiguate homonymous metrics of different evaluators. - If prefix is not provided in the argument, self.default_prefix - will be used instead. Default: None - - Metrics: - - MSE (float): Mean of Squared Error - """ - - metric = 'MSE' - - def process_image(self, gt, pred, mask): - """Process an image. - - Args: - gt (Torch | np.ndarray): GT image. - pred (Torch | np.ndarray): Pred image. - mask (Torch | np.ndarray): Mask of evaluation. - Returns: - result (np.ndarray): MSE result. - """ - - gt = gt / 255. - pred = pred / 255. - - diff = gt - pred - diff *= diff - - if self.mask_key is not None: - diff *= mask - result = diff.sum() / mask.sum() - else: - result = diff.mean() - - return result - - -@METRICS.register_module() -class PSNR(BaseSampleWiseMetric): - """Peak Signal-to-Noise Ratio. - - Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio - - Args: - - gt_key (str): Key of ground-truth. Default: 'gt_img' - pred_key (str): Key of prediction. Default: 'pred_img' - collect_device (str): Device name used for collecting results from - different ranks during distributed training. Must be 'cpu' or - 'gpu'. Defaults to 'cpu'. - prefix (str, optional): The prefix that will be added in the metric - names to disambiguate homonymous metrics of different evaluators. - If prefix is not provided in the argument, self.default_prefix - will be used instead. Default: None - crop_border (int): Cropped pixels in each edges of an image. These - pixels are not involved in the PSNR calculation. Default: 0. - input_order (str): Whether the input order is 'HWC' or 'CHW'. - Default: 'CHW'. - convert_to (str): Whether to convert the images to other color models. - If None, the images are not altered. When computing for 'Y', - the images are assumed to be in BGR order. Options are 'Y' and - None. Default: None. - - Metrics: - - PSNR (float): Peak Signal-to-Noise Ratio - """ - - metric = 'PSNR' - - def __init__(self, - gt_key: str = 'gt_img', - pred_key: str = 'pred_img', - collect_device: str = 'cpu', - prefix: Optional[str] = None, - crop_border=0, - input_order='CHW', - convert_to=None) -> None: - super().__init__( - gt_key=gt_key, - pred_key=pred_key, - mask_key=None, - collect_device=collect_device, - prefix=prefix) - - self.crop_border = crop_border - self.input_order = input_order - self.convert_to = convert_to - - def process_image(self, gt, pred, mask): - """Process an image. - - Args: - gt (Torch | np.ndarray): GT image. - pred (Torch | np.ndarray): Pred image. - mask (Torch | np.ndarray): Mask of evaluation. - Returns: - np.ndarray: PSNR result. - """ - - return psnr( - img1=gt, - img2=pred, - crop_border=self.crop_border, - input_order=self.input_order, - convert_to=self.convert_to, - channel_order=self.channel_order) - - -@METRICS.register_module() -class SNR(BaseSampleWiseMetric): - """Signal-to-Noise Ratio. - - Ref: https://en.wikipedia.org/wiki/Signal-to-noise_ratio - - Args: - - gt_key (str): Key of ground-truth. Default: 'gt_img' - pred_key (str): Key of prediction. Default: 'pred_img' - collect_device (str): Device name used for collecting results from - different ranks during distributed training. Must be 'cpu' or - 'gpu'. Defaults to 'cpu'. - prefix (str, optional): The prefix that will be added in the metric - names to disambiguate homonymous metrics of different evaluators. - If prefix is not provided in the argument, self.default_prefix - will be used instead. Default: None - crop_border (int): Cropped pixels in each edges of an image. These - pixels are not involved in the SNR calculation. Default: 0. - input_order (str): Whether the input order is 'HWC' or 'CHW'. - Default: 'CHW'. - convert_to (str): Whether to convert the images to other color models. - If None, the images are not altered. When computing for 'Y', - the images are assumed to be in BGR order. Options are 'Y' and - None. Default: None. - - Metrics: - - SNR (float): Signal-to-Noise Ratio - """ - - metric = 'SNR' - - def __init__(self, - gt_key: str = 'gt_img', - pred_key: str = 'pred_img', - collect_device: str = 'cpu', - prefix: Optional[str] = None, - crop_border=0, - input_order='CHW', - convert_to=None) -> None: - super().__init__( - gt_key=gt_key, - pred_key=pred_key, - mask_key=None, - collect_device=collect_device, - prefix=prefix) - - self.crop_border = crop_border - self.input_order = input_order - self.convert_to = convert_to - - def process_image(self, gt, pred, mask): - """Process an image. - - Args: - gt (Torch | np.ndarray): GT image. - pred (Torch | np.ndarray): Pred image. - mask (Torch | np.ndarray): Mask of evaluation. - Returns: - np.ndarray: SNR result. - """ - - return snr( - gt=gt, - pred=pred, - crop_border=self.crop_border, - input_order=self.input_order, - convert_to=self.convert_to, - channel_order=self.channel_order) - - -def psnr(img1, - img2, - crop_border=0, - input_order='HWC', - convert_to=None, - channel_order='rgb'): - """Calculate PSNR (Peak Signal-to-Noise Ratio). - - Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio - - Args: - img1 (ndarray): Images with range [0, 255]. - img2 (ndarray): Images with range [0, 255]. - crop_border (int): Cropped pixels in each edges of an image. These - pixels are not involved in the PSNR calculation. Default: 0. - input_order (str): Whether the input order is 'HWC' or 'CHW'. - Default: 'HWC'. - convert_to (str): Whether to convert the images to other color models. - If None, the images are not altered. When computing for 'Y', - the images are assumed to be in BGR order. Options are 'Y' and - None. Default: None. - channel_order (str): The channel order of image. Default: 'rgb'. - - Returns: - result (float): PSNR result. - """ - - assert img1.shape == img2.shape, ( - f'Image shapes are different: {img1.shape}, {img2.shape}.') - - img1 = img_transform( - img1, - crop_border=crop_border, - input_order=input_order, - convert_to=convert_to, - channel_order=channel_order) - img2 = img_transform( - img2, - crop_border=crop_border, - input_order=input_order, - convert_to=convert_to, - channel_order=channel_order) - - mse_value = ((img1 - img2)**2).mean() - if mse_value == 0: - result = float('inf') - else: - result = 20. * np.log10(255. / np.sqrt(mse_value)) - - return result - - -def snr(gt, - pred, - crop_border=0, - input_order='HWC', - convert_to=None, - channel_order='rgb'): - """Calculate PSNR (Peak Signal-to-Noise Ratio). - - Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio - - Args: - gt (ndarray): Images with range [0, 255]. - pred (ndarray): Images with range [0, 255]. - crop_border (int): Cropped pixels in each edges of an image. These - pixels are not involved in the PSNR calculation. Default: 0. - input_order (str): Whether the input order is 'HWC' or 'CHW'. - Default: 'HWC'. - convert_to (str): Whether to convert the images to other color models. - If None, the images are not altered. When computing for 'Y', - the images are assumed to be in BGR order. Options are 'Y' and - None. Default: None. - channel_order (str): The channel order of image. Default: 'rgb'. - - Returns: - float: SNR result. - """ - - assert gt.shape == pred.shape, ( - f'Image shapes are different: {gt.shape}, {pred.shape}.') - - gt = img_transform( - gt, - crop_border=crop_border, - input_order=input_order, - convert_to=convert_to, - channel_order=channel_order) - pred = img_transform( - pred, - crop_border=crop_border, - input_order=input_order, - convert_to=convert_to, - channel_order=channel_order) - - signal = ((gt)**2).mean() - noise = ((gt - pred)**2).mean() - - result = 10. * np.log10(signal / noise) - - return result diff --git a/mmedit/evaluation/metrics/psnr.py b/mmedit/evaluation/metrics/psnr.py new file mode 100644 index 0000000000..9aec992b1a --- /dev/null +++ b/mmedit/evaluation/metrics/psnr.py @@ -0,0 +1,131 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import numpy as np + +from mmedit.registry import METRICS +from .base_sample_wise_metric import BaseSampleWiseMetric +from .metrics_utils import img_transform + + +@METRICS.register_module() +class PSNR(BaseSampleWiseMetric): + """Peak Signal-to-Noise Ratio. + + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + + gt_key (str): Key of ground-truth. Default: 'gt_img' + pred_key (str): Key of prediction. Default: 'pred_img' + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Default: None + crop_border (int): Cropped pixels in each edges of an image. These + pixels are not involved in the PSNR calculation. Default: 0. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'CHW'. + convert_to (str): Whether to convert the images to other color models. + If None, the images are not altered. When computing for 'Y', + the images are assumed to be in BGR order. Options are 'Y' and + None. Default: None. + + Metrics: + - PSNR (float): Peak Signal-to-Noise Ratio + """ + + metric = 'PSNR' + + def __init__(self, + gt_key: str = 'gt_img', + pred_key: str = 'pred_img', + collect_device: str = 'cpu', + prefix: Optional[str] = None, + crop_border=0, + input_order='CHW', + convert_to=None) -> None: + super().__init__( + gt_key=gt_key, + pred_key=pred_key, + mask_key=None, + collect_device=collect_device, + prefix=prefix) + + self.crop_border = crop_border + self.input_order = input_order + self.convert_to = convert_to + + def process_image(self, gt, pred, mask): + """Process an image. + + Args: + gt (Torch | np.ndarray): GT image. + pred (Torch | np.ndarray): Pred image. + mask (Torch | np.ndarray): Mask of evaluation. + Returns: + np.ndarray: PSNR result. + """ + + return psnr( + img1=gt, + img2=pred, + crop_border=self.crop_border, + input_order=self.input_order, + convert_to=self.convert_to, + channel_order=self.channel_order) + + +def psnr(img1, + img2, + crop_border=0, + input_order='HWC', + convert_to=None, + channel_order='rgb'): + """Calculate PSNR (Peak Signal-to-Noise Ratio). + + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + img1 (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edges of an image. These + pixels are not involved in the PSNR calculation. Default: 0. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + convert_to (str): Whether to convert the images to other color models. + If None, the images are not altered. When computing for 'Y', + the images are assumed to be in BGR order. Options are 'Y' and + None. Default: None. + channel_order (str): The channel order of image. Default: 'rgb'. + + Returns: + result (float): PSNR result. + """ + + assert img1.shape == img2.shape, ( + f'Image shapes are different: {img1.shape}, {img2.shape}.') + + img1 = img_transform( + img1, + crop_border=crop_border, + input_order=input_order, + convert_to=convert_to, + channel_order=channel_order) + img2 = img_transform( + img2, + crop_border=crop_border, + input_order=input_order, + convert_to=convert_to, + channel_order=channel_order) + + mse_value = ((img1 - img2)**2).mean() + if mse_value == 0: + result = float('inf') + else: + result = 20. * np.log10(255. / np.sqrt(mse_value)) + + return result diff --git a/mmedit/evaluation/metrics/sad.py b/mmedit/evaluation/metrics/sad.py new file mode 100644 index 0000000000..05abb6153f --- /dev/null +++ b/mmedit/evaluation/metrics/sad.py @@ -0,0 +1,82 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Sequence + +import numpy as np +from mmengine.evaluator import BaseMetric + +from mmedit.registry import METRICS +from .metrics_utils import _fetch_data_and_check, average + + +@METRICS.register_module() +class SAD(BaseMetric): + """Sum of Absolute Differences metric for image matting. + + This metric compute per-pixel absolute difference and sum across all + pixels. + i.e. sum(abs(a-b)) / norm_const + + .. note:: + + Current implementation assume image / alpha / trimap array in numpy + format and with pixel value ranging from 0 to 255. + + .. note:: + + pred_alpha should be masked by trimap before passing + into this metric + + Default prefix: '' + + Args: + norm_const (int): Divide the result to reduce its magnitude. + Default to 1000. + + Metrics: + - SAD (float): Sum of Absolute Differences + """ + + default_prefix = '' + + def __init__( + self, + norm_const=1000, + **kwargs, + ) -> None: + self.norm_const = norm_const + super().__init__(**kwargs) + + def process(self, data_batch: Sequence[dict], + data_samples: Sequence[dict]) -> None: + """Process one batch of data and predictions. + + Args: + data_batch (Sequence[Tuple[Any, dict]]): A batch of data + from the dataloader. + predictions (Sequence[dict]): A batch of outputs from + the model. + """ + for data_sample in data_samples: + pred_alpha, gt_alpha, _ = _fetch_data_and_check(data_sample) + + # divide by 1000 to reduce the magnitude of the result + sad_sum = np.abs(pred_alpha - gt_alpha).sum() / self.norm_const + + result = {'sad': sad_sum} + + self.results.append(result) + + def compute_metrics(self, results: List): + """Compute the metrics from processed results. + + Args: + results (dict): The processed results of each batch. + + Returns: + Dict: The computed metrics. The keys are the names of the metrics, + and the values are corresponding results. + """ + + sad = average(results, 'sad') + + return {'SAD': sad} diff --git a/mmedit/evaluation/metrics/snr.py b/mmedit/evaluation/metrics/snr.py new file mode 100644 index 0000000000..e94d35c4f6 --- /dev/null +++ b/mmedit/evaluation/metrics/snr.py @@ -0,0 +1,130 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import numpy as np + +from mmedit.registry import METRICS +from .base_sample_wise_metric import BaseSampleWiseMetric +from .metrics_utils import img_transform + + +@METRICS.register_module() +class SNR(BaseSampleWiseMetric): + """Signal-to-Noise Ratio. + + Ref: https://en.wikipedia.org/wiki/Signal-to-noise_ratio + + Args: + + gt_key (str): Key of ground-truth. Default: 'gt_img' + pred_key (str): Key of prediction. Default: 'pred_img' + collect_device (str): Device name used for collecting results from + different ranks during distributed training. Must be 'cpu' or + 'gpu'. Defaults to 'cpu'. + prefix (str, optional): The prefix that will be added in the metric + names to disambiguate homonymous metrics of different evaluators. + If prefix is not provided in the argument, self.default_prefix + will be used instead. Default: None + crop_border (int): Cropped pixels in each edges of an image. These + pixels are not involved in the SNR calculation. Default: 0. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'CHW'. + convert_to (str): Whether to convert the images to other color models. + If None, the images are not altered. When computing for 'Y', + the images are assumed to be in BGR order. Options are 'Y' and + None. Default: None. + + Metrics: + - SNR (float): Signal-to-Noise Ratio + """ + + metric = 'SNR' + + def __init__(self, + gt_key: str = 'gt_img', + pred_key: str = 'pred_img', + collect_device: str = 'cpu', + prefix: Optional[str] = None, + crop_border=0, + input_order='CHW', + convert_to=None) -> None: + super().__init__( + gt_key=gt_key, + pred_key=pred_key, + mask_key=None, + collect_device=collect_device, + prefix=prefix) + + self.crop_border = crop_border + self.input_order = input_order + self.convert_to = convert_to + + def process_image(self, gt, pred, mask): + """Process an image. + + Args: + gt (Torch | np.ndarray): GT image. + pred (Torch | np.ndarray): Pred image. + mask (Torch | np.ndarray): Mask of evaluation. + Returns: + np.ndarray: SNR result. + """ + + return snr( + gt=gt, + pred=pred, + crop_border=self.crop_border, + input_order=self.input_order, + convert_to=self.convert_to, + channel_order=self.channel_order) + + +def snr(gt, + pred, + crop_border=0, + input_order='HWC', + convert_to=None, + channel_order='rgb'): + """Calculate PSNR (Peak Signal-to-Noise Ratio). + + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + + Args: + gt (ndarray): Images with range [0, 255]. + pred (ndarray): Images with range [0, 255]. + crop_border (int): Cropped pixels in each edges of an image. These + pixels are not involved in the PSNR calculation. Default: 0. + input_order (str): Whether the input order is 'HWC' or 'CHW'. + Default: 'HWC'. + convert_to (str): Whether to convert the images to other color models. + If None, the images are not altered. When computing for 'Y', + the images are assumed to be in BGR order. Options are 'Y' and + None. Default: None. + channel_order (str): The channel order of image. Default: 'rgb'. + + Returns: + float: SNR result. + """ + + assert gt.shape == pred.shape, ( + f'Image shapes are different: {gt.shape}, {pred.shape}.') + + gt = img_transform( + gt, + crop_border=crop_border, + input_order=input_order, + convert_to=convert_to, + channel_order=channel_order) + pred = img_transform( + pred, + crop_border=crop_border, + input_order=input_order, + convert_to=convert_to, + channel_order=channel_order) + + signal = ((gt)**2).mean() + noise = ((gt - pred)**2).mean() + + result = 10. * np.log10(signal / noise) + + return result diff --git a/tests/test_evaluation/test_metrics/test_connectivity_error.py b/tests/test_evaluation/test_metrics/test_connectivity_error.py new file mode 100644 index 0000000000..29ce9be702 --- /dev/null +++ b/tests/test_evaluation/test_metrics/test_connectivity_error.py @@ -0,0 +1,108 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from pathlib import Path + +import numpy as np +import pytest +import torch + +from mmedit.datasets.transforms import LoadImageFromFile +from mmedit.evaluation.metrics import ConnectivityError + + +class TestMattingMetrics: + + @classmethod + def setup_class(cls): + # Make sure these values are immutable across different test cases. + + # This test depends on the interface of loading + # if loading is changed, data should be change accordingly. + test_path = Path(__file__).parent.parent.parent + alpha_path = ( + test_path / 'data' / 'matting_dataset' / 'alpha' / 'GT05.jpg') + + results = dict(alpha_path=alpha_path) + config = dict(key='alpha') + image_loader = LoadImageFromFile(**config) + results = image_loader(results) + assert results['alpha'].ndim == 3 + + gt_alpha = np.ones((32, 32), dtype=np.uint8) * 255 + trimap = np.zeros((32, 32), dtype=np.uint8) + trimap[:16, :16] = 128 + trimap[16:, 16:] = 255 + # non-masked pred_alpha + pred_alpha = torch.zeros((32, 32), dtype=torch.uint8) + # masked pred_alpha + masked_pred_alpha = pred_alpha.clone() + masked_pred_alpha[trimap == 0] = 0 + masked_pred_alpha[trimap == 255] = 255 + + gt_alpha = gt_alpha[..., None] + trimap = trimap[..., None] + # pred_alpha = pred_alpha.unsqueeze(0) + # masked_pred_alpha = masked_pred_alpha.unsqueeze(0) + + cls.data_batch = [{ + 'inputs': [], + 'data_samples': { + 'ori_trimap': trimap, + 'ori_alpha': gt_alpha, + }, + }] + + cls.data_samples = [d_['data_samples'] for d_ in cls.data_batch] + + cls.bad_preds1_ = [{'pred_alpha': dict(data=pred_alpha)}] + # pred_alpha should be masked by trimap before evaluation + cls.bad_preds1 = copy.deepcopy(cls.data_samples) + for d, p in zip(cls.bad_preds1, cls.bad_preds1_): + d['output'] = p + + cls.bad_preds2_ = [{'pred_alpha': dict(data=pred_alpha[0])}] + # pred_alpha should be 3 dimensional + cls.bad_preds2 = copy.deepcopy(cls.data_samples) + for d, p in zip(cls.bad_preds2, cls.bad_preds2_): + d['output'] = p + + cls.good_preds_ = [{'pred_alpha': dict(data=masked_pred_alpha)}] + cls.good_preds = copy.deepcopy((cls.data_samples)) + for d, p in zip(cls.good_preds, cls.good_preds_): + d['output'] = p + + def test_connectivity_error(self): + """Test connectivity error for evaluating predicted alpha matte.""" + + data_batch, bad_pred1, bad_pred2, good_pred = ( + self.data_batch, + self.bad_preds1, + self.bad_preds2, + self.good_preds, + ) + + conn_err = ConnectivityError() + + with pytest.raises(ValueError): + conn_err.process(data_batch, bad_pred1) + + with pytest.raises(ValueError): + conn_err.process(data_batch, bad_pred2) + + # process 2 batches + conn_err.process(data_batch, good_pred) + conn_err.process(data_batch, good_pred) + + assert conn_err.results == [ + { + 'conn_err': 0.256, + }, + { + 'conn_err': 0.256, + }, + ] + + res = conn_err.compute_metrics(conn_err.results) + + assert list(res.keys()) == ['ConnectivityError'] + assert np.allclose(res['ConnectivityError'], 0.256) diff --git a/tests/test_evaluation/test_metrics/test_matting.py b/tests/test_evaluation/test_metrics/test_gradient_error.py similarity index 54% rename from tests/test_evaluation/test_metrics/test_matting.py rename to tests/test_evaluation/test_metrics/test_gradient_error.py index 6777912a3b..f6720585c5 100644 --- a/tests/test_evaluation/test_metrics/test_matting.py +++ b/tests/test_evaluation/test_metrics/test_gradient_error.py @@ -7,8 +7,7 @@ import torch from mmedit.datasets.transforms import LoadImageFromFile -from mmedit.evaluation.metrics import (SAD, ConnectivityError, GradientError, - MattingMSE) +from mmedit.evaluation.metrics import GradientError class TestMattingMetrics: @@ -72,78 +71,6 @@ def setup_class(cls): for d, p in zip(cls.good_preds, cls.good_preds_): d['output'] = p - def test_sad(self): - """Test SAD for evaluating predicted alpha matte.""" - - data_batch, bad_pred1, bad_pred2, good_pred = ( - self.data_batch, - self.bad_preds1, - self.bad_preds2, - self.good_preds, - ) - - sad = SAD() - - with pytest.raises(ValueError): - sad.process(data_batch, bad_pred1) - - with pytest.raises(ValueError): - sad.process(data_batch, bad_pred2) - - # process 2 batches - sad.process(data_batch, good_pred) - sad.process(data_batch, good_pred) - - assert sad.results == [ - { - 'sad': 0.768, - }, - { - 'sad': 0.768, - }, - ] - - res = sad.compute_metrics(sad.results) - - assert list(res.keys()) == ['SAD'] - np.testing.assert_almost_equal(res['SAD'], 0.768) - - def test_mse(self): - """Test MattingMSE for evaluating predicted alpha matte.""" - - data_batch, bad_pred1, bad_pred2, good_pred = ( - self.data_batch, - self.bad_preds1, - self.bad_preds2, - self.good_preds, - ) - - mse = MattingMSE() - - with pytest.raises(ValueError): - mse.process(data_batch, bad_pred1) - - with pytest.raises(ValueError): - mse.process(data_batch, bad_pred2) - - # process 2 batches - mse.process(data_batch, good_pred) - mse.process(data_batch, good_pred) - - assert mse.results == [ - { - 'mse': 3.0, - }, - { - 'mse': 3.0, - }, - ] - - res = mse.compute_metrics(mse.results) - - assert list(res.keys()) == ['MattingMSE'] - np.testing.assert_almost_equal(res['MattingMSE'], 3.0) - def test_gradient_error(self): """Test gradient error for evaluating predicted alpha matte.""" @@ -176,39 +103,3 @@ def test_gradient_error(self): assert list(res.keys()) == ['GradientError'] np.testing.assert_almost_equal(el['grad_err'], 0.0028887) # assert np.allclose(res['GradientError'], 0.0028887) - - def test_connectivity_error(self): - """Test connectivity error for evaluating predicted alpha matte.""" - - data_batch, bad_pred1, bad_pred2, good_pred = ( - self.data_batch, - self.bad_preds1, - self.bad_preds2, - self.good_preds, - ) - - conn_err = ConnectivityError() - - with pytest.raises(ValueError): - conn_err.process(data_batch, bad_pred1) - - with pytest.raises(ValueError): - conn_err.process(data_batch, bad_pred2) - - # process 2 batches - conn_err.process(data_batch, good_pred) - conn_err.process(data_batch, good_pred) - - assert conn_err.results == [ - { - 'conn_err': 0.256, - }, - { - 'conn_err': 0.256, - }, - ] - - res = conn_err.compute_metrics(conn_err.results) - - assert list(res.keys()) == ['ConnectivityError'] - assert np.allclose(res['ConnectivityError'], 0.256) diff --git a/tests/test_evaluation/test_metrics/test_mae.py b/tests/test_evaluation/test_metrics/test_mae.py new file mode 100644 index 0000000000..ef03773d1b --- /dev/null +++ b/tests/test_evaluation/test_metrics/test_mae.py @@ -0,0 +1,51 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy + +import numpy as np +import torch + +from mmedit.evaluation.metrics import MAE + + +class TestPixelMetrics: + + @classmethod + def setup_class(cls): + + mask = np.ones((32, 32, 3)) * 2 + mask[:16] *= 0 + gt = np.ones((32, 32, 3)) * 2 + data_sample = dict(gt_img=gt, mask=mask, gt_channel_order='bgr') + cls.data_batch = [dict(data_samples=data_sample)] + cls.predictions = [dict(pred_img=np.ones((32, 32, 3)))] + + cls.data_batch.append( + dict( + data_samples=dict( + gt_img=torch.from_numpy(gt), + mask=torch.from_numpy(mask), + img_channel_order='bgr'))) + cls.predictions.append({ + k: torch.from_numpy(deepcopy(v)) + for (k, v) in cls.predictions[0].items() + }) + + for d, p in zip(cls.data_batch, cls.predictions): + d['output'] = p + cls.predictions = cls.data_batch + + def test_mae(self): + + # Single MAE + mae = MAE() + mae.process(self.data_batch, self.predictions) + result = mae.compute_metrics(mae.results) + assert 'MAE' in result + np.testing.assert_almost_equal(result['MAE'], 0.003921568627) + + # Masked MAE + mae = MAE(mask_key='mask', prefix='MAE') + mae.process(self.data_batch, self.predictions) + result = mae.compute_metrics(mae.results) + assert 'MAE' in result + np.testing.assert_almost_equal(result['MAE'], 0.003921568627) diff --git a/tests/test_evaluation/test_metrics/test_matting_mse.py b/tests/test_evaluation/test_metrics/test_matting_mse.py new file mode 100644 index 0000000000..532f248928 --- /dev/null +++ b/tests/test_evaluation/test_metrics/test_matting_mse.py @@ -0,0 +1,108 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from pathlib import Path + +import numpy as np +import pytest +import torch + +from mmedit.datasets.transforms import LoadImageFromFile +from mmedit.evaluation.metrics import MattingMSE + + +class TestMattingMetrics: + + @classmethod + def setup_class(cls): + # Make sure these values are immutable across different test cases. + + # This test depends on the interface of loading + # if loading is changed, data should be change accordingly. + test_path = Path(__file__).parent.parent.parent + alpha_path = ( + test_path / 'data' / 'matting_dataset' / 'alpha' / 'GT05.jpg') + + results = dict(alpha_path=alpha_path) + config = dict(key='alpha') + image_loader = LoadImageFromFile(**config) + results = image_loader(results) + assert results['alpha'].ndim == 3 + + gt_alpha = np.ones((32, 32), dtype=np.uint8) * 255 + trimap = np.zeros((32, 32), dtype=np.uint8) + trimap[:16, :16] = 128 + trimap[16:, 16:] = 255 + # non-masked pred_alpha + pred_alpha = torch.zeros((32, 32), dtype=torch.uint8) + # masked pred_alpha + masked_pred_alpha = pred_alpha.clone() + masked_pred_alpha[trimap == 0] = 0 + masked_pred_alpha[trimap == 255] = 255 + + gt_alpha = gt_alpha[..., None] + trimap = trimap[..., None] + # pred_alpha = pred_alpha.unsqueeze(0) + # masked_pred_alpha = masked_pred_alpha.unsqueeze(0) + + cls.data_batch = [{ + 'inputs': [], + 'data_samples': { + 'ori_trimap': trimap, + 'ori_alpha': gt_alpha, + }, + }] + + cls.data_samples = [d_['data_samples'] for d_ in cls.data_batch] + + cls.bad_preds1_ = [{'pred_alpha': dict(data=pred_alpha)}] + # pred_alpha should be masked by trimap before evaluation + cls.bad_preds1 = copy.deepcopy(cls.data_samples) + for d, p in zip(cls.bad_preds1, cls.bad_preds1_): + d['output'] = p + + cls.bad_preds2_ = [{'pred_alpha': dict(data=pred_alpha[0])}] + # pred_alpha should be 3 dimensional + cls.bad_preds2 = copy.deepcopy(cls.data_samples) + for d, p in zip(cls.bad_preds2, cls.bad_preds2_): + d['output'] = p + + cls.good_preds_ = [{'pred_alpha': dict(data=masked_pred_alpha)}] + cls.good_preds = copy.deepcopy((cls.data_samples)) + for d, p in zip(cls.good_preds, cls.good_preds_): + d['output'] = p + + def test_mse(self): + """Test MattingMSE for evaluating predicted alpha matte.""" + + data_batch, bad_pred1, bad_pred2, good_pred = ( + self.data_batch, + self.bad_preds1, + self.bad_preds2, + self.good_preds, + ) + + mse = MattingMSE() + + with pytest.raises(ValueError): + mse.process(data_batch, bad_pred1) + + with pytest.raises(ValueError): + mse.process(data_batch, bad_pred2) + + # process 2 batches + mse.process(data_batch, good_pred) + mse.process(data_batch, good_pred) + + assert mse.results == [ + { + 'mse': 3.0, + }, + { + 'mse': 3.0, + }, + ] + + res = mse.compute_metrics(mse.results) + + assert list(res.keys()) == ['MattingMSE'] + np.testing.assert_almost_equal(res['MattingMSE'], 3.0) diff --git a/tests/test_evaluation/test_metrics/test_metrics_utils.py b/tests/test_evaluation/test_metrics/test_metrics_utils.py index 2cf7417815..290682a3b0 100644 --- a/tests/test_evaluation/test_metrics/test_metrics_utils.py +++ b/tests/test_evaluation/test_metrics/test_metrics_utils.py @@ -1,7 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import numpy as np +import pytest from mmedit.evaluation.metrics import metrics_utils +from mmedit.evaluation.metrics.metrics_utils import reorder_image def test_average(): @@ -29,3 +31,21 @@ def test_obtain_data(): data_sample = {'data_samples': {key: img}} result = metrics_utils.obtain_data(data_sample, key) assert not (result - img).any() + + +def test_reorder_image(): + img_hw = np.ones((32, 32)) + img_hwc = np.ones((32, 32, 3)) + img_chw = np.ones((3, 32, 32)) + + with pytest.raises(ValueError): + reorder_image(img_hw, 'HH') + + output = reorder_image(img_hw) + assert output.shape == (32, 32, 1) + + output = reorder_image(img_hwc) + assert output.shape == (32, 32, 3) + + output = reorder_image(img_chw, input_order='CHW') + assert output.shape == (32, 32, 3) diff --git a/tests/test_evaluation/test_metrics/test_mse.py b/tests/test_evaluation/test_metrics/test_mse.py new file mode 100644 index 0000000000..5b0fdfeeab --- /dev/null +++ b/tests/test_evaluation/test_metrics/test_mse.py @@ -0,0 +1,51 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy + +import numpy as np +import torch + +from mmedit.evaluation.metrics import MSE + + +class TestPixelMetrics: + + @classmethod + def setup_class(cls): + + mask = np.ones((32, 32, 3)) * 2 + mask[:16] *= 0 + gt = np.ones((32, 32, 3)) * 2 + data_sample = dict(gt_img=gt, mask=mask, gt_channel_order='bgr') + cls.data_batch = [dict(data_samples=data_sample)] + cls.predictions = [dict(pred_img=np.ones((32, 32, 3)))] + + cls.data_batch.append( + dict( + data_samples=dict( + gt_img=torch.from_numpy(gt), + mask=torch.from_numpy(mask), + img_channel_order='bgr'))) + cls.predictions.append({ + k: torch.from_numpy(deepcopy(v)) + for (k, v) in cls.predictions[0].items() + }) + + for d, p in zip(cls.data_batch, cls.predictions): + d['output'] = p + cls.predictions = cls.data_batch + + def test_mse(self): + + # Single MSE + mae = MSE() + mae.process(self.data_batch, self.predictions) + result = mae.compute_metrics(mae.results) + assert 'MSE' in result + np.testing.assert_almost_equal(result['MSE'], 0.000015378700496) + + # Masked MSE + mae = MSE(mask_key='mask', prefix='MSE') + mae.process(self.data_batch, self.predictions) + result = mae.compute_metrics(mae.results) + assert 'MSE' in result + np.testing.assert_almost_equal(result['MSE'], 0.000015378700496) diff --git a/tests/test_evaluation/test_metrics/test_pixel_metrics.py b/tests/test_evaluation/test_metrics/test_pixel_metrics.py deleted file mode 100644 index 32476fb24c..0000000000 --- a/tests/test_evaluation/test_metrics/test_pixel_metrics.py +++ /dev/null @@ -1,199 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -from copy import deepcopy - -import numpy as np -import pytest -import torch - -from mmedit.evaluation.metrics import MAE, MSE, PSNR, SNR, psnr -from mmedit.evaluation.metrics.metrics_utils import reorder_image - - -class TestPixelMetrics: - - @classmethod - def setup_class(cls): - - mask = np.ones((32, 32, 3)) * 2 - mask[:16] *= 0 - gt = np.ones((32, 32, 3)) * 2 - data_sample = dict(gt_img=gt, mask=mask, gt_channel_order='bgr') - cls.data_batch = [dict(data_samples=data_sample)] - cls.predictions = [dict(pred_img=np.ones((32, 32, 3)))] - - cls.data_batch.append( - dict( - data_samples=dict( - gt_img=torch.from_numpy(gt), - mask=torch.from_numpy(mask), - img_channel_order='bgr'))) - cls.predictions.append({ - k: torch.from_numpy(deepcopy(v)) - for (k, v) in cls.predictions[0].items() - }) - - for d, p in zip(cls.data_batch, cls.predictions): - d['output'] = p - cls.predictions = cls.data_batch - - def test_mae(self): - - # Single MAE - mae = MAE() - mae.process(self.data_batch, self.predictions) - result = mae.compute_metrics(mae.results) - assert 'MAE' in result - np.testing.assert_almost_equal(result['MAE'], 0.003921568627) - - # Masked MAE - mae = MAE(mask_key='mask', prefix='MAE') - mae.process(self.data_batch, self.predictions) - result = mae.compute_metrics(mae.results) - assert 'MAE' in result - np.testing.assert_almost_equal(result['MAE'], 0.003921568627) - - def test_mse(self): - - # Single MSE - mae = MSE() - mae.process(self.data_batch, self.predictions) - result = mae.compute_metrics(mae.results) - assert 'MSE' in result - np.testing.assert_almost_equal(result['MSE'], 0.000015378700496) - - # Masked MSE - mae = MSE(mask_key='mask', prefix='MSE') - mae.process(self.data_batch, self.predictions) - result = mae.compute_metrics(mae.results) - assert 'MSE' in result - np.testing.assert_almost_equal(result['MSE'], 0.000015378700496) - - def test_psnr(self): - - psnr_ = PSNR() - psnr_.process(self.data_batch, self.predictions) - result = psnr_.compute_metrics(psnr_.results) - assert 'PSNR' in result - np.testing.assert_almost_equal(result['PSNR'], 48.1308036) - - def test_snr(self): - - snr_ = SNR() - snr_.process(self.data_batch, self.predictions) - result = snr_.compute_metrics(snr_.results) - assert 'SNR' in result - np.testing.assert_almost_equal(result['SNR'], 6.0206001996994) - - -def test_reorder_image(): - img_hw = np.ones((32, 32)) - img_hwc = np.ones((32, 32, 3)) - img_chw = np.ones((3, 32, 32)) - - with pytest.raises(ValueError): - reorder_image(img_hw, 'HH') - - output = reorder_image(img_hw) - assert output.shape == (32, 32, 1) - - output = reorder_image(img_hwc) - assert output.shape == (32, 32, 3) - - output = reorder_image(img_chw, input_order='CHW') - assert output.shape == (32, 32, 3) - - -def test_psnr(): - img_hw_1 = np.ones((32, 32)) - img_hwc_1 = np.ones((32, 32, 3)) - img_chw_1 = np.ones((3, 32, 32)) - img_hw_2 = np.ones((32, 32)) * 2 - img_hwc_2 = np.ones((32, 32, 3)) * 2 - img_chw_2 = np.ones((3, 32, 32)) * 2 - - with pytest.raises(ValueError): - psnr(img_hw_1, img_hw_2, crop_border=0, input_order='HH') - - with pytest.raises(ValueError): - psnr(img_hw_1, img_hw_2, crop_border=0, convert_to='ABC') - - psnr_result = psnr(img_hw_1, img_hw_2, crop_border=0) - np.testing.assert_almost_equal(psnr_result, 48.1308036) - psnr_result = psnr(img_hwc_1, img_hwc_2, crop_border=0, input_order='HWC') - np.testing.assert_almost_equal(psnr_result, 48.1308036) - psnr_result = psnr(img_chw_1, img_chw_2, crop_border=0, input_order='CHW') - np.testing.assert_almost_equal(psnr_result, 48.1308036) - - psnr_result = psnr(img_hw_1, img_hw_2, crop_border=2) - np.testing.assert_almost_equal(psnr_result, 48.1308036) - psnr_result = psnr(img_hwc_1, img_hwc_2, crop_border=3, input_order='HWC') - np.testing.assert_almost_equal(psnr_result, 48.1308036) - psnr_result = psnr(img_chw_1, img_chw_2, crop_border=4, input_order='CHW') - np.testing.assert_almost_equal(psnr_result, 48.1308036) - - psnr_result = psnr(img_hwc_1, img_hwc_2, crop_border=0, convert_to=None) - np.testing.assert_almost_equal(psnr_result, 48.1308036) - psnr_result = psnr(img_hwc_1, img_hwc_2, crop_border=0, convert_to='Y') - np.testing.assert_almost_equal(psnr_result, 49.4527218) - - # test float inf - psnr_result = psnr(img_hw_1, img_hw_1, crop_border=0) - assert psnr_result == float('inf') - - # test uint8 - img_hw_1 = np.zeros((32, 32), dtype=np.uint8) - img_hw_2 = np.ones((32, 32), dtype=np.uint8) * 255 - psnr_result = psnr(img_hw_1, img_hw_2, crop_border=0) - assert psnr_result == 0 - - -def test_snr(): - img_hw_1 = np.ones((32, 32)) - img_hwc_1 = np.ones((32, 32, 3)) - img_chw_1 = np.ones((3, 32, 32)) - img_hw_2 = np.ones((32, 32)) * 2 - img_hwc_2 = np.ones((32, 32, 3)) * 2 - img_chw_2 = np.ones((3, 32, 32)) * 2 - - with pytest.raises(ValueError): - psnr(img_hw_1, img_hw_2, crop_border=0, input_order='HH') - - with pytest.raises(ValueError): - psnr(img_hw_1, img_hw_2, crop_border=0, convert_to='ABC') - - psnr_result = psnr(img_hw_1, img_hw_2, crop_border=0) - np.testing.assert_almost_equal(psnr_result, 48.1308036) - psnr_result = psnr(img_hwc_1, img_hwc_2, crop_border=0, input_order='HWC') - np.testing.assert_almost_equal(psnr_result, 48.1308036) - psnr_result = psnr(img_chw_1, img_chw_2, crop_border=0, input_order='CHW') - np.testing.assert_almost_equal(psnr_result, 48.1308036) - - psnr_result = psnr(img_hw_1, img_hw_2, crop_border=2) - np.testing.assert_almost_equal(psnr_result, 48.1308036) - psnr_result = psnr(img_hwc_1, img_hwc_2, crop_border=3, input_order='HWC') - np.testing.assert_almost_equal(psnr_result, 48.1308036) - psnr_result = psnr(img_chw_1, img_chw_2, crop_border=4, input_order='CHW') - np.testing.assert_almost_equal(psnr_result, 48.1308036) - - psnr_result = psnr(img_hwc_1, img_hwc_2, crop_border=0, convert_to=None) - np.testing.assert_almost_equal(psnr_result, 48.1308036) - psnr_result = psnr(img_hwc_1, img_hwc_2, crop_border=0, convert_to='Y') - np.testing.assert_almost_equal(psnr_result, 49.4527218) - - # test float inf - psnr_result = psnr(img_hw_1, img_hw_1, crop_border=0) - assert psnr_result == float('inf') - - # test uint8 - img_hw_1 = np.zeros((32, 32), dtype=np.uint8) - img_hw_2 = np.ones((32, 32), dtype=np.uint8) * 255 - psnr_result = psnr(img_hw_1, img_hw_2, crop_border=0) - assert psnr_result == 0 - - -t = TestPixelMetrics() -t.setup_class() -t.test_mae() -t.test_mse() -t.test_psnr() -t.test_snr() diff --git a/tests/test_evaluation/test_metrics/test_psnr.py b/tests/test_evaluation/test_metrics/test_psnr.py new file mode 100644 index 0000000000..b220c651aa --- /dev/null +++ b/tests/test_evaluation/test_metrics/test_psnr.py @@ -0,0 +1,88 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy + +import numpy as np +import pytest +import torch + +from mmedit.evaluation.metrics import PSNR, psnr + + +class TestPixelMetrics: + + @classmethod + def setup_class(cls): + + mask = np.ones((32, 32, 3)) * 2 + mask[:16] *= 0 + gt = np.ones((32, 32, 3)) * 2 + data_sample = dict(gt_img=gt, mask=mask, gt_channel_order='bgr') + cls.data_batch = [dict(data_samples=data_sample)] + cls.predictions = [dict(pred_img=np.ones((32, 32, 3)))] + + cls.data_batch.append( + dict( + data_samples=dict( + gt_img=torch.from_numpy(gt), + mask=torch.from_numpy(mask), + img_channel_order='bgr'))) + cls.predictions.append({ + k: torch.from_numpy(deepcopy(v)) + for (k, v) in cls.predictions[0].items() + }) + + for d, p in zip(cls.data_batch, cls.predictions): + d['output'] = p + cls.predictions = cls.data_batch + + def test_psnr(self): + + psnr_ = PSNR() + psnr_.process(self.data_batch, self.predictions) + result = psnr_.compute_metrics(psnr_.results) + assert 'PSNR' in result + np.testing.assert_almost_equal(result['PSNR'], 48.1308036) + + +def test_psnr(): + img_hw_1 = np.ones((32, 32)) + img_hwc_1 = np.ones((32, 32, 3)) + img_chw_1 = np.ones((3, 32, 32)) + img_hw_2 = np.ones((32, 32)) * 2 + img_hwc_2 = np.ones((32, 32, 3)) * 2 + img_chw_2 = np.ones((3, 32, 32)) * 2 + + with pytest.raises(ValueError): + psnr(img_hw_1, img_hw_2, crop_border=0, input_order='HH') + + with pytest.raises(ValueError): + psnr(img_hw_1, img_hw_2, crop_border=0, convert_to='ABC') + + psnr_result = psnr(img_hw_1, img_hw_2, crop_border=0) + np.testing.assert_almost_equal(psnr_result, 48.1308036) + psnr_result = psnr(img_hwc_1, img_hwc_2, crop_border=0, input_order='HWC') + np.testing.assert_almost_equal(psnr_result, 48.1308036) + psnr_result = psnr(img_chw_1, img_chw_2, crop_border=0, input_order='CHW') + np.testing.assert_almost_equal(psnr_result, 48.1308036) + + psnr_result = psnr(img_hw_1, img_hw_2, crop_border=2) + np.testing.assert_almost_equal(psnr_result, 48.1308036) + psnr_result = psnr(img_hwc_1, img_hwc_2, crop_border=3, input_order='HWC') + np.testing.assert_almost_equal(psnr_result, 48.1308036) + psnr_result = psnr(img_chw_1, img_chw_2, crop_border=4, input_order='CHW') + np.testing.assert_almost_equal(psnr_result, 48.1308036) + + psnr_result = psnr(img_hwc_1, img_hwc_2, crop_border=0, convert_to=None) + np.testing.assert_almost_equal(psnr_result, 48.1308036) + psnr_result = psnr(img_hwc_1, img_hwc_2, crop_border=0, convert_to='Y') + np.testing.assert_almost_equal(psnr_result, 49.4527218) + + # test float inf + psnr_result = psnr(img_hw_1, img_hw_1, crop_border=0) + assert psnr_result == float('inf') + + # test uint8 + img_hw_1 = np.zeros((32, 32), dtype=np.uint8) + img_hw_2 = np.ones((32, 32), dtype=np.uint8) * 255 + psnr_result = psnr(img_hw_1, img_hw_2, crop_border=0) + assert psnr_result == 0 diff --git a/tests/test_evaluation/test_metrics/test_sad.py b/tests/test_evaluation/test_metrics/test_sad.py new file mode 100644 index 0000000000..f05dd3d5b8 --- /dev/null +++ b/tests/test_evaluation/test_metrics/test_sad.py @@ -0,0 +1,108 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from pathlib import Path + +import numpy as np +import pytest +import torch + +from mmedit.datasets.transforms import LoadImageFromFile +from mmedit.evaluation.metrics import SAD + + +class TestMattingMetrics: + + @classmethod + def setup_class(cls): + # Make sure these values are immutable across different test cases. + + # This test depends on the interface of loading + # if loading is changed, data should be change accordingly. + test_path = Path(__file__).parent.parent.parent + alpha_path = ( + test_path / 'data' / 'matting_dataset' / 'alpha' / 'GT05.jpg') + + results = dict(alpha_path=alpha_path) + config = dict(key='alpha') + image_loader = LoadImageFromFile(**config) + results = image_loader(results) + assert results['alpha'].ndim == 3 + + gt_alpha = np.ones((32, 32), dtype=np.uint8) * 255 + trimap = np.zeros((32, 32), dtype=np.uint8) + trimap[:16, :16] = 128 + trimap[16:, 16:] = 255 + # non-masked pred_alpha + pred_alpha = torch.zeros((32, 32), dtype=torch.uint8) + # masked pred_alpha + masked_pred_alpha = pred_alpha.clone() + masked_pred_alpha[trimap == 0] = 0 + masked_pred_alpha[trimap == 255] = 255 + + gt_alpha = gt_alpha[..., None] + trimap = trimap[..., None] + # pred_alpha = pred_alpha.unsqueeze(0) + # masked_pred_alpha = masked_pred_alpha.unsqueeze(0) + + cls.data_batch = [{ + 'inputs': [], + 'data_samples': { + 'ori_trimap': trimap, + 'ori_alpha': gt_alpha, + }, + }] + + cls.data_samples = [d_['data_samples'] for d_ in cls.data_batch] + + cls.bad_preds1_ = [{'pred_alpha': dict(data=pred_alpha)}] + # pred_alpha should be masked by trimap before evaluation + cls.bad_preds1 = copy.deepcopy(cls.data_samples) + for d, p in zip(cls.bad_preds1, cls.bad_preds1_): + d['output'] = p + + cls.bad_preds2_ = [{'pred_alpha': dict(data=pred_alpha[0])}] + # pred_alpha should be 3 dimensional + cls.bad_preds2 = copy.deepcopy(cls.data_samples) + for d, p in zip(cls.bad_preds2, cls.bad_preds2_): + d['output'] = p + + cls.good_preds_ = [{'pred_alpha': dict(data=masked_pred_alpha)}] + cls.good_preds = copy.deepcopy((cls.data_samples)) + for d, p in zip(cls.good_preds, cls.good_preds_): + d['output'] = p + + def test_sad(self): + """Test SAD for evaluating predicted alpha matte.""" + + data_batch, bad_pred1, bad_pred2, good_pred = ( + self.data_batch, + self.bad_preds1, + self.bad_preds2, + self.good_preds, + ) + + sad = SAD() + + with pytest.raises(ValueError): + sad.process(data_batch, bad_pred1) + + with pytest.raises(ValueError): + sad.process(data_batch, bad_pred2) + + # process 2 batches + sad.process(data_batch, good_pred) + sad.process(data_batch, good_pred) + + assert sad.results == [ + { + 'sad': 0.768, + }, + { + 'sad': 0.768, + }, + ] + + res = sad.compute_metrics(sad.results) + + assert list(res.keys()) == ['SAD'] + np.testing.assert_almost_equal(res['SAD'], 0.768) diff --git a/tests/test_evaluation/test_metrics/test_snr.py b/tests/test_evaluation/test_metrics/test_snr.py new file mode 100644 index 0000000000..9354bb68b9 --- /dev/null +++ b/tests/test_evaluation/test_metrics/test_snr.py @@ -0,0 +1,88 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from copy import deepcopy + +import numpy as np +import pytest +import torch + +from mmedit.evaluation.metrics import SNR, psnr + + +class TestPixelMetrics: + + @classmethod + def setup_class(cls): + + mask = np.ones((32, 32, 3)) * 2 + mask[:16] *= 0 + gt = np.ones((32, 32, 3)) * 2 + data_sample = dict(gt_img=gt, mask=mask, gt_channel_order='bgr') + cls.data_batch = [dict(data_samples=data_sample)] + cls.predictions = [dict(pred_img=np.ones((32, 32, 3)))] + + cls.data_batch.append( + dict( + data_samples=dict( + gt_img=torch.from_numpy(gt), + mask=torch.from_numpy(mask), + img_channel_order='bgr'))) + cls.predictions.append({ + k: torch.from_numpy(deepcopy(v)) + for (k, v) in cls.predictions[0].items() + }) + + for d, p in zip(cls.data_batch, cls.predictions): + d['output'] = p + cls.predictions = cls.data_batch + + def test_snr(self): + + snr_ = SNR() + snr_.process(self.data_batch, self.predictions) + result = snr_.compute_metrics(snr_.results) + assert 'SNR' in result + np.testing.assert_almost_equal(result['SNR'], 6.0206001996994) + + +def test_snr(): + img_hw_1 = np.ones((32, 32)) + img_hwc_1 = np.ones((32, 32, 3)) + img_chw_1 = np.ones((3, 32, 32)) + img_hw_2 = np.ones((32, 32)) * 2 + img_hwc_2 = np.ones((32, 32, 3)) * 2 + img_chw_2 = np.ones((3, 32, 32)) * 2 + + with pytest.raises(ValueError): + psnr(img_hw_1, img_hw_2, crop_border=0, input_order='HH') + + with pytest.raises(ValueError): + psnr(img_hw_1, img_hw_2, crop_border=0, convert_to='ABC') + + psnr_result = psnr(img_hw_1, img_hw_2, crop_border=0) + np.testing.assert_almost_equal(psnr_result, 48.1308036) + psnr_result = psnr(img_hwc_1, img_hwc_2, crop_border=0, input_order='HWC') + np.testing.assert_almost_equal(psnr_result, 48.1308036) + psnr_result = psnr(img_chw_1, img_chw_2, crop_border=0, input_order='CHW') + np.testing.assert_almost_equal(psnr_result, 48.1308036) + + psnr_result = psnr(img_hw_1, img_hw_2, crop_border=2) + np.testing.assert_almost_equal(psnr_result, 48.1308036) + psnr_result = psnr(img_hwc_1, img_hwc_2, crop_border=3, input_order='HWC') + np.testing.assert_almost_equal(psnr_result, 48.1308036) + psnr_result = psnr(img_chw_1, img_chw_2, crop_border=4, input_order='CHW') + np.testing.assert_almost_equal(psnr_result, 48.1308036) + + psnr_result = psnr(img_hwc_1, img_hwc_2, crop_border=0, convert_to=None) + np.testing.assert_almost_equal(psnr_result, 48.1308036) + psnr_result = psnr(img_hwc_1, img_hwc_2, crop_border=0, convert_to='Y') + np.testing.assert_almost_equal(psnr_result, 49.4527218) + + # test float inf + psnr_result = psnr(img_hw_1, img_hw_1, crop_border=0) + assert psnr_result == float('inf') + + # test uint8 + img_hw_1 = np.zeros((32, 32), dtype=np.uint8) + img_hw_2 = np.ones((32, 32), dtype=np.uint8) * 255 + psnr_result = psnr(img_hw_1, img_hw_2, crop_border=0) + assert psnr_result == 0