-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Datumaro] Mean and std for dataset (#1734)
* Add meanstd * Add stats cli * Update changelog Co-authored-by: Nikita Manovich <40690625+nmanovic@users.noreply.github.com>
- Loading branch information
1 parent
3fee4cf
commit 12f7855
Showing
4 changed files
with
148 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
|
||
# Copyright (C) 2020 Intel Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
import cv2 | ||
import numpy as np | ||
|
||
|
||
def mean_std(dataset): | ||
""" | ||
Computes unbiased mean and std. dev. for dataset images, channel-wise. | ||
""" | ||
# Use an online algorithm to: | ||
# - handle different image sizes | ||
# - avoid cancellation problem | ||
|
||
stats = np.empty((len(dataset), 2, 3), dtype=np.double) | ||
counts = np.empty(len(dataset), dtype=np.uint32) | ||
|
||
mean = lambda i, s: s[i][0] | ||
var = lambda i, s: s[i][1] | ||
|
||
for i, item in enumerate(dataset): | ||
counts[i] = np.prod(item.image.size) | ||
|
||
image = item.image.data | ||
if len(image.shape) == 2: | ||
image = image[:, :, np.newaxis] | ||
else: | ||
image = image[:, :, :3] | ||
# opencv is much faster than numpy here | ||
cv2.meanStdDev(image.astype(np.double) / 255, | ||
mean=mean(i, stats), stddev=var(i, stats)) | ||
|
||
# make variance unbiased | ||
np.multiply(np.square(stats[:, 1]), | ||
(counts / (counts - 1))[:, np.newaxis], | ||
out=stats[:, 1]) | ||
|
||
_, mean, var = StatsCounter().compute_stats(stats, counts, mean, var) | ||
return mean * 255, np.sqrt(var) * 255 | ||
|
||
class StatsCounter: | ||
# Implements online parallel computation of sample variance | ||
# https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm | ||
|
||
# Needed do avoid catastrophic cancellation in floating point computations | ||
@staticmethod | ||
def pairwise_stats(count_a, mean_a, var_a, count_b, mean_b, var_b): | ||
delta = mean_b - mean_a | ||
m_a = var_a * (count_a - 1) | ||
m_b = var_b * (count_b - 1) | ||
M2 = m_a + m_b + delta ** 2 * count_a * count_b / (count_a + count_b) | ||
return ( | ||
count_a + count_b, | ||
mean_a * 0.5 + mean_b * 0.5, | ||
M2 / (count_a + count_b - 1) | ||
) | ||
|
||
# stats = float array of shape N, 2 * d, d = dimensions of values | ||
# count = integer array of shape N | ||
# mean_accessor = function(idx, stats) to retrieve element mean | ||
# variance_accessor = function(idx, stats) to retrieve element variance | ||
# Recursively computes total count, mean and variance, does O(log(N)) calls | ||
@staticmethod | ||
def compute_stats(stats, counts, mean_accessor, variance_accessor): | ||
m = mean_accessor | ||
v = variance_accessor | ||
n = len(stats) | ||
if n == 1: | ||
return counts[0], m(0, stats), v(0, stats) | ||
if n == 2: | ||
return __class__.pairwise_stats( | ||
counts[0], m(0, stats), v(0, stats), | ||
counts[1], m(1, stats), v(1, stats) | ||
) | ||
h = n // 2 | ||
return __class__.pairwise_stats( | ||
*__class__.compute_stats(stats[:h], counts[:h], m, v), | ||
*__class__.compute_stats(stats[h:], counts[h:], m, v) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
import numpy as np | ||
|
||
from datumaro.components.extractor import Extractor, DatasetItem | ||
from datumaro.components.operations import mean_std | ||
|
||
from unittest import TestCase | ||
|
||
|
||
class TestOperations(TestCase): | ||
def test_mean_std(self): | ||
expected_mean = [100, 50, 150] | ||
expected_std = [20, 50, 10] | ||
|
||
class TestExtractor(Extractor): | ||
def __iter__(self): | ||
return iter([ | ||
DatasetItem(id=1, image=np.random.normal( | ||
expected_mean, expected_std, | ||
size=(w, h, 3)) | ||
) | ||
for i, (w, h) in enumerate([ | ||
(3000, 100), (800, 600), (400, 200), (700, 300) | ||
]) | ||
]) | ||
|
||
actual_mean, actual_std = mean_std(TestExtractor()) | ||
|
||
for em, am in zip(expected_mean, actual_mean): | ||
self.assertAlmostEqual(em, am, places=0) | ||
for estd, astd in zip(expected_std, actual_std): | ||
self.assertAlmostEqual(estd, astd, places=0) |