Skip to content

Commit

Permalink
[QUANTIZE] Update.
Browse files Browse the repository at this point in the history
  • Loading branch information
ZihengJiang authored and tqchen committed Dec 2, 2018
1 parent 6d751ba commit 963c758
Show file tree
Hide file tree
Showing 5 changed files with 252 additions and 25 deletions.
22 changes: 14 additions & 8 deletions python/tvm/relay/quantize/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def attach_simulated_quantize(data, kind):
True, "round", kind)


def _build_module(graph, params):
def _build_module(graph, params=None):
model, lib, params = _build(graph, target='llvm', params=params)
module = _runtime.create(model, lib, tvm.cpu(0))
module.set_input(**params)
Expand All @@ -109,12 +109,21 @@ def register_qfield_rewrite(op_name, frewrite=None, level=10):
return _reg.register(op_name, "FQFieldRewrite", frewrite, level)


def annotate(graph):
def annotate(graph, params):
graph = _ir_pass.infer_type(graph)
graph = _ir_pass.simplify_inference(graph)
graph = _ir_pass.infer_type(graph)
graph = _ir_pass.backward_fold_scale_axis(graph)
graph = _ir_pass.infer_type(graph)
graph = _ir_pass.forward_fold_scale_axis(graph)
var_map = {arg.name_hint: arg for arg in graph.params}
const_map = {var_map[key]: _expr.Constant(params[key]) for key in params}
graph = _expr.bind(graph, const_map)
graph = _ir_pass.fold_constant(graph)
return _quantize.annotate(graph)


def calibrate(graph, params, dataset=None):
def calibrate(graph, dataset=None):
def _scalar(x, dtype):
return _expr.const(np.array(x).astype(dtype))

Expand Down Expand Up @@ -143,11 +152,8 @@ def visit_func(e):

if kind == QFieldKind.WEIGHT:
var = e.args[0]
if isinstance(var, _expr.Constant):
arr = var.data
else:
arr = params[var.name_hint]
raise ValueError
assert isinstance(var, _expr.Constant)
arr = var.data
scale = power2_scale(arr)
else:
scale = cfg.global_scale
Expand Down
29 changes: 25 additions & 4 deletions python/tvm/relay/quantize/quantize_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import absolute_import
import tvm
import topi
import ctypes
from .. import expr as _expr
from ..op import op as _reg
from .quantize import QFieldKind, QFieldExpr, register_qfield_rewrite
Expand Down Expand Up @@ -73,7 +75,6 @@ def add_rewrite(ref_call, new_args, ctx):
elif lhs.kind == QFieldKind.REAL and rhs.kind == QFieldKind.ACTIVATION:
# add residual
# TODO record and dom_scale compute on int32
print('haha')
lhs_expr = attach_simulated_quantize(lhs.expr, QFieldKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs.expr])
return QFieldExpr(expr, QFieldKind.ACTIVATION)
Expand Down Expand Up @@ -107,6 +108,19 @@ def relu_rewrite(ref_call, new_args, ctx):
return None


@tvm.register_func("debug_print")
def debug_print_impl(x, y, msg):
print(ctypes.string_at(msg))
print(x.asnumpy())
x.copyto(y)


def debug_print(x, msg):
return tvm.extern(x.shape, [x], lambda ins, outs:
tvm.call_packed("debug_print", ins[0], outs[0], msg),
name='debug')


@_reg.register_compute("simulated_quantize")
def simulated_quantize_compute(attrs, inputs, output_type, target):
"""Compiler for simulated_quantize."""
Expand All @@ -117,18 +131,25 @@ def simulated_quantize_compute(attrs, inputs, output_type, target):
data, scale, bit, clip_min, clip_max = inputs

if attrs.kind == QFieldKind.REAL:
# dequantize, do nothing
return [topi.identity(data)]

# simulate rounding error
# data = debug_print(data, 'original_data')
scaled_data = topi.divide(data, scale)
# scaled_data = debug_print(scaled_data, 'scaled_data')
clipped_data = topi.maximum(topi.minimum(scaled_data, clip_max), clip_min)
round_data = topi.round(clipped_data)
# round_data = debug_print(round_data, 'round_data')

# recover data
rdata = topi.multiply(round_data, scale)
return [rdata]


_reg.register_schedule("simulated_quantize", _reg.schedule_injective)
_reg.register_pattern("simulated_quantize", _reg.OpPattern.BROADCAST)
def schedule_naive(attrs, outputs, target):
s = tvm.create_schedule([x.op for x in outputs])
return s


