Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tensorize][TOPI] Add AMX Tensorizing for int8 batch matmul #13745

Merged
merged 6 commits into from
Jan 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
from tvm.auto_scheduler import is_auto_scheduler_enabled
from tvm.meta_schedule import is_meta_schedule_enabled
from tvm.relay.ty import is_dynamic
from tvm.target import Target
from tvm.te import SpecializedCondition
from tvm.topi.x86.utils import target_has_vnni

from .. import op as _op
from .generic import *
Expand Down Expand Up @@ -618,24 +616,22 @@ def dense_pack_strategy_cpu(attrs, inputs, out_type, target):
def batch_matmul_strategy_cpu(attrs, inputs, out_type, target):
"""batch_matmul x86 strategy"""
strategy = _op.OpStrategy()
mcpu = Target.current().mcpu

need_auto_scheduler_layout = is_auto_scheduler_enabled()
need_meta_schedule_layout = is_meta_schedule_enabled()

if (
not attrs.transpose_a
and attrs.transpose_b
and target_has_vnni(mcpu)
and inputs[0].dtype == "uint8"
and inputs[1].dtype == "int8"
and inputs[1].shape[-2] % 16 == 0
and inputs[1].shape[-1] % 4 == 0
):
strategy.add_implementation(
wrap_compute_batch_matmul(topi.x86.batch_matmul_vnni_compute, need_out_dtype=True),
wrap_topi_schedule(topi.x86.schedule_batch_matmul_vnni),
name="batch_matmul_vnni.x86",
wrap_compute_batch_matmul(topi.x86.batch_matmul_int8_compute, need_out_dtype=True),
wrap_topi_schedule(topi.x86.schedule_batch_matmul_int8),
name="batch_matmul_int8.x86",
plevel=10,
)
elif is_dynamic(out_type) or need_auto_scheduler_layout or need_meta_schedule_layout:
Expand Down
53 changes: 42 additions & 11 deletions python/tvm/topi/x86/batch_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name,too-many-locals,unused-variable
# pylint: disable=unused-argument
"""x86 batch_matmul operators"""
import tvm
from tvm import autotvm, te
Expand All @@ -24,18 +25,24 @@
from .. import generic, nn
from ..transform import layout_transform
from ..utils import get_const_tuple, get_max_power2_factor, traverse_inline
from .dense import dense_vnni_schedule
from .dense import dense_vnni_schedule, dense_amx_int8_schedule
from .injective import schedule_injective_from_existing
from .utils import target_has_vnni, target_has_amx


@autotvm.register_topi_compute("batch_matmul_vnni.x86")
def batch_matmul_vnni_compute(cfg, x, y, *_):
def batch_matmul_int8_compute(cfg, x, y, *_):
"""Compute for uint8 x int8 -> int32 batch_matmul"""
batch, m, k = x.shape
packed_y_layout = "BNK16n4k"
packed_y = layout_transform(y, "BNK", packed_y_layout)
_, n_o, _, n_i, _ = packed_y.shape
ak = te.reduce_axis((0, k), name="k")
mcpu = tvm.target.Target.current().mcpu
if target_has_vnni(mcpu):
attrs_info = {"schedule_rule": "batch_matmul_vnni"}
else:
attrs_info = None

z = te.compute(
(batch, m, n_o * n_i),
Expand All @@ -46,14 +53,10 @@ def batch_matmul_vnni_compute(cfg, x, y, *_):
),
axis=ak,
),
tag="batch_matmul_vnni",
attrs={"schedule_rule": "batch_matmul_vnni"},
tag="batch_matmul_int8",
attrs=attrs_info,
)

_, a_y, _ = z.op.axis
cfg.define_split("tile_y", a_y, num_outputs=2)
cfg.define_knob("layout_trans_compute_root", [0, 1])

return z


Expand All @@ -67,6 +70,7 @@ def batch_matmul_vnni_schedule(cfg, s, C, O, layout_trans):
# Parallelize over batch
fused = s[O].fuse(O.op.axis[0], fused_inner)
s[O].parallel(fused)
cfg.define_knob("layout_trans_compute_root", [0, 1])

if cfg["layout_trans_compute_root"].val:
s[layout_trans].compute_root()
Expand All @@ -80,6 +84,29 @@ def batch_matmul_vnni_schedule(cfg, s, C, O, layout_trans):
return s


def batch_matmul_amx_schedule(cfg, s, C, O, layout_trans):
"""Schedule batch_matmul compute using AMX tdpbusd instruction"""
# C: The output of batched GEMM
# O: The output of the fused op

# Schedule the GEMM part
s, fused_inner = dense_amx_int8_schedule(cfg, s, C, O, do_parallel=False)
# Parallelize over ouuter loop
fused = s[O].fuse(O.op.axis[0], fused_inner)
s[O].parallel(fused)
cfg.define_knob("layout_trans_compute_root", [0, 1])

if cfg["layout_trans_compute_root"].val:
s[layout_trans].compute_root()
schedule_injective_from_existing(s, layout_trans)
else:
_, _, _, ni, ki = s[layout_trans].op.axis
s[layout_trans].vectorize(ki)
s[layout_trans].unroll(ni)

return s


