Skip to content

Commit

Permalink
[Refactoring] decompose the implementations of different metrics into…
Browse files Browse the repository at this point in the history
… several files (#1161)

* refactor metrics
* add UT for refactored metrics
  • Loading branch information
zengyh1900 authored Oct 9, 2022
1 parent f6886a1 commit 944d3a8
Show file tree
Hide file tree
Showing 22 changed files with 1,439 additions and 1,078 deletions.
31 changes: 18 additions & 13 deletions mmedit/evaluation/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,45 @@
# Copyright (c) OpenMMLab. All rights reserved.

from .connectivity_error import ConnectivityError
from .equivariance import Equivariance
from .fid import FrechetInceptionDistance, TransFID
from .gradient_error import GradientError
from .inception_score import InceptionScore, TransIS
from .matting import SAD, ConnectivityError, GradientError, MattingMSE
from .mae import MAE
from .matting_mse import MattingMSE
from .ms_ssim import MultiScaleStructureSimilarity
from .mse import MSE
from .niqe import NIQE, niqe
from .pixel_metrics import MAE, MSE, PSNR, SNR, psnr, snr
from .ppl import PerceptualPathLength
from .precision_and_recall import PrecisionAndRecall
from .psnr import PSNR, psnr
from .sad import SAD
from .snr import SNR, snr
from .ssim import SSIM, ssim
from .swd import SlicedWassersteinDistance

__all__ = [
'ConnectivityError',
'GradientError',
'MAE',
'MattingMSE',
'MSE',
'NIQE',
'niqe',
'PSNR',
'psnr',
'SAD',
'SNR',
'snr',
'SSIM',
'ssim',
'Equivariance',
'MultiScaleStructureSimilarity',
'FrechetInceptionDistance',
'TransFID',
'InceptionScore',
'MultiScaleStructureSimilarity',
'TransIS',
'SAD',
'MattingMSE',
'ConnectivityError',
'GradientError',
'PerceptualPathLength',
'MultiScaleStructureSimilarity',
'PrecisionAndRecall',
'SlicedWassersteinDistance',
'TransFID',
'TransIS',
'NIQE',
'niqe',
'Equivariance',
]
117 changes: 117 additions & 0 deletions mmedit/evaluation/metrics/connectivity_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""Evaluation metrics used in Image Matting."""

from typing import List, Sequence

import cv2
import numpy as np
from mmengine.evaluator import BaseMetric

from mmedit.registry import METRICS
from .metrics_utils import _fetch_data_and_check, average


@METRICS.register_module()
class ConnectivityError(BaseMetric):
"""Connectivity error for evaluating alpha matte prediction.
.. note::
Current implementation assume image / alpha / trimap array in numpy
format and with pixel value ranging from 0 to 255.
.. note::
pred_alpha should be masked by trimap before passing
into this metric
Args:
step (float): Step of threshold when computing intersection between
`alpha` and `pred_alpha`. Default to 0.1 .
norm_const (int): Divide the result to reduce its magnitude.
Default to 1000.
Default prefix: ''
Metrics:
- ConnectivityError (float): Connectivity Error
"""

def __init__(
self,
step=0.1,
norm_constant=1000,
**kwargs,
) -> None:
self.step = step
self.norm_constant = norm_constant
super().__init__(**kwargs)

def process(self, data_batch: Sequence[dict],
data_samples: Sequence[dict]) -> None:
"""Process one batch of data samples and predictions. The processed
results should be stored in ``self.results``, which will be used to
compute the metrics when all batches have been processed.
Args:
data_batch (Sequence[dict]): A batch of data from the dataloader.
predictions (Sequence[dict]): A batch of outputs from
the model.
"""

for data_sample in data_samples:
pred_alpha, gt_alpha, trimap = _fetch_data_and_check(data_sample)

thresh_steps = np.arange(0, 1 + self.step, self.step)
round_down_map = -np.ones_like(gt_alpha)
for i in range(1, len(thresh_steps)):
gt_alpha_thresh = gt_alpha >= thresh_steps[i]
pred_alpha_thresh = pred_alpha >= thresh_steps[i]
intersection = gt_alpha_thresh & pred_alpha_thresh
intersection = intersection.astype(np.uint8)

# connected components
_, output, stats, _ = cv2.connectedComponentsWithStats(
intersection, connectivity=4)
# start from 1 in dim 0 to exclude background
size = stats[1:, -1]

# largest connected component of the intersection
omega = np.zeros_like(gt_alpha)
if len(size) != 0:
max_id = np.argmax(size)
# plus one to include background
omega[output == max_id + 1] = 1

mask = (round_down_map == -1) & (omega == 0)
round_down_map[mask] = thresh_steps[i - 1]
round_down_map[round_down_map == -1] = 1

gt_alpha_diff = gt_alpha - round_down_map
pred_alpha_diff = pred_alpha - round_down_map
# only calculate difference larger than or equal to 0.15
gt_alpha_phi = 1 - gt_alpha_diff * (gt_alpha_diff >= 0.15)
pred_alpha_phi = 1 - pred_alpha_diff * (pred_alpha_diff >= 0.15)

connectivity_error = np.sum(
np.abs(gt_alpha_phi - pred_alpha_phi) * (trimap == 128))

# divide by 1000 to reduce the magnitude of the result
connectivity_error /= self.norm_constant

self.results.append({'conn_err': connectivity_error})

def compute_metrics(self, results: List):
"""Compute the metrics from processed results.
Args:
results (dict): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""

conn_err = average(results, 'conn_err')

return {'ConnectivityError': conn_err}
95 changes: 95 additions & 0 deletions mmedit/evaluation/metrics/gradient_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Sequence

import cv2
import numpy as np
from mmengine.evaluator import BaseMetric

from mmedit.registry import METRICS
from ..functional import gauss_gradient
from .metrics_utils import _fetch_data_and_check, average


@METRICS.register_module()
class GradientError(BaseMetric):
"""Gradient error for evaluating alpha matte prediction.
.. note::
Current implementation assume image / alpha / trimap array in numpy
format and with pixel value ranging from 0 to 255.
.. note::
pred_alpha should be masked by trimap before passing
into this metric
Args:
sigma (float): Standard deviation of the gaussian kernel.
Defaults to 1.4 .
norm_const (int): Divide the result to reduce its magnitude.
Defaults to 1000 .
Default prefix: ''
Metrics:
- GradientError (float): Gradient Error
"""

def __init__(
self,
sigma=1.4,
norm_constant=1000,
**kwargs,
) -> None:
self.sigma = sigma
self.norm_constant = norm_constant
super().__init__(**kwargs)

def process(self, data_batch: Sequence[dict],
data_samples: Sequence[dict]) -> None:
"""Process one batch of data samples and predictions. The processed
results should be stored in ``self.results``, which will be used to
compute the metrics when all batches have been processed.
Args:
data_batch (Sequence[dict]): A batch of data from the dataloader.
predictions (Sequence[dict]): A batch of outputs from
the model.
"""

for data_sample in data_samples:
pred_alpha, gt_alpha, trimap = _fetch_data_and_check(data_sample)

gt_alpha_normed = np.zeros_like(gt_alpha)
pred_alpha_normed = np.zeros_like(pred_alpha)

cv2.normalize(gt_alpha, gt_alpha_normed, 1.0, 0.0, cv2.NORM_MINMAX)
cv2.normalize(pred_alpha, pred_alpha_normed, 1.0, 0.0,
cv2.NORM_MINMAX)

gt_alpha_grad = gauss_gradient(gt_alpha_normed, self.sigma)
pred_alpha_grad = gauss_gradient(pred_alpha_normed, self.sigma)
# this is the sum over n samples
grad_loss = ((gt_alpha_grad - pred_alpha_grad)**2 *
(trimap == 128)).sum()

# divide by 1000 to reduce the magnitude of the result
grad_loss /= self.norm_constant

self.results.append({'grad_err': grad_loss})

def compute_metrics(self, results: List):
"""Compute the metrics from processed results.
Args:
results (dict): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""

grad_err = average(results, 'grad_err')

return {'GradientError': grad_err}
60 changes: 60 additions & 0 deletions mmedit/evaluation/metrics/mae.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) OpenMMLab. All rights reserved.
"""Evaluation metrics based on pixels."""

import numpy as np

from mmedit.registry import METRICS
from .base_sample_wise_metric import BaseSampleWiseMetric


@METRICS.register_module()
class MAE(BaseSampleWiseMetric):
"""Mean Absolute Error metric for image.
mean(abs(a-b))
Args:
gt_key (str): Key of ground-truth. Default: 'gt_img'
pred_key (str): Key of prediction. Default: 'pred_img'
mask_key (str, optional): Key of mask, if mask_key is None, calculate
all regions. Default: None
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Default: None
Metrics:
- MAE (float): Mean of Absolute Error
"""

metric = 'MAE'

def process_image(self, gt, pred, mask):
"""Process an image.
Args:
gt (Tensor | np.ndarray): GT image.
pred (Tensor | np.ndarray): Pred image.
mask (Tensor | np.ndarray): Mask of evaluation.
Returns:
result (np.ndarray): MAE result.
"""

gt = gt / 255.
pred = pred / 255.

diff = gt - pred
diff = abs(diff)

if self.mask_key is not None:
diff *= mask # broadcast for channel dimension
scale = np.prod(diff.shape) / np.prod(mask.shape)
result = diff.sum() / (mask.sum() * scale + 1e-12)
else:
result = diff.mean()

return result
Loading

0 comments on commit 944d3a8

Please sign in to comment.