From a5740bcf5287097b64dff8adb50f0cddc2c41349 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 31 Oct 2021 19:36:53 +0900 Subject: [PATCH] partitioning looks good --- python/tvm/contrib/cutlass/build.py | 10 +++++++--- python/tvm/contrib/cutlass/gen_gemm.py | 2 +- python/tvm/relay/op/contrib/cutlass.py | 6 +++--- tests/python/contrib/test_cutlass.py | 6 +++++- 4 files changed, 16 insertions(+), 8 deletions(-) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index b33d87505ecd..5e68570e504f 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -111,10 +111,14 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t NN, KK, annotator.signature["ret_dtype"], - profile_all, - use_multiprocessing, batched=True, + profile_all=profile_all, + use_multiprocessing=use_multiprocessing, ) + new_attrs["batch"] = arg0_shape[0] + new_attrs["batch_stride_A"] = arg0_shape[1] * arg0_shape[2] + new_attrs["batch_stride_B"] = arg1_shape[1] * arg1_shape[2] + new_attrs["batch_stride_C"] = arg0_shape[1] * arg1_shape[2] else: MM = arg0_shape[0] KK = arg0_shape[1] @@ -122,7 +126,7 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t out = cutlass_profiler.profile( MM, NN, KK, annotator.signature["ret_dtype"], profile_all, use_multiprocessing ) - if new_attrs["op_type"] == "cutlass.dense": + if new_attrs["op_type"] in ["cutlass.dense", "cutlass.batch_matmul"]: new_attrs["cutlass_op_def"] = out["opdef"] elif new_attrs["op_type"] == "cutlass.dense_bias": new_attrs["cutlass_op_def"] = out["opdef_bias"] diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index 44b57d29f46a..53fed5828c9c 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -141,7 +141,7 @@ def generate_tensor_op_common(math_instructions, alignment_constraints, get_tile math_inst.element_accumulator, ] - out = create_gemm_operator(layouts, tile_descriptions, data_type, alignment_constraints, batched) + out = create_gemm_operator(layouts, tile_descriptions, data_type, alignment_constraints, batched=batched) ops.extend(out) diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py index 44821293e70b..1d6d8a3d6662 100644 --- a/python/tvm/relay/op/contrib/cutlass.py +++ b/python/tvm/relay/op/contrib/cutlass.py @@ -51,8 +51,8 @@ def make_gemm_pattern(with_bias=True, with_act=None, out_dtype="float16"): return make_gelu_pattern(gemm_out, out_dtype) -def make_batched_matmul_pattern(): - return is_op("nn.batched_matmul")(wildcard(), wildcard() +def make_batch_matmul_pattern(): + return is_op("nn.batch_matmul")(wildcard(), wildcard()) def partition_for_cutlass(mod): @@ -71,7 +71,7 @@ def partition_for_cutlass(mod): dense_bias_relu_pat, dense_bias_pat, dense_pat, - make_batched_matmul_pattern() + ("cutlass.batch_matmul", make_batch_matmul_pattern()) ] mod = transform.MergeComposite(cutlass_patterns)(mod) mod = transform.AnnotateTarget(["cutlass"])(mod) diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index fb446dc58ef6..cd538927d376 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -77,6 +77,9 @@ def profile_and_build(mod, params, sm, tmp_dir="./tmp", lib_path="compile.so"): mod, num_cutlass_partition = tune_cutlass_kernels( mod, sm, profile_all=False, use_multiprocessing=False, tmp_dir=tmp_dir ) + print(mod) + return None, None, None + with tvm.transform.PassContext(opt_level=3): lib = relay.build(mod, target="cuda", params=params) lib = build_cutlass_kernels(lib, sm, tmp_dir, lib_path) @@ -120,8 +123,9 @@ def verify_batch_matmul(func, batch, M, N, K, sm=80, atol=1e-5, rtol=1e-5, run_b x_np = np.random.uniform(-1, 1, (batch, M, K)).astype("float16") y_np = np.random.uniform(-1, 1, (batch, N, K)).astype("float16") - rt_mod_ref, dev = get_ref_rt_mod(mod, {}) rt_mod, dev, num_partition = profile_and_build(mod, {}, sm) + return + rt_mod_ref, dev = get_ref_rt_mod(mod, {}) assert num_partition > 0 x = tvm.nd.array(x_np, device=dev)