Skip to content

Commit

Permalink
[Relay][Quantization] Per-Channel FQ2I (apache#8883)
Browse files Browse the repository at this point in the history
* WIP support per-channel quantization

* more WIP

* More WIP

* fix issue with per-channel bias_add

* Fix fake quantize tests (apache#4)

* Fixed fake quantize issues.

* Formatting.

* Cleanup unused imports

* Fix real int8 tests.

* Add Relu

* One more little one (apache#5)

* Fixed fake quantize issues.

* Formatting.

* Cleanup unused imports

* Fix real int8 tests.

* Fix requantize shape bug.

* Non-working Per-channel Dense

* Fix legalization for non spatial operators. (apache#6)

* Fix legalization for non spatial operators.

* Fix axis checks for end2end functionality.

* fix axis normalization

fix lint

fix lint again

* Per channel fq2i (apache#8)

* WIP support per-channel quantization

* more WIP

* More WIP

* fix issue with per-channel bias_add

* Fix fake quantize tests (apache#4)

* Fixed fake quantize issues.

* Formatting.

* Cleanup unused imports

* Fix real int8 tests.

* Add Relu

* One more little one (apache#5)

* Fixed fake quantize issues.

* Formatting.

* Cleanup unused imports

* Fix real int8 tests.

* Fix requantize shape bug.

* Non-working Per-channel Dense

* Fix legalization for non spatial operators. (apache#6)

* Fix legalization for non spatial operators.

* Fix axis checks for end2end functionality.

* fix axis normalization

fix lint

fix lint again

* Fix bug in requantize dimension expansion.

* Format.

Co-authored-by: Josh Fromm <jwfromm@octoml.ai>

* respond to review comments

respond to review comments

Co-authored-by: Josh Fromm <jwfromm@octoml.ai>
  • Loading branch information
2 people authored and ylc committed Jan 13, 2022
1 parent cc803b1 commit bd0eed6
Show file tree
Hide file tree
Showing 15 changed files with 315 additions and 59 deletions.
8 changes: 6 additions & 2 deletions include/tvm/ir/affine_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,24 +71,28 @@ class TensorAffineTypeNode : public AffineTypeNode {
RelayExpr zero_point;
/*! \brief The data type of this type */
DataType dtype;
/*! \brief The axis for per-channel quantization */
int axis;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("scale", &scale);
v->Visit("zero_point", &zero_point);
v->Visit("dtype", &dtype);
v->Visit("axis", &axis);
}

bool SEqualReduce(const TensorAffineTypeNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal(scale, other->scale) && equal(zero_point, other->zero_point) &&
equal(dtype, other->dtype);
equal(dtype, other->dtype) && equal(axis, other->axis);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce->MarkGraphNode();
hash_reduce(scale);
hash_reduce(zero_point);
hash_reduce(dtype);
hash_reduce(axis);
}

static constexpr const char* _type_key = "TensorAffineType";
Expand All @@ -101,7 +105,7 @@ class TensorAffineTypeNode : public AffineTypeNode {
*/
class TensorAffineType : public AffineType {
public:
TVM_DLL TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype);
TVM_DLL TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype, int axis);

TVM_DEFINE_OBJECT_REF_METHODS(TensorAffineType, AffineType, TensorAffineTypeNode);
};
Expand Down
9 changes: 7 additions & 2 deletions python/tvm/ir/affine_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,15 @@ class TensorAffineType(AffineType):
dtype : str
The content data type.
axis : int
The axis for per-channel quantization.
"""

def __init__(self, scale, zero_point, dtype):
self.__init_handle_by_constructor__(_ffi_api.TensorAffineType, scale, zero_point, dtype)
def __init__(self, scale, zero_point, dtype, axis=-1):
self.__init_handle_by_constructor__(
_ffi_api.TensorAffineType, scale, zero_point, dtype, axis
)


@tvm._ffi.register_object("TupleAffineType")
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ def _impl_v1(cls, inputs, attr, params):
attr["dilations"] = [1] + list(attr["dilations"])
if "pads" in attr:
attr["pads"] = [0, attr["pads"][0], 0, attr["pads"][1]]

attr["channels"] = kernel_shapes[0][0]
out = AttrCvt(
op_name=dimension_picker("conv"),
transforms={
Expand Down
33 changes: 29 additions & 4 deletions python/tvm/relay/qnn/op/legalizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import tvm
from tvm import relay
from tvm._ffi.base import TVMError
from .. import op as reg

#################################################
Expand Down Expand Up @@ -139,11 +140,35 @@ def helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay_op):
data, kernel, input_zero_point, kernel_zero_point, _, _ = inputs

shift_data = relay.subtract(
relay.cast(data, dtype="int16"), relay.cast(input_zero_point, "int16")
)
shift_kernel = relay.subtract(
relay.cast(kernel, dtype="int16"), relay.cast(kernel_zero_point, "int16")
relay.cast(data, dtype="int16"), relay.cast(input_zero_point, dtype="int16")
)
# If kernel zero point is a scalar we can directly subtract it.
if len(types[3].shape) == 0:
shift_kernel = relay.subtract(
relay.cast(kernel, dtype="int16"), relay.cast(kernel_zero_point, dtype="int16")
)
# Otherwise it needs to be broadcast.
else:
# Determine output axis of kernel for spatial operations.
if hasattr(attrs, "kernel_layout"):
output_axis = tvm.tir.layout(attrs["kernel_layout"]).index_of("O")
# For dense operations, broadcast to [N, K] layout.
elif isinstance(attrs, relay.op.op_attrs.DenseAttrs):
output_axis = 0
# For matrix multiplication instead expand to [K, N] layout.
elif isinstance(attrs, relay.op.op_attrs.MatmulAttrs):
output_axis = 1
else:
raise TVMError(
"Legalization of %s is not yet supported with per channel parameters"
% str(type(attrs))
)

shift_kernel = relay.nn.bias_add(
relay.cast(kernel, dtype="int16"),
relay.cast(kernel_zero_point, dtype="int16"),
output_axis,
)
new_attrs = {k: attrs[k] for k in attrs.keys()}
return relay_op(shift_data, shift_kernel, **new_attrs)

Expand Down
9 changes: 7 additions & 2 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,10 @@ def conv2d(
):
r"""Quantized 2D convolution.
This operator convolves quantized data with quantized kernel. The scale of
the output quantized tensor is the product of the kernel_scale and
This operator convolves quantized data with quantized kernel.
If doing Per-channel quantization, qnn expects the kernel_zero_scale
and optionally the kernel_zero_point will be 1-D vectors instead of scalars.
The scale of the output quantized tensor is the product of the kernel_scale and
input_scale of the input quantized tensors. The zero point of the output
quantized tensor is 0. By default, the dtype of output is int32. Please also
refer to Requantize operator to understand how to scale back the int32
Expand Down Expand Up @@ -544,6 +546,9 @@ def dense(
`Y = X * W`
If doing Per-channel quantization, qnn expects the kernel_zero_scale
and optionally the kernel_zero_point will be 1-D vectors instead of scalars.
Parameters
----------
data : tvm.relay.Expr
Expand Down
79 changes: 68 additions & 11 deletions python/tvm/relay/transform/fake_quantization_to_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,22 @@
import tvm
from tvm import relay
from tvm.ir import TensorAffineType, TupleAffineType
from tvm.tir import bijective_layout
from ..op import register_fake_quantization_to_integer


def fold_constant(expr):
return relay.transform.FoldConstantExpr(expr, tvm.IRModule())


def get_zeros(scale):
return fold_constant(relay.op.cast(relay.op.zeros_like(scale), "int32"))


def infer_shape(expr):
return relay.transform.InferType()(tvm.IRModule.from_expr(expr))["main"].body.checked_type.shape


@register_fake_quantization_to_integer("qnn.dequantize")
def dequantize(expr, type_map):
"""Remove dequantize op"""
Expand Down Expand Up @@ -52,8 +61,13 @@ def quantize(expr, type_map):
expr.args[1],
expr.args[2],
out_dtype=expr.attrs.out_dtype,
axis=t.axis,
)
return [out, TensorAffineType(expr.args[1], expr.args[2], expr.attrs.out_dtype)]

return [
out,
TensorAffineType(expr.args[1], expr.args[2], expr.attrs.out_dtype, expr.attrs.axis),
]


def register_unary_identity(op_name):
Expand Down Expand Up @@ -94,14 +108,19 @@ def bias_add(expr, type_map):
b_t = type_map[b]
in_scale = fold_constant(x_t.scale)
in_zero_point = fold_constant(x_t.zero_point)
if not tvm.ir.structural_equal(x_t, b_t):
if not (
tvm.ir.structural_equal(x_t.scale, b_t.scale)
and tvm.ir.structural_equal(x_t.zero_point, b_t.zero_point)
and tvm.ir.structural_equal(x_t.dtype, b_t.dtype)
):
b = relay.qnn.op.requantize(
b,
b_t.scale,
b_t.zero_point,
in_scale,
in_zero_point,
out_dtype=x_t.dtype,
axis=0,
)
out = relay.op.nn.bias_add(x, b, **expr.attrs)
return [out, x_t]
Expand All @@ -116,11 +135,13 @@ def conv2d(expr, type_map):
x_t = type_map[x]
w_t = type_map[weight]
conv_scale = fold_constant(x_t.scale * w_t.scale)
conv_zp = relay.const(0)
conv_zp = get_zeros(conv_scale)
out = relay.qnn.op.conv2d(
x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs
)
return [out, TensorAffineType(conv_scale, conv_zp, out.attrs.out_dtype)]
out_layout = attrs["out_layout"] if attrs["out_layout"] != "" else attrs["data_layout"]
out_axis = bijective_layout(out_layout, "NCHW").backward_index(list(range(4)))[1]
return [out, TensorAffineType(conv_scale, conv_zp, out.attrs.out_dtype, out_axis.value)]


@register_fake_quantization_to_integer("nn.dense")
Expand All @@ -132,11 +153,11 @@ def dense(expr, type_map):
x_t = type_map[x]
w_t = type_map[weight]
dense_scale = fold_constant(x_t.scale * w_t.scale)
dense_zp = relay.const(0)
dense_zp = get_zeros(dense_scale)
out = relay.qnn.op.dense(
x, weight, x_t.zero_point, w_t.zero_point, x_t.scale, w_t.scale, **attrs
)
return [out, TensorAffineType(dense_scale, dense_zp, out.attrs.out_dtype)]
return [out, TensorAffineType(dense_scale, dense_zp, out.attrs.out_dtype, 1)]


@register_fake_quantization_to_integer("nn.batch_matmul")
Expand All @@ -148,7 +169,7 @@ def batch_matmul(expr, type_map):
matmul_scale = fold_constant(x_t.scale * y_t.scale)
matmul_zp = relay.const(0)
out = relay.qnn.op.batch_matmul(x, y, x_t.zero_point, y_t.zero_point, x_t.scale, y_t.scale)
return [out, TensorAffineType(matmul_scale, matmul_zp, out.attrs.out_dtype)]
return [out, TensorAffineType(matmul_scale, matmul_zp, out.attrs.out_dtype, x_t.axis)]


@register_fake_quantization_to_integer("concatenate")
Expand Down Expand Up @@ -198,19 +219,52 @@ def clip(expr, type_map):
amax = expr.attrs.a_max
scale = fold_constant(t.scale)
z_p = fold_constant(t.zero_point)
if isinstance(scale, relay.expr.Constant) and isinstance(z_p, relay.expr.Constant):
if (
isinstance(scale, relay.expr.Constant)
and scale.data.numpy().size == 1
and isinstance(z_p, relay.expr.Constant)
and z_p.data.numpy().size == 1
):
scale = scale.data.numpy().item()
z_p = z_p.data.numpy().item()
new_min = int(amin / scale + z_p)
new_max = int(amax / scale + z_p)
out = relay.op.clip(arg, new_min, new_max)
else:
amin = relay.op.round(relay.op.const(amin) / scale + z_p)
amax = relay.op.round(relay.op.const(amax) / scale + z_p)
out = relay.op.minimum(relay.op.maximum(arg, amin), amax)
if not isinstance(amin, relay.expr.Constant):
amin = relay.op.const(amin)
if not isinstance(amax, relay.expr.Constant):
amax = relay.op.const(amax)

scale_shape = infer_shape(scale)
if len(scale_shape) > 0 and scale_shape[0] > 1:
b_shape = [1] * len(infer_shape(arg))
b_shape[t.axis] = -1
amin = relay.op.reshape(relay.op.broadcast_to(amin, scale_shape), b_shape)
amax = relay.op.reshape(relay.op.broadcast_to(amax, scale_shape), b_shape)
amin = relay.qnn.op.quantize(amin, scale, z_p, t.axis, t.dtype)
amax = relay.qnn.op.quantize(amax, scale, z_p, t.axis, t.dtype)
out = relay.op.minimum(relay.op.maximum(arg, fold_constant(amin)), fold_constant(amax))

return [out, t]


@register_fake_quantization_to_integer("nn.relu")
def relu(expr, type_map):
"""Rewrite a relu op"""
arg = expr.args[0]
t = type_map[arg]
scale_shape = infer_shape(t.scale)
z_p = t.zero_point
assert len(scale_shape) <= 1
if len(scale_shape) == 1 and scale_shape[0] > 1:
b_shape = [1] * len(infer_shape(arg))
b_shape[t.axis] = -1
z_p = relay.op.reshape(relay.op.broadcast_to(z_p, scale_shape), b_shape)
zero = relay.op.cast(z_p, t.dtype)
return [relay.op.maximum(arg, fold_constant(zero)), t]


@register_fake_quantization_to_integer("nn.pad")
def pad(expr, type_map):
"""Rewite an nn.pad op"""
Expand All @@ -231,6 +285,7 @@ def pad(expr, type_map):
t.scale,
t.zero_point,
out_dtype=t.dtype,
axis=pad_t.axis,
)
else:
## If the pad-value is a constant, we need to quantize it
Expand Down Expand Up @@ -319,6 +374,7 @@ def binary(expr, type_map):
out_t.scale,
out_t.zero_point,
out_dtype=out_t.dtype,
axis=left_t.axis,
)

if right_t != out_t:
Expand All @@ -329,6 +385,7 @@ def binary(expr, type_map):
out_t.scale,
out_t.zero_point,
out_dtype=out_t.dtype,
axis=right_t.axis,
)
out = op(left, right)
return [out, out_t]
Expand Down
10 changes: 6 additions & 4 deletions src/ir/affine_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,28 @@ namespace tvm {
using tvm::ReprPrinter;
using namespace tvm::runtime;

TensorAffineType::TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype) {
TensorAffineType::TensorAffineType(RelayExpr scale, RelayExpr zero_point, DataType dtype,
int axis) {
ObjectPtr<TensorAffineTypeNode> n = make_object<TensorAffineTypeNode>();
n->scale = std::move(scale);
n->zero_point = std::move(zero_point);
n->dtype = std::move(dtype);
n->axis = std::move(axis);
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(TensorAffineTypeNode);

TVM_REGISTER_GLOBAL("ir.TensorAffineType")
.set_body_typed([](RelayExpr scale, RelayExpr zero_point, DataType dtype) {
return TensorAffineType(scale, zero_point, dtype);
.set_body_typed([](RelayExpr scale, RelayExpr zero_point, DataType dtype, int axis) {
return TensorAffineType(scale, zero_point, dtype, axis);
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TensorAffineTypeNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const TensorAffineTypeNode*>(ref.get());
p->stream << "TensorAffineType(" << node->scale << ", " << node->zero_point << ", "
<< node->dtype << ")";
<< node->dtype << ", " << node->axis << ")";
});

TupleAffineType::TupleAffineType(Array<TensorAffineType> types) {
Expand Down
11 changes: 6 additions & 5 deletions src/relay/qnn/op/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& kernel_zero_point,
* \param input_zero_point The input zero point expr.
* \param param The qnn conv2d attributes.
* \param out_channels The number of output channels.
* \return The sequence of Relay operatos for term3.
* \return The sequence of Relay operators for term3.
* \note The term3 looks like this
*
* Sigma(c,r,s) zp_a * QW(k, c, r, s)
Expand Down Expand Up @@ -625,7 +625,7 @@ Expr Conv2DCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3,
* \node Lowering of the qnn.conv2d operator
* A quantized tensor is represented in following manner
* A = scale_a x (QA - zp_A)
* where QA is quantized tensor, scale_a and zp_A are quantizations
* where QA is quantized tensor, scale_a and zp_A are quantization
* params.
*
* Quantized convolution will convolve two quantized tensors and returns a
Expand Down Expand Up @@ -662,8 +662,8 @@ Expr Conv2DCombineTerms(const Expr& term1, const Expr& term2, const Expr& term3,
* a workaround, we fall back to simpler lowering using int32 conv if
* the conv is dilated. We fallback also in case of grouped conv.
*
* For depthwise, we can similarly unroll the computation. The intial compute is as follows
* wehere cm = channel_multiplier
* For depthwise, we can similarly unroll the computation. The initial compute is as follows
* where cm = channel_multiplier
*
* Qc(n, oc, oh, ow) = Sigma(r, s) (Qw(oc/m, oc%/m, r, s) - zp_w)
* * (Qa(n, oc/cm, oh + r, ow + s) - zp_a)
Expand Down Expand Up @@ -693,12 +693,13 @@ Expr QnnConv2DCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
Expr kernel_zero_point = new_args[3];
const auto* param = attrs.as<Conv2DAttrs>();
ICHECK(param != nullptr);
// Assertion checks for exisiing support.
// Assertion checks for existing support.
ICHECK(param->data_layout == "NCHW" || param->data_layout == "NHWC")
<< "qnn.conv2d supports only NCHW/NHWC input data layout.";
ICHECK(param->kernel_layout == "OIHW" || param->kernel_layout == "HWIO" ||
param->kernel_layout == "HWOI")
<< "qnn.conv2d supports only OIHW/HWIO/HWOI kernel data layout.";
ICHECK(param->kernel_size.defined()) << "qnn.conv2d requires kernel size to be specified.";

int batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier;
std::tie(batch_size, in_channels, out_channels, kernel_h, kernel_w, channel_multiplier) =
Expand Down
Loading

0 comments on commit bd0eed6

Please sign in to comment.