Skip to content

Commit

Permalink
Reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
trahman-quic1 committed Aug 25, 2022
1 parent a2a9e60 commit 0eb2d8b
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 10 deletions.
2 changes: 1 addition & 1 deletion python/tvm/topi/hexagon/qnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@
""" Computes and schedules for Hexagon quantized ops """

from .avg_pool2d import qnn_avg_pool2d_compute, qnn_avg_pool2d_schedule
from .qadd_qsub_qmul import *
from .qadd_qsub_qmul import *
12 changes: 5 additions & 7 deletions python/tvm/topi/hexagon/qnn/qadd_qsub_qmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from tvm import te
from tvm import tir
from ..utils import get_layout_transform_fn, get_fixed_point_value
import tvm


def broadcast_axis(tensor_A, tensor_B):
Expand All @@ -51,7 +50,7 @@ def saturate(x: te.Tensor, dtype: str):
return te.max(te.min_value(dtype), te.min(x, te.max_value(dtype)))


def get_int_scale(scale_A, scale_B, scale_M, zero_point_A, zero_point_B, zero_point_M, op, dtype):
def get_int_scale(scale_A, scale_B, scale_M, zero_point_A, zero_point_B, zero_point_M, op):
"""Get fixed-point number"""
C_recip = 1 / scale_M

Expand Down Expand Up @@ -104,7 +103,7 @@ def qadd_broadcast_compute(
n_b, h_b, w_b, c_b = B_broadcast

scale_a, scale_b, rsh, corr = get_int_scale(
scale_A, scale_B, scale_M, zero_point_A, zero_point_B, zero_point_M, "qadd", "int16"
scale_A, scale_B, scale_M, zero_point_A, zero_point_B, zero_point_M, "qadd"
)

return te.compute(
Expand Down Expand Up @@ -141,7 +140,7 @@ def qsubtract_broadcast_compute(
n_b, h_b, w_b, c_b = B_broadcast

scale_a, scale_b, rsh, corr = get_int_scale(
scale_A, scale_B, scale_M, zero_point_A, zero_point_B, zero_point_M, "qsub", "int16"
scale_A, scale_B, scale_M, zero_point_A, zero_point_B, zero_point_M, "qsub"
)

return te.compute(
Expand Down Expand Up @@ -178,7 +177,7 @@ def qmultiply_broadcast_compute(
n_b, h_b, w_b, c_b = B_broadcast

scale_int, rsh, corr = get_int_scale(
scale_A, scale_B, scale_M, zero_point_A, zero_point_B, zero_point_M, "qmul", "int16"
scale_A, scale_B, scale_M, zero_point_A, zero_point_B, zero_point_M, "qmul"
)

return te.compute(
Expand All @@ -205,7 +204,6 @@ def tir_schedule_quant(
output_layout: str,
tensor_A_layout: str,
tensor_B_layout: str,
op_name: str,
):
"""Schedule for output layout nhwc-8h8w32c-2d"""
func = te.create_prim_func([tensor_A, tensor_B, out_M])
Expand Down Expand Up @@ -234,4 +232,4 @@ def tir_schedule_quant(

s.reorder(n, h_o, w_o, c_o, h_i, wio, wii, c_i)

return s
return s
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def input_np_B(input_shape_B, dtype):
def quantize_input_np_A(input_np_A, dtype):
if dtype == "uint8" or dtype == "int8":
global zero_point_A_val, scale_A_val
input_np_A_quantized, scale_A_val, zero_point_A_val= quantize_np(input_np_A, dtype)
input_np_A_quantized, scale_A_val, zero_point_A_val = quantize_np(input_np_A, dtype)
return input_np_A_quantized


Expand Down Expand Up @@ -390,4 +390,4 @@ def test_transform(


if __name__ == "__main__":
tvm.testing.main()
tvm.testing.main()

0 comments on commit 0eb2d8b

Please sign in to comment.