diff --git a/python/tvm/relay/qnn/op/qnn.py b/python/tvm/relay/qnn/op/qnn.py index 1717bc42fe946..2e9c0d83148a2 100644 --- a/python/tvm/relay/qnn/op/qnn.py +++ b/python/tvm/relay/qnn/op/qnn.py @@ -19,6 +19,7 @@ from __future__ import absolute_import as _abs from . import _make +from tvm import relay def requantize(data, input_scale, @@ -72,3 +73,55 @@ def requantize(data, output_zero_point, rounding, out_dtype) + +def concatenate(data, + input_scales, + input_zero_points, + output_scale, + output_zero_point, + output_dtype, + axis): + """Concatenate the quantized input tensors along the given axis. + + Parameters + ---------- + data : Union(List[relay.Expr], Tuple[relay.Expr]) + A list of quantized tensors + + input_scales : List[float32] + A list of scales of quantized tensors + + input_zero_points : List[int32] + A list of zero points of quantized tensors + + output_scale : float32 + A scales of output + + output_zero_point : int32 + A zero points of output + + axis : int + The axis along which the tensors are concatenated. + + Returns + ------- + result: relay.Expr + The concatenated tensor + """ + + data = list(data) + requantized_exprs = list(data) + # If the output qnn params do not match the input qnn params, we call requantize on the input + # params. + for idx, quantized_expr in enumerate(data): + scale = input_scales[idx] + zero_point = input_zero_points[idx] + if scale != output_scale or zero_point != output_zero_point: + requantized_exprs[idx] = requantize(quantized_expr, + input_scale=scale, + input_zero_point=zero_point, + output_scale=output_scale, + output_zero_point=output_zero_point, + out_dtype=output_dtype) + # As all tensors now share same qnn params, we can directly call relay concatenate. + return relay.concatenate(tuple(requantized_exprs), axis) diff --git a/tests/python/relay/test_qnn_concatenate.py b/tests/python/relay/test_qnn_concatenate.py new file mode 100644 index 0000000000000..febc12cde3d90 --- /dev/null +++ b/tests/python/relay/test_qnn_concatenate.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +import numpy as np +from tvm import relay +from tvm.contrib import graph_runtime +import topi.testing + +def test_qnn_concatenate(): + data_dtype = 'int32' + axis = 0 + x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype) + y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype) + x_scale = (62 + 64) / (np.power(2, 32) - 1.0) + y_scale = (62 + 64) / (np.power(2, 32) - 1.0) + + x = relay.var("x", shape=(1, 64), dtype=data_dtype) + y = relay.var("y", shape=(1, 64), dtype=data_dtype) + z = relay.qnn.op.concatenate((x, y), + input_scales=[x_scale, y_scale], + input_zero_points=[0, 0], + output_scale=y_scale, + output_zero_point=1, + output_dtype=data_dtype, + axis=axis) + + func = relay.Function([x, y], z) + mod = relay.Module.from_expr(func) + mod = relay.transform.Legalize()(mod) + func = mod["main"] + + golden_output = np.concatenate((x_data, y_data), axis=axis) + golden_output = np.add(1, golden_output) + + intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm") + op_res = intrp.evaluate(func)(x_data, y_data) + np.testing.assert_equal(op_res.asnumpy(), golden_output) + +if __name__ == '__main__': + test_qnn_concatenate()