Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Mar 2, 2020
1 parent da89492 commit 5be737e
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 56 deletions.
139 changes: 83 additions & 56 deletions python/tvm/relay/frontend/qnn_torch.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,36 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-self, invalid-name, unused-argument
""" Functions to convert quantized torch models to QNN"""

import numpy as np

import torch

import tvm
import numpy as np
from tvm import relay
from tvm.relay import expr as _expr
from tvm.relay import op as _op
from tvm.relay.frontend.common import infer_shape


class QuantParam:
""" A placeholder for weight quantization parameters """

def __init__(self, weight, bias, scale, zero_point, param_key):
param_prefix = param_key[:-len("._packed_params")]
self.weight_var = _expr.var(param_prefix + "_weight",
Expand All @@ -26,7 +49,7 @@ def __init__(self, weight, bias, scale, zero_point, param_key):
self.zero_point = _expr.const(zero_point, dtype="int32")


def unpack_quant_params(param_name, packed_params, unpack_func):
def _unpack_quant_params(param_name, packed_params, unpack_func):
qweight, bias = unpack_func(packed_params)
weight_np = qweight.dequantize().numpy()

Expand All @@ -43,6 +66,7 @@ def unpack_quant_params(param_name, packed_params, unpack_func):


def get_weight_quant_params(script_module):
""" Retrive and unpack weight parameters from quantized modules """
conv_packed_params = []
linear_packed_params = []

Expand All @@ -64,14 +88,15 @@ def get_weight_quant_params(script_module):
assert param_name in state_dict
key = name + "." + param_name
packed_param = state_dict[param_name]
quant_params[key] = unpack_quant_params(key, packed_param,
unpack_func)
quant_params[key] = _unpack_quant_params(key, packed_param,
unpack_func)

return quant_params


def add_quant_params_to_outputs(outputs, output_index_map,
packed_param_map, quant_params):
""" Add quant params to outputs so that they can be referenced later """
for node_name, packed_param_name in packed_param_map.items():
qparam = quant_params[packed_param_name]
output_index_map[node_name] = len(outputs)
Expand All @@ -82,7 +107,7 @@ def add_quant_params_to_outputs(outputs, output_index_map,
outputs.append(param_tup)


def get_quant_param_for_input(input_value):
def _get_quant_param_for_input(input_value):
output_quant_param_indices = {
"aten::quantize_per_tensor": (1, 2),
"quantized::conv2d": (6, 7),
Expand All @@ -106,18 +131,18 @@ def dfs(current_node):
scale = current_node.inputsAt(indices[0])
zp = current_node.inputsAt(indices[1])
return scale, zp
else:
# Assume quantized tensor comes earlier in the args
for arg in current_node.inputs():
return dfs(arg.node())

# Assume quantized tensor comes earlier in the args
for arg in current_node.inputs():
return dfs(arg.node())

assert False, "No producer for %s" % (str(current_node))

return dfs(input_value.node())


def get_add_scalar_output_quant_param(input_scale, input_zero_point,
scalar):
def _get_add_scalar_output_quant_param(input_scale, input_zero_point,
scalar):
# refer to aten/src/ATen/native/quantized/cpu/qadd.cpp
q_min = 0
q_max = 255
Expand All @@ -139,8 +164,8 @@ def get_add_scalar_output_quant_param(input_scale, input_zero_point,
return s_prime, z_prime


def get_mul_scalar_output_quant_param(input_scale, input_zero_point,
scalar):
def _get_mul_scalar_output_quant_param(input_scale, input_zero_point,
scalar):
# refer to aten/src/ATen/native/quantized/cpu/qmul.cpp
q_min = 0
q_max = 255
Expand All @@ -161,19 +186,19 @@ def get_mul_scalar_output_quant_param(input_scale, input_zero_point,
return s_prime, z_prime


def add_output_quant_params_to_scalar_op(node, graph,
input_scale, input_zero_point,
scalar):
def _add_output_quant_params_to_scalar_op(node, graph,
input_scale, input_zero_point,
scalar):
operator = node.kind()

if operator == "quantized::mul_scalar":
out_scale, out_zero_point = \
get_mul_scalar_output_quant_param(input_scale, input_zero_point,
scalar)
_get_mul_scalar_output_quant_param(input_scale, input_zero_point,
scalar)
elif operator == "quantized::add_scalar":
out_scale, out_zero_point = \
get_add_scalar_output_quant_param(input_scale, input_zero_point,
scalar)
_get_add_scalar_output_quant_param(input_scale, input_zero_point,
scalar)
else:
assert False, "unsupported scalar op: %s" % operator

Expand All @@ -191,11 +216,13 @@ def add_output_quant_params_to_scalar_op(node, graph,


def add_input_quant_params_to_op_inputs(graph):
# Quantized operators in PyTorch do not take input quant params as
# arguments. But QNN expects them to be passed in as arguements.
# To simplify the translation of inputs, we add input quant params
# to inputs of PyTorch quantized operator nodes. See _impl in
# _quantized_conv2d() below for example of why this is helpful.
"""
Quantized operators in PyTorch do not take input quant params as
arguments. But QNN expects them to be passed in as arguements.
To simplify the translation of inputs, we add input quant params
to inputs of PyTorch quantized operator nodes. See _impl in
_quantized_conv2d() below for example of why this is helpful.
"""
num_quantized_inputs = {"quantized::conv2d": 1,
"quantized::conv2d_relu": 1,
"quantized::linear": 1,
Expand Down Expand Up @@ -227,12 +254,12 @@ def add_input_quant_params_to_op_inputs(graph):
if operator == "quantized::cat":
inputs = node.inputsAt(0).node().inputs()
for inp in inputs:
scale, zp = get_quant_param_for_input(inp)
scale, zp = _get_quant_param_for_input(inp)
input_scales.append(scale)
input_zero_points.append(zp)
else:
for i in range(num_quantized_inputs[operator]):
scale, zp = get_quant_param_for_input(node.inputsAt(i))
scale, zp = _get_quant_param_for_input(node.inputsAt(i))
input_scales.append(scale)
input_zero_points.append(zp)

Expand All @@ -241,16 +268,17 @@ def add_input_quant_params_to_op_inputs(graph):
inp_scale = input_scales[0].node().f("value")
inp_zero_point = input_zero_points[0].node().i("value")

add_output_quant_params_to_scalar_op(node, graph,
inp_scale, inp_zero_point,
scalar)
_add_output_quant_params_to_scalar_op(node, graph,
inp_scale, inp_zero_point,
scalar)

for scale, zp in zip(input_scales, input_zero_points):
node.addInput(scale)
node.addInput(zp)


def add_quant_params(params, quant_params):
""" Add quant parameters to TVM param map """
for qparam in quant_params.values():
params[qparam.weight_var.name_hint] = tvm.nd.array(qparam.weight)
if qparam.bias is not None:
Expand Down Expand Up @@ -283,31 +311,31 @@ def quantized_relu(data, input_zero_point):


def _quantize_per_tensor():
def _impl(inputs, input_type):
def _impl(inputs, _):
return relay.qnn.op.quantize(inputs[0], _expr.const(inputs[1]),
_expr.const(inputs[2]), out_dtype="uint8",
axis=1)
return _impl


def _dequantize():
def _impl(inputs, input_type):
def _impl(inputs, _):
inp_scale = _expr.const(inputs[1])
inp_zero_point = _expr.const(inputs[2])
return relay.qnn.op.dequantize(inputs[0], inp_scale, inp_zero_point)
return _impl


def get_numpy(relay_const_scalar):
def _get_numpy(relay_const_scalar):
return relay_const_scalar.data.asnumpy()


def get_scalar(relay_const_scalar):
return np.asscalar(get_numpy(relay_const_scalar))
def _get_scalar(relay_const_scalar):
return np.asscalar(_get_numpy(relay_const_scalar))


def _quantized_conv2d(with_relu=False):
def _impl(inputs, input_type):
def _impl(inputs, _):
# refer to src/ATen/native/quantized/cpu/qconv.cpp
# inputs[0]: input tensor
# inputs[1]: (weight, scale, zero_point, bias)
Expand Down Expand Up @@ -338,7 +366,7 @@ def _impl(inputs, input_type):
out_channels = weight_shape[0]

if padding[0] != 0 or padding[1] != 0:
pad_val = get_scalar(input_zero_point)
pad_val = _get_scalar(input_zero_point)
inp = _op.nn.pad(inputs[0], pad_width=((0, 0),
(0, 0),
(padding[0], padding[0]),
Expand All @@ -356,7 +384,7 @@ def _impl(inputs, input_type):
channels=out_channels)

# input scale * weight scale
requant_input_scale = _expr.const(inputs[8] * get_numpy(weight_scale))
requant_input_scale = _expr.const(inputs[8] * _get_numpy(weight_scale))
bias_var = inputs[1][3]

if bias_var is not None:
Expand All @@ -374,7 +402,7 @@ def _impl(inputs, input_type):
out_dtype="int32", axis=1)
clip_min = 0
if with_relu:
clip_min = get_scalar(output_zero_point)
clip_min = _get_scalar(output_zero_point)

clip = _op.tensor.clip(requantized, clip_min, 255.)
return _op.cast(clip, dtype="uint8")
Expand All @@ -383,7 +411,7 @@ def _impl(inputs, input_type):


def _binop(relay_op, with_relu=False):
def _impl(inputs, input_type):
def _impl(inputs, _):
output_scale = _expr.const(inputs[2])
output_zero_point = _expr.const(inputs[3])
assert len(inputs) == 8, "Input quant params not found in op inputs"
Expand Down Expand Up @@ -421,7 +449,7 @@ def _impl(inputs, input_type):


def _linear(with_relu=False):
def _impl(inputs, input_type):
def _impl(inputs, _):
weight = inputs[1][0]
weight_scale = inputs[1][1]
weight_zero_point = inputs[1][2]
Expand All @@ -437,7 +465,7 @@ def _impl(inputs, input_type):
input_scale, weight_scale,
units=weight_shape[0])

requant_input_scale = _expr.const(inputs[4] * get_numpy(weight_scale))
requant_input_scale = _expr.const(inputs[4] * _get_numpy(weight_scale))
bias_var = inputs[1][3]

if bias_var is not None:
Expand All @@ -455,7 +483,7 @@ def _impl(inputs, input_type):
out_dtype="int32", axis=1)
clip_min = 0
if with_relu:
clip_min = get_scalar(output_zero_point)
clip_min = _get_scalar(output_zero_point)

clip = _op.tensor.clip(requantized, clip_min, 255.)
return _op.cast(clip, dtype="uint8")
Expand All @@ -464,7 +492,7 @@ def _impl(inputs, input_type):


def _cat():
def _impl(inputs, input_type):
def _impl(inputs, _):
axis = inputs[1]
output_scale = _expr.const(inputs[2])
output_zero_point = _expr.const(inputs[3])
Expand All @@ -485,7 +513,7 @@ def _impl(inputs, input_type):


def _add_scalar():
def _impl(inputs, input_type):
def _impl(inputs, _):
# refer to aten/src/ATen/native/quantized/cpu/qadd.cpp
assert len(inputs) == 6, "Input quant params not found in op inputs"
s = inputs[4]
Expand All @@ -504,9 +532,8 @@ def _impl(inputs, input_type):
dequantized_add = _op.tensor.add(dequant, _expr.const(c_q * s))
return relay.qnn.op.quantize(dequantized_add, out_scale, out_zp,
axis=1, out_dtype="uint8")
else:
# only scale change
return inputs[0]
# only scale change
return inputs[0]

return _impl

Expand All @@ -517,7 +544,7 @@ def quantize_scalar(data, scale, zero_point):


def _relu6():
def _impl(inputs, input_type):
def _impl(inputs, _):
assert len(inputs) == 4, "Input quant params not found in op inputs"
input_scale = inputs[2]
input_zero_point = inputs[3]
Expand All @@ -527,7 +554,7 @@ def _impl(inputs, input_type):


def _mul_scalar():
def _impl(inputs, input_type):
def _impl(inputs, _):
# refer to aten/src/ATen/native/quantized/cpu/qmul.cpp
assert len(inputs) == 6, "Input quant params not found in op inputs"
other_val = inputs[1] # scalar
Expand All @@ -538,12 +565,12 @@ def _impl(inputs, input_type):
elif other_val == 0.0:
shape = infer_shape(inputs[0])
return _op.full(_expr.const(0), shape, dtype="uint8")
else:
q_min = 0
q_max = 255
bias = _expr.const(q_max + q_min, dtype="int8")
int8 = bias - _op.cast(inputs[0], "int8")
return _op.cast(int8, "uint8")

q_min = 0
q_max = 255
bias = _expr.const(q_max + q_min, dtype="int8")
int8 = bias - _op.cast(inputs[0], "int8")
return _op.cast(int8, "uint8")

return _impl

Expand Down
18 changes: 18 additions & 0 deletions tests/python/frontend/pytorch/qnn_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-self, invalid-name, unused-argument
""" Tests on quantized torch model conversion """
import os

from PIL import Image
Expand Down

0 comments on commit 5be737e

Please sign in to comment.