@autotvm.register_topi_compute("batch_matmul.x86")
def batch_matmul(
cfg, tensor_a, tensor_b, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True
Expand Down Expand Up @@ -202,14 +229,18 @@ def _callback(op):


@autotvm.register_topi_schedule("batch_matmul_vnni.x86")
def schedule_batch_matmul_vnni(cfg, outs):
def schedule_batch_matmul_int8(cfg, outs):
"""Schedule for batch_matmul_vnni"""
s = te.create_schedule([x.op for x in outs])
mcpu = tvm.target.Target.current().mcpu

def _callback(op):
if "batch_matmul_vnni" in op.tag:
if "batch_matmul_int8" in op.tag:
layout_trans = op.input_tensors[1]
batch_matmul_vnni_schedule(cfg, s, op.output(0), outs[0], layout_trans)
if target_has_amx(mcpu):
batch_matmul_amx_schedule(cfg, s, op.output(0), outs[0], layout_trans)
elif target_has_vnni(mcpu):
batch_matmul_vnni_schedule(cfg, s, op.output(0), outs[0], layout_trans)

traverse_inline(s, outs[0].op, _callback)
return s
Expand Down
21 changes: 4 additions & 17 deletions python/tvm/topi/x86/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def split_k(out, rd_axis):
cfg.define_split("tile_k", rd_axis, num_outputs=5, filter=lambda y: y.size[-1] == 128)
return cfg["tile_k"].apply(s, out, rd_axis)

a_x, a_y = C.op.axis
a_x, a_y = C.op.axis[-2:]
(a_k,) = C.op.reduce_axis
CF = s.cache_write(C, "amx.tmm")

Expand All @@ -447,16 +447,16 @@ def split_k(out, rd_axis):
s[CF].compute_at(s[C], a_yo)

(a_k_f,) = CF.op.reduce_axis
a_x_f, a_y_f = CF.op.axis
a_x_f, a_y_f = CF.op.axis[-2:]

a_xo_f, a_xi_f = s[CF].split(a_x_f, factor=32)

a_yo_f, a_yi_f = s[CF].split(a_y_f, factor=32)
a_k3_f, a_k2_f, a_k1_f, a_ko_f, a_ki_f = split_k(CF, a_k_f)
s[CF].reorder(a_k3_f, a_k2_f, a_k1_f, a_ko_f, a_xo_f, a_yo_f, a_ki_f, a_xi_f, a_yi_f)

(m, k) = CF.op.input_tensors[0].shape
(n, c, n_i, c_i) = CF.op.input_tensors[1].shape
(m, k) = CF.op.input_tensors[0].shape[-2:]
(n, c, n_i, c_i) = CF.op.input_tensors[1].shape[-4:]
n = n * n_i

s[CF].tensorize(a_ki_f, dot_32x128x32_u8s8s32_sapphirerapids(LDA=int(k)))
Expand All @@ -479,19 +479,6 @@ def split_k(out, rd_axis):
return s, fused


@autotvm.register_topi_schedule("dense_amx_int8.x86")
def schedule_dense_amx_int8(cfg, outs):
"""Create a schedule for dense_amx_int8"""
s = te.create_schedule([x.op for x in outs])

def _callback(op):
if "dense_amx_int8" in op.tag:
dense_amx_int8_schedule(cfg, s, op.output(0), outs[0])

traverse_inline(s, outs[0].op, _callback)
return s


def matmul_blas_common(cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, transpose_b, lib):
"""Compute matmul/dense using a BLAS library"""
M, K = get_const_tuple(tensor_a.shape)
Expand Down
55 changes: 55 additions & 0 deletions tests/python/relay/test_op_level10.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,61 @@ def test_batch_matmul_vnni(b, m, n, k):
np.testing.assert_equal(out, ref)


@pytest.mark.skip("skip due to AMX feature not avaliable yet")
@pytest.mark.parametrize(
"b,m,n,k",
[
(16, 32, 32, 128),
(16, 32, 32, 127),
(16, 32, 31, 128),
],
)
def test_batch_matmul_amx(b, m, n, k):
amx_init = tvm.get_global_func("runtime.amx_init")
amx_tileconfig = tvm.get_global_func("runtime.amx_tileconfig")
assert amx_init()
assert amx_tileconfig(16, 64) # config tile size to 16 rows by 64 columns.

x_shape = (b, m, k)
y_shape = (b, n, k)
z_shape = (b, m, n)

for lhs_dtype in ["uint8", "int8"]:
x = relay.var("x", shape=x_shape, dtype=lhs_dtype)
y = relay.var("y", shape=y_shape, dtype="int8")
z = relay.var("z", shape=z_shape, dtype="int32")
bmm = relay.nn.batch_matmul(x, y, out_dtype="int32")
out = bmm + z
mod = tvm.IRModule.from_expr(out)

target = "llvm -mcpu=sapphirerapids"
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target)

asm = lib.lib.get_source("asm")
assert "tilezero" in asm
assert "tileloaddt1" in asm
assert "tdpbusd" in asm
assert "tilestored" in asm

dev = tvm.device(target, 0)
runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))

x_np = np.random.uniform(1, 10, size=x_shape).astype(lhs_dtype)
y_np = np.random.uniform(1, 10, size=y_shape).astype("int8")
z_np = np.random.uniform(1, 10, size=z_shape).astype("int32")

runtime.set_input("x", x_np)
runtime.set_input("y", y_np)
runtime.set_input("z", z_np)
runtime.run()

out = runtime.get_output(0).numpy()
ref = tvm.topi.testing.batch_matmul(x_np, y_np, out_dtype="int32") + z_np

np.testing.assert_equal(out, ref)


@pytest.mark.skip("Requires GFX10 AMDGPU")
def test_batch_matmul_rocm_sdot4():
x_shape = (16, 32, 96)
Expand Down