_reg.register_schedule("simulated_quantize", schedule_naive)
_reg.register_pattern("simulated_quantize", _reg.OpPattern.OPAQUE)
14 changes: 11 additions & 3 deletions src/relay/pass/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ Expr QuantizeQStateRewrite(const Call& ref_call,
Expr clip_min = new_args[3];
Expr clip_max = new_args[4];

float dom_scale_imm = GetScalarFromConstant<float>(dom_scale);
int bit_imm = GetScalarFromConstant<int>(bit);
float clip_min_imm = GetScalarFromConstant<float>(clip_min);
float clip_max_imm = GetScalarFromConstant<float>(clip_max);
Expand Down Expand Up @@ -223,6 +224,7 @@ Expr QuantizeQStateRewrite(const Call& ref_call,
return QIntStateNode::make(round_data, dom_scale, bit_imm, Float(32));
}
} else if (const auto* n = new_args[0].as<QRealStateNode>()) {
LOG(FATAL) << "wrong";
Expr data = n->data;
if (kind == kReal) LOG(FATAL) << "wrong";
Expr scaled_data = Divide(data, dom_scale);
Expand All @@ -232,7 +234,7 @@ Expr QuantizeQStateRewrite(const Call& ref_call,
Expr data = new_args[0];
// normal expr
if (kind == kReal) LOG(FATAL) << "wrong";
Expr scaled_data = Divide(data, dom_scale);
Expr scaled_data = Multiply(data, MakeConstantScalar(Float(32), 1 / dom_scale_imm));
Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm);
return QIntStateNode::make(round_data, dom_scale, bit_imm, Float(32));
}
Expand Down Expand Up @@ -267,7 +269,13 @@ Expr Conv2dQStateRewrite(const Call& ref_call,
Expr ldata = Cast(lhs->data, Int(lhs->safe_nbit));
Expr rdata = Cast(rhs->data, Int(rhs->safe_nbit));

Expr ret = ForwardOp(ref_call, {ldata, rdata});
const auto ref_attrs = ref_call->attrs.as<Conv2DAttrs>();
auto attrs = make_node<Conv2DAttrs>();
*attrs = *ref_attrs;
attrs->out_dtype = Int(32);

Expr ret = CallNode::make(ref_call->op,
{ldata, rdata}, Attrs(attrs), ref_call->type_args);
Expr dom_scale = FoldConstant(Multiply(lhs->dom_scale, rhs->dom_scale));
return QIntStateNode::make(ret, dom_scale, GetConfigBit(kQActivation), Int(32));
}
Expand Down Expand Up @@ -325,7 +333,7 @@ inline Expr MulAndDiv(Expr data, float s1, float s2) {
return LeftShift(data, MakeConstantScalar(Int(32), static_cast<int>(magnitude)));
} else {
data = Cast(data, Float(32));
return Divide(Multiply(data, MakeConstantScalar(Float(32), s1)), MakeConstantScalar(Float(32), s2));
return Multiply(data, MakeConstantScalar(Float(32), s1 / s2));
}
}

Expand Down
177 changes: 177 additions & 0 deletions tests/python/quantize/evaluate_gluon_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import logging
import argparse
import os
import mxnet as mx
from mxnet import gluon
from mxnet.gluon.model_zoo import vision
from gluoncv.data import imagenet

# Two functions for reading data from record file or raw images
def get_val_data(rec_val,
batch_size,
num_workers=4):
rec_val = os.path.expanduser(rec_val)
mean_rgb = [123.68, 116.779, 103.939]
std_rgb = [58.393, 57.12, 57.375]
def batch_fn(batch, ctx):
data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
return data, label

val_data = mx.io.ImageRecordIter(
path_imgrec = rec_val,
preprocess_threads = num_workers,
shuffle = True,
batch_size = batch_size,
resize = 256,
data_shape = (3, 224, 224),
mean_r = mean_rgb[0],
mean_g = mean_rgb[1],
mean_b = mean_rgb[2],
std_r = std_rgb[0],
std_g = std_rgb[1],
std_b = std_rgb[2],
)
return val_data, batch_fn


def evaluate(args, graph, lib, params, ctx):
"""Evaluate on the validation set."""
import tvm
from tvm.contrib import graph_runtime

# tetup dataset.
batch_size = args.batch_size
val_data, batch_fn = get_val_data(args.rec_val, batch_size)
# create runtime module
m = graph_runtime.create(graph, lib, ctx)
m.set_input(**params)
oshape = (batch_size, args.num_classes)
out_arr = tvm.nd.empty(oshape, "float32")
# setup evaluaiton metric
acc_top1 = mx.metric.Accuracy()
acc_top5 = mx.metric.TopKAccuracy(5)
val_data.reset()
acc_top1.reset()
acc_top5.reset()
# Execute
for i, batch in enumerate(val_data):
data, label = batch_fn(batch, [mx.cpu(0)])
m.run(data=data[0].asnumpy())
m.get_output(0, out_arr)
acc_top1.update(label, [mx.nd.array(out_arr.asnumpy())])
acc_top5.update(label, [mx.nd.array(out_arr.asnumpy())])

if args.log_interval and not (i + 1) % args.log_interval:
_, top1 = acc_top1.get()
_, top5 = acc_top5.get()
nsamples = (i + 1) * batch_size
logging.info('[%d samples] validation: acc-top1=%f acc-top5=%f', nsamples, top1, top5)
logging.info('[final] validation: acc-top1=%f acc-top5=%f', top1, top5)


def build_nnvm(args, gluon_model):
"""Build with nnvm path"""
import tvm
import nnvm
import nnvm.compiler
net, params = nnvm.frontend.from_mxnet(gluon_model)
data_shape = (args.batch_size, 3, 224, 224)
shape_dict = {'data': data_shape}
shape_dict.update({k: v.shape for k, v in params.items()})
dtype_dict = {"data": "float32"}
target = args.target

with nnvm.compiler.build_config(opt_level=3):
graph, lib, params = nnvm.compiler.build(
net, target, shape_dict, dtype_dict, params=params)
ctx = tvm.nd.context(target, 0)
return graph, lib,params, ctx


def build_relay(args, gluon_model):
"""Build with relay."""
import tvm
from tvm import relay
data_shape = (args.batch_size, 3, 224, 224)
net, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape})
target = args.target
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(
net, target, params=params)
ctx = tvm.nd.context(target, 0)
return graph, lib, params, ctx


def build_quantize(args, gluon_model):
print('build quantize')
"""Build with relay."""
import tvm
from tvm import relay
from tvm.relay import quantize as qtz
from tvm.relay import ir_pass
data_shape = (args.batch_size, 3, 224, 224)
net, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape})
target = args.target

with qtz.qconfig(skip_k_conv=2, global_scale=2.0):
# print(net.astext())

graph = net
# graph = ir_pass.infer_type(graph)
# graph = ir_pass.simplify_inference(graph)
# var_map = {arg.name_hint: arg for arg in graph.params}
# const_map = {var_map[key]: tvm.relay.const(params[key]) for key in params}
# graph = tvm.relay.bind(graph, const_map)
# graph = ir_pass.fold_constant(graph)
# print('after const folding\n')
# print(graph.astext())
qgraph = qtz.annotate(graph, params)
# print('after annotate\n')
# print(qgraph.astext())
qgraph = qtz.calibrate(qgraph)
# print('after calibrate\n')
# print(qgraph.astext())
# qgraph = qtz.realize(qgraph)
# print('after realize\n')
# print(qgraph.astext())

with relay.build_config(opt_level=3):
graph, lib, params = relay.build(
qgraph, target, params=params)
ctx = tvm.nd.context(target, 0)
return graph, lib, params, ctx


def main(args):
gluon_model = vision.get_model(args.model, pretrained=True)
if args.use_nnvm:
graph, lib, params, ctx = build_nnvm(args, gluon_model)
else:
# graph, lib, params, ctx = build_relay(args, gluon_model)
graph, lib, params, ctx = build_quantize(args, gluon_model)

logging.info("Finish building model %s...", args.model)
evaluate(args, graph, lib, params, ctx)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Evaluate ImageNet validation accuracy")
parser.add_argument("--rec-val", type=str, default="~/.mxnet/datasets/imagenet/rec/val.rec",
help="the validation data")
parser.add_argument("--num-classes", type=int, default=1000,
help="batch size")
parser.add_argument("--model", type=str, default="resnet18_v1",
help="Name of the model")
parser.add_argument("--log-interval", type=int, default=100,
help="log interval")
parser.add_argument("--batch-size", type=int, default=1,
help="batch size")
parser.add_argument("--target", type=str, default="cuda",
help="target option")
parser.add_argument("--use-nnvm", action="store_true",
help="Use legacy nnvm compiler")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
logging.info(args)
main(args)
Loading

0 comments on commit 963c758

Please sign in to comment.