Skip to content

Commit

Permalink
tanh float16 (apache#12165)
Browse files Browse the repository at this point in the history
Co-authored-by: aakaverm <aakaverm@qti.qualcomm.com>
  • Loading branch information
2 people authored and Mikael Sevenier committed Jul 26, 2022
1 parent 0a7b682 commit 754d40a
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 0 deletions.
1 change: 1 addition & 0 deletions python/tvm/topi/hexagon/slice_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@
from .conv2d import *
from .reshape import reshape_compute, reshape_stir_schedule
from .relu import relu_compute, relu_stir_schedule
from .tanh import tanh_te_compute, tanhf16_schedule
56 changes: 56 additions & 0 deletions python/tvm/topi/hexagon/slice_ops/tanh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name

""" Hexagon tanh slice op compute and schedule """
import tvm
from tvm import te, tir
from ..utils import get_layout_transform_fn


def tanh_te_compute(in_tensor):
out_tensor = te.compute(
in_tensor.shape, lambda n, h, w, c: tvm.tir.tanh(in_tensor[n, h, w, c]), name="tanhf16"
)
return out_tensor


def tanhf16_stir_sched_nhwc(func, in_layout, out_layout, h_split_factor=8):
"""Schedule for nhwc fp16 to nchw fp16 layout"""
sch = tir.Schedule(func, debug_mask="all")
block_name = "tanhf16"
n, h, w, c = sch.get_loops(sch.get_block(block_name))
h_outer, h_inner = sch.split(h, [None, h_split_factor])
w_outer, w_inner = sch.split(w, [None, 4])
c_outer, c_inner = sch.split(c, [None, 32])
w_inner_o, w_inner_i = sch.split(w_inner, [None, 2])
sch.reorder(n, h_outer, w_outer, c_outer, h_inner, w_inner_o, c_inner, w_inner_i)
sch.transform_layout(block_name, "A", in_layout)
sch.transform_layout(block_name, block_name, out_layout)
fused = sch.fuse(c_inner, w_inner_i)
sch.vectorize(fused)
return sch


def tanhf16_schedule(tanh_func, in_layout_str, out_layout_str):
in_layout_transform_func = get_layout_transform_fn(in_layout_str)
out_layout_transform_func = get_layout_transform_fn(out_layout_str)
return tanhf16_stir_sched_nhwc(
tanh_func,
in_layout_transform_func,
out_layout_transform_func,
)
109 changes: 109 additions & 0 deletions tests/python/contrib/test_hexagon/topi/test_tanh_slice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" Test for Hexagon slice tanh op """
import numpy as np
import pytest

import tvm
import tvm.testing
from tvm import te
import tvm.topi.hexagon.slice_ops as sl
import tvm.contrib.hexagon
from ..infrastructure import allocate_hexagon_array, transform_numpy

# pylint: disable=invalid-name


class TestTanhSlice:
"""For Testing Tanh fp16 op"""

input_shape, orig_layout, input_layout, output_layout, axis_sep = tvm.testing.parameters(
((1, 8, 4, 32), "nhwc", "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", [4]),
((1, 16, 12, 64), "nhwc", "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", [4]),
((1, 64, 64, 32), "nhwc", "nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-2d", [4]),
)
dtype = tvm.testing.parameter("float16")
working_scope = tvm.testing.parameter("global.vtcm")

@tvm.testing.fixture
def input_np(self, input_shape, dtype):
return np.random.uniform(size=input_shape).astype(dtype)

@tvm.testing.fixture
def transformed_input_np(self, input_np, orig_layout, input_layout):
return transform_numpy(input_np, orig_layout, input_layout)

@tvm.testing.fixture
def expected_output_np(self, input_np):
ref_np = np.tanh(input_np)
return ref_np

@tvm.testing.fixture
def transformed_expected_output_np(self, expected_output_np, orig_layout, output_layout):
return transform_numpy(expected_output_np, orig_layout, output_layout)

@tvm.testing.requires_hexagon
def test_tanh(
self,
input_shape,
dtype,
input_layout,
output_layout,
transformed_input_np,
transformed_expected_output_np,
axis_sep,
hexagon_session,
working_scope,
):
"""Top Level testing function for tanh fp16 op"""

target_hexagon = tvm.target.hexagon("v69")
target = tvm.target.Target(target_hexagon, host=target_hexagon)
A = te.placeholder(input_shape, name="A", dtype=dtype)
M = sl.tanh_te_compute(A)
tanhf16_func = te.create_prim_func([A, M])
tir_s = sl.tanhf16_schedule(tanhf16_func, input_layout, output_layout)
A_data = allocate_hexagon_array(
hexagon_session.device,
data=transformed_input_np,
axis_separators=axis_sep,
mem_scope=working_scope,
)
M_data = allocate_hexagon_array(
hexagon_session.device,
tensor_shape=transformed_expected_output_np.shape,
dtype=transformed_expected_output_np.dtype,
axis_separators=axis_sep,
mem_scope=working_scope,
)
with tvm.transform.PassContext(opt_level=3):
tir_irm = tvm.lower(tir_s.mod, [A, M], name="tanhf16")
runtime_module = tvm.build(tir_irm, target=target, name="tanhf16")
mod = hexagon_session.load_module(runtime_module)

mod(A_data, M_data)
output_np = M_data.numpy()
tvm.testing.assert_allclose(
output_np,
transformed_expected_output_np,
1e-3,
1e-3,
)


if __name__ == "__main__":
sys.exit(pytest.main(sys.argv))

0 comments on commit 754d40a

Please sign in to comment.