Skip to content

Commit

Permalink
[HAGO] Support scales vector in simulated_quantize. (apache#7)
Browse files Browse the repository at this point in the history
* [HAGO] Support scales vector in simulated_quantize.

* [HAGO] Support per channel scales during simulation.
  • Loading branch information
ZihengJiang authored and hypercubestart committed Mar 12, 2021
1 parent c7bb6ba commit a4f28a9
Show file tree
Hide file tree
Showing 15 changed files with 369 additions and 432 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@
from __future__ import absolute_import
import tvm._ffi

tvm._ffi._init_api("hago._quantize", __name__)
tvm._ffi._init_api("hago.quantize", __name__)
33 changes: 27 additions & 6 deletions python/tvm/hago/_op_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def compare(origin, ret, msg):
return ret



def isclose(old, new, rtol, atol):
# compare two arrays under quantized situation
thold = np.max(np.abs(old))
Expand Down Expand Up @@ -147,12 +146,13 @@ def check_overflow(data, in_dtype, output):
tvm.nd.array(arr.astype('float32')).copyto(output)


@_reg.register_compute("hago.simulated_quantize")
@_reg.register_compute("nn.simulated_quantize")
def simulated_quantize_compute(attrs, inputs, out_type):
"""Compiler for simulated_quantize."""
assert len(inputs) == 5
assert attrs.sign
assert attrs.rounding == "round"
axis = attrs.axis

data, in_scale, out_scale, clip_min, clip_max = inputs
data = my_print(data, '\n\n*******************************************')
Expand All @@ -161,31 +161,52 @@ def simulated_quantize_compute(attrs, inputs, out_type):
origin = data
data = inspect(data, 'original data')

##################################
# simulate overflow truncate error
if attrs.in_dtype != 'float32':
# data = topi.divide(data, in_scale)
# data = tvm.extern(data.shape, [data], lambda ins, outs: tvm.call_packed(
# "tvm.contrib.check_overflow", ins[0], str(attrs.in_dtype), outs[0]))
# data = topi.multiply(data, in_scale)

data = topi.divide(data, in_scale)
if len(in_scale.shape) == 1:
assert axis is not None
assert len(out_scale.shape) == 0
# per-channel dequantize
expand_axes = [i for i in range(len(data.shape)) if i != axis]
in_scale = topi.expand_like(in_scale, data, expand_axes)
data = topi.divide(data, in_scale)
else:
data = topi.divide(data, in_scale)
data = topi.cast(topi.round(data), 'int64')
data = topi.cast(data, attrs.in_dtype)
data = topi.multiply(data, in_scale)

########################################
# dequantize, directly return real value
if attrs.out_dtype == 'float32':
data = my_print(data, '*******************************************\n\n')
return [topi.identity(data)]


#########################
# simulate rounding error
scaled_data = topi.divide(data, out_scale)
if len(out_scale.shape) == 1:
assert axis is not None
assert len(in_scale.shape) == 0
# per-channel quantize
expand_axes = [i for i in range(len(data.shape)) if i != axis]
out_scale = topi.expand_like(out_scale, data, expand_axes)
scaled_data = topi.divide(data, out_scale)
else:
scaled_data = topi.divide(data, out_scale)
scaled_data = inspect(scaled_data, 'scaled data')

round_data = topi.round(scaled_data)
round_data = inspect(round_data, 'round data')

#########################
# simulate clipping error
clipped_data = topi.maximum(topi.minimum(round_data, clip_max), clip_min)
clipped_data = inspect(clipped_data, 'clipped data')

Expand All @@ -199,8 +220,8 @@ def simulated_quantize_compute(attrs, inputs, out_type):
return [ret]


_reg.register_schedule("hago.simulated_quantize", tvm.relay.op.strategy.schedule_simulated_quantize)
_reg.register_pattern("hago.simulated_quantize", _reg.OpPattern.OPAQUE)
_reg.register_schedule("nn.simulated_quantize", tvm.relay.op.strategy.schedule_simulated_quantize)
_reg.register_pattern("nn.simulated_quantize", _reg.OpPattern.OPAQUE)


# infer scale function registered for ops
Expand Down
90 changes: 68 additions & 22 deletions python/tvm/hago/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
from __future__ import absolute_import

from .base import *
from . import _quantize
from .topology import Topology, analyze_topology
from .topology import Topology, NodeKind, analyze_topology
from .quantize import create_quantizer
from .record import Strategy
from ..contrib import graph_runtime
Expand All @@ -31,49 +30,96 @@
from collections import OrderedDict

class Stats(object):
def __init__(self, data):
def __init__(self, topology, data):
"""
data: intermediate data * number_of_batches
data: [num of intermediate data][number of batch][tensor]
"""
self.data = data
# Range represents avg min/max
self.range = []
self.power_of_two_range = []
for idx in range(len(data)):
samples = len(self.data[idx])
arr = np.concatenate(self.data[idx]).reshape(samples, -1)
avg_min = np.average(np.min(arr, axis=1))
avg_max = np.average(np.max(arr, axis=1))
arange = np.amax([np.abs(avg_min), np.abs(avg_max)])
self.range.append(arange)
power_of_two_range = 2**np.math.ceil(np.math.log(arange, 2)) if arange > 0 else 1.0
self.power_of_two_range.append(power_of_two_range)
self.topology = topology
self.node_kinds = list(topology.node2kind().values())
self.node_edges = list(topology.node2edges().values())
print(self.node_kinds)
self.data = []
for idx, batched_data in enumerate(data):
if self.node_kinds[idx] in (NodeKind.Input, NodeKind.Activation):
flatten_data = np.concatenate(batched_data)
elif self.node_kinds[idx] == NodeKind.Weight:
flatten_data = batched_data[0]
else:
raise ValueError
self.data.append(flatten_data)
self._avg_range = None
self._pot_range = None

def __len__(self):
return len(self.data)

def data(self, idx):
return self.data[idx]

def _round2pot(self, x):
pot = 2**np.math.ceil(np.math.log(x, 2)) if x > 0 else 1.0
return pot

def _calculate_avg_range(self, arr):
num_samples = arr.shape[0]
arr = np.reshape(arr, (num_samples, -1))
avg_min = np.average(np.min(arr, axis=1))
avg_max = np.average(np.max(arr, axis=1))
arange = np.amax([np.abs(avg_min), np.abs(avg_max)])
return arange

@property
def avg_range(self):
if self._avg_range is None:
self._avg_range = []
for idx, arr in enumerate(self.data):
if self.node_kinds[idx] in (NodeKind.Input, NodeKind.Activation):
arange = self._calculate_avg_range(arr)
elif self.node_kinds[idx] == NodeKind.Weight:
axis = current_qconfig().per_channel_scale_axis
out_edges = self.node_edges[idx]
assert len(out_edges) == 1
op_node = out_edges[0][1]
print(op_node.op.name)
if axis is not None and op_node.op.name in ['nn.dense', 'nn.conv2d']:
# per channel scales
axis = current_qconfig().per_channel_scale_axis
arr = np.moveaxis(arr, axis, 0)
num_scales = arr.shape[0]
arr = np.reshape(arr, (num_scales, -1))
arange = np.amax(np.abs(arr), axis=1)
else:
arange = np.amax(np.abs(arr))
self._avg_range.append(arange)
return self._avg_range

@property
def pot_range(self):
if self._pot_range is None:
self._pot_range = [self._round2pot(r) for r in self.avg_range]
return self._pot_range

def mean(self, idx):
pass

def variance(self, idx):
pass


def collect_stats(graph, dataset, ctx, target):

def collect_stats(graph, topology, dataset, ctx, target):
assert isinstance(graph, relay.Function)
assert graph == topology.graph
logging.info("collecting statistics for calibration...")
outputs = []
nodes = []
def fvisit(node):
if isinstance(node, (relay.Var, relay.Constant, relay.Call)):
outputs.append(node)
nodes.append(node)
relay.analysis.post_order_visit(graph, fvisit)
out = relay.Tuple(outputs)
out = relay.Tuple(nodes)
func = relay.Function(graph.params, out)
outputs = evaluate(func, dataset, ctx, target)
stats = Stats(outputs)
stats = Stats(topology, outputs)
logging.info("statistics collected")
return stats

Expand Down
Loading

0 comments on commit a4f28a9

Please sign in to comment.