diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index bfd2a7731418..36ccd3fb6c89 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -139,7 +139,6 @@ struct StridedSliceAttrs : public tvm::AttrsNode { } }; - struct SliceLikeAttrs : public tvm::AttrsNode { Array axes; @@ -151,16 +150,16 @@ struct SliceLikeAttrs : public tvm::AttrsNode { } }; -// Clip +/*! \brief Attributes for Clip operator */ struct ClipAttrs : public tvm::AttrsNode { double a_min; double a_max; TVM_DECLARE_ATTRS(ClipAttrs, "relay.attrs.ClipAttrs") { - TVM_ATTR_FIELD(a_min) - .describe("The minimum clip value."); - TVM_ATTR_FIELD(a_max) - .describe("The maximum clip value."); + TVM_ATTR_FIELD(a_min) + .describe("The minimum clip value."); + TVM_ATTR_FIELD(a_max) + .describe("The maximum clip value."); } }; diff --git a/include/tvm/relay/op.h b/include/tvm/relay/op.h index d3c5edd31461..0fd54ff5b8fa 100644 --- a/include/tvm/relay/op.h +++ b/include/tvm/relay/op.h @@ -551,6 +551,7 @@ inline ValueType OpMap::get(const Expr& expr, return map_.get(expr, def_value); } + /*! * \brief Check that an expression is a "primtive operator". * diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 44d819ea78a3..b9d4695b70f8 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -8,7 +8,7 @@ from . import expr_functor from . import module from . import ir_pass -from .build_module import build, build_config, create_executor +from .build_module import build, build_config, create_executor, optimize from . import parser from . import debug @@ -23,6 +23,7 @@ from . import image from . import frontend from . import backend +from . import quantize from .scope_builder import ScopeBuilder diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 51a4ff873e0a..9641e0fd6fef 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -129,7 +129,7 @@ def _bind_params_by_name(func, params): return expr.bind(func, bind_dict) -def optimize(func, target, params=None): +def optimize(func, target=None, params=None): """Perform target invariant optimizations. Parameters @@ -400,7 +400,7 @@ def _make_executor(self, func): graph_json, mod, params = build(func, target=self.target) gmodule = _graph_rt.create(graph_json, mod, self.ctx) if params: - gmodule.set_input(*params) + gmodule.set_input(**params) def _graph_wrapper(*args, **kwargs): args = self._convert_args(func, args, kwargs) diff --git a/python/tvm/relay/quantize/__init__.py b/python/tvm/relay/quantize/__init__.py new file mode 100644 index 000000000000..bdb87b55518b --- /dev/null +++ b/python/tvm/relay/quantize/__init__.py @@ -0,0 +1,6 @@ +#pylint: disable=wildcard-import, redefined-builtin +"""Automatic quantization utilities.""" +from __future__ import absolute_import as _abs + +from .quantize import * +from ._annotate import register_annotate_function diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py new file mode 100644 index 000000000000..7eb8af57a70b --- /dev/null +++ b/python/tvm/relay/quantize/_annotate.py @@ -0,0 +1,246 @@ +#pylint: disable=unused-argument +"""Internal module for registering attribute for annotation.""" +from __future__ import absolute_import + +import topi +from . import _quantize +from .quantize import QAnnotateKind, current_qconfig +from .quantize import _conv_counter, _set_conv_counter +from .. import expr as _expr +from .. import op as _op +from ..op import op as _reg +from ..base import register_relay_node +from ..._ffi.function import register_func + + +@_reg.register_compute("relay.op.annotation.simulated_quantize") +def simulated_quantize_compute(attrs, inputs, out_type, target): + """Compiler for simulated_quantize.""" + assert len(inputs) == 4 + assert attrs.sign + assert attrs.rounding == "round" + + data, scale, clip_min, clip_max = inputs + + # simulate rounding error + scaled_data = topi.divide(data, scale) + clipped_data = topi.maximum(topi.minimum(scaled_data, clip_max), clip_min) + round_data = topi.round(clipped_data) + + # recover data + rdata = topi.multiply(round_data, scale) + return [rdata] + + +_reg.register_schedule("relay.op.annotation.simulated_quantize", + _reg.schedule_injective) +_reg.register_pattern("relay.op.annotation.simulated_quantize", + _reg.OpPattern.OPAQUE) + + +@register_relay_node +class QAnnotateExpr(_expr.TempExpr): + """A special kind of Expr for Annotating. + + Parameters + --------- + expr: Expr + the original relay ir expr. + + kind: QAnnotateKind + the kind of annotation field. + """ + def __init__(self, expr, kind): + self.__init_handle_by_constructor__( + _quantize.make_annotate_expr, expr, kind) + + +def _forward_op(ref_call, args): + """forward the operator of ref_call with provided arguments""" + return _expr.Call( + ref_call.op, args, ref_call.attrs, ref_call.type_args) + + +def _get_expr_kind(anno): + """Get the expression and QAnnotateKind from QAnnotateExpr or Expr""" + if isinstance(anno, QAnnotateExpr): + return anno.expr, anno.kind + return anno, None + + +def register_annotate_function(op_name, frewrite=None, level=10): + """register a rewrite function for operator, used by annotation. + + Parameters + --------- + op_name: str + The name of operation + + frewrite : function, optional + The function to be registered. + + level : int, optional + The priority level + """ + def default_rewrite(ref_call, new_args, ctx): + # recover from QAnnotateExpr + args = [_get_expr_kind(x)[0] for x in new_args] + return _forward_op(ref_call, args) + + def _register(func): + """internal register function""" + def frewrite_with_guard(ref_call, new_args, ctx): + if not current_qconfig().guard(ref_call): + return default_rewrite(ref_call, new_args, ctx) + return func(ref_call, new_args, ctx) + _op.op._Register(op_name, "FQAnnotateRewrite", frewrite_with_guard, level) + return frewrite_with_guard + + return _register(frewrite) if frewrite is not None else _register + + +@register_func("relay.quantize.attach_simulated_quantize") +def attach_simulated_quantize(data, kind, sign=True, rounding="round"): + """Attach a simulated quantize operation after input data expr. + + Parameters + --------- + data: Expr + the original data expr. + + kind: QAnnotateKind + the kind of annotation field. + """ + dom_scale = _expr.var("dom_scale") + clip_min = _expr.var("clip_min") + clip_max = _expr.var("clip_max") + return _quantize.simulated_quantize( + data, dom_scale, clip_min, clip_max, kind, sign, rounding) + + +@register_annotate_function("nn.conv2d") +def conv2d_rewrite(ref_call, new_args, ctx): + """Rewrite function for conv2d. Lhs of conv will be quantized to + input field, and rhs of conv will be quantized to weight field. + Output would be in activation field""" + cnt = _conv_counter() + if cnt < current_qconfig().skip_k_conv: + _set_conv_counter(cnt + 1) + return None + _set_conv_counter(cnt + 1) + + lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) + rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) + + if lhs_kind is None or lhs_kind != QAnnotateKind.INPUT: + lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT) + + assert rhs_kind is None + rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) + + expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) + return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) + + +@register_annotate_function("multiply") +def multiply_rewrite(ref_call, new_args, ctx): + """Rewrite function for multiply.""" + if _conv_counter() <= current_qconfig().skip_k_conv: + return None + + lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) + rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) + + if lhs_kind is None and rhs_kind is None: + return None + if lhs_kind == QAnnotateKind.ACTIVATION and rhs_kind is None: + # quantize lhs to INPUT field + lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT) + # quantize rhs to WEIGHT field + rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) + expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) + return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) + raise ValueError + + +@register_annotate_function("add") +def add_rewrite(ref_call, new_args, ctx): + """Rewrite function for add.""" + if _conv_counter() <= current_qconfig().skip_k_conv: + return None + + lhs_expr, lhs_kind = _get_expr_kind(new_args[0]) + rhs_expr, rhs_kind = _get_expr_kind(new_args[1]) + + if lhs_kind is None and rhs_kind is None: + return None + if lhs_kind is None and rhs_kind is not None: + # quantize lhs to INPUT field if it is normal expression + lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT) + if lhs_kind is not None and rhs_kind is None: + if isinstance(rhs_expr, _expr.Constant): + # quantize rhs to WEIGHT field if it is Constant + rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) + else: + # quantize rhs to INPUT field if it is not Constant + rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT) + + expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) + return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) + + +def identity_rewrite(ref_call, new_args, ctx): + """Simply forward the original operation""" + if _conv_counter() <= current_qconfig().skip_k_conv: + return None + + x_expr, x_kind = _get_expr_kind(new_args[0]) + if x_kind is None: + return None + + ret_expr = _forward_op(ref_call, [x_expr]) + return QAnnotateExpr(ret_expr, x_kind) + + +register_annotate_function("nn.relu", identity_rewrite) +register_annotate_function("strided_slice", identity_rewrite) +register_annotate_function("nn.avg_pool2d", identity_rewrite) + + +def pool2d_rewrite(ref_call, new_args, ctx): + """Rewrite function for max pool2d""" + if _conv_counter() <= current_qconfig().skip_k_conv: + return None + expr, x_kind = _get_expr_kind(new_args[0]) + + if x_kind is None: + return None + if x_kind == QAnnotateKind.ACTIVATION: + expr = attach_simulated_quantize(expr, QAnnotateKind.INPUT) + expr = _forward_op(ref_call, [expr]) + return QAnnotateExpr(expr, QAnnotateKind.INPUT) + + +register_annotate_function("nn.max_pool2d", pool2d_rewrite) + + +@register_annotate_function("concatenate") +def concatenate_rewrite(ref_call, new_args, ctx): + """Rewrite function for concatenate""" + if _conv_counter() <= current_qconfig().skip_k_conv: + return None + + input_tuple = new_args[0] + expr_list = [_get_expr_kind(x)[0] for x in input_tuple] + kind_list = [_get_expr_kind(x)[1] for x in input_tuple] + + # make sure the inputs of concatenate are all normal + # expression or annotate expression + if kind_list[0] is None: + for k in kind_list: + assert k is None + return None + for k in kind_list: + assert k is not None + expr = _forward_op(ref_call, [_expr.Tuple(expr_list)]) + return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) diff --git a/python/tvm/relay/quantize/_quantize.py b/python/tvm/relay/quantize/_quantize.py new file mode 100644 index 000000000000..32f67cdc3d85 --- /dev/null +++ b/python/tvm/relay/quantize/_quantize.py @@ -0,0 +1,6 @@ +#pylint: disable=unused-argument +"""Internal module for quantization.""" +from __future__ import absolute_import +from tvm._ffi.function import _init_api + +_init_api("relay._quantize", __name__) diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py new file mode 100644 index 000000000000..6756090f14a7 --- /dev/null +++ b/python/tvm/relay/quantize/quantize.py @@ -0,0 +1,285 @@ +#pylint: disable=unused-argument +"""Automatic quantization toolkit.""" +from __future__ import absolute_import +import numpy as np + +from . import _quantize +from .. import expr as _expr +from .. import ir_pass as _ir_pass +from .. import build_module as _build +from .. import op as _op +from ... import make as _make +from ..base import NodeBase, register_relay_node + + +class QAnnotateKind(object): + """Denote the kind of annotation field, corresponding + to different nbit configure.""" + INPUT = 1 + WEIGHT = 2 + ACTIVATION = 3 + + +def kind2str(kind): + """Convert a `QAnnotateKind` to string""" + str_map = { + QAnnotateKind.INPUT: "input", + QAnnotateKind.WEIGHT: "weight", + QAnnotateKind.ACTIVATION: "activation", + } + assert kind in str_map + return str_map[kind] + + +@register_relay_node("relay.quantize.QConfig") +class QConfig(NodeBase): + """Configure the quantization behavior by setting config variables. + + Note + ---- + This object is backed by node system in C++, with arguments that can be + exchanged between python and C++. + + Do not construct directly, use qconfig instead. + + The fields that are backed by the C++ node are immutable once an instance + is constructed. See _node_defaults for the fields. + """ + + _node_defaults = { + "nbit_input": 8, + "nbit_weight": 8, + "nbit_activation": 32, + "dtype_input": "int8", + "dtype_weight": "int8", + "dtype_activation": "int32", + "global_scale": 8.0, + "skip_k_conv": 1, + "round_for_shift": True, + "store_lowbit_output": True, + "debug_enabled_ops": None, + } + + # pylint: disable=no-member + def __init__(self, handle): + """Initialize the function with handle + + Parameters + ---------- + handle : SymbolHandle + the handle to the underlying C++ Symbol + """ + super(QConfig, self).__init__(handle) + self.handle = handle + + def guard(self, ref_call): + op_name = ref_call.op.name + if self.debug_enabled_ops is not None: + name_list = [x.value for x in self.debug_enabled_ops] + if op_name not in name_list: + return False + return True + + def get_nbit_by_kind(self, kind): + name = kind2str(kind) + return getattr(self, 'nbit_' + name) + + def get_dtype_by_kind(self, kind): + name = kind2str(kind) + return getattr(self, 'dtype_' + name) + + def __enter__(self): + # pylint: disable=protected-access + _quantize._EnterQConfigScope(self) + return self + + def __exit__(self, ptype, value, trace): + _quantize._ExitQConfigScope(self) + + def __setattr__(self, name, value): + if name in QConfig._node_defaults: + raise AttributeError( + "'%s' object cannot set attribute '%s'" % (str(type(self)), name)) + return super(QConfig, self).__setattr__(name, value) + + +def current_qconfig(): + """Get the current quantization configuration.""" + return _quantize._GetCurrentQConfig() + + +def qconfig(**kwargs): + """Configure the quantization behavior by setting config variables. + + Parameters + --------- + nbit_dict: dict of QAnnotateKind -> int + Number of bit for every kind of annotate field. + + global_scale: float + The global scale for calibration. + + skip_k_conv: int + The number of skipped conv2d. + + round_for_shift: boolean + Whether to add bias for rounding during shift. + + store_lowbit_output: boolean + Whether to store low-bit integer back as output before dequantizing. + Some accelerators need this, e.g. VTA. + + Returns + ------- + config: QConfig + The quantization configuration + """ + node_args = {k: v if k not in kwargs else kwargs[k] + for k, v in QConfig._node_defaults.items()} + return _make.node("relay.quantize.QConfig", **node_args) + + +CONV_COUNTER = 0 + + +def _conv_counter(): + """Get the global counter for conv2d.""" + return CONV_COUNTER + + +def _set_conv_counter(n): + """Set the value of the global conv2d counter.""" + global CONV_COUNTER + CONV_COUNTER = n + + +def annotate(graph): + """Given a float32 graph, annotate will rewrite the graph + and return back a graph which simulates the error brought by + current quantization scheme. + + Parameters + --------- + graph: Function + The original graph + + Returns + ------- + ret: Function + The graph after annotation + """ + _set_conv_counter(0) # reset counter + return _quantize.annotate(graph) + + +def calibrate(graph, dataset=None): + """The calibrate procedure will try to calculate the content of + dom_scale, nbit, clip_min, clip_max for every `simulated_quantize` + operator. + + Parameters + --------- + graph: Function + The simulation graph after annotation. + + dataset: list of dict of Var -> NDArray + The calibration dataset. + + Returns + ------- + ret: Function + The graph after calibration + """ + def power2_scale(arr): + """calculate weight scale with nearest mode-2 scale""" + val = np.amax(np.abs(arr.asnumpy())) + return 2**np.math.ceil(np.math.log(val, 2)) if val > 0 else 1.0 + + cfg = current_qconfig() + const_params = {} + quantize_op = _op.get("relay.op.annotation.simulated_quantize") + + def visit_func(expr): + """Internal visit function""" + if isinstance(expr, _expr.Call) and expr.op == quantize_op: + _, ndom_scale, nclip_min, nclip_max = expr.args + attrs = expr.attrs + kind = attrs.kind + nbit = cfg.get_nbit_by_kind(kind) + + valid_bit = nbit - attrs.sign + + if kind == QAnnotateKind.WEIGHT: + var = expr.args[0] + assert isinstance(var, _expr.Constant) + scale = power2_scale(var.data) + else: + scale = cfg.global_scale + + def _make_const(val): + return _expr.const(val, 'float32') + + valid_range = 2**valid_bit + const_params[ndom_scale] = _make_const(scale / valid_range) + const_params[nclip_min] = _make_const(- (valid_range - 1)) + const_params[nclip_max] = _make_const((valid_range - 1)) + + _ir_pass.post_order_visit(graph, visit_func) + return _expr.bind(graph, const_params) + + +def realize(graph): + """The realize pass will transform the simulated quantized + graph, which computes with float32 actually, to a real low-bit + integer graph. It will replace the simulated_quantize with + several fine-grained operators like add, multiply, and shift + as more as possible for performance (fusion, etc.) + + Parameters + --------- + graph: Function + The simulated graph after calibrating. + + Returns + ------- + ret: Function + The graph after realization + """ + return _quantize.realize(graph) + + +def quantize(graph, params=None, dataset=None): + """ The quantization procedure. Before running the three main + procedure of quantization, "annotate", "calibrate" and "realize" + , we need to do "SimplifyInference", "FoldScaleAxis", "FoldConstant" + first for optimizing. + + Parameters + --------- + graph: Function + The original graph. + + params : dict of str to NDArray + Input parameters to the graph that do not change + during inference time. Used for constant folding. + + dataset: list of dict of Var -> NDArray + The calibration dataset. + + Returns + ------- + ret: Function + The graph after quantization + """ + opt_passes = ["SimplifyInference", + "FoldScaleAxis", + "FoldConstant", + "CanonicalizeOps"] + with _build.build_config(add_pass=opt_passes): + graph = _build.optimize(graph, params=params) + + graph = annotate(graph) + graph = calibrate(graph, dataset) + graph = realize(graph) + graph = _ir_pass.fold_constant(graph) + return graph diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index c1719e81a6c6..e7b4a918c984 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -228,11 +228,11 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { void ExprVisitor::VisitType(const Type& t) { return; } - // visitor to implement apply class ExprApplyVisit : public ExprVisitor { public: explicit ExprApplyVisit(std::function f) : f_(f) {} + void VisitExpr(const Expr& e) final { if (visited_.count(e.get()) != 0) return; visited_.insert(e.get()); @@ -257,7 +257,6 @@ TVM_REGISTER_API("relay._ir_pass.post_order_visit") }); }); - // Implement bind. class ExprBinder : public ExprMutator { public: diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 6d583bfd6636..f4d625195a2f 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -1601,7 +1601,6 @@ RELAY_REGISTER_OP("slice_like") .set_attr("FTVMCompute", SliceLikeCompute) .set_attr("TOpPattern", kInjective); - // relay.layout_transform Array LayoutTransformCompute(const Attrs& attrs, const Array& inputs, diff --git a/src/relay/op/tensor/unary.cc b/src/relay/op/tensor/unary.cc index b83fdacda1ee..06720d67713c 100644 --- a/src/relay/op/tensor/unary.cc +++ b/src/relay/op/tensor/unary.cc @@ -103,8 +103,10 @@ This function takes a tensor, a minimum value `a_min`, and a maximum value `a_ma .set_attr("TOpPattern", kElemWise) .set_attr("TOpIsStateful", false) .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) +.set_attrs_type_key("relay.attrs.ClipAttrs") .set_support_level(3); + RELAY_REGISTER_UNARY_OP("floor") .describe(R"code(Returns the floor of input array, computed element-wise. )code" TVM_ADD_FILELINE) diff --git a/src/relay/pass/pattern_util.h b/src/relay/pass/pattern_util.h index 6f1215812543..82111287a4e1 100644 --- a/src/relay/pass/pattern_util.h +++ b/src/relay/pass/pattern_util.h @@ -10,6 +10,7 @@ #include #include +#include #include #include #include @@ -19,6 +20,45 @@ namespace tvm { namespace relay { +/*! + * \brief Dispatch DataType to the C++ data type + * during runtime. + */ +#define TVM_DTYPE_DISPATCH(type, DType, ...) \ + if (type == Float(64)) { \ + typedef double DType; \ + {__VA_ARGS__} \ + } else if (type == Float(32)) { \ + typedef float DType; \ + {__VA_ARGS__} \ + } else if (type == Int(64)) { \ + typedef int64_t DType; \ + {__VA_ARGS__} \ + } else if (type == Int(32)) { \ + typedef int32_t DType; \ + {__VA_ARGS__} \ + } else if (type == Int(16)) { \ + typedef int16_t DType; \ + {__VA_ARGS__} \ + } else if (type == Int(8)) { \ + typedef int8_t DType; \ + {__VA_ARGS__} \ + } else if (type == UInt(64)) { \ + typedef uint64_t DType; \ + {__VA_ARGS__} \ + } else if (type == UInt(32)) { \ + typedef uint32_t DType; \ + {__VA_ARGS__} \ + } else if (type == UInt(16)) { \ + typedef uint16_t DType; \ + {__VA_ARGS__} \ + } else if (type == UInt(8)) { \ + typedef uint8_t DType; \ + {__VA_ARGS__} \ + } else { \ + LOG(FATAL) << "unknown data type " << type; \ + } + /*! * \brief Try to match lhs and rhs via broadcasting rule, such that: * @@ -145,9 +185,10 @@ inline int64_t GetConv2DSuperChannelsDim(const CallNode* call) { */ template inline Constant MakeConstantScalar(DataType dtype, T value) { - CHECK_EQ(sizeof(T) * 8, dtype.bits()) << "data type mismatch"; runtime::NDArray arr = runtime::NDArray::Empty({}, Type2TVMType(dtype), {kDLCPU, 0}); - *static_cast(arr->data) = value; + TVM_DTYPE_DISPATCH(dtype, DType, { + *static_cast(arr->data) = value; + }) return ConstantNode::make(arr); } @@ -168,6 +209,25 @@ inline Expr Log(Expr e) { static const Op& op = Op::Get("log"); return CallNode::make(op, {e}); } +/*! + * \brief Get an immediate scalar from a Constant expr. + * + * \param expr The Constant expr. + * \return A scalar with type T. + */ +template +T GetScalarFromConstant(Expr expr) { + const auto* n = expr.as(); + CHECK(n->is_scalar()); + return static_cast(n->data->data)[0]; +} + +inline Expr Cast(Expr x, DataType dtype) { + static const Op& op = Op::Get("cast"); + auto attrs = make_node(); + attrs->dtype = dtype; + return CallNode::make(op, {x}, Attrs(attrs), {}); +} inline Expr Negative(Expr x) { static const Op& op = Op::Get("negative"); @@ -181,12 +241,39 @@ inline Expr Sqrt(Expr x) { } +inline Expr Relu(Expr x) { + static const Op& op = Op::Get("nn.relu"); + return CallNode::make(op, {x}, Attrs(), {}); +} + + +inline Expr Round(Expr x) { + static const Op& op = Op::Get("round"); + return CallNode::make(op, {x}, Attrs(), {}); +} + + +inline Expr Clip(Expr x, double a_min, double a_max) { + static const Op& op = Op::Get("clip"); + auto attrs = make_node(); + attrs->a_min = a_min; + attrs->a_max = a_max; + return CallNode::make(op, {x}, Attrs(attrs), {}); +} + + inline Expr Add(Expr lhs, Expr rhs) { static const Op& op = Op::Get("add"); return CallNode::make(op, {lhs, rhs}, Attrs(), {}); } +inline Expr Substract(Expr lhs, Expr rhs) { + static const Op& op = Op::Get("subtract"); + return CallNode::make(op, {lhs, rhs}, Attrs(), {}); +} + + inline Expr Multiply(Expr lhs, Expr rhs) { static const Op& op = Op::Get("multiply"); return CallNode::make(op, {lhs, rhs}, Attrs(), {}); @@ -208,6 +295,24 @@ inline Expr OneLike(Expr e) { return CallNode::make(op, {e}); } +inline Expr Power(Expr lhs, Expr rhs) { + static const Op& op = Op::Get("power"); + return CallNode::make(op, {lhs, rhs}, Attrs(), {}); +} + + +inline Expr RightShift(Expr x, Expr nbit) { + static const Op& op = Op::Get("right_shift"); + return CallNode::make(op, {x, nbit}, Attrs(), {}); +} + + +inline Expr LeftShift(Expr x, Expr nbit) { + static const Op& op = Op::Get("left_shift"); + return CallNode::make(op, {x, nbit}, Attrs(), {}); +} + + inline Expr ReshapeLike(Expr lhs, Expr rhs) { static const Op& op = Op::Get("reshape_like"); return CallNode::make(op, {lhs, rhs}, Attrs(), {}); diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize.cc new file mode 100644 index 000000000000..497f4a20dcd4 --- /dev/null +++ b/src/relay/pass/quantize.cc @@ -0,0 +1,550 @@ +/*! + * Copyright (c) 2018 by Contributors + * + * \file quantize.cc + * + * \brief transform a graph to a low-bit graph + * for compression and acceleration. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "pattern_util.h" +#include "quantize.h" + + +namespace tvm { +namespace relay { +namespace quantize { + +/*! \brief Attribute for simulated quantize operator */ +struct SimulatedQuantizeAttrs : public tvm::AttrsNode { + int kind; + bool sign; + std::string rounding; + + TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") { + TVM_ATTR_FIELD(kind) + .describe("kind of field, hint for nbit/dtype configuration."); + TVM_ATTR_FIELD(sign).set_default(true) + .describe("whether to use signed data type."); + TVM_ATTR_FIELD(rounding).set_default("round") + .describe("rounding mode. Can be 'floor', 'ceil', 'round'"); + } +}; + +TVM_REGISTER_NODE_TYPE(SimulatedQuantizeAttrs); + +bool SimulatedQuantizeRel(const Array& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 5); + const auto param = attrs.as(); + CHECK(param != nullptr); + + const auto* data = types[0].as(); + CHECK(data != nullptr); + CHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty"; + + reporter->Assign(types[1], TensorTypeNode::make({}, Float(32))); // dom_scale + reporter->Assign(types[2], TensorTypeNode::make({}, Float(32))); // clip_min + reporter->Assign(types[3], TensorTypeNode::make({}, Float(32))); // clip_max + reporter->Assign(types[4], types[0]); // output + return true; +} + +RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize") +.describe(R"code(simulated quantize op)code" TVM_ADD_FILELINE) +.set_num_inputs(4) +.add_argument("data", "Tensor", "The input data.") +.add_argument("dom_scale", "Tensor", "The domain scale of input data. It should be a scalar") +.add_argument("clip_min", "Tensor", "lower bound. It should be a scalar") +.add_argument("clip_max", "Tensor", "upper bound. It should be a scalar") +.set_attrs_type_key("relay.attrs.SimulatedQuantizeAttrs") +.set_support_level(10) +.add_type_rel("SimulatedQuantize", SimulatedQuantizeRel); + +TVM_REGISTER_API("relay._quantize.simulated_quantize") +.set_body_typed( + [](Expr data, Expr dom_scale, Expr clip_min, Expr clip_max, + int kind, bool sign, std::string rounding) { + auto attrs = make_node(); + attrs->kind = kind; + attrs->sign = sign; + attrs->rounding = rounding; + static const Op& op = Op::Get("relay.op.annotation.simulated_quantize"); + return CallNode::make(op, {data, dom_scale, clip_min, clip_max}, Attrs(attrs), {}); + }); + + +// ============= +// annotate pass + +Expr QAnnotateExprNode::Realize() const { + const auto& cfg = QConfig::Current(); + if (cfg->store_lowbit_output) { + // store low bit output back for VTA + const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize"); + return (*f)(this->expr, static_cast(kQInput)); + } else { + return expr; + } +} + +QAnnotateExpr QAnnotateExprNode::make(Expr expr, QAnnotateKind kind) { + auto rnode = make_node(); + rnode->expr = expr; + rnode->kind = kind; + return QAnnotateExpr(rnode); +} + +TVM_REGISTER_API("relay._quantize.make_annotate_expr") +.set_body([](TVMArgs args, TVMRetValue *ret) { + *ret = QAnnotateExprNode::make(args[0], + static_cast(args[1].operator int())); + }); + + +TVM_REGISTER_API("relay._quantize.annotate") +.set_body_typed([] (const Expr& expr) { + std::function fmulti_ref = [](const Expr& e) { + if (e->derived_from()) { + const auto* n = e.as(); + CHECK(n); + const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize"); + Expr ret = (*f)(n->expr, static_cast(kQInput)); + return static_cast(QAnnotateExprNode::make(ret, kQInput)); + } + return e; + }; + return ForwardRewrite(expr, "FQAnnotateRewrite", nullptr, nullptr); +}); + + +// ============= +// realize pass + +Expr QRealizeIntExprNode::Realize() const { + const auto& cfg = QConfig::Current(); + Expr data = this->data; + if (cfg->store_lowbit_output) { + data = Cast(data, cfg->dtype_input); + } + // dequantize + data = Cast(data, Float(32)); + data = Multiply(data, this->dom_scale); + return data; +} + +QRealizeIntExpr QRealizeIntExprNode::make(Expr data, Expr dom_scale, DataType dtype) { + NodePtr n = make_node(); + n->data = std::move(data); + n->dom_scale = std::move(dom_scale); + n->dtype = std::move(dtype); + return QRealizeIntExpr(n); +} + + +inline Expr ForwardOp(const Call& ref_call, const Array& args) { + return CallNode::make(ref_call->op, + args, ref_call->attrs, ref_call->type_args); +} + + +/* calculate `data * s1 / s2`, use shift if possible */ +inline Expr MulAndDiv(Expr data, float s1, float s2) { + // here we assume the dtype of data is dtype activation + const QConfig& cfg = QConfig::Current(); + if (s1 == s2) return data; + + float factor = s1 / s2; + float shift_factor = std::log2(factor); + CHECK_GT(shift_factor, 0); + if (static_cast(shift_factor) == shift_factor) { + return LeftShift(data, MakeConstantScalar(cfg->dtype_activation, + static_cast(shift_factor))); + } else if (static_cast(factor) == factor) { + return Multiply(data, MakeConstantScalar(cfg->dtype_activation, factor)); + } else { + LOG(FATAL) << "fall back to float computation"; + data = Cast(data, Float(32)); + return Multiply(data, MakeConstantScalar(Float(32), factor)); + } +} + +Expr QuantizeRealize(const Call& ref_call, + const Array& new_args, + const NodeRef& ctx) { + const QConfig& cfg = QConfig::Current(); + // do not handle data type cast + const auto param = ref_call->attrs.as(); + CHECK_EQ(param->rounding, "round"); + + Expr dom_scale = new_args[1]; + Expr clip_min = new_args[2]; + Expr clip_max = new_args[3]; + + float dom_scale_imm = GetScalarFromConstant(dom_scale); + float clip_min_imm = GetScalarFromConstant(clip_min); + float clip_max_imm = GetScalarFromConstant(clip_max); + + // x * idom_scale = y * odom_scale + // => y = x * idom_scale / odom_scale + if (const auto* n = new_args[0].as()) { + Expr data = n->data; + float idom_scale_imm = GetScalarFromConstant(n->dom_scale); + float odom_scale_imm = GetScalarFromConstant(dom_scale); + float shift_nbit = std::log2(odom_scale_imm / idom_scale_imm); + // int32->int8 + CHECK_GT(shift_nbit, 0); + if (static_cast(shift_nbit) == shift_nbit) { + // use shift + if (cfg->round_for_shift) { + float round_bias = std::pow(2.0, shift_nbit - 1); + data = Add(data, MakeConstantScalar(cfg->dtype_activation, static_cast(round_bias))); + } + data = RightShift(data, MakeConstantScalar(cfg->dtype_activation, + static_cast(shift_nbit))); + data = Clip(data, clip_min_imm, clip_max_imm); + return QRealizeIntExprNode::make(data, dom_scale, n->dtype); + } else { + // float computation + data = Cast(data, Float(32)); + Expr scaled_data = Multiply(data, Divide(n->dom_scale, dom_scale)); + Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm); + return QRealizeIntExprNode::make(round_data, dom_scale, Float(32)); + } + } + + // quantize from real + CHECK(!new_args[0]->derived_from()); + Expr data = new_args[0]; + Expr scaled_data = Multiply(data, MakeConstantScalar(Float(32), 1 / dom_scale_imm)); + Expr round_data = Clip(Round(scaled_data), clip_min_imm, clip_max_imm); + return QRealizeIntExprNode::make(round_data, dom_scale, Float(32)); +} + +RELAY_REGISTER_OP("simulated_quantize") +.set_attr("FQRealizeRewrite", QuantizeRealize); + + +Expr Conv2dRealize(const Call& ref_call, + const Array& new_args, + const NodeRef& ctx) { + const QConfig& cfg = QConfig::Current(); + CHECK_EQ(new_args.size(), 2); + if (!new_args[0]->derived_from() && !new_args[1]->derived_from()) { + return Expr(nullptr); + } + const auto* lhs = new_args[0].as(); + CHECK(lhs); + const auto* rhs = new_args[1].as(); + CHECK(rhs); + + Expr ldata = lhs->data; + if (lhs->dtype != cfg->dtype_input) { + ldata = Cast(ldata, cfg->dtype_input); + } + Expr rdata = Cast(rhs->data, cfg->dtype_weight); + + const auto ref_attrs = ref_call->attrs.as(); + auto attrs = make_node(); + *attrs = *ref_attrs; + DataType out_dtype = cfg->dtype_activation; + attrs->out_dtype = out_dtype; + + Expr ret = CallNode::make(ref_call->op, + {ldata, rdata}, Attrs(attrs), ref_call->type_args); + Expr dom_scale = FoldConstant(Multiply(lhs->dom_scale, rhs->dom_scale)); + return QRealizeIntExprNode::make(ret, dom_scale, out_dtype); +} + +RELAY_REGISTER_OP("nn.conv2d") +.set_attr("FQRealizeRewrite", Conv2dRealize); + + +Expr MulRealize(const Call& ref_call, + const Array& new_args, + const NodeRef& ctx) { + const QConfig& cfg = QConfig::Current(); + CHECK_EQ(new_args.size(), 2); + if (new_args[0].as() && new_args[1].as()) { + // execute the operation with activation data type. + const auto* lhs = new_args[0].as(); + const auto* rhs = new_args[1].as(); + Expr ldata = lhs->data; + Expr rdata = rhs->data; + + DataType dtype = cfg->dtype_activation; + if (lhs->dtype == Float(32)) { + ldata = Cast(ldata, dtype); + } else { + CHECK_EQ(lhs->dtype, dtype); + } + if (rhs->dtype == Float(32)) { + rdata = Cast(rdata, dtype); + } else { + CHECK_EQ(rhs->dtype, dtype); + } + + Expr ret = ForwardOp(ref_call, {ldata, rdata}); + Expr dom_scale = FoldConstant(Multiply(lhs->dom_scale, rhs->dom_scale)); + return QRealizeIntExprNode::make(ret, dom_scale, dtype); + } + CHECK(!new_args[0]->derived_from() && !new_args[1]->derived_from()); + return Expr(nullptr); +} + +RELAY_REGISTER_OP("multiply") +.set_attr("FQRealizeRewrite", MulRealize); + + +float ChooseDomScale(const std::vector& nptrs) { + if (nptrs.size() == 2) { + // x = a * s1, y = b * s2 + // x + y = (a * s1 / s2 + b) * s2, if s1 > s2 + // = (a + b * s2 / s1) * s1, if s2 > s1 + float s1 = GetScalarFromConstant(nptrs[0]->dom_scale); + float s2 = GetScalarFromConstant(nptrs[1]->dom_scale); + return s1 > s2 ? s2 : s1; + } else { + const QConfig& cfg = QConfig::Current(); + float scale = cfg->global_scale; + return scale / std::pow(2.0, cfg->nbit_activation - 1); + } +} + + +/* \brief Unify the dom scale of arguments */ +Array UnifyDTypeScale(const Array& args, + DataType* dtype_ptr, + Expr* scale_ptr) { + const QConfig& cfg = QConfig::Current(); + + std::vector nptrs; + Array ret; + for (auto arg : args) { + const auto* nptr = arg.as(); + CHECK(nptr); + nptrs.push_back(nptr); + ret.push_back(nptr->data); + } + + // unify the data type + DataType dtype = cfg->dtype_activation; + for (size_t i = 0; i < ret.size(); ++i) { + if (nptrs[i]->dtype != dtype) { + ret.Set(i, Cast(ret[i], dtype)); + } + } + + // unify the dom_scale + float s = ChooseDomScale(nptrs); + Expr dom_scale = MakeConstantScalar(Float(32), s); + for (size_t i = 0; i < ret.size(); ++i) { + float cur_s = GetScalarFromConstant(nptrs[i]->dom_scale); + LOG(INFO) << "unify data scale from " << cur_s << " to " << s; + ret.Set(i, MulAndDiv(ret[i], cur_s, s)); + } + + *dtype_ptr = dtype; + *scale_ptr = dom_scale; + return ret; +} + +Expr AddRealize(const Call& ref_call, + const Array& new_args, + const NodeRef& ctx) { + CHECK_EQ(new_args.size(), 2); + if (new_args[0].as() && new_args[1].as()) { + DataType dtype; + Expr dom_scale; + Array ret_args = UnifyDTypeScale(new_args, &dtype, &dom_scale); + Expr ret = ForwardOp(ref_call, ret_args); + return QRealizeIntExprNode::make(ret, dom_scale, dtype); + } + CHECK(!new_args[0]->derived_from() && !new_args[1]->derived_from()); + return Expr(nullptr); +} + +RELAY_REGISTER_OP("add") +.set_attr("FQRealizeRewrite", AddRealize); + + +Expr ConcatenateRealize(const Call& ref_call, + const Array& new_args, + const NodeRef& ctx) { + CHECK_EQ(new_args.size(), 1); + + const auto* tuple = new_args[0].as(); + CHECK(tuple); + const Array& arr = tuple->fields; + + if (arr[0].as()) { + DataType dtype; + Expr dom_scale; + Array ret_args = UnifyDTypeScale(arr, &dtype, &dom_scale); + Expr ret = ForwardOp(ref_call, {TupleNode::make(ret_args)}); + return QRealizeIntExprNode::make(ret, dom_scale, dtype); + } else { + for (auto arg : new_args) { + CHECK(!arg->derived_from()); + } + return Expr(nullptr); + } +} + +RELAY_REGISTER_OP("concatenate") +.set_attr("FQRealizeRewrite", ConcatenateRealize); + + +/* \brief forward the original operator */ +Expr IdentityRealize(const Call& ref_call, + const Array& new_args, + const NodeRef& ctx) { + CHECK_EQ(new_args.size(), 1); + if (const auto* n = new_args[0].as()) { + Expr ret = ForwardOp(ref_call, {n->data}); + return QRealizeIntExprNode::make(ret, n->dom_scale, n->dtype); + } + CHECK(!new_args[0]->derived_from()); + return Expr(nullptr); +} + +RELAY_REGISTER_OP("nn.relu") +.set_attr("FQRealizeRewrite", IdentityRealize); + +RELAY_REGISTER_OP("strided_slice") +.set_attr("FQRealizeRewrite", IdentityRealize); + + +Expr MaxPoolRealize(const Call& ref_call, + const Array& new_args, + const NodeRef& ctx) { + const QConfig& cfg = QConfig::Current(); + CHECK_EQ(new_args.size(), 1); + if (const auto* n = new_args[0].as()) { + Expr data = Cast(n->data, cfg->dtype_input); + Expr ret = ForwardOp(ref_call, {data}); + return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_input); + } + CHECK(!new_args[0]->derived_from()); + return Expr(nullptr); +} + +RELAY_REGISTER_OP("nn.max_pool2d") +.set_attr("FQRealizeRewrite", MaxPoolRealize); + + +Expr AvgPoolRealize(const Call& ref_call, + const Array& new_args, + const NodeRef& ctx) { + const QConfig& cfg = QConfig::Current(); + CHECK_EQ(new_args.size(), 1); + if (const auto* n = new_args[0].as()) { + Expr data = n->data; + if (n->dtype != cfg->dtype_activation) { + data = Cast(n->data, cfg->dtype_activation); + } + Expr ret = ForwardOp(ref_call, {data}); + return QRealizeIntExprNode::make(ret, n->dom_scale, cfg->dtype_activation); + } + CHECK(!new_args[0]->derived_from()); + return Expr(nullptr); +} + +RELAY_REGISTER_OP("nn.avg_pool2d") +.set_attr("FQRealizeRewrite", AvgPoolRealize); + + +TVM_REGISTER_API("relay._quantize.realize") +.set_body_typed([](const Expr& e) { + Expr ret = ForwardRewrite(e, "FQRealizeRewrite", nullptr, nullptr); + return ret; +}); + + +// ============= +// qconfig + +QConfig qconfig() { + return QConfig(make_node()); +} + +/*! \brief Entry to hold the BuildConfig context stack. */ +struct TVMQConfigThreadLocalEntry { + /*! \brief The default build config if the stack is empty */ + QConfig default_config; + + /*! \brief The current build config context */ + std::stack context_stack; + + TVMQConfigThreadLocalEntry() : + default_config(qconfig()) { + } +}; + +/*! \brief Thread local store to hold the BuildConfig context stack. */ +typedef dmlc::ThreadLocalStore TVMQConfigThreadLocalStore; + +void QConfig::EnterQConfigScope(const QConfig& build_config) { + TVMQConfigThreadLocalEntry *entry = TVMQConfigThreadLocalStore::Get(); + entry->context_stack.push(build_config); +} + +void QConfig::ExitQConfigScope() { + TVMQConfigThreadLocalEntry *entry = TVMQConfigThreadLocalStore::Get(); + entry->context_stack.pop(); +} + +QConfig QConfig::Current() { + TVMQConfigThreadLocalEntry *entry = TVMQConfigThreadLocalStore::Get(); + if (entry->context_stack.size() > 0) { + return entry->context_stack.top(); + } + + return entry->default_config; +} + +TVM_REGISTER_NODE_TYPE(QConfigNode); + +TVM_STATIC_IR_FUNCTOR(IRPrinter, vtable) +.set_dispatch([](const QConfigNode *op, IRPrinter *p) { + p->stream << "qconfig("; + p->stream << "nbit_input=" << op->nbit_input << ", "; + p->stream << "nbit_weight=" << op->nbit_weight << ", "; + p->stream << "nbit_activation=" << op->nbit_activation << ", "; + p->stream << "global_scale=" << op->global_scale << ", "; + p->stream << "skip_k_conv==" << op->skip_k_conv << ", "; + p->stream << "round_for_shift==" << op->round_for_shift << ", "; + p->stream << "store_lowbit_output==" << op->store_lowbit_output << ", "; + p->stream << "debug_enabled_ops==" << op->debug_enabled_ops; + p->stream << ")"; +}); + +TVM_REGISTER_API("relay._quantize._GetCurrentQConfig") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = QConfig::Current(); + }); + +TVM_REGISTER_API("relay._quantize._EnterQConfigScope") +.set_body([](TVMArgs args, TVMRetValue* ret) { + QConfig target = args[0]; + QConfig::EnterQConfigScope(target); + }); + +TVM_REGISTER_API("relay._quantize._ExitQConfigScope") +.set_body([](TVMArgs args, TVMRetValue* ret) { + QConfig::ExitQConfigScope(); + }); + +} // namespace quantize +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/quantize.h b/src/relay/pass/quantize.h new file mode 100644 index 000000000000..ef44bf3718eb --- /dev/null +++ b/src/relay/pass/quantize.h @@ -0,0 +1,199 @@ +/*! + * Copyright (c) 2018 by Contributors. + * + * \file tvm/relay/pass/quantize.h + * \brief Header of definitions for quantization + */ +#ifndef TVM_RELAY_PASS_QUANTIZE_H_ +#define TVM_RELAY_PASS_QUANTIZE_H_ + +#include +#include +#include +#include "pattern_util.h" + +namespace tvm { +namespace relay { +namespace quantize { + +/*! \brief Kind of annotate field */ +enum QAnnotateKind : int { + kQInput = 1, + kQWeight = 2, + kQActivation = 3, +}; + +/*! + * \brief TempExpr used during annotate forward rewrite. + */ +class QAnnotateExpr; +/*! + * \brief TempExprNode used during annotate forward rewrite. + */ +class QAnnotateExprNode : public TempExprNode { + public: + /*! \brief The original expression */ + Expr expr; + /*! \brief The kind of annotate field */ + QAnnotateKind kind; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("expr", &expr); + v->Visit("kind", &kind); + } + + TVM_DLL static QAnnotateExpr make(Expr expr, QAnnotateKind kind); + + Expr Realize() const final; + + static constexpr const char* _type_key = "relay.QAnnotateExpr"; + TVM_DECLARE_NODE_TYPE_INFO(QAnnotateExprNode, TempExprNode); +}; + +RELAY_DEFINE_NODE_REF(QAnnotateExpr, QAnnotateExprNode, TempExpr); + + +/*! \brief TempExpr used during realize forward rewrite. */ +class QRealizeExpr; +/*! \brief TempExpr representing integer. */ +class QRealizeIntExpr; + +class QRealizeExprNode : public TempExprNode { + public: + /*! \brief The original expression */ + Expr data; + static constexpr const char* _type_key = "relay.quantize.QRealizeExpr"; + TVM_DECLARE_BASE_NODE_INFO(QRealizeExprNode, TempExprNode); +}; + +RELAY_DEFINE_NODE_REF(QRealizeExpr, QRealizeExprNode, TempExpr); + + +class QRealizeIntExprNode : public QRealizeExprNode { + public: + Expr dom_scale; + /*! \brief current data type */ + DataType dtype; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("data", &data); + v->Visit("dom_scale", &dom_scale); + v->Visit("dtype", &dtype); + } + + Expr Realize() const final; + + TVM_DLL static QRealizeIntExpr make(Expr data, Expr dom_scale, DataType dtype); + + static constexpr const char * _type_key = "relay.quantize.QRealizeIntExpr"; + TVM_DECLARE_NODE_TYPE_INFO(QRealizeIntExprNode, QRealizeExprNode); +}; + +RELAY_DEFINE_NODE_REF(QRealizeIntExpr, QRealizeIntExprNode, QRealizeExpr); + + +class QConfig; + +/*! +* \brief Container for build configuration options +*/ +class QConfigNode : public Node { + public: + int nbit_input = 8; + int nbit_weight = 8; + int nbit_activation = 32; + DataType dtype_input = Int(8); + DataType dtype_weight = Int(8); + DataType dtype_activation = Int(32); + double global_scale = 8.0; + int skip_k_conv = 1; + bool round_for_shift = true; + bool store_lowbit_output = true; + Array debug_enabled_ops = Array(NodePtr(nullptr)); + + void VisitAttrs(AttrVisitor* v) final { + v->Visit("nbit_input", &nbit_input); + v->Visit("nbit_weight", &nbit_weight); + v->Visit("nbit_activation", &nbit_activation); + v->Visit("dtype_input", &dtype_input); + v->Visit("dtype_weight", &dtype_weight); + v->Visit("dtype_activation", &dtype_activation); + v->Visit("global_scale", &global_scale); + v->Visit("skip_k_conv", &skip_k_conv); + v->Visit("round_for_shift", &round_for_shift); + v->Visit("store_lowbit_output", &store_lowbit_output); + v->Visit("debug_enabled_ops", &debug_enabled_ops); + } + + static constexpr const char* _type_key = "relay.quantize.QConfig"; + TVM_DECLARE_NODE_TYPE_INFO(QConfigNode, Node); +}; + +/*! +* \brief Container for build configuration options +*/ +class QConfig : public NodeRef { + public: + QConfig() {} + explicit QConfig(NodePtr n) : NodeRef(n) {} + + const QConfigNode* operator->() const { + return static_cast(node_.get()); + } + + QConfigNode* operator->() { + return static_cast(node_.get()); + } + + /*! + * \brief Push a new BuildConfig context onto the thread local stack. + * \param build_config The configuration to set as the current context. + */ + static void EnterQConfigScope(const QConfig& qconfig); + + /*! + * \brief Pop a build config off the thread local context stack, restoring the previous + * configuration as the current context. + */ + static void ExitQConfigScope(); + + /*! + * \brief Get the current BuildConfig context from thread local storage, or a default + * configuration if a BuildConfig scope has not been entered. + * \return The configuration that is the current context. + */ + static QConfig Current(); + + using ContainerType = QConfigNode; +}; + +/*! + * \brief RAII container to provide a scoped BuildConfig context. Pushes a configuration onto the + * context stack when constructed, and pops it when destructed. + */ +struct QConfigContext { + /*! + * \brief Enter a new BuildConfig context. The given BuildConfig becomes the new current + * context. When the BuildConfigContext is destructed, the previous context is restored. + * \param build_config The BuildConfig to set as the new current context. + */ + explicit QConfigContext(const QConfig& qconfig) { + QConfig::EnterQConfigScope(qconfig); + } + + /*! \brief Destructor. Pops the context off the thread local stack. */ + ~QConfigContext() { + QConfig::ExitQConfigScope(); + } +}; + +/*! +* \brief Construct a BuildConfig containing a new BuildConfigNode +* \return The new BuildConfig +*/ +TVM_DLL QConfig qconfig(); + +} // namespace quantize +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_PASS_QUANTIZE_H_ diff --git a/tests/python/relay/test_ad.py b/tests/python/relay/test_pass_gradient.py similarity index 100% rename from tests/python/relay/test_ad.py rename to tests/python/relay/test_pass_gradient.py diff --git a/tests/python/relay/test_pass_quantize.py b/tests/python/relay/test_pass_quantize.py new file mode 100644 index 000000000000..6d65d7b2d9ee --- /dev/null +++ b/tests/python/relay/test_pass_quantize.py @@ -0,0 +1,84 @@ +import math +import numpy as np +import tvm +from tvm import relay +from tvm.relay import quantize as qtz + + +def make_dataset(graph, size=100): + args = relay.ir_pass.infer_type(graph).params + def create_arr(var): + ttype = var.type_annotation + np_arr = np.random.uniform(-1.0, 1.0, size=ttype.concrete_shape).astype(ttype.dtype) + return tvm.ndarray.array(np_arr) + + params = {} + for arg in args: + if arg.name_hint == 'data': + dataset = [{'data': create_arr(arg)} for _ in range(size)] + else: + params[arg.name_hint] = create_arr(arg) + return dataset, params + + +def test_simulated_quantize(): + data = relay.var("data", relay.ty.TensorType((3, 4, 5, 6), "float32")) + out = qtz._annotate.attach_simulated_quantize(data, 1) + out = relay.ir_pass.infer_type(out) + assert out.checked_type == out.args[0].checked_type + assert out.args[1].checked_type == relay.ty.TensorType(tuple(), "float32") + assert out.args[2].checked_type == relay.ty.TensorType(tuple(), "float32") + assert out.args[3].checked_type == relay.ty.TensorType(tuple(), "float32") + + +def test_quantize_pass(): + def quantize_weight(arr): + maximum = np.amax(np.abs(arr.asnumpy())) + scale = 2**math.ceil(math.log(maximum, 2)) + out = np.around(arr.asnumpy() / scale * 128).astype('int8') + out = np.clip(out, -127, 127) + return relay.const(out, 'int8') + + n, c, h, w = 1, 3, 224, 224 + def make_graph(data): + weight = relay.var("conv_weight") + out = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1), channels=c) + out = relay.Function(relay.ir_pass.free_vars(out), out) + return out + + def make_qgraph(data, weight): + out = data * relay.const(32.0) + out = relay.round(out) + out = relay.clip(out, a_min=-127, a_max=127) + out = out.astype('int8') + + out = relay.nn.conv2d(out, weight, kernel_size=(3, 3), + padding=(1, 1), channels=c, out_dtype='int32') + out = out.astype('float32') + out = relay.multiply(out, relay.const(0.00024414062)) + out = relay.Function(relay.ir_pass.free_vars(out), out) + return out + + data = relay.var("data", relay.TensorType((n, c, h, w), "float32")) + graph = make_graph(data) + dataset, params = make_dataset(graph, 10) + + with qtz.qconfig(skip_k_conv=0, global_scale=4.0, + round_for_shift=False, store_lowbit_output=False): + qgraph0 = qtz.quantize(graph, params) + qgraph0 = relay.ir_pass.infer_type(qgraph0) + + conv_weight = quantize_weight(params['conv_weight']) + qgraph1 = make_qgraph(data, conv_weight) + qgraph1 = relay.ir_pass.infer_type(qgraph1) + + graph = relay.create_executor('graph') + res0 = graph.evaluate(qgraph0)(dataset[0]['data']) + res1 = graph.evaluate(qgraph1)(dataset[0]['data']) + tvm.testing.assert_allclose(res0.asnumpy(), res1.asnumpy()) + + +if __name__ == "__main__": + np.random.seed(42) + test_simulated_quantize() + test_quantize_pass()