Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] Improve injective schedule to enable half2 #8457

Merged
merged 4 commits into from
Jul 14, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions python/tvm/topi/cuda/injective.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# under the License.
# pylint: disable=invalid-name, unused-variable,
"""Schedule for composition of injective operator"""
import numpy as np

import tvm
from tvm import te
from .. import utils
Expand All @@ -36,13 +38,21 @@ def schedule_injective_from_existing(sch, out):
sch: Schedule
The updated schedule.
"""

def find_nearest_small_factor(num, target):
"""Find the nearest factor of the given number that is smaller than the target."""
for i in range(target, 0, -1):
if num % i == 0:
return i
# Unreachable because i=1 must hold.
return -1

fused = sch[out].fuse(*sch[out].op.axis)
num_thread = tvm.target.Target.current(allow_none=False).max_num_threads
max_block = 256

# vectorize on fp16 data type. This allows to better utilize the memory
# bandwidth.
vector_width = 4 if out.dtype == "float16" else 1
# Vectorize on fp16 data type to enable half2 for better memory bandwidth utilization.
vector_width = 2 if out.dtype == "float16" else 1

is_dynamic_output = False
for dim in out.shape:
Expand All @@ -54,6 +64,26 @@ def schedule_injective_from_existing(sch, out):

try:
const_size = utils.get_const_int(out_len)

# Adjust block and thread to make sure they are dividable so that vectorize can be
# correctly applied.
if vector_width > 1 and const_size % vector_width == 0:
remain_total_size = const_size // vector_width
cand_sizes = []
for max_size in [num_thread, max_block]:
cand_sizes.append(
max_size
if remain_total_size % max_size == 0
else find_nearest_small_factor(remain_total_size, max_size)
)
remain_total_size //= cand_sizes[-1]

# If the product of candidate dividable (block * thread) is too small,
# then the performance may be worse even half2 is enabled. Note that 0.7
# is just a heuristic ratio and may not be optimal for all workloads.
if np.prod(cand_sizes) / (max_block * num_thread) >= 0.7:
num_thread, max_block = cand_sizes

need_block_split = const_size > max_block * num_thread * vector_width
except ValueError:
need_block_split = False
Expand Down