From 4791015ae233928526076b213f010bbae851308b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 20 Jan 2022 11:41:54 +0900 Subject: [PATCH 01/15] introduce profile_all_alignments option --- python/tvm/contrib/cutlass/gen_conv2d.py | 25 ++++++------ python/tvm/contrib/cutlass/gen_tensor_op.py | 44 ++++++++++++++++++--- 2 files changed, 50 insertions(+), 19 deletions(-) diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index c09017adfd95..cf96fb594396 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -168,14 +168,6 @@ def get_default(self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32): ) return {"name": name, "opdef": opdef} - def check_align(self, op_name, C, K): - """Filter out kernels that cannot be supported.""" - match = re.match(".*_align([1-9]+)", op_name) - assert match is not None and len(match.groups()) == 1 - # The same alignment is used for all axes - align = int(match.groups()[0]) - return all([dim % align == 0 for dim in [C, K]]) - def select_op( self, d_shape, @@ -187,6 +179,7 @@ def select_op( data_dtype, weight_dtype, use_3xtf32, + profile_all_alignments=False, profile_all=True, use_multiprocessing=False, ): @@ -216,11 +209,15 @@ def select_op( return self.cache[workload] ops = GENERATOR_FUNC_TABLE[self.sm]( - out_dtype, data_dtype, weight_dtype, enumerate_conv2d_operators, use_3xtf32 + out_dtype, + data_dtype, + weight_dtype, + enumerate_conv2d_operators, + lambda align: all([dim % align == 0 for dim in [IC, OC]]), + use_3xtf32, + profile_all_alignments, ) - ops = list(filter(lambda op: self.check_align(op["name"], IC, OC), ops)) - if profile_all: self.engine.compile_all(ops, use_multiprocessing) @@ -252,6 +249,7 @@ def profile( data_dtype, weight_dtype, use_3xtf32=True, + profile_all_alignments=False, profile_all=True, use_multiprocessing=False, ): @@ -269,8 +267,9 @@ def profile( data_dtype, weight_dtype, use_3xtf32, - profile_all=profile_all, - use_multiprocessing=use_multiprocessing, + profile_all_alignments, + profile_all, + use_multiprocessing ) name, opdef = create_conv2d_operator_with_epilogue( diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 6bb4f290233e..97af84e76990 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -62,7 +62,9 @@ def generate_tensor_op_common( return ops -def generate_sm75_tensor_op_1688(out_dtype, arg0_dtype, arg1_dtype, op_creator): +def generate_sm75_tensor_op_1688( + out_dtype, arg0_dtype, arg1_dtype, op_creator, check_align, _, profile_all_alignments=False +): """Generate GEMM or Conv2D kernels for Turing.""" assert out_dtype in ["float32", "float16", "int32"] min_cc = 75 @@ -114,6 +116,12 @@ def generate_sm75_tensor_op_1688(out_dtype, arg0_dtype, arg1_dtype, op_creator): ([64, 64, 64], 2, [2, 2, 1], min_cc, max_cc), ] + alignment_constraints = [align for align in alignment_constraints if check_align(align)] + assert len(alignment_constraints) > 0 + + if not profile_all_alignments: + alignment_constraints = [alignment_constraints[0]] + def get_tile_descriptions(math_inst): return [ TileDescription(threadblock_shape, stages, warp_count, math_inst, min_cc, max_cc) @@ -125,7 +133,15 @@ def get_tile_descriptions(math_inst): ) -def generate_sm80_tensor_op_16816(out_dtype, arg0_dtype, arg1_dtype, op_creator, use_3xtf32=True): +def generate_sm80_tensor_op_16816( + out_dtype, + arg0_dtype, + arg1_dtype, + op_creator, + check_align, + use_3xtf32=True, + profile_all_alignments=False, +): """Generate GEMM or Conv2D kernels for Ampere.""" min_cc = 80 max_cc = 1024 @@ -218,15 +234,31 @@ def get_tile_descriptions(math_inst): for threadblock_shape, stages, warp_count, min_cc, max_cc in tile_descriptions ] + alignment_constraints = [align for align in alignment_constraints if check_align(align)] + + if len(alignment_constraints) > 0 and not profile_all_alignments: + alignment_constraints = [alignment_constraints[0]] + if arg0_dtype != "float32" and arg1_dtype != "float32": - sm75_kernels = generate_sm75_tensor_op_1688(out_dtype, arg0_dtype, arg1_dtype, op_creator) + sm75_kernels = generate_sm75_tensor_op_1688( + out_dtype, + arg0_dtype, + arg1_dtype, + op_creator, + check_align, + False, + profile_all_alignments, + ) else: # TF32 (float32 + float32 case) is only supported on sm80 sm75_kernels = [] - sm80_kernels = generate_tensor_op_common( - math_instructions, alignment_constraints, get_tile_descriptions, op_creator - ) + if len(alignment_constraints) > 0: + sm80_kernels = generate_tensor_op_common( + math_instructions, alignment_constraints, get_tile_descriptions, op_creator + ) + else: + sm80_kernels = [] # TODO(masahi): For int8 kernels, The CUTLASS generator modifies the output tensor alignment # after ops are created. Revisit how important this modification is. From 7a8b8977be58c4063c237b02286b6e66d030b015 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 20 Jan 2022 13:43:42 +0900 Subject: [PATCH 02/15] add profile_all_alignment option to API --- python/tvm/contrib/cutlass/build.py | 6 +++++- tests/python/contrib/test_cutlass.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index c919ff283343..ba7360794628 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -272,7 +272,7 @@ def handle_conv2d( def tune_cutlass_kernels( - mod, sm, use_3xtf32=True, profile_all=True, use_multiprocessing=False, tmp_dir="./tmp" + mod, sm, use_3xtf32=True, profile_all_alignments=False, profile_all=True, use_multiprocessing=False, tmp_dir="./tmp" ): """Given a module partitioned for CUTLASS offloading, profile each workload to select which kernels to emit. @@ -286,6 +286,9 @@ def tune_cutlass_kernels( An integer specifying the compute capability. For example, 75 for Turing and 80 or 86 for Ampere. + profile_all_alignments : bool + TODO + profile_all : bool Whether or not profile all candidate kernels, or stop profiling after the first applicable kernel is found. @@ -342,6 +345,7 @@ def tune_cutlass_kernels( arg0_dtype, arg1_dtype, use_3xtf32, + profile_all_alignments, profile_all, use_multiprocessing, ) diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index 57f2f39c641b..395cb660bbed 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -188,6 +188,7 @@ def profile_and_build( mod, sm, use_3xtf32=use_3xtf32, + profile_all_alignments=True, profile_all=False, use_multiprocessing=False, tmp_dir=tmp_dir, From 589fe6579ea5123224d730665c845c543254770c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 20 Jan 2022 15:54:20 +0900 Subject: [PATCH 03/15] wip --- python/tvm/contrib/cutlass/build.py | 2 ++ python/tvm/contrib/cutlass/gen_gemm.py | 2 +- python/tvm/runtime/module.py | 6 ++++-- tests/python/contrib/test_cutlass.py | 9 +++++---- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index ba7360794628..5b005cdde883 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -237,6 +237,7 @@ def handle_conv2d( data_dtype, weight_dtype, use_3xtf32, + profile_all_alignments, profile_all, use_multiprocessing, ): @@ -257,6 +258,7 @@ def handle_conv2d( data_dtype, weight_dtype, use_3xtf32, + profile_all_alignments, profile_all=profile_all, use_multiprocessing=use_multiprocessing, ) diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index 445acb9305c8..7f02124d12ee 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -168,7 +168,7 @@ def get_default( For now, the default kernel was picked arbitrary. """ ops = GENERATOR_FUNC_TABLE[self.sm]( - out_dtype, arg0_dtype, arg1_dtype, enumerate_gemm_operators, use_3xtf32 + out_dtype, arg0_dtype, arg1_dtype, enumerate_gemm_operators, lambda _: True, use_3xtf32 ) default_kernel_name = DEFAULT_KERNELS[self.sm][(arg0_dtype, out_dtype)] diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index da7c52ad119e..cf2787dda750 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -411,7 +411,8 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No "c", "cc", "cpp", - ], "The module.format needs to be either c, cc or cpp" + "cu", + ], "The module.format needs to be either c, cc, cpp or cu." object_format = module.format has_c_module = True else: @@ -426,7 +427,8 @@ def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=No "c", "cc", "cpp", - ], "The module.format needs to be either c, cc or cpp" + "cu", + ], "The module.format needs to be either c, cc, cpp, or cu." object_format = module.format else: object_format = "c" diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index 395cb660bbed..5c7342056948 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -531,9 +531,9 @@ def test_conv2d(): mod_nchw = get_conv2d_nchw(d_shape, w_shape, padding) mod_dyn = get_conv2d_nchw(dyn_batch_shape, w_shape, padding) - verify_conv2d( - mod_dyn, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False - ) + # verify_conv2d( + # mod_dyn, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False + # ) for data_dtype, weight_dtype, out_dtype in [ ("float32", "float32", "float32"), # 3xtf32 @@ -632,4 +632,5 @@ def test_conv2d_residual_block(): if __name__ == "__main__": - pytest.main([__file__]) + # pytest.main([__file__]) + test_conv2d() From eab051fcb1162cfa4fa14db56cfb6723552f4505 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 23 Jan 2022 11:43:40 +0900 Subject: [PATCH 04/15] fixed dynamic case --- python/tvm/contrib/cutlass/gen_gemm.py | 27 +++++++++++++------------- tests/python/contrib/test_cutlass.py | 8 ++++---- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index 7f02124d12ee..34aced8523f6 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -30,12 +30,7 @@ def create_gemm_operator_with_epilogue( - op_type, - tile_description, - data_type, - alignment, - swizzling_functor, - batched=False, + op_type, tile_description, data_type, alignment, swizzling_functor, batched=False, ): """ Instantiate a cutlass kernel from the given configuration, @@ -63,8 +58,9 @@ def create_gemm_operator_with_epilogue( swizzling_functor, ) - return op.procedural_name(), EmitGemmInstance().emit( - op, no_beta_scaling=no_beta_scaling, batched=batched + return ( + op.procedural_name(), + EmitGemmInstance().emit(op, no_beta_scaling=no_beta_scaling, batched=batched), ) @@ -168,8 +164,15 @@ def get_default( For now, the default kernel was picked arbitrary. """ ops = GENERATOR_FUNC_TABLE[self.sm]( - out_dtype, arg0_dtype, arg1_dtype, enumerate_gemm_operators, lambda _: True, use_3xtf32 + out_dtype, + arg0_dtype, + arg1_dtype, + enumerate_gemm_operators, + lambda _: True, + use_3xtf32, + profile_all_alignments=True, # To include align1 kernels ) + default_kernel_name = DEFAULT_KERNELS[self.sm][(arg0_dtype, out_dtype)] if arg0_dtype == "float32": @@ -212,11 +215,7 @@ def select_op( return op ops = GENERATOR_FUNC_TABLE[self.sm]( - out_dtype, - arg0_dtype, - arg1_dtype, - enumerate_gemm_operators, - use_3xtf32=use_3xtf32, + out_dtype, arg0_dtype, arg1_dtype, enumerate_gemm_operators, use_3xtf32=use_3xtf32, ) ops = list(filter(lambda op: self.check_align(op["name"], M, N, K), ops)) diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index 5c7342056948..8a3312047ce6 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -188,7 +188,7 @@ def profile_and_build( mod, sm, use_3xtf32=use_3xtf32, - profile_all_alignments=True, + profile_all_alignments=False, profile_all=False, use_multiprocessing=False, tmp_dir=tmp_dir, @@ -531,9 +531,9 @@ def test_conv2d(): mod_nchw = get_conv2d_nchw(d_shape, w_shape, padding) mod_dyn = get_conv2d_nchw(dyn_batch_shape, w_shape, padding) - # verify_conv2d( - # mod_dyn, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False - # ) + verify_conv2d( + mod_dyn, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False + ) for data_dtype, weight_dtype, out_dtype in [ ("float32", "float32", "float32"), # 3xtf32 From 47272e522630d57359b15df01d3fd23243172e5b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 23 Jan 2022 11:45:20 +0900 Subject: [PATCH 05/15] black --- python/tvm/contrib/cutlass/build.py | 10 ++++++++-- python/tvm/contrib/cutlass/gen_conv2d.py | 2 +- tests/python/contrib/test_cutlass.py | 13 ++----------- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 5b005cdde883..409c437a8230 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -237,7 +237,7 @@ def handle_conv2d( data_dtype, weight_dtype, use_3xtf32, - profile_all_alignments, + profile_all_alignments, profile_all, use_multiprocessing, ): @@ -274,7 +274,13 @@ def handle_conv2d( def tune_cutlass_kernels( - mod, sm, use_3xtf32=True, profile_all_alignments=False, profile_all=True, use_multiprocessing=False, tmp_dir="./tmp" + mod, + sm, + use_3xtf32=True, + profile_all_alignments=False, + profile_all=True, + use_multiprocessing=False, + tmp_dir="./tmp", ): """Given a module partitioned for CUTLASS offloading, profile each workload to select which kernels to emit. diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index cf96fb594396..19ed4a86035e 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -269,7 +269,7 @@ def profile( use_3xtf32, profile_all_alignments, profile_all, - use_multiprocessing + use_multiprocessing, ) name, opdef = create_conv2d_operator_with_epilogue( diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index 8a3312047ce6..24189bd28bce 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -350,12 +350,7 @@ def test_dense(): ) # 3xtf32 verify_dense( - dense_fp32, - M, - N, - K, - data_dtype="float32", - weight_dtype="float32", + dense_fp32, M, N, K, data_dtype="float32", weight_dtype="float32", ) @@ -382,11 +377,7 @@ def test_dense_dynamic(): # TVM native fp16 dense (without tensorcore), using fp16 accum, seems to have accuracy issues # Use cublas as a reference verify_dense( - get_dense_with_shape(data_shape, weight_shape), - M, - N, - K, - ref_target="cuda -libs=cublas", + get_dense_with_shape(data_shape, weight_shape), M, N, K, ref_target="cuda -libs=cublas", ) verify_dense( From 31fdd6641aa18a7d052438acd9a9a1e515585446 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 23 Jan 2022 11:50:34 +0900 Subject: [PATCH 06/15] update gen_gemm too --- python/tvm/contrib/cutlass/gen_gemm.py | 28 +++++++++++++------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index 34aced8523f6..569d0e788f62 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=invalid-name """GEMM kernel generator and profiler for CUTLASS.""" -import re from .gemm_operation import GemmOperation, EmitGemmInstance from .gemm_profiler import GemmProfilerEmitter from .gen_tensor_op import ProfilerEngine, GENERATOR_FUNC_TABLE, EPILOGUE_MAP @@ -146,17 +145,6 @@ def __init__(self, sm, cutlass_path, binary_path): self.sm = sm self.cache = {} - def check_align(self, op_name, M, N, K): - """Filter out kernels that cannot be supported.""" - match = re.match(".*_align([1-9]+)", op_name) - assert match is not None and len(match.groups()) == 1 - # The same alignment is used for all axes - align = int(match.groups()[0]) - # TODO(masahi): CUTLASS alignment check on gemm kernels is too restrictive. - # See https://github.com/NVIDIA/cutlass/issues/362. - # When the above issue is resolved, we can remove the alignment check on M below. - return all([dim % align == 0 for dim in [M, N, K]]) - def get_default( self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32=True, batched=False ): @@ -203,6 +191,7 @@ def select_op( arg0_dtype, arg1_dtype, use_3xtf32, + profile_all_alignments=False, profile_all=True, use_multiprocessing=False, ): @@ -214,10 +203,19 @@ def select_op( op = self.cache[(M, N, K)] return op + # TODO(masahi): CUTLASS alignment check on gemm kernels is too restrictive. + # See https://github.com/NVIDIA/cutlass/issues/362. + # When the above issue is resolved, we can remove the alignment check on M below. + ops = GENERATOR_FUNC_TABLE[self.sm]( - out_dtype, arg0_dtype, arg1_dtype, enumerate_gemm_operators, use_3xtf32=use_3xtf32, + out_dtype, + arg0_dtype, + arg1_dtype, + enumerate_gemm_operators, + lambda align: all([dim % align == 0 for dim in [M, N, K]]), + use_3xtf32=use_3xtf32, + profile_all_alignments=profile_all_alignments, ) - ops = list(filter(lambda op: self.check_align(op["name"], M, N, K), ops)) if profile_all: self.engine.compile_all(ops, use_multiprocessing) @@ -243,6 +241,7 @@ def profile( arg0_dtype, arg1_dtype, use_3xtf32=True, + profile_all_alignments=False, profile_all=True, use_multiprocessing=False, batched=False, @@ -259,6 +258,7 @@ def profile( arg0_dtype, arg1_dtype, use_3xtf32, + profile_all_alignments=profile_all_alignments, profile_all=profile_all, use_multiprocessing=use_multiprocessing, ) From d1439162e37084a9f95759e799ab8b980e4b2e04 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 23 Jan 2022 11:54:44 +0900 Subject: [PATCH 07/15] minor improvement --- python/tvm/contrib/cutlass/gen_gemm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index 569d0e788f62..fdaaac9ae6e3 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -156,9 +156,9 @@ def get_default( arg0_dtype, arg1_dtype, enumerate_gemm_operators, - lambda _: True, + lambda align: align == 1, # Only request align1 kernels use_3xtf32, - profile_all_alignments=True, # To include align1 kernels + profile_all_alignments=False, ) default_kernel_name = DEFAULT_KERNELS[self.sm][(arg0_dtype, out_dtype)] From 022498b6455c048266a6e660d6cb0bfd09e1004c Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 23 Jan 2022 11:55:48 +0900 Subject: [PATCH 08/15] fix --- python/tvm/contrib/cutlass/gen_gemm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index fdaaac9ae6e3..696396b845b2 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -156,9 +156,9 @@ def get_default( arg0_dtype, arg1_dtype, enumerate_gemm_operators, - lambda align: align == 1, # Only request align1 kernels + lambda align: align == 1 # Only request align1 kernels use_3xtf32, - profile_all_alignments=False, + profile_all_alignments=True, # To include all align1 kernels ) default_kernel_name = DEFAULT_KERNELS[self.sm][(arg0_dtype, out_dtype)] From 71ccb95a18bccc44b67941206476febaf2e48417 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 24 Jan 2022 01:44:26 +0900 Subject: [PATCH 09/15] all tests work --- python/tvm/contrib/cutlass/gen_gemm.py | 9 +++++++-- tests/python/contrib/test_cutlass.py | 16 ++++++++++++---- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index 696396b845b2..f80f691f44e3 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -29,7 +29,12 @@ def create_gemm_operator_with_epilogue( - op_type, tile_description, data_type, alignment, swizzling_functor, batched=False, + op_type, + tile_description, + data_type, + alignment, + swizzling_functor, + batched=False, ): """ Instantiate a cutlass kernel from the given configuration, @@ -156,7 +161,7 @@ def get_default( arg0_dtype, arg1_dtype, enumerate_gemm_operators, - lambda align: align == 1 # Only request align1 kernels + lambda align: align == 1, # Only request align1 kernels use_3xtf32, profile_all_alignments=True, # To include all align1 kernels ) diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index 24189bd28bce..a0ac1a4c6bad 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -350,7 +350,12 @@ def test_dense(): ) # 3xtf32 verify_dense( - dense_fp32, M, N, K, data_dtype="float32", weight_dtype="float32", + dense_fp32, + M, + N, + K, + data_dtype="float32", + weight_dtype="float32", ) @@ -377,7 +382,11 @@ def test_dense_dynamic(): # TVM native fp16 dense (without tensorcore), using fp16 accum, seems to have accuracy issues # Use cublas as a reference verify_dense( - get_dense_with_shape(data_shape, weight_shape), M, N, K, ref_target="cuda -libs=cublas", + get_dense_with_shape(data_shape, weight_shape), + M, + N, + K, + ref_target="cuda -libs=cublas", ) verify_dense( @@ -623,5 +632,4 @@ def test_conv2d_residual_block(): if __name__ == "__main__": - # pytest.main([__file__]) - test_conv2d() + pytest.main([__file__]) From 514810a6b39748c0a68d4c54599ca21c72e077f6 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 24 Jan 2022 01:50:30 +0900 Subject: [PATCH 10/15] add doc --- python/tvm/contrib/cutlass/build.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 409c437a8230..c88f310e636e 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -294,8 +294,12 @@ def tune_cutlass_kernels( An integer specifying the compute capability. For example, 75 for Turing and 80 or 86 for Ampere. + use_3xtf32 : bool + Wheter or not use slower but very accurate (compared to tf32) 3xtf32 mode for + fp32 inputs on tensorcore. + profile_all_alignments : bool - TODO + When True, profile all kernal varaints with smaller alignments than the largest possible. profile_all : bool Whether or not profile all candidate kernels, or stop profiling after From 07bb45b647da0f5eaed3ae264d7d6e7dc6601709 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 24 Jan 2022 02:00:27 +0900 Subject: [PATCH 11/15] fixed for sm = 75 case --- python/tvm/contrib/cutlass/gen_gemm.py | 2 +- tests/python/contrib/test_cutlass.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index f80f691f44e3..f7e216deba44 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -218,7 +218,7 @@ def select_op( arg1_dtype, enumerate_gemm_operators, lambda align: all([dim % align == 0 for dim in [M, N, K]]), - use_3xtf32=use_3xtf32, + use_3xtf32, profile_all_alignments=profile_all_alignments, ) diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index a0ac1a4c6bad..d792bd2954dd 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -240,6 +240,9 @@ def verify_dense( ): if not has_cutlass(): return + if sm < 80 and data_dtype == "float32": + return + mod = tvm.IRModule.from_expr(func) typ = relay.transform.InferType()(mod)["main"].body.checked_type out_dtype = typ.dtype @@ -451,6 +454,8 @@ def verify_conv2d( ): if not has_cutlass(): return + if sm < 80 and data_dtype == "float32": + return mod_nchw = tvm.IRModule.from_expr(expr_nchw) mod_ref = tvm.IRModule.from_expr(expr_ref) From 1fb329639e483e3f76b08f16ae53bb10ad1541f4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 24 Jan 2022 02:04:45 +0900 Subject: [PATCH 12/15] fix typo --- python/tvm/contrib/cutlass/build.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index c88f310e636e..72f4a28fc10c 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -299,7 +299,7 @@ def tune_cutlass_kernels( fp32 inputs on tensorcore. profile_all_alignments : bool - When True, profile all kernal varaints with smaller alignments than the largest possible. + When True, profile all kernal variants with smaller alignments than the largest possible. profile_all : bool Whether or not profile all candidate kernels, or stop profiling after From ba1bbb9c050ffad0cfd6a4fe921936d6618af9f3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 24 Jan 2022 11:30:24 +0900 Subject: [PATCH 13/15] remove unused import --- python/tvm/contrib/cutlass/gen_conv2d.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index 19ed4a86035e..a6e6300f213a 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=invalid-name """Conv2d kernel generator and profiler for CUTLASS.""" -import re from .conv2d_operation import Conv2dOperation, EmitConv2dInstance from .gen_gemm import CutlassGemmProfiler from .conv2d_profiler import Conv2dProfilerEmitter From d15a995aa817a2af96853f50f7ab17729658e510 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 25 Jan 2022 04:40:36 +0900 Subject: [PATCH 14/15] profile_all -> find_first_valid --- python/tvm/contrib/cutlass/build.py | 30 ++++++++++++------------ python/tvm/contrib/cutlass/gen_conv2d.py | 12 +++++----- python/tvm/contrib/cutlass/gen_gemm.py | 12 +++++----- tests/python/contrib/test_cutlass.py | 2 +- 4 files changed, 28 insertions(+), 28 deletions(-) diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 72f4a28fc10c..fb59d02f9450 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -104,7 +104,7 @@ def select_gemm_kernel( arg1_dtype, use_3xtf32, batched, - profile_all, + find_first_valid, use_multiprocessing, ): """Run CUTLASS profiler to select the best kernel, or return the default one for dynamic @@ -126,10 +126,10 @@ def select_gemm_kernel( arg1_dtype, use_3xtf32, batched=batched, - profile_all=profile_all, + find_first_valid=find_first_valid, use_multiprocessing=use_multiprocessing, ) - if profile_all: + if not find_first_valid: logger.info("The best kernel is %s", name) else: logger.info("Picked the first kernel found %s", name) @@ -146,7 +146,7 @@ def handle_batch_matmul( arg0_dtype, arg1_dtype, use_3xtf32, - profile_all, + find_first_valid, use_multiprocessing, ): """Profile and select a kernel for batch_matmul op workload.""" @@ -165,7 +165,7 @@ def handle_batch_matmul( arg1_dtype, use_3xtf32, True, - profile_all, + find_first_valid, use_multiprocessing, ) @@ -191,7 +191,7 @@ def handle_dense( arg0_dtype, arg1_dtype, use_3xtf32, - profile_all, + find_first_valid, use_multiprocessing, ): """Profile and select a kernel for dense op workload.""" @@ -210,7 +210,7 @@ def handle_dense( arg1_dtype, use_3xtf32, False, - profile_all, + find_first_valid, use_multiprocessing, ) @@ -238,7 +238,7 @@ def handle_conv2d( weight_dtype, use_3xtf32, profile_all_alignments, - profile_all, + find_first_valid, use_multiprocessing, ): """Profile and select a kernel for conv2d op workload.""" @@ -259,10 +259,10 @@ def handle_conv2d( weight_dtype, use_3xtf32, profile_all_alignments, - profile_all=profile_all, + find_first_valid=find_first_valid, use_multiprocessing=use_multiprocessing, ) - if profile_all: + if not find_first_valid: logger.info("The best kernel is %s", name) else: logger.info("Picked the first kernel found %s", name) @@ -278,7 +278,7 @@ def tune_cutlass_kernels( sm, use_3xtf32=True, profile_all_alignments=False, - profile_all=True, + find_first_valid=False, use_multiprocessing=False, tmp_dir="./tmp", ): @@ -301,7 +301,7 @@ def tune_cutlass_kernels( profile_all_alignments : bool When True, profile all kernal variants with smaller alignments than the largest possible. - profile_all : bool + find_first_valid : bool Whether or not profile all candidate kernels, or stop profiling after the first applicable kernel is found. @@ -358,7 +358,7 @@ def tune_cutlass_kernels( arg1_dtype, use_3xtf32, profile_all_alignments, - profile_all, + find_first_valid, use_multiprocessing, ) ) @@ -373,7 +373,7 @@ def tune_cutlass_kernels( arg0_dtype, arg1_dtype, use_3xtf32, - profile_all, + find_first_valid, use_multiprocessing, ) ) @@ -388,7 +388,7 @@ def tune_cutlass_kernels( arg0_dtype, arg1_dtype, use_3xtf32, - profile_all, + find_first_valid, use_multiprocessing, ) ) diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index a6e6300f213a..5b37dca5b297 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -179,7 +179,7 @@ def select_op( weight_dtype, use_3xtf32, profile_all_alignments=False, - profile_all=True, + find_first_valid=False, use_multiprocessing=False, ): """ @@ -217,7 +217,7 @@ def select_op( profile_all_alignments, ) - if profile_all: + if find_first_valid: self.engine.compile_all(ops, use_multiprocessing) args = ( @@ -228,7 +228,7 @@ def select_op( for op in ops: out = self.engine.evaluate(op, args.split(" ")) op["runtime"] = out - if out < float("inf") and not profile_all: + if out < float("inf") and find_first_valid: self.cache[workload] = op return op @@ -249,11 +249,11 @@ def profile( weight_dtype, use_3xtf32=True, profile_all_alignments=False, - profile_all=True, + find_first_valid=False, use_multiprocessing=False, ): """Profile and select the best kernel from candidate kernels. - If profile_all is False, return immediately after the first applicable kernel is found. + If find_first_valid is True, return immediately after the first applicable kernel is found. If use_multiprocessing is True, compile all profiler executables in parallel. """ op = self.select_op( @@ -267,7 +267,7 @@ def profile( weight_dtype, use_3xtf32, profile_all_alignments, - profile_all, + find_first_valid, use_multiprocessing, ) diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index f7e216deba44..f75822dd0f9e 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -197,7 +197,7 @@ def select_op( arg1_dtype, use_3xtf32, profile_all_alignments=False, - profile_all=True, + find_first_valid=False, use_multiprocessing=False, ): """ @@ -222,13 +222,13 @@ def select_op( profile_all_alignments=profile_all_alignments, ) - if profile_all: + if find_first_valid: self.engine.compile_all(ops, use_multiprocessing) for op in ops: out = self.engine.evaluate(op, [M, N, K]) op["runtime"] = out - if out < float("inf") and not profile_all: + if out < float("inf") and find_first_valid: self.cache[(M, N, K)] = op return op @@ -247,12 +247,12 @@ def profile( arg1_dtype, use_3xtf32=True, profile_all_alignments=False, - profile_all=True, + find_first_valid=False, use_multiprocessing=False, batched=False, ): """Profile and select the best kernel from candidate kernels. - If profile_all is False, return immediately after the first applicable kernel is found. + If find_first_valid is True, return immediately after the first applicable kernel is found. If use_multiprocessing is True, compile all profiler executables in parallel. """ op = self.select_op( @@ -264,7 +264,7 @@ def profile( arg1_dtype, use_3xtf32, profile_all_alignments=profile_all_alignments, - profile_all=profile_all, + find_first_valid=find_first_valid, use_multiprocessing=use_multiprocessing, ) diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index d792bd2954dd..00506ecf0527 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -189,7 +189,7 @@ def profile_and_build( sm, use_3xtf32=use_3xtf32, profile_all_alignments=False, - profile_all=False, + find_first_valid=True, use_multiprocessing=False, tmp_dir=tmp_dir, ) From 6de3d6429b90ea4b96f30ff8f215dc6990ee1fe4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 25 Jan 2022 04:44:37 +0900 Subject: [PATCH 15/15] fix --- python/tvm/contrib/cutlass/gen_conv2d.py | 2 +- python/tvm/contrib/cutlass/gen_gemm.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index 5b37dca5b297..b6dba009f2b2 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -217,7 +217,7 @@ def select_op( profile_all_alignments, ) - if find_first_valid: + if not find_first_valid: self.engine.compile_all(ops, use_multiprocessing) args = ( diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index f75822dd0f9e..bb591985cab5 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -222,7 +222,7 @@ def select_op( profile_all_alignments=profile_all_alignments, ) - if find_first_valid: + if not find_first_valid: self.engine.compile_all(ops, use_multiprocessing) for op in ops: