diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index bbc853ae5706d..5f2ca2212425c 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -98,7 +98,7 @@ def attach_simulated_quantize(data, kind): True, "round", kind) -def _build_module(graph, params): +def _build_module(graph, params=None): model, lib, params = _build(graph, target='llvm', params=params) module = _runtime.create(model, lib, tvm.cpu(0)) module.set_input(**params) @@ -109,12 +109,21 @@ def register_qfield_rewrite(op_name, frewrite=None, level=10): return _reg.register(op_name, "FQFieldRewrite", frewrite, level) -def annotate(graph): +def annotate(graph, params): graph = _ir_pass.infer_type(graph) + graph = _ir_pass.simplify_inference(graph) + graph = _ir_pass.infer_type(graph) + graph = _ir_pass.backward_fold_scale_axis(graph) + graph = _ir_pass.infer_type(graph) + graph = _ir_pass.forward_fold_scale_axis(graph) + var_map = {arg.name_hint: arg for arg in graph.params} + const_map = {var_map[key]: _expr.Constant(params[key]) for key in params} + graph = _expr.bind(graph, const_map) + graph = _ir_pass.fold_constant(graph) return _quantize.annotate(graph) -def calibrate(graph, params, dataset=None): +def calibrate(graph, dataset=None): def _scalar(x, dtype): return _expr.const(np.array(x).astype(dtype)) @@ -143,11 +152,8 @@ def visit_func(e): if kind == QFieldKind.WEIGHT: var = e.args[0] - if isinstance(var, _expr.Constant): - arr = var.data - else: - arr = params[var.name_hint] - raise ValueError + assert isinstance(var, _expr.Constant) + arr = var.data scale = power2_scale(arr) else: scale = cfg.global_scale diff --git a/python/tvm/relay/quantize/quantize_ops.py b/python/tvm/relay/quantize/quantize_ops.py index 54f8a4024b338..ae7335643c68e 100644 --- a/python/tvm/relay/quantize/quantize_ops.py +++ b/python/tvm/relay/quantize/quantize_ops.py @@ -1,5 +1,7 @@ from __future__ import absolute_import +import tvm import topi +import ctypes from .. import expr as _expr from ..op import op as _reg from .quantize import QFieldKind, QFieldExpr, register_qfield_rewrite @@ -73,7 +75,6 @@ def add_rewrite(ref_call, new_args, ctx): elif lhs.kind == QFieldKind.REAL and rhs.kind == QFieldKind.ACTIVATION: # add residual # TODO record and dom_scale compute on int32 - print('haha') lhs_expr = attach_simulated_quantize(lhs.expr, QFieldKind.INPUT) expr = _forward_op(ref_call, [lhs_expr, rhs.expr]) return QFieldExpr(expr, QFieldKind.ACTIVATION) @@ -107,6 +108,19 @@ def relu_rewrite(ref_call, new_args, ctx): return None +@tvm.register_func("debug_print") +def debug_print_impl(x, y, msg): + print(ctypes.string_at(msg)) + print(x.asnumpy()) + x.copyto(y) + + +def debug_print(x, msg): + return tvm.extern(x.shape, [x], lambda ins, outs: + tvm.call_packed("debug_print", ins[0], outs[0], msg), + name='debug') + + @_reg.register_compute("simulated_quantize") def simulated_quantize_compute(attrs, inputs, output_type, target): """Compiler for simulated_quantize.""" @@ -117,18 +131,25 @@ def simulated_quantize_compute(attrs, inputs, output_type, target): data, scale, bit, clip_min, clip_max = inputs if attrs.kind == QFieldKind.REAL: - # dequantize, do nothing return [topi.identity(data)] # simulate rounding error + # data = debug_print(data, 'original_data') scaled_data = topi.divide(data, scale) + # scaled_data = debug_print(scaled_data, 'scaled_data') clipped_data = topi.maximum(topi.minimum(scaled_data, clip_max), clip_min) round_data = topi.round(clipped_data) + # round_data = debug_print(round_data, 'round_data') # recover data rdata = topi.multiply(round_data, scale) return [rdata] -_reg.register_schedule("simulated_quantize", _reg.schedule_injective) -_reg.register_pattern("simulated_quantize", _reg.OpPattern.BROADCAST) +def schedule_naive(attrs, outputs, target): + s = tvm.create_schedule([x.op for x in outputs]) + return s + + +_reg.register_schedule("simulated_quantize", schedule_naive) +_reg.register_pattern("simulated_quantize", _reg.OpPattern.OPAQUE) diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize.cc index c35be315b2e2a..cb103abae5fb0 100644 --- a/src/relay/pass/quantize.cc +++ b/src/relay/pass/quantize.cc @@ -185,6 +185,7 @@ Expr QuantizeQStateRewrite(const Call& ref_call, Expr clip_min = new_args[3]; Expr clip_max = new_args[4]; + float dom_scale_imm = GetScalarFromConstant(dom_scale); int bit_imm = GetScalarFromConstant(bit); float clip_min_imm = GetScalarFromConstant(clip_min); float clip_max_imm = GetScalarFromConstant(clip_max); @@ -223,6 +224,7 @@ Expr QuantizeQStateRewrite(const Call& ref_call, return QIntStateNode::make(round_data, dom_scale, bit_imm, Float(32)); } } else if (const auto* n = new_args[0].as()) { + LOG(FATAL) << "wrong"; Expr data = n->data; if (kind == kReal) LOG(FATAL) << "wrong"; Expr scaled_data = Divide(data, dom_scale); @@ -232,7 +234,7 @@ Expr QuantizeQStateRewrite(const Call& ref_call, Expr data = new_args[0]; // normal expr if (kind == kReal) LOG(FATAL) << "wrong"; - Expr scaled_data = Divide(data, dom_scale); + Expr scaled_data = Multiply(data, MakeConstantScalar(Float(32), 1 / dom_scale_imm)); Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm); return QIntStateNode::make(round_data, dom_scale, bit_imm, Float(32)); } @@ -267,7 +269,13 @@ Expr Conv2dQStateRewrite(const Call& ref_call, Expr ldata = Cast(lhs->data, Int(lhs->safe_nbit)); Expr rdata = Cast(rhs->data, Int(rhs->safe_nbit)); - Expr ret = ForwardOp(ref_call, {ldata, rdata}); + const auto ref_attrs = ref_call->attrs.as(); + auto attrs = make_node(); + *attrs = *ref_attrs; + attrs->out_dtype = Int(32); + + Expr ret = CallNode::make(ref_call->op, + {ldata, rdata}, Attrs(attrs), ref_call->type_args); Expr dom_scale = FoldConstant(Multiply(lhs->dom_scale, rhs->dom_scale)); return QIntStateNode::make(ret, dom_scale, GetConfigBit(kQActivation), Int(32)); } @@ -325,7 +333,7 @@ inline Expr MulAndDiv(Expr data, float s1, float s2) { return LeftShift(data, MakeConstantScalar(Int(32), static_cast(magnitude))); } else { data = Cast(data, Float(32)); - return Divide(Multiply(data, MakeConstantScalar(Float(32), s1)), MakeConstantScalar(Float(32), s2)); + return Multiply(data, MakeConstantScalar(Float(32), s1 / s2)); } } diff --git a/tests/python/quantize/evaluate_gluon_model.py b/tests/python/quantize/evaluate_gluon_model.py new file mode 100644 index 0000000000000..565cb841caef4 --- /dev/null +++ b/tests/python/quantize/evaluate_gluon_model.py @@ -0,0 +1,177 @@ +import logging +import argparse +import os +import mxnet as mx +from mxnet import gluon +from mxnet.gluon.model_zoo import vision +from gluoncv.data import imagenet + +# Two functions for reading data from record file or raw images +def get_val_data(rec_val, + batch_size, + num_workers=4): + rec_val = os.path.expanduser(rec_val) + mean_rgb = [123.68, 116.779, 103.939] + std_rgb = [58.393, 57.12, 57.375] + def batch_fn(batch, ctx): + data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0) + label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0) + return data, label + + val_data = mx.io.ImageRecordIter( + path_imgrec = rec_val, + preprocess_threads = num_workers, + shuffle = True, + batch_size = batch_size, + resize = 256, + data_shape = (3, 224, 224), + mean_r = mean_rgb[0], + mean_g = mean_rgb[1], + mean_b = mean_rgb[2], + std_r = std_rgb[0], + std_g = std_rgb[1], + std_b = std_rgb[2], + ) + return val_data, batch_fn + + +def evaluate(args, graph, lib, params, ctx): + """Evaluate on the validation set.""" + import tvm + from tvm.contrib import graph_runtime + + # tetup dataset. + batch_size = args.batch_size + val_data, batch_fn = get_val_data(args.rec_val, batch_size) + # create runtime module + m = graph_runtime.create(graph, lib, ctx) + m.set_input(**params) + oshape = (batch_size, args.num_classes) + out_arr = tvm.nd.empty(oshape, "float32") + # setup evaluaiton metric + acc_top1 = mx.metric.Accuracy() + acc_top5 = mx.metric.TopKAccuracy(5) + val_data.reset() + acc_top1.reset() + acc_top5.reset() + # Execute + for i, batch in enumerate(val_data): + data, label = batch_fn(batch, [mx.cpu(0)]) + m.run(data=data[0].asnumpy()) + m.get_output(0, out_arr) + acc_top1.update(label, [mx.nd.array(out_arr.asnumpy())]) + acc_top5.update(label, [mx.nd.array(out_arr.asnumpy())]) + + if args.log_interval and not (i + 1) % args.log_interval: + _, top1 = acc_top1.get() + _, top5 = acc_top5.get() + nsamples = (i + 1) * batch_size + logging.info('[%d samples] validation: acc-top1=%f acc-top5=%f', nsamples, top1, top5) + logging.info('[final] validation: acc-top1=%f acc-top5=%f', top1, top5) + + +def build_nnvm(args, gluon_model): + """Build with nnvm path""" + import tvm + import nnvm + import nnvm.compiler + net, params = nnvm.frontend.from_mxnet(gluon_model) + data_shape = (args.batch_size, 3, 224, 224) + shape_dict = {'data': data_shape} + shape_dict.update({k: v.shape for k, v in params.items()}) + dtype_dict = {"data": "float32"} + target = args.target + + with nnvm.compiler.build_config(opt_level=3): + graph, lib, params = nnvm.compiler.build( + net, target, shape_dict, dtype_dict, params=params) + ctx = tvm.nd.context(target, 0) + return graph, lib,params, ctx + + +def build_relay(args, gluon_model): + """Build with relay.""" + import tvm + from tvm import relay + data_shape = (args.batch_size, 3, 224, 224) + net, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape}) + target = args.target + with relay.build_config(opt_level=3): + graph, lib, params = relay.build( + net, target, params=params) + ctx = tvm.nd.context(target, 0) + return graph, lib, params, ctx + + +def build_quantize(args, gluon_model): + print('build quantize') + """Build with relay.""" + import tvm + from tvm import relay + from tvm.relay import quantize as qtz + from tvm.relay import ir_pass + data_shape = (args.batch_size, 3, 224, 224) + net, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape}) + target = args.target + + with qtz.qconfig(skip_k_conv=2, global_scale=2.0): + # print(net.astext()) + + graph = net + # graph = ir_pass.infer_type(graph) + # graph = ir_pass.simplify_inference(graph) + # var_map = {arg.name_hint: arg for arg in graph.params} + # const_map = {var_map[key]: tvm.relay.const(params[key]) for key in params} + # graph = tvm.relay.bind(graph, const_map) + # graph = ir_pass.fold_constant(graph) + # print('after const folding\n') + # print(graph.astext()) + qgraph = qtz.annotate(graph, params) + # print('after annotate\n') + # print(qgraph.astext()) + qgraph = qtz.calibrate(qgraph) + # print('after calibrate\n') + # print(qgraph.astext()) + # qgraph = qtz.realize(qgraph) + # print('after realize\n') + # print(qgraph.astext()) + + with relay.build_config(opt_level=3): + graph, lib, params = relay.build( + qgraph, target, params=params) + ctx = tvm.nd.context(target, 0) + return graph, lib, params, ctx + + +def main(args): + gluon_model = vision.get_model(args.model, pretrained=True) + if args.use_nnvm: + graph, lib, params, ctx = build_nnvm(args, gluon_model) + else: + # graph, lib, params, ctx = build_relay(args, gluon_model) + graph, lib, params, ctx = build_quantize(args, gluon_model) + + logging.info("Finish building model %s...", args.model) + evaluate(args, graph, lib, params, ctx) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Evaluate ImageNet validation accuracy") + parser.add_argument("--rec-val", type=str, default="~/.mxnet/datasets/imagenet/rec/val.rec", + help="the validation data") + parser.add_argument("--num-classes", type=int, default=1000, + help="batch size") + parser.add_argument("--model", type=str, default="resnet18_v1", + help="Name of the model") + parser.add_argument("--log-interval", type=int, default=100, + help="log interval") + parser.add_argument("--batch-size", type=int, default=1, + help="batch size") + parser.add_argument("--target", type=str, default="cuda", + help="target option") + parser.add_argument("--use-nnvm", action="store_true", + help="Use legacy nnvm compiler") + args = parser.parse_args() + logging.basicConfig(level=logging.INFO) + logging.info(args) + main(args) diff --git a/tests/python/quantize/test_pass_quantize.py b/tests/python/quantize/test_pass_quantize.py index 3a4f6afe89f93..9f5670736a4c7 100644 --- a/tests/python/quantize/test_pass_quantize.py +++ b/tests/python/quantize/test_pass_quantize.py @@ -11,7 +11,7 @@ def test_simulated_quantize(): bit = relay.var("bit") clip_min = relay.var("clip_min") clip_max = relay.var("clip_max") - out = relay.simulated_quantize(data, scale, bit, clip_min, clip_max, sign=True, rounding='round') + out = qtz.simulated_quantize(data, scale, bit, clip_min, clip_max, sign=True, rounding='round', kind=0) out = relay.ir_pass.infer_type(out) assert out.checked_type == out.args[0].checked_type assert out.args[1].checked_type == relay.ty.TensorType(tuple(), "float32") @@ -33,7 +33,7 @@ def residual_block(data, cnt): data = relay.var("data", relay.TensorType((n, c, h, w), "float32")) out = data - for i in range(2): + for i in range(1): out = residual_block(out, i) out = relay.ir_pass.infer_type(out) @@ -42,7 +42,7 @@ def residual_block(data, cnt): def make_dataset(args, size=100): def create_arr(var): ttype = var.type_annotation - np_arr = np.random.rand(*ttype.concrete_shape).astype(ttype.dtype) + np_arr = np.random.uniform(-1.0, 1.0, size=ttype.concrete_shape).astype(ttype.dtype) return tvm.ndarray.array(np_arr) params = {} @@ -57,22 +57,37 @@ def create_arr(var): old_graph = relay.Function(args, out) dataset, params = make_dataset(args, 10) - with qtz.qconfig(skip_k_conv=1): + mod = qtz._build_module(old_graph, params) + mod.set_input(**dataset[0]) + mod.run() + out = mod.get_output(0) + + with qtz.qconfig(skip_k_conv=0, global_scale=4.0): + old_graph = relay.ir_pass.infer_type(old_graph) print('before:') print(old_graph.astext()) - qgraph = qtz.annotate(old_graph) - print('after annotate:') - print(qgraph.astext()) - print('\n') - qgraph = qtz.calibrate(qgraph, params, dataset) + qgraph = qtz.annotate(old_graph, params) + # print('after annotate:') + # print(qgraph.astext()) + # print('\n') + qgraph = qtz.calibrate(qgraph, dataset) + qgraph = relay.ir_pass.infer_type(qgraph) print('after calibrate:') print(qgraph.astext()) - raise ValueError print('\n') qgraph = qtz.realize(qgraph) print('after realize:') print(qgraph.astext()) + qmod = qtz._build_module(qgraph) + qmod.set_input(**dataset[0]) + qmod.run() + qout = qmod.get_output(0) + print('out:') + print(out) + print('qout:') + print(qout) + if __name__ == "__main__": test_simulated_quantize()