Skip to content

Commit

Permalink
[CUDA] dense_tensorcore/batch_matmul_tensorcore support int8/int4 (ap…
Browse files Browse the repository at this point in the history
…ache#8402)

* add int8/int tensorcore for dense/batch_matmul

* fix bug

* fix lint

* Apply suggestions from code review

Co-authored-by: Chenfan <jcf94@outlook.com>

* fix for reviewer

* fix lint

Co-authored-by: Chenfan <jcf94@outlook.com>
  • Loading branch information
2 people authored and ylc committed Jan 13, 2022
1 parent 7fcbc6a commit 21838ff
Show file tree
Hide file tree
Showing 9 changed files with 392 additions and 201 deletions.
15 changes: 9 additions & 6 deletions python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -844,13 +844,16 @@ def batch_matmul_strategy_cuda(attrs, inputs, out_type, target):
x, y = inputs
_, M, K = get_const_tuple(x.shape)
_, N, K = get_const_tuple(y.shape)
if x.dtype in ["float16", "int8", "uint8"] and (
(M % 8 == 0 and K % 16 == 0 and N % 32 == 0)
or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0)
or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)
):
if (
x.dtype in ["float16", "int8", "uint8"]
and (
(M % 8 == 0 and K % 16 == 0 and N % 32 == 0)
or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0)
or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)
)
) or (x.dtype in ["int4", "uint4"] and K % 32 == 0 and M % 8 == 0 and N % 8 == 0):
strategy.add_implementation(
wrap_compute_batch_matmul(topi.cuda.batch_matmul_tensorcore),
wrap_compute_batch_matmul(topi.cuda.batch_matmul_tensorcore, need_out_dtype=True),
wrap_topi_schedule(topi.cuda.schedule_batch_matmul_tensorcore),
name="batch_matmul_tensorcore.cuda",
plevel=20,
Expand Down
83 changes: 44 additions & 39 deletions python/tvm/topi/cuda/batch_matmul_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@


@autotvm.register_topi_compute("batch_matmul_tensorcore.cuda")
def batch_matmul_tensorcore(cfg, x, y, out_shape=None):
def batch_matmul_tensorcore(cfg, x, y, out_shape=None, out_dtype=None):
"""batch matmul tensorcore operator on cuda"""
# todo: deal with out_shape for broadcast, liuxin.ai
return batch_matmul_tensorcore_cuda(x, y)
return batch_matmul_tensorcore_cuda(x, y, out_dtype)


@autotvm.register_topi_schedule("batch_matmul_tensorcore.cuda")
Expand All @@ -57,10 +57,8 @@ def _schedule(cfg, s, C):
A, B = s[C].op.input_tensors
batch, m_dim, k_dim = get_const_tuple(A.shape)
batch, n_dim, k_dim = get_const_tuple(B.shape)
data_dtype = A.dtype
out_dtype = C.dtype
# inline astype fp16
s[A].compute_inline()
s[B].compute_inline()

# Explicit memory access
AS = s.cache_read(A, "shared", [C])
Expand Down Expand Up @@ -94,32 +92,37 @@ def _schedule(cfg, s, C):
cfg.define_knob("vec", [1, 2, 4, 8])

# Ensure that the default parameters are applicable when autotvm is not in use
if m_dim % 32 == 0 and n_dim % 8 == 0:
cfg.define_knob("wmma_m", [32, 16, 8])
elif m_dim % 16 == 0 and n_dim % 16 == 0:
cfg.define_knob("wmma_m", [16, 8, 32])
elif m_dim % 8 == 0 and n_dim % 32 == 0:
cfg.define_knob("wmma_m", [8, 16, 32])
if data_dtype in ["float16", "uint8", "int8"]:
if m_dim % 32 == 0 and n_dim % 8 == 0:
cfg.define_knob("wmma_m", [32, 16, 8])
elif m_dim % 16 == 0 and n_dim % 16 == 0:
cfg.define_knob("wmma_m", [16, 8, 32])
elif m_dim % 8 == 0 and n_dim % 32 == 0:
cfg.define_knob("wmma_m", [8, 16, 32])
wmma_k = 16
wmma_m = cfg["wmma_m"].val
if wmma_m == 16:
wmma_n = 16
elif wmma_m == 8:
wmma_n = 32
elif wmma_m == 32:
wmma_n = 8
elif data_dtype in ["int4", "uint4"]:
wmma_m = wmma_n = 8
wmma_k = 32
else:
raise ValueError("data dtype %s is not yet supported" % data_dtype)

warp_size = 32
wmma_k = 16
block_row_warps = cfg["block_row_warps"].val
block_col_warps = cfg["block_col_warps"].val
warp_row_tiles = cfg["warp_row_tiles"].val
warp_col_tiles = cfg["warp_col_tiles"].val
chunk = cfg["chunk"].val
offset = cfg["offset"].val
offsetCS = cfg["offsetCS"].val
wmma_m = cfg["wmma_m"].val
vec = cfg["vec"].val

if wmma_m == 16:
wmma_n = 16
elif wmma_m == 8:
wmma_n = 32
elif wmma_m == 32:
wmma_n = 8

# Define the stride of intrin functions
AS_align = chunk * wmma_k + offset
BS_align = chunk * wmma_k + offset
Expand Down Expand Up @@ -211,10 +214,8 @@ def shared_shedule(stage, strides):
shared_shedule(BS, BS_align)

shape = (wmma_m, wmma_n, wmma_k)
# TODO: add checking here, datatype casting may cause precision loss
in_dtype = "float16"
AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=in_dtype)
BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=in_dtype)
AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=data_dtype)
BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=data_dtype)
k_gemm = te.reduce_axis((0, wmma_k), name="k_gemm")
CL_compute = te.compute(
(wmma_m, wmma_n),
Expand All @@ -236,7 +237,7 @@ def shared_shedule(stage, strides):
"row_major",
(wmma_m, wmma_k),
(wmma_m, wmma_k),
"float16",
data_dtype,
),
)
s[BF].tensorize(
Expand All @@ -248,7 +249,7 @@ def shared_shedule(stage, strides):
"col_major",
(wmma_n, wmma_k),
(wmma_n, wmma_k),
"float16",
data_dtype,
),
)
s[CF].tensorize(
Expand All @@ -270,7 +271,7 @@ def _callback(op):
return s


def batch_matmul_tensorcore_cuda(x, y):
def batch_matmul_tensorcore_cuda(x, y, out_dtype=None):
"""Computes batch matrix multiplication of `x` and `y` when `x` and `y` are
data in batch.
Expand All @@ -294,22 +295,26 @@ def batch_matmul_tensorcore_cuda(x, y):
assert x_shape[2] == y_shape[2], "shapes of x and y is inconsistent"
batch, M, K = x.shape
N = y.shape[1]
out_dtype = x.dtype

assert (
(M % 8 == 0 and K % 16 == 0 and N % 32 == 0)
or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0)
or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)
), "The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32)"

x_16 = te.compute((batch, M, K), lambda b, i, k: x[b, i, k].astype("float16"))
y_16 = te.compute((batch, N, K), lambda b, j, k: y[b, j, k].astype("float16"))
if out_dtype is None:
out_dtype = x.dtype

assert x.dtype == y.dtype
assert x.dtype in ["float16", "uint8", "int8", "uint4", "int4"]
if x.dtype in ["float16", "uint8", "int8"]:
assert (
(M % 8 == 0 and K % 16 == 0 and N % 32 == 0)
or (M % 16 == 0 and K % 16 == 0 and N % 16 == 0)
or (M % 32 == 0 and K % 16 == 0 and N % 8 == 0)
), "The shape of (M, K, N) must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32)"
else:
assert (
M % 8 == 0 and K % 32 == 0 and N % 8 == 0
), "The shape of (M, K, N) must be multiple of (8, 32, 8)"

k = te.reduce_axis((0, K), name="k")
return te.compute(
(batch, M, N),
lambda b, i, j: te.sum(
x_16[b, i, k].astype(out_dtype) * y_16[b, j, k].astype(out_dtype), axis=k
),
lambda b, i, j: te.sum(x[b, i, k].astype(out_dtype) * y[b, j, k].astype(out_dtype), axis=k),
tag="batch_matmul_tensorcore",
)
81 changes: 44 additions & 37 deletions python/tvm/topi/cuda/dense_tensorcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,22 +60,27 @@ def dense_tensorcore_cuda(data, weight, bias=None, out_dtype=None):
out_dtype = data.dtype
batch, in_dim = get_const_tuple(data.shape)
out_dim, _ = get_const_tuple(weight.shape)
assert (
(batch % 8 == 0 and in_dim % 16 == 0 and out_dim % 32 == 0)
or (batch % 16 == 0 and in_dim % 16 == 0 and out_dim % 16 == 0)
or (batch % 32 == 0 and in_dim % 16 == 0 and out_dim % 8 == 0)
), (
"The shape of (batch, in_dim, out_dim) "
"must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now"
)

assert data.dtype == weight.dtype
assert data.dtype in ["float16", "int8", "uint8", "int4", "uint4"]
if data.dtype in ["float16", "int8", "uint8"]:
assert (
(batch % 8 == 0 and in_dim % 16 == 0 and out_dim % 32 == 0)
or (batch % 16 == 0 and in_dim % 16 == 0 and out_dim % 16 == 0)
or (batch % 32 == 0 and in_dim % 16 == 0 and out_dim % 8 == 0)
), (
"The shape of (batch, in_dim, out_dim) "
"must be multiple of (16, 16, 16) or (32, 16, 8) or (8, 16, 32) for now"
)
else:
assert (
batch % 8 == 0 and in_dim % 32 == 0 and out_dim % 8 == 0
), "The shape of (batch, in_dim, out_dim) must be multiple of (8, 32, 8)"

k = te.reduce_axis((0, in_dim), name="k")
data_16 = te.compute((batch, in_dim), lambda b, i: data[b, i].astype("float16"))
weight_16 = te.compute((out_dim, in_dim), lambda o, i: weight[o, i].astype("float16"))
matmul = te.compute(
(batch, out_dim),
lambda i, j: te.sum(
data_16[i, k].astype(out_dtype) * weight_16[j, k].astype(out_dtype), axis=k
),
lambda i, j: te.sum(data[i, k].astype(out_dtype) * weight[j, k].astype(out_dtype), axis=k),
name="T_dense",
tag="dense_tensorcore",
)
Expand All @@ -92,9 +97,8 @@ def _schedule_dense_tensorcore(cfg, s, C):
"""Schedule dense operator using Tensorcore"""
A, B = s[C].op.input_tensors
batch, out_dim = get_const_tuple(C.shape)
data_dtype = A.dtype
out_dtype = C.dtype
s[A].compute_inline()
s[B].compute_inline()

# Explicit memory access
AS = s.cache_read(A, "shared", [C])
Expand Down Expand Up @@ -127,33 +131,38 @@ def _schedule_dense_tensorcore(cfg, s, C):
cfg.define_knob("offsetCS", [0, 8])
cfg.define_knob("vec", [1, 2, 4, 8])

# Ensure that the default parameters are applicable when autotvm is not in use
if batch % 32 == 0 and out_dim % 8 == 0:
cfg.define_knob("wmma_m", [32, 16, 8])
elif batch % 16 == 0 and out_dim % 16 == 0:
cfg.define_knob("wmma_m", [16, 8, 32])
elif batch % 8 == 0 and out_dim % 32 == 0:
cfg.define_knob("wmma_m", [8, 16, 32])
if data_dtype in ["float16", "int8", "uint8"]:
# Ensure that the default parameters are applicable when autotvm is not in use
if batch % 32 == 0 and out_dim % 8 == 0:
cfg.define_knob("wmma_m", [32, 16, 8])
elif batch % 16 == 0 and out_dim % 16 == 0:
cfg.define_knob("wmma_m", [16, 8, 32])
elif batch % 8 == 0 and out_dim % 32 == 0:
cfg.define_knob("wmma_m", [8, 16, 32])
wmma_k = 16
wmma_m = cfg["wmma_m"].val
if wmma_m == 16:
wmma_n = 16
elif wmma_m == 8:
wmma_n = 32
elif wmma_m == 32:
wmma_n = 8
elif data_dtype in ["int4", "uint4"]:
wmma_m = wmma_n = 8
wmma_k = 32
else:
raise ValueError("data dtype %s is not yet supported" % data_dtype)

warp_size = 32
wmma_k = 16
block_row_warps = cfg["block_row_warps"].val
block_col_warps = cfg["block_col_warps"].val
warp_row_tiles = cfg["warp_row_tiles"].val
warp_col_tiles = cfg["warp_col_tiles"].val
chunk = cfg["chunk"].val
offset = cfg["offset"].val
offsetCS = cfg["offsetCS"].val
wmma_m = cfg["wmma_m"].val
vec = cfg["vec"].val

if wmma_m == 16:
wmma_n = 16
elif wmma_m == 8:
wmma_n = 32
elif wmma_m == 32:
wmma_n = 8

# Define the stride of intrin functions
AS_align = chunk * wmma_k + offset
BS_align = chunk * wmma_k + offset
Expand Down Expand Up @@ -245,10 +254,8 @@ def shared_shedule(stage, strides):
shared_shedule(BS, BS_align)

shape = (wmma_m, wmma_n, wmma_k)
# TODO: add checking here, datatype casting may cause precision loss
in_dtype = "float16"
AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=in_dtype)
BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=in_dtype)
AL_gemm = te.placeholder((wmma_m, wmma_k), name="AL_gemm", dtype=data_dtype)
BL_gemm = te.placeholder((wmma_n, wmma_k), name="BL_gemm", dtype=data_dtype)
k_gemm = te.reduce_axis((0, wmma_k), name="k_gemm")
CL_compute = te.compute(
(wmma_m, wmma_n),
Expand All @@ -264,13 +271,13 @@ def shared_shedule(stage, strides):
s[AF].tensorize(
b_ii,
intrin_wmma_load_matrix_A(
AF_stride, AS_stride, shape, "row_major", (wmma_m, wmma_k), (wmma_m, wmma_k), "float16"
AF_stride, AS_stride, shape, "row_major", (wmma_m, wmma_k), (wmma_m, wmma_k), data_dtype
),
)
s[BF].tensorize(
o_ii,
intrin_wmma_load_matrix_W(
BF_stride, BS_stride, shape, "col_major", (wmma_n, wmma_k), (wmma_n, wmma_k), "float16"
BF_stride, BS_stride, shape, "col_major", (wmma_n, wmma_k), (wmma_n, wmma_k), data_dtype
),
)
s[CF].tensorize(
Expand Down
Loading

0 comments on commit 21838ff

Please sign in to comment.