From 49870155f65fb5f0594c1a56df348ba998c00f7b Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 12 Jul 2021 20:50:46 +0000 Subject: [PATCH 1/4] [CUDA] Improve injective schedule to enable half2 --- python/tvm/topi/cuda/injective.py | 36 ++++++++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/python/tvm/topi/cuda/injective.py b/python/tvm/topi/cuda/injective.py index cce56b796cea..4bd63ab7b1ae 100644 --- a/python/tvm/topi/cuda/injective.py +++ b/python/tvm/topi/cuda/injective.py @@ -16,6 +16,8 @@ # under the License. # pylint: disable=invalid-name, unused-variable, """Schedule for composition of injective operator""" +import numpy as np + import tvm from tvm import te from .. import utils @@ -36,13 +38,21 @@ def schedule_injective_from_existing(sch, out): sch: Schedule The updated schedule. """ + + def find_nearest_small_factor(num, target): + """Find the nearest factor of the given number that is smaller than the target.""" + for i in range(target, 0, -1): + if num % i == 0: + return i + # Unreachable because i=1 must hold. + return -1 + fused = sch[out].fuse(*sch[out].op.axis) num_thread = tvm.target.Target.current(allow_none=False).max_num_threads max_block = 256 - # vectorize on fp16 data type. This allows to better utilize the memory - # bandwidth. - vector_width = 4 if out.dtype == "float16" else 1 + # Vectorize on fp16 data type to enable half2 for better memory bandwidth utilization. + vector_width = 2 if out.dtype == "float16" else 1 is_dynamic_output = False for dim in out.shape: @@ -54,6 +64,26 @@ def schedule_injective_from_existing(sch, out): try: const_size = utils.get_const_int(out_len) + + # Adjust block and thread to make sure they are dividable so that vectorize can be + # correctly applied. + if vector_width > 1 and const_size % vector_width == 0: + remain_size = const_size // vector_width + cand_sizes = [] + for curr_size in [num_thread, max_block]: + cand_sizes.append( + curr_size + if remain_size % curr_size == 0 + else find_nearest_small_factor(remain_size, curr_size) + ) + remain_size //= cand_sizes[-1] + + # If the product of candidate dividable (block * thread) is too small, + # then the performance may be worse even half2 is enabled. Note that 0.7 + # is just a heuristic ratio and may not be optimal for all workloads. + if np.prod(cand_sizes) / (max_block * num_thread) >= 0.7: + max_block, num_thread = cand_sizes + need_block_split = const_size > max_block * num_thread * vector_width except ValueError: need_block_split = False From 69bcfd1672e7f519b339fba39e7b5535e91a3586 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 12 Jul 2021 21:39:05 +0000 Subject: [PATCH 2/4] lint --- python/tvm/topi/cuda/injective.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/tvm/topi/cuda/injective.py b/python/tvm/topi/cuda/injective.py index 4bd63ab7b1ae..d2bec84f5089 100644 --- a/python/tvm/topi/cuda/injective.py +++ b/python/tvm/topi/cuda/injective.py @@ -68,15 +68,15 @@ def find_nearest_small_factor(num, target): # Adjust block and thread to make sure they are dividable so that vectorize can be # correctly applied. if vector_width > 1 and const_size % vector_width == 0: - remain_size = const_size // vector_width + remain_total_size = const_size // vector_width cand_sizes = [] - for curr_size in [num_thread, max_block]: + for max_size in [num_thread, max_block]: cand_sizes.append( - curr_size - if remain_size % curr_size == 0 - else find_nearest_small_factor(remain_size, curr_size) + max_size + if remain_total_size % max_size == 0 + else find_nearest_small_factor(remain_total_size, max_size) ) - remain_size //= cand_sizes[-1] + remain_total_size //= cand_sizes[-1] # If the product of candidate dividable (block * thread) is too small, # then the performance may be worse even half2 is enabled. Note that 0.7 From f9cbf9f7ee6bf1e456e5a7ae4bab5ca9cda3a66e Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 13 Jul 2021 00:25:38 +0000 Subject: [PATCH 3/4] fix --- python/tvm/topi/cuda/injective.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/topi/cuda/injective.py b/python/tvm/topi/cuda/injective.py index d2bec84f5089..0faddc31c25a 100644 --- a/python/tvm/topi/cuda/injective.py +++ b/python/tvm/topi/cuda/injective.py @@ -82,7 +82,7 @@ def find_nearest_small_factor(num, target): # then the performance may be worse even half2 is enabled. Note that 0.7 # is just a heuristic ratio and may not be optimal for all workloads. if np.prod(cand_sizes) / (max_block * num_thread) >= 0.7: - max_block, num_thread = cand_sizes + num_thread, max_block = cand_sizes need_block_split = const_size > max_block * num_thread * vector_width except ValueError: From d5409c83871580bc877c0191eb3c2ba15829e6c9 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 13 Jul 2021 16:36:01 +0000 Subject: [PATCH 4/4] trigger ci