diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 4585809f63e1..d0ad377203c9 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -23,9 +23,7 @@ from tvm.auto_scheduler import is_auto_scheduler_enabled from tvm.meta_schedule import is_meta_schedule_enabled from tvm.relay.ty import is_dynamic -from tvm.target import Target from tvm.te import SpecializedCondition -from tvm.topi.x86.utils import target_has_vnni from .. import op as _op from .generic import * @@ -618,7 +616,6 @@ def dense_pack_strategy_cpu(attrs, inputs, out_type, target): def batch_matmul_strategy_cpu(attrs, inputs, out_type, target): """batch_matmul x86 strategy""" strategy = _op.OpStrategy() - mcpu = Target.current().mcpu need_auto_scheduler_layout = is_auto_scheduler_enabled() need_meta_schedule_layout = is_meta_schedule_enabled() @@ -626,16 +623,15 @@ def batch_matmul_strategy_cpu(attrs, inputs, out_type, target): if ( not attrs.transpose_a and attrs.transpose_b - and target_has_vnni(mcpu) and inputs[0].dtype == "uint8" and inputs[1].dtype == "int8" and inputs[1].shape[-2] % 16 == 0 and inputs[1].shape[-1] % 4 == 0 ): strategy.add_implementation( - wrap_compute_batch_matmul(topi.x86.batch_matmul_vnni_compute, need_out_dtype=True), - wrap_topi_schedule(topi.x86.schedule_batch_matmul_vnni), - name="batch_matmul_vnni.x86", + wrap_compute_batch_matmul(topi.x86.batch_matmul_int8_compute, need_out_dtype=True), + wrap_topi_schedule(topi.x86.schedule_batch_matmul_int8), + name="batch_matmul_int8.x86", plevel=10, ) elif is_dynamic(out_type) or need_auto_scheduler_layout or need_meta_schedule_layout: diff --git a/python/tvm/topi/x86/batch_matmul.py b/python/tvm/topi/x86/batch_matmul.py index 025f41660c9c..9f3bc2951524 100644 --- a/python/tvm/topi/x86/batch_matmul.py +++ b/python/tvm/topi/x86/batch_matmul.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name,too-many-locals,unused-variable +# pylint: disable=unused-argument """x86 batch_matmul operators""" import tvm from tvm import autotvm, te @@ -24,18 +25,24 @@ from .. import generic, nn from ..transform import layout_transform from ..utils import get_const_tuple, get_max_power2_factor, traverse_inline -from .dense import dense_vnni_schedule +from .dense import dense_vnni_schedule, dense_amx_int8_schedule from .injective import schedule_injective_from_existing +from .utils import target_has_vnni, target_has_amx @autotvm.register_topi_compute("batch_matmul_vnni.x86") -def batch_matmul_vnni_compute(cfg, x, y, *_): +def batch_matmul_int8_compute(cfg, x, y, *_): """Compute for uint8 x int8 -> int32 batch_matmul""" batch, m, k = x.shape packed_y_layout = "BNK16n4k" packed_y = layout_transform(y, "BNK", packed_y_layout) _, n_o, _, n_i, _ = packed_y.shape ak = te.reduce_axis((0, k), name="k") + mcpu = tvm.target.Target.current().mcpu + if target_has_vnni(mcpu): + attrs_info = {"schedule_rule": "batch_matmul_vnni"} + else: + attrs_info = None z = te.compute( (batch, m, n_o * n_i), @@ -46,14 +53,10 @@ def batch_matmul_vnni_compute(cfg, x, y, *_): ), axis=ak, ), - tag="batch_matmul_vnni", - attrs={"schedule_rule": "batch_matmul_vnni"}, + tag="batch_matmul_int8", + attrs=attrs_info, ) - _, a_y, _ = z.op.axis - cfg.define_split("tile_y", a_y, num_outputs=2) - cfg.define_knob("layout_trans_compute_root", [0, 1]) - return z @@ -67,6 +70,7 @@ def batch_matmul_vnni_schedule(cfg, s, C, O, layout_trans): # Parallelize over batch fused = s[O].fuse(O.op.axis[0], fused_inner) s[O].parallel(fused) + cfg.define_knob("layout_trans_compute_root", [0, 1]) if cfg["layout_trans_compute_root"].val: s[layout_trans].compute_root() @@ -80,6 +84,29 @@ def batch_matmul_vnni_schedule(cfg, s, C, O, layout_trans): return s +def batch_matmul_amx_schedule(cfg, s, C, O, layout_trans): + """Schedule batch_matmul compute using AMX tdpbusd instruction""" + # C: The output of batched GEMM + # O: The output of the fused op + + # Schedule the GEMM part + s, fused_inner = dense_amx_int8_schedule(cfg, s, C, O, do_parallel=False) + # Parallelize over ouuter loop + fused = s[O].fuse(O.op.axis[0], fused_inner) + s[O].parallel(fused) + cfg.define_knob("layout_trans_compute_root", [0, 1]) + + if cfg["layout_trans_compute_root"].val: + s[layout_trans].compute_root() + schedule_injective_from_existing(s, layout_trans) + else: + _, _, _, ni, ki = s[layout_trans].op.axis + s[layout_trans].vectorize(ki) + s[layout_trans].unroll(ni) + + return s + + @autotvm.register_topi_compute("batch_matmul.x86") def batch_matmul( cfg, tensor_a, tensor_b, out_shape=None, out_dtype=None, transpose_a=False, transpose_b=True @@ -202,14 +229,18 @@ def _callback(op): @autotvm.register_topi_schedule("batch_matmul_vnni.x86") -def schedule_batch_matmul_vnni(cfg, outs): +def schedule_batch_matmul_int8(cfg, outs): """Schedule for batch_matmul_vnni""" s = te.create_schedule([x.op for x in outs]) + mcpu = tvm.target.Target.current().mcpu def _callback(op): - if "batch_matmul_vnni" in op.tag: + if "batch_matmul_int8" in op.tag: layout_trans = op.input_tensors[1] - batch_matmul_vnni_schedule(cfg, s, op.output(0), outs[0], layout_trans) + if target_has_amx(mcpu): + batch_matmul_amx_schedule(cfg, s, op.output(0), outs[0], layout_trans) + elif target_has_vnni(mcpu): + batch_matmul_vnni_schedule(cfg, s, op.output(0), outs[0], layout_trans) traverse_inline(s, outs[0].op, _callback) return s diff --git a/python/tvm/topi/x86/dense.py b/python/tvm/topi/x86/dense.py index ada19d598cdf..bb99a632811b 100644 --- a/python/tvm/topi/x86/dense.py +++ b/python/tvm/topi/x86/dense.py @@ -436,7 +436,7 @@ def split_k(out, rd_axis): cfg.define_split("tile_k", rd_axis, num_outputs=5, filter=lambda y: y.size[-1] == 128) return cfg["tile_k"].apply(s, out, rd_axis) - a_x, a_y = C.op.axis + a_x, a_y = C.op.axis[-2:] (a_k,) = C.op.reduce_axis CF = s.cache_write(C, "amx.tmm") @@ -447,7 +447,7 @@ def split_k(out, rd_axis): s[CF].compute_at(s[C], a_yo) (a_k_f,) = CF.op.reduce_axis - a_x_f, a_y_f = CF.op.axis + a_x_f, a_y_f = CF.op.axis[-2:] a_xo_f, a_xi_f = s[CF].split(a_x_f, factor=32) @@ -455,8 +455,8 @@ def split_k(out, rd_axis): a_k3_f, a_k2_f, a_k1_f, a_ko_f, a_ki_f = split_k(CF, a_k_f) s[CF].reorder(a_k3_f, a_k2_f, a_k1_f, a_ko_f, a_xo_f, a_yo_f, a_ki_f, a_xi_f, a_yi_f) - (m, k) = CF.op.input_tensors[0].shape - (n, c, n_i, c_i) = CF.op.input_tensors[1].shape + (m, k) = CF.op.input_tensors[0].shape[-2:] + (n, c, n_i, c_i) = CF.op.input_tensors[1].shape[-4:] n = n * n_i s[CF].tensorize(a_ki_f, dot_32x128x32_u8s8s32_sapphirerapids(LDA=int(k))) @@ -479,19 +479,6 @@ def split_k(out, rd_axis): return s, fused -@autotvm.register_topi_schedule("dense_amx_int8.x86") -def schedule_dense_amx_int8(cfg, outs): - """Create a schedule for dense_amx_int8""" - s = te.create_schedule([x.op for x in outs]) - - def _callback(op): - if "dense_amx_int8" in op.tag: - dense_amx_int8_schedule(cfg, s, op.output(0), outs[0]) - - traverse_inline(s, outs[0].op, _callback) - return s - - def matmul_blas_common(cfg, tensor_a, tensor_b, bias, out_dtype, transpose_a, transpose_b, lib): """Compute matmul/dense using a BLAS library""" M, K = get_const_tuple(tensor_a.shape) diff --git a/tests/python/relay/test_op_level10.py b/tests/python/relay/test_op_level10.py index 619a0b5a9333..cdf4e734842b 100644 --- a/tests/python/relay/test_op_level10.py +++ b/tests/python/relay/test_op_level10.py @@ -520,6 +520,61 @@ def test_batch_matmul_vnni(b, m, n, k): np.testing.assert_equal(out, ref) +@pytest.mark.skip("skip due to AMX feature not avaliable yet") +@pytest.mark.parametrize( + "b,m,n,k", + [ + (16, 32, 32, 128), + (16, 32, 32, 127), + (16, 32, 31, 128), + ], +) +def test_batch_matmul_amx(b, m, n, k): + amx_init = tvm.get_global_func("runtime.amx_init") + amx_tileconfig = tvm.get_global_func("runtime.amx_tileconfig") + assert amx_init() + assert amx_tileconfig(16, 64) # config tile size to 16 rows by 64 columns. + + x_shape = (b, m, k) + y_shape = (b, n, k) + z_shape = (b, m, n) + + for lhs_dtype in ["uint8", "int8"]: + x = relay.var("x", shape=x_shape, dtype=lhs_dtype) + y = relay.var("y", shape=y_shape, dtype="int8") + z = relay.var("z", shape=z_shape, dtype="int32") + bmm = relay.nn.batch_matmul(x, y, out_dtype="int32") + out = bmm + z + mod = tvm.IRModule.from_expr(out) + + target = "llvm -mcpu=sapphirerapids" + with tvm.transform.PassContext(opt_level=3): + lib = relay.build(mod, target=target) + + asm = lib.lib.get_source("asm") + assert "tilezero" in asm + assert "tileloaddt1" in asm + assert "tdpbusd" in asm + assert "tilestored" in asm + + dev = tvm.device(target, 0) + runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev)) + + x_np = np.random.uniform(1, 10, size=x_shape).astype(lhs_dtype) + y_np = np.random.uniform(1, 10, size=y_shape).astype("int8") + z_np = np.random.uniform(1, 10, size=z_shape).astype("int32") + + runtime.set_input("x", x_np) + runtime.set_input("y", y_np) + runtime.set_input("z", z_np) + runtime.run() + + out = runtime.get_output(0).numpy() + ref = tvm.topi.testing.batch_matmul(x_np, y_np, out_dtype="int32") + z_np + + np.testing.assert_equal(out, ref) + + @pytest.mark.skip("Requires GFX10 AMDGPU") def test_batch_matmul_rocm_sdot4(): x_shape = (16, 32, 96)