Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-688] Fix quantization divide by zero errors (#11833)
Browse files Browse the repository at this point in the history
* Fix quantization bug

* Added tests and made sure the edge case is now considered correctly without 1 off errors

* Changed back to original truncated distribution but with different kl divergence calc

* Reorder back to original format

* Reorder back to original format (again)

* Change comments

* Clarified comments

* Changed norm division
  • Loading branch information
OneRaynyDay authored and reminisce committed Jul 24, 2018
1 parent 4bb141d commit 55fef30
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 14 deletions.
33 changes: 20 additions & 13 deletions python/mxnet/contrib/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ def _smooth_distribution(p, eps=0.0001):
is_nonzeros = (p != 0).astype(np.float32)
n_zeros = is_zeros.sum()
n_nonzeros = p.size - n_zeros
if not n_nonzeros:
raise ValueError('The discrete probability distribution is malformed. All entries are 0.')
eps1 = eps * float(n_zeros) / float(n_nonzeros)
assert eps1 < 1.0, 'n_zeros=%d, n_nonzeros=%d, eps1=%f' % (n_zeros, n_nonzeros, eps1)
hist = p.astype(np.float32)
Expand All @@ -252,6 +254,9 @@ def _smooth_distribution(p, eps=0.0001):
# pylint: disable=line-too-long
def _get_optimal_threshold(arr, num_bins=8001, num_quantized_bins=255):
"""Given a dataset, find the optimal threshold for quantizing it.
The reference distribution is `q`, and the candidate distribution is `p`.
`q` is a truncated version of the original distribution.
Ref: http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf
"""
if isinstance(arr, NDArray):
Expand All @@ -274,22 +279,21 @@ def _get_optimal_threshold(arr, num_bins=8001, num_quantized_bins=255):
max_val = np.max(arr)
th = max(abs(min_val), abs(max_val))

hist, hist_edeges = np.histogram(arr, bins=num_bins, range=(-th, th))
hist, hist_edges = np.histogram(arr, bins=num_bins, range=(-th, th))
zero_bin_idx = num_bins // 2
num_half_quantized_bins = num_quantized_bins // 2
assert np.allclose(hist_edeges[zero_bin_idx] + hist_edeges[zero_bin_idx + 1],
assert np.allclose(hist_edges[zero_bin_idx] + hist_edges[zero_bin_idx + 1],
0, rtol=1e-5, atol=1e-7)

thresholds = np.zeros(num_bins // 2 + 1 - num_quantized_bins // 2)
divergence = np.zeros_like(thresholds)
quantized_bins = np.zeros(num_quantized_bins, dtype=np.int32)
# i means the number of bins on half axis excluding the zero bin
# i means the number of bins on half axis excluding the zero bin.
for i in range(num_quantized_bins // 2,
num_bins // 2 + 1):
p_bin_idx_start = zero_bin_idx - i
p_bin_idx_stop = zero_bin_idx + i + 1
thresholds[i - num_half_quantized_bins] = hist_edeges[p_bin_idx_stop]
# sliced_nd_hist is used to generate candidate distribution q
thresholds[i - num_half_quantized_bins] = hist_edges[p_bin_idx_stop]
sliced_nd_hist = hist[p_bin_idx_start:p_bin_idx_stop]

# generate reference distribution p
Expand All @@ -303,32 +307,35 @@ def _get_optimal_threshold(arr, num_bins=8001, num_quantized_bins=255):
right_outlier_count = np.sum(hist[p_bin_idx_stop:])
p[-1] += right_outlier_count
# is_nonzeros[k] indicates whether hist[k] is nonzero
is_nonzeros = (sliced_nd_hist != 0).astype(np.int32)
is_nonzeros = (p != 0).astype(np.int32)

# calculate how many bins should be merged to generate quantized distribution q
num_merged_bins = p.size // num_quantized_bins
num_merged_bins = sliced_nd_hist.size // num_quantized_bins
# merge hist into num_quantized_bins bins
for j in range(num_quantized_bins):
start = j * num_merged_bins
stop = start + num_merged_bins
quantized_bins[j] = sliced_nd_hist[start:stop].sum()
quantized_bins[-1] += sliced_nd_hist[num_quantized_bins * num_merged_bins:].sum()
# expand quantized_bins into p.size bins
q = np.zeros(p.size, dtype=np.float32)
q = np.zeros(sliced_nd_hist.size, dtype=np.float32)
for j in range(num_quantized_bins):
start = j * num_merged_bins
if j == num_quantized_bins - 1:
stop = -1
stop = len(is_nonzeros)
else:
stop = start + num_merged_bins
norm = is_nonzeros[start:stop].sum()
if norm != 0:
q[start:stop] = float(quantized_bins[j]) / float(norm)
q[sliced_nd_hist == 0] = 0
q[p == 0] = 0
p = _smooth_distribution(p)
q = _smooth_distribution(q)
# There is a chance that q is an invalid probability distribution.
try:
q = _smooth_distribution(q)
except ValueError:
divergence[i - num_half_quantized_bins] = float("inf")
divergence[i - num_half_quantized_bins] = stats.entropy(p, q)
quantized_bins[:] = 0

min_divergence_idx = np.argmin(divergence)
min_divergence = divergence[min_divergence_idx]
Expand All @@ -352,7 +359,7 @@ def _get_optimal_thresholds(nd_dict, num_bins=8001, num_quantized_bins=255, logg
layer_names = list(nd_dict.keys())
for name in layer_names:
assert name in nd_dict
min_val, max_val, min_divergence, opt_th =\
min_val, max_val, min_divergence, opt_th = \
_get_optimal_threshold(nd_dict[name], num_bins=num_bins,
num_quantized_bins=num_quantized_bins)
del nd_dict[name] # release the memory of ndarray
Expand Down
25 changes: 24 additions & 1 deletion tests/python/quantization/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import os
import mxnet as mx
import numpy as np
from mxnet.test_utils import assert_almost_equal, rand_ndarray, rand_shape_nd, same, DummyIter
from mxnet.test_utils import assert_almost_equal, assert_exception, rand_ndarray, rand_shape_nd, same, DummyIter
from common import with_seed
from mxnet.module import Module
from mxnet.io import NDArrayIter
Expand Down Expand Up @@ -463,6 +463,7 @@ def check_qsym_qdtype(qsym, qdtype):
for qdtype in ['int8', 'uint8']:
check_quantize_model(qdtype)


@with_seed()
def test_quantize_sym_with_calib():
sym = get_fp32_sym()
Expand All @@ -485,6 +486,28 @@ def test_quantize_sym_with_calib():
assert_almost_equal(np.array([lhs]), np.array([rhs]), rtol=1e-3, atol=1e-4)


@with_seed()
def test_smooth_distribution():
assert_exception(lambda: mx.contrib.quant._smooth_distribution(np.zeros((2,)), eps=1e-3), ValueError)
dirac_delta = np.zeros((5,))
dirac_delta[2] = 1
smooth_dirac_delta = dirac_delta.copy()
smooth_dirac_delta += 1e-3
smooth_dirac_delta[2] -= 5e-3
assert_almost_equal(mx.contrib.quant._smooth_distribution(dirac_delta, eps=1e-3), smooth_dirac_delta)


@with_seed()
def test_optimal_threshold_adversarial_case():
# The worst case for the optimal_threshold function is when the values are concentrated
# at one edge: [0, 0, ..., 1000]. (histogram)
# We want to make sure that the optimal threshold in this case is the max.
arr = np.array([2]*1000)
res = mx.contrib.quant._get_optimal_threshold(arr, num_quantized_bins=5)
# The threshold should be 2.
assert res[3] - 2 < 1e-5


@with_seed()
@unittest.skip("Flaky test: https://github.com/apache/incubator-mxnet/issues/11456")
def test_get_optimal_thresholds():
Expand Down

0 comments on commit 55fef30

Please sign in to comment.