Skip to content

Commit

Permalink
Make batch matrix multiplication on GPU tunable (apache#5752)
Browse files Browse the repository at this point in the history
This is primarily aimed at the AMD GPU backend and done as part
of a project for AMD, but should work for all users of the GPU
schedule.
  • Loading branch information
t-vi authored and Trevor Morris committed Jun 12, 2020
1 parent 56dfec2 commit 629ad8d
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 26 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
"""batch_matmul cuda strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_batch_matmul(topi.nn.batch_matmul),
wrap_compute_batch_matmul(topi.cuda.batch_matmul),
wrap_topi_schedule(topi.cuda.schedule_batch_matmul),
name="batch_matmul.cuda",
plevel=10)
Expand Down
80 changes: 56 additions & 24 deletions topi/python/topi/cuda/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,24 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,too-many-locals,unused-variable
# pylint: disable=invalid-name,too-many-locals,unused-variable,unused-argument
"""cuda batch_matmul operators"""
import tvm
from tvm import autotvm
from tvm import te
from tvm.contrib import cublas
from tvm.autotvm.task.space import SplitEntity, OtherOptionEntity
from .. import nn
from ..util import traverse_inline, get_const_tuple, get_max_power2_factor

def schedule_batch_matmul(outs):
@autotvm.register_topi_compute("batch_matmul.cuda")
def batch_matmul(cfg, x, y):
"""Compute conv2d with NCHW layout"""
return nn.batch_matmul(x, y)


@autotvm.register_topi_schedule("batch_matmul.cuda")
def schedule_batch_matmul(cfg, outs):
"""Schedule for batch_matmul
Parameters
Expand All @@ -37,7 +48,7 @@ def schedule_batch_matmul(outs):
outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
s = te.create_schedule([x.op for x in outs])

def _schedule(op):
def _schedule(cfg, op):
C = op.output(0)
A, B = s[C].op.input_tensors
_, M, N = get_const_tuple(C.shape)
Expand All @@ -51,55 +62,76 @@ def _schedule(op):
C = s.outputs[0].output(0)

b, y, x = s[C].op.axis
y_bn = get_max_power2_factor(M, 64)
x_bn = get_max_power2_factor(N, 64)
by, y = s[C].split(y, y_bn)
bx, x = s[C].split(x, x_bn)
y_nthreads = min(y_bn, 8)
x_nthreads = min(x_bn, 8)
ty, yi = s[C].split(y, nparts=y_nthreads)
tx, xi = s[C].split(x, nparts=x_nthreads)
thread_x = te.thread_axis((0, x_nthreads), "threadIdx.x")
thread_y = te.thread_axis((0, y_nthreads), "threadIdx.y")
k, = s[CC].op.reduce_axis

cfg.define_split("tile_y", y, num_outputs=3)
cfg.define_split("tile_x", x, num_outputs=3)
cfg.define_split("tile_k", k, num_outputs=2)
cfg.define_knob("auto_unroll_max_step", [8, 16, 32, 64])
target = tvm.target.Target.current()
if target.target_name in ['nvptx', 'rocm']:
# llvm-based backends cannot do non-explicit unrolling
cfg.define_knob("unroll_explicit", [1])
else:
cfg.define_knob("unroll_explicit", [0, 1])

if cfg.is_fallback:
y_bn = get_max_power2_factor(M, 64)
x_bn = get_max_power2_factor(N, 64)
y_nthreads = min(y_bn, 8)
x_nthreads = min(x_bn, 8)
cfg['tile_x'] = SplitEntity([-1, x_nthreads, x_bn // x_nthreads])
cfg['tile_y'] = SplitEntity([-1, y_nthreads, y_bn // y_nthreads])
cfg['tile_k'] = SplitEntity([-1, 8])
cfg['auto_unroll_max_step'] = OtherOptionEntity(16)

by, ty, yi = cfg["tile_y"].apply(s, C, y)
bx, tx, xi = cfg["tile_x"].apply(s, C, x)

thread_x = te.thread_axis("threadIdx.x")
thread_y = te.thread_axis("threadIdx.y")

s[C].reorder(b, by, bx, ty, tx, yi, xi)
s[C].bind(b, te.thread_axis("blockIdx.z"))
s[C].bind(by, te.thread_axis("blockIdx.y"))
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(ty, thread_y)
s[C].bind(tx, thread_x)
s[C].pragma(yi, "auto_unroll_max_step", 16)
s[C].pragma(yi, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val)
s[C].pragma(yi, 'unroll_explicit', cfg['unroll_explicit'].val)

s[CC].compute_at(s[C], tx)
_, yi, xi = s[CC].op.axis
k, = s[CC].op.reduce_axis
ko, ki = s[CC].split(k, 8)
ko, ki = cfg["tile_k"].apply(s, CC, k)
s[CC].reorder(ko, ki, yi, xi)
s[CC].pragma(ki, "auto_unroll_max_step", 16)
s[CC].pragma(ki, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val)
s[CC].pragma(ki, 'unroll_explicit', cfg['unroll_explicit'].val)

s[AA].compute_at(s[CC], ko)
s[AL].compute_at(s[CC], ki)
s[BB].compute_at(s[CC], ko)
s[BL].compute_at(s[CC], ki)
_, y, k = s[AA].op.axis
ty, yi = s[AA].split(y, nparts=y_nthreads)
tx, ki = s[AA].split(k, nparts=x_nthreads)
ty, yi = s[AA].split(y, nparts=cfg["tile_y"].size[1])
tx, ki = s[AA].split(k, nparts=cfg["tile_x"].size[1])
s[AA].reorder(ty, tx, yi, ki)
s[AA].bind(ty, thread_y)
s[AA].bind(tx, thread_x)
s[AA].pragma(yi, "auto_unroll_max_step", 16)
s[AA].pragma(yi, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val)
s[AA].pragma(yi, 'unroll_explicit', cfg['unroll_explicit'].val)

_, x, k = s[BB].op.axis
ty, xi = s[BB].split(x, nparts=y_nthreads)
tx, ki = s[BB].split(k, nparts=x_nthreads)
ty, xi = s[BB].split(x, nparts=cfg["tile_y"].size[1])
tx, ki = s[BB].split(k, nparts=cfg["tile_x"].size[1])
s[BB].bind(ty, thread_y)
s[BB].bind(tx, thread_x)
s[BB].reorder(ty, tx, xi, ki)
s[BB].pragma(xi, "auto_unroll_max_step", 16)
s[BB].pragma(xi, "auto_unroll_max_step", cfg['auto_unroll_max_step'].val)
s[BB].pragma(xi, 'unroll_explicit', cfg['unroll_explicit'].val)

def _callback(op):
if "batch_matmul" in op.tag:
_schedule(op)
_schedule(cfg, op)

traverse_inline(s, outs[0].op, _callback)
return s
Expand Down
2 changes: 1 addition & 1 deletion topi/tests/python/test_topi_batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
_batch_matmul_implement = {
"generic": (topi.nn.batch_matmul, topi.generic.schedule_batch_matmul),
"cpu": (topi.x86.batch_matmul, topi.x86.schedule_batch_matmul),
"gpu": (topi.nn.batch_matmul, topi.cuda.schedule_batch_matmul),
"gpu": (topi.cuda.batch_matmul, topi.cuda.schedule_batch_matmul),
}

def verify_batch_matmul(batch, M, N, K):
Expand Down

0 comments on commit 629ad8d

Please sign in to comment.