Skip to content

Commit

Permalink
[QNN] Concatenate operator
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed Aug 8, 2019
1 parent 3ac27fc commit 43be829
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 0 deletions.
53 changes: 53 additions & 0 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from __future__ import absolute_import as _abs
from . import _make
from tvm import relay

def requantize(data,
input_scale,
Expand Down Expand Up @@ -72,3 +73,55 @@ def requantize(data,
output_zero_point,
rounding,
out_dtype)

def concatenate(data,
input_scales,
input_zero_points,
output_scale,
output_zero_point,
output_dtype,
axis):
"""Concatenate the quantized input tensors along the given axis.
Parameters
----------
data : Union(List[relay.Expr], Tuple[relay.Expr])
A list of quantized tensors
input_scales : List[float32]
A list of scales of quantized tensors
input_zero_points : List[int32]
A list of zero points of quantized tensors
output_scale : float32
A scales of output
output_zero_point : int32
A zero points of output
axis : int
The axis along which the tensors are concatenated.
Returns
-------
result: relay.Expr
The concatenated tensor
"""

data = list(data)
requantized_exprs = list(data)
# If the output qnn params do not match the input qnn params, we call requantize on the input
# params.
for idx, quantized_expr in enumerate(data):
scale = input_scales[idx]
zero_point = input_zero_points[idx]
if scale != output_scale or zero_point != output_zero_point:
requantized_exprs[idx] = requantize(quantized_expr,
input_scale=scale,
input_zero_point=zero_point,
output_scale=output_scale,
output_zero_point=output_zero_point,
out_dtype=output_dtype)
# As all tensors now share same qnn params, we can directly call relay concatenate.
return relay.concatenate(tuple(requantized_exprs), axis)
55 changes: 55 additions & 0 deletions tests/python/relay/test_qnn_concatenate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import tvm
import numpy as np
from tvm import relay
from tvm.contrib import graph_runtime
import topi.testing

def test_qnn_concatenate():
data_dtype = 'int32'
axis = 0
x_data = np.arange(-32, 32, 1).reshape(1, 64).astype(data_dtype)
y_data = np.arange(-64, 64, 2).reshape(1, 64).astype(data_dtype)
x_scale = (62 + 64) / (np.power(2, 32) - 1.0)
y_scale = (62 + 64) / (np.power(2, 32) - 1.0)

x = relay.var("x", shape=(1, 64), dtype=data_dtype)
y = relay.var("y", shape=(1, 64), dtype=data_dtype)
z = relay.qnn.op.concatenate((x, y),
input_scales=[x_scale, y_scale],
input_zero_points=[0, 0],
output_scale=y_scale,
output_zero_point=1,
output_dtype=data_dtype,
axis=axis)

func = relay.Function([x, y], z)
mod = relay.Module.from_expr(func)
mod = relay.transform.Legalize()(mod)
func = mod["main"]

golden_output = np.concatenate((x_data, y_data), axis=axis)
golden_output = np.add(1, golden_output)

intrp = relay.create_executor("graph", ctx=tvm.cpu(0), target="llvm")
op_res = intrp.evaluate(func)(x_data, y_data)
np.testing.assert_equal(op_res.asnumpy(), golden_output)

if __name__ == '__main__':
test_qnn_concatenate()

0 comments on commit 43be829

Please sign in to comment.