diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 58e7a115444e..615b9003adc9 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -83,6 +83,87 @@ def visit_call(self, call): self.signature["ret_dtype"] = op.ret_type.dtype +def select_gemm_kernel( + cutlass_profiler, MM, KK, NN, out_dtype, batched, profile_all, use_multiprocessing +): + """Run CUTLASS profiler to select the best kernel, or return the default one for dynamic + workloads.""" + if any(isinstance(s, tvm.tir.Any) for s in [MM, KK, NN]): + out = cutlass_profiler.get_default(out_dtype, batched=batched) + logger.info("Picked the default kernel %s", out["name"]) + else: + out = cutlass_profiler.profile( + MM, + NN, + KK, + out_dtype, + batched=batched, + profile_all=profile_all, + use_multiprocessing=use_multiprocessing, + ) + if profile_all: + logger.info("The best kernel is %s", out["name"]) + else: + logger.info("Picked the first kernel found %s", out["name"]) + return out + + +def handle_batch_matmul( + cutlass_profiler, op_type, arg0_shape, arg1_shape, out_dtype, profile_all, use_multiprocessing +): + """Profile and select a kernel for batch_matmul op workload.""" + MM = arg0_shape[1] + KK = arg0_shape[2] + NN = arg1_shape[1] + + out = select_gemm_kernel( + cutlass_profiler, MM, KK, NN, out_dtype, True, profile_all, use_multiprocessing + ) + + if op_type == "cutlass.batch_matmul": + cutlass_op_def = out["opdef"] + else: + raise ValueError("%s pattern is not implemented." % op_type) + + return { + "batch": arg0_shape[0], + "batch_stride_A": arg0_shape[1] * arg0_shape[2], + "batch_stride_B": arg1_shape[1] * arg1_shape[2], + "batch_stride_C": arg0_shape[1] * arg1_shape[1], + "cutlass_op_def": cutlass_op_def, + "cutlass_op_name": out["name"], + } + + +def handle_dense( + cutlass_profiler, op_type, arg0_shape, arg1_shape, out_dtype, profile_all, use_multiprocessing +): + """Profile and select a kernel for dense op workload.""" + MM = arg0_shape[0] + KK = arg0_shape[1] + NN = arg1_shape[0] + + out = select_gemm_kernel( + cutlass_profiler, MM, KK, NN, out_dtype, False, profile_all, use_multiprocessing + ) + + if op_type == "cutlass.dense": + cutlass_op_def = out["opdef"] + elif op_type == "cutlass.dense_bias": + cutlass_op_def = out["opdef_bias"] + elif op_type == "cutlass.dense_bias_relu": + cutlass_op_def = out["opdef_bias_relu"] + elif "cutlass.dense_bias_gelu" in op_type: + cutlass_op_def = out["opdef_bias_gelu"] + else: + raise ValueError("%s pattern is not implemented." % op_type) + + return { + "cutlass_op_def": cutlass_op_def, + "cutlass_op_name": out["name"], + } + + def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, tmp_dir="./tmp"): """Given a module partitioned for CUTLASS offloading, profile each workload to select which kernels to emit. @@ -123,41 +204,41 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t if "cutlass" in fun_name: num_cutlass_partition += 1 annotator.visit(func) - # call cutlass profiler to find best settings, update attr - new_attrs = {} + out_dtype = annotator.signature["ret_dtype"] + op_type = annotator.signature["op_type"] + + new_attrs = {"op_type": op_type} new_attrs.update(annotator.signature) - for key in func.attrs.keys(): - new_attrs[key] = func.attrs[key] - # call profiler + new_attrs.update(func.attrs) arg0_shape = new_attrs["arg0_shape"] arg1_shape = new_attrs["arg1_shape"] - MM = arg0_shape[0] - KK = arg0_shape[1] - NN = arg1_shape[0] - out_dtype = annotator.signature["ret_dtype"] - if any(isinstance(s, tvm.tir.Any) for s in [MM, KK, NN]): - out = cutlass_profiler.get_default(out_dtype) - logger.info("Picked the default kernel %s", out["name"]) - else: - out = cutlass_profiler.profile( - MM, NN, KK, out_dtype, profile_all, use_multiprocessing + + if "batch_matmul" in op_type: + new_attrs.update( + handle_batch_matmul( + cutlass_profiler, + op_type, + arg0_shape, + arg1_shape, + out_dtype, + profile_all, + use_multiprocessing, + ) + ) + elif "dense" in op_type: + new_attrs.update( + handle_dense( + cutlass_profiler, + op_type, + arg0_shape, + arg1_shape, + out_dtype, + profile_all, + use_multiprocessing, + ) ) - if profile_all: - logger.info("The best kernel is %s", out["name"]) - else: - logger.info("Picked the first kernel found %s", out["name"]) - - if new_attrs["op_type"] == "cutlass.dense": - new_attrs["cutlass_op_def"] = out["opdef"] - elif new_attrs["op_type"] == "cutlass.dense_bias": - new_attrs["cutlass_op_def"] = out["opdef_bias"] - elif new_attrs["op_type"] == "cutlass.dense_bias_relu": - new_attrs["cutlass_op_def"] = out["opdef_bias_relu"] - elif "cutlass.dense_bias_gelu" in new_attrs["op_type"]: - new_attrs["cutlass_op_def"] = out["opdef_bias_gelu"] else: - raise ValueError("%s pattern is not implemented." % new_attrs["op_type"]) - new_attrs["cutlass_op_name"] = out["name"] + raise ValueError("%s unsupported composite" % op_type) if new_attrs["cutlass_op_name"].find("_tn_align") > 0: new_attrs["lda"] = "K" diff --git a/python/tvm/contrib/cutlass/gemm_operation.py b/python/tvm/contrib/cutlass/gemm_operation.py index e53b3ee7b93a..4673b4bdea65 100644 --- a/python/tvm/contrib/cutlass/gemm_operation.py +++ b/python/tvm/contrib/cutlass/gemm_operation.py @@ -174,7 +174,7 @@ def __init__(self): >""" self.gemm_template = """ // Gemm operator ${operation_name} - using Operation_${operation_name} = cutlass::gemm::device::Gemm< + using Operation_${operation_name} = cutlass::gemm::device::${kernel_name}< ${element_a}, ${layout_a}, ${element_b}, ${layout_b}, ${element_c}, ${layout_c}, @@ -189,13 +189,12 @@ def __init__(self): ${stages}, ${align_a}, ${align_b}, - false, + ${split_k_serial} ${math_operation} - ${residual} >; """ - def emit(self, operation, no_beta_scaling=False): + def emit(self, operation, no_beta_scaling=False, batched=False): """Instantiate a GEMM kernel from given `operation`.""" warp_shape = [ operation.tile_description.threadblock_shape[idx] @@ -206,8 +205,6 @@ def emit(self, operation, no_beta_scaling=False): min(operation.C.alignment * DataTypeSize[operation.C.element], 128) // DataTypeSize[operation.C.element] ) - residual = "" - complex_transform_tag = "cutlass::ComplexTransform::kNone" values = { "operation_name": operation.procedural_name(), "element_a": DataTypeTag[operation.A.element], @@ -243,14 +240,14 @@ def emit(self, operation, no_beta_scaling=False): "stages": str(operation.tile_description.stages), "align_a": str(operation.A.alignment), "align_b": str(operation.B.alignment), - "transform_a": complex_transform_tag, - "transform_b": complex_transform_tag, "math_operation": MathOperationTag[ operation.tile_description.math_instruction.math_operation ], - "residual": residual, } + values["kernel_name"] = "GemmBatched" if batched else "Gemm" + values["split_k_serial"] = "" if batched else "false," + gemm_template = substitute_template( self.gemm_template, { diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index a43c6d414e38..1ed4bfe5fc4c 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -47,6 +47,7 @@ def create_gemm_operator( alignment_constraints, epilogue_functor=EpilogueFunctor.LinearCombination, swizzling_functor=SwizzlingFunctor.Identity8, + batched=False, ): """Exhaustively instantiate all kernels from a given configuration.""" ret = [] @@ -55,6 +56,9 @@ def create_gemm_operator( element_a, element_b, element_c, element_epilogue = data_type + if batched: + swizzling_functor = SwizzlingFunctor.Batched + for layout in layouts: for tile_description in tile_descriptions: for alignment in alignment_constraints: @@ -109,15 +113,17 @@ def create_gemm_operator( kernel_emitter = EmitGemmInstance() op_entry["op"] = op op_entry["name"] = op.procedural_name() - op_entry["opdef"] = kernel_emitter.emit(op) - op_entry["opdef_bias"] = kernel_emitter.emit(op_bias, no_beta_scaling=True) + op_entry["opdef"] = kernel_emitter.emit(op, batched=batched) + op_entry["opdef_bias"] = kernel_emitter.emit( + op_bias, no_beta_scaling=True, batched=batched + ) op_entry["opdef_bias_relu"] = kernel_emitter.emit( - op_bias_relu, no_beta_scaling=True + op_bias_relu, no_beta_scaling=True, batched=batched ) - op_entry["opdef_bias_gelu"] = kernel_emitter.emit(op_bias_gelu) + op_entry["opdef_bias_gelu"] = kernel_emitter.emit(op_bias_gelu, batched=batched) op_entry["src"] = profiler_emitter.emit( op.procedural_name(), - op_entry["opdef"], + kernel_emitter.emit(op, batched=False), DataTypeTag[element_a], DataTypeTag[element_b], DataTypeTag[element_c], @@ -128,7 +134,9 @@ def create_gemm_operator( return ret -def generate_tensor_op_common(math_instructions, alignment_constraints, get_tile_descriptions): +def generate_tensor_op_common( + math_instructions, alignment_constraints, get_tile_descriptions, batched=False +): """Common kernel generator to be used by archtecture specific generators.""" ops = [] layouts = [ @@ -143,14 +151,16 @@ def generate_tensor_op_common(math_instructions, alignment_constraints, get_tile math_inst.element_accumulator, ] - out = create_gemm_operator(layouts, tile_descriptions, data_type, alignment_constraints) + out = create_gemm_operator( + layouts, tile_descriptions, data_type, alignment_constraints, batched=batched + ) ops.extend(out) return ops -def generate_sm75_tensor_op_1688(out_dtype): +def generate_sm75_tensor_op_1688(out_dtype, batched=False): """Generate GEMM kernels for Turing.""" assert out_dtype in ["float32", "float16"] math_instructions = { @@ -192,11 +202,11 @@ def get_tile_descriptions(math_inst): ] return generate_tensor_op_common( - math_instructions, alignment_constraints, get_tile_descriptions + math_instructions, alignment_constraints, get_tile_descriptions, batched ) -def generate_sm80_tensor_op_16816(out_dtype): +def generate_sm80_tensor_op_16816(out_dtype, batched=False): """Generate GEMM kernels for Ampere.""" assert out_dtype in ["float32", "float16"] math_instructions = { @@ -250,7 +260,7 @@ def get_tile_descriptions(math_inst): ] return generate_tensor_op_common( - math_instructions, alignment_constraints, get_tile_descriptions + math_instructions, alignment_constraints, get_tile_descriptions, batched ) @@ -350,17 +360,19 @@ def check_align(self, op_name, M): return False return True - def get_default(self, out_dtype): + def get_default(self, out_dtype, batched=False): """Return the default kernel for the requested architecture. For now, the default kernel was picked arbitrary. """ - ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype) + ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, batched) default_kernel_name = DEFAULT_KERNELS[self.sm][out_dtype] filtered = list(filter(lambda op: op["name"] == default_kernel_name, ops)) assert len(filtered) == 1 return filtered[0] - def profile(self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=False): + def profile( + self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=False, batched=False + ): """Profile and select the best kernel from candidate kernels. If profile_all is False, return immediately after the first applicable kernel is found. If use_multiprocessing is True, compile all profiler executables in parallel. @@ -368,7 +380,7 @@ def profile(self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=Fals if (M, N, K) in self.cache: return self.cache[(M, N, K)] - ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype) + ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, batched) ops = list(filter(lambda op: self.check_align(op["name"], M), ops)) for op in ops: diff --git a/python/tvm/contrib/cutlass/library.py b/python/tvm/contrib/cutlass/library.py index 7d544293901a..a3b90ff83d1f 100644 --- a/python/tvm/contrib/cutlass/library.py +++ b/python/tvm/contrib/cutlass/library.py @@ -160,6 +160,7 @@ class SwizzlingFunctor(enum.Enum): Identity2 = enum_auto() Identity4 = enum_auto() Identity8 = enum_auto() + Batched = enum_auto() SwizzlingFunctorTag = { @@ -167,6 +168,7 @@ class SwizzlingFunctor(enum.Enum): SwizzlingFunctor.Identity2: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>", SwizzlingFunctor.Identity4: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>", SwizzlingFunctor.Identity8: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>", + SwizzlingFunctor.Batched: "cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle", } diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py index 631089ce766d..8ed371844a1c 100644 --- a/python/tvm/relay/op/contrib/cutlass.py +++ b/python/tvm/relay/op/contrib/cutlass.py @@ -20,13 +20,13 @@ def make_gelu_pattern(bias_out, out_dtype="float16"): - mul = is_op("multiply")(bias_out, is_constant()) + mul = is_op("multiply")(bias_out, is_constant() | wildcard()) if out_dtype == "float16": erf = is_op("cast")(is_op("erf")(is_op("cast")(mul))) else: erf = is_op("erf")(mul) - mul_half = is_op("multiply")(erf, is_constant()) - add = is_op("add")(mul_half, is_constant()) + mul_half = is_op("multiply")(erf, is_constant() | wildcard()) + add = is_op("add")(mul_half, is_constant() | wildcard()) return is_op("multiply")(add, bias_out) @@ -51,6 +51,10 @@ def make_gemm_pattern(with_bias=True, with_act=None, out_dtype="float16"): return make_gelu_pattern(gemm_out, out_dtype) +def make_batch_matmul_pattern(): + return is_op("nn.batch_matmul")(wildcard(), wildcard()) + + def partition_for_cutlass(mod): """Partition the input module into CUTLASS-supported subgraphs.""" dense_pat = ("cutlass.dense", make_gemm_pattern(False, None)) @@ -67,6 +71,7 @@ def partition_for_cutlass(mod): dense_bias_relu_pat, dense_bias_pat, dense_pat, + ("cutlass.batch_matmul", make_batch_matmul_pattern()), ] mod = transform.MergeComposite(cutlass_patterns)(mod) mod = transform.AnnotateTarget(["cutlass"])(mod) diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc index c1217a08b712..f154f8641a64 100644 --- a/src/relay/backend/contrib/cutlass/codegen.cc +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -54,19 +54,21 @@ std::string GetDimAsStr(ObjectRef dim) { return kAnyDim; } -Str2StrMap DenseArgs(const Map& attrs) { +inline void CutlassPrint(std::ostringstream& os, const std::string& stmt, int indent = 2) { + for (int i = 0; i < indent; ++i) { + os << " "; + } + os << stmt; +} + +Str2StrMap GemmArgsCommon(const Map& attrs) { Str2StrMap args; auto arg0_dtype = std::string(attrs["arg0_dtype"].as()->data); auto arg1_dtype = std::string(attrs["arg1_dtype"].as()->data); auto ret_dtype = std::string(attrs["ret_dtype"].as()->data); - auto arg0_shape = attrs["arg0_shape"].as(); - auto arg1_shape = attrs["arg1_shape"].as(); args["ElementInputA"] = dtype_map.at(arg0_dtype); args["ElementInputB"] = dtype_map.at(arg1_dtype); args["ElementOutput"] = dtype_map.at(ret_dtype); - args["M"] = GetDimAsStr(arg0_shape->at(0)); - args["K"] = GetDimAsStr(arg0_shape->at(1)); - args["N"] = GetDimAsStr(arg1_shape->at(0)); args["op_def"] = std::string(attrs["cutlass_op_def"].as()->data); args["op_name"] = std::string(attrs["cutlass_op_name"].as()->data); args["op_type"] = std::string(attrs["op_type"].as()->data); @@ -76,23 +78,33 @@ Str2StrMap DenseArgs(const Map& attrs) { return args; } -inline void CutlassPrint(std::ostringstream& os, const std::string& stmt, int indent = 2) { - for (int i = 0; i < indent; ++i) { - os << " "; - } - os << stmt; +Str2StrMap DenseArgs(const Map& attrs) { + Str2StrMap args = GemmArgsCommon(attrs); + auto arg0_shape = attrs["arg0_shape"].as(); + auto arg1_shape = attrs["arg1_shape"].as(); + args["M"] = GetDimAsStr(arg0_shape->at(0)); + args["K"] = GetDimAsStr(arg0_shape->at(1)); + args["N"] = GetDimAsStr(arg1_shape->at(0)); + return args; } -std::string DenseOp(std::string id, const Str2StrMap& attrs, - const std::vector& func_args) { - bool has_bias = false; - bool is_gelu = - attrs.at("op_type").find("cutlass.dense_bias_gelu") != std::string::npos; // fp32 or fp16 - if (attrs.at("op_type") == "cutlass.dense_bias" || - attrs.at("op_type") == "cutlass.dense_bias_relu" || is_gelu) { - has_bias = true; - } - std::ostringstream gemm_decl; +Str2StrMap BatchMatmulArgs(const Map& attrs) { + Str2StrMap args = GemmArgsCommon(attrs); + args["batch"] = GetDimAsStr(attrs["batch"]); + args["batch_stride_A"] = GetDimAsStr(attrs["batch_stride_A"]); + args["batch_stride_B"] = GetDimAsStr(attrs["batch_stride_B"]); + args["batch_stride_C"] = GetDimAsStr(attrs["batch_stride_C"]); + auto arg0_shape = attrs["arg0_shape"].as(); + auto arg1_shape = attrs["arg1_shape"].as(); + args["M"] = GetDimAsStr(arg0_shape->at(1)); + args["K"] = GetDimAsStr(arg0_shape->at(2)); + args["N"] = GetDimAsStr(arg1_shape->at(1)); + return args; +} + +void AppendPrologue(std::ostringstream& gemm_decl, const Str2StrMap& attrs, + const std::vector& func_args, const std::string& kernel, + bool has_bias, bool is_gelu, int m_axis_idx, int n_axis_idx, int k_axis_idx) { CutlassPrint(gemm_decl, "using ElementInputA = " + attrs.at("ElementInputA") + ";\n"); CutlassPrint(gemm_decl, "using ElementInputB = " + attrs.at("ElementInputB") + ";\n"); CutlassPrint(gemm_decl, "using ElementOutput = " + attrs.at("ElementOutput") + ";\n"); @@ -107,11 +119,10 @@ std::string DenseOp(std::string id, const Str2StrMap& attrs, return attrs.at(axis); } }; - CutlassPrint(gemm_decl, "int M = " + get_dim("M", 0, 0) + ";\n"); - CutlassPrint(gemm_decl, "int N = " + get_dim("N", 1, 0) + ";\n"); - CutlassPrint(gemm_decl, "int K = " + get_dim("K", 0, 1) + ";\n"); + CutlassPrint(gemm_decl, "int M = " + get_dim("M", 0, m_axis_idx) + ";\n"); + CutlassPrint(gemm_decl, "int N = " + get_dim("N", 1, n_axis_idx) + ";\n"); + CutlassPrint(gemm_decl, "int K = " + get_dim("K", 0, k_axis_idx) + ";\n"); CutlassPrint(gemm_decl, "cutlass::gemm::GemmCoord problem_size(M, N, K);\n"); - // Initialize alpha for dot product computation CutlassPrint(gemm_decl, "ElementComputeEpilogue alpha = ElementComputeEpilogue(1);\n"); if (is_gelu) { // GeLU epilogue does not compile with NoBetaScaling, so we explicitly specify the scale. @@ -120,11 +131,6 @@ std::string DenseOp(std::string id, const Str2StrMap& attrs, CutlassPrint(gemm_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(0);\n"); } - // Split K dimension into 1 partitions - CutlassPrint(gemm_decl, "int split_k_slices = 1;\n"); - - // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch - // instantiated CUTLASS kernel ICHECK(func_args.size() >= 2); CutlassPrint(gemm_decl, "void* ptr_a = (void*)(" + func_args[0] + "->data);\n"); CutlassPrint(gemm_decl, "void* ptr_b = (void*)(" + func_args[1] + "->data);\n"); @@ -132,33 +138,24 @@ std::string DenseOp(std::string id, const Str2StrMap& attrs, ICHECK(func_args.size() >= 3); CutlassPrint(gemm_decl, "void* ptr_c_bias = (void*)(" + func_args[2] + "->data);\n"); } + CutlassPrint(gemm_decl, "void* ptr_out = (void*)(out0);\n"); - CutlassPrint(gemm_decl, "typename Gemm::Arguments arguments{\n"); + CutlassPrint(gemm_decl, "using " + kernel + " = Operation_" + attrs.at("op_name") + ";\n"); + CutlassPrint(gemm_decl, "typename " + kernel + "::Arguments arguments{\n"); CutlassPrint(gemm_decl, " problem_size,\n"); - CutlassPrint(gemm_decl, " {static_cast(ptr_a), " + attrs.at("lda") + "},\n"); - CutlassPrint(gemm_decl, " {static_cast(ptr_b), " + attrs.at("ldb") + "},\n"); - if (has_bias) { - CutlassPrint(gemm_decl, " {static_cast(ptr_c_bias), 0},\n"); - } else { - CutlassPrint(gemm_decl, " {static_cast(ptr_out), " + attrs.at("ldc") + "},\n"); - } - CutlassPrint(gemm_decl, " {static_cast(ptr_out), " + attrs.at("ldc") + "},\n"); - if (has_bias && !is_gelu) { - CutlassPrint(gemm_decl, " {alpha},\n"); - } else { - // For GeLU, we explicitly specify the scale. - CutlassPrint(gemm_decl, " {alpha, beta},\n"); - } - CutlassPrint(gemm_decl, " split_k_slices};\n"); +} +void AppendGemmExecute(std::ostringstream& gemm_decl, const std::string& kernel) { // Using the arguments, query for extra workspace required for matrix multiplication computation - CutlassPrint(gemm_decl, "size_t workspace_size = Gemm::get_workspace_size(arguments);\n"); + CutlassPrint(gemm_decl, + "size_t workspace_size = " + kernel + "::get_workspace_size(arguments);\n"); // Allocate workspace memory CutlassPrint(gemm_decl, "cutlass::device_memory::allocation workspace(workspace_size);\n"); // Instantiate CUTLASS kernel depending on template - CutlassPrint(gemm_decl, "Gemm gemm_op;\n"); + CutlassPrint(gemm_decl, kernel + " gemm_op;\n"); + // Check the problem size is supported or not CutlassPrint(gemm_decl, "cutlass::Status status = gemm_op.can_implement(arguments);\n"); CutlassPrint(gemm_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); @@ -168,6 +165,72 @@ std::string DenseOp(std::string id, const Str2StrMap& attrs, // Launch initialized CUTLASS kernel CutlassPrint(gemm_decl, "status = gemm_op();\n"); CutlassPrint(gemm_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); +} + +std::string DenseOp(std::string id, const Str2StrMap& attrs, + const std::vector& func_args) { + bool has_bias = false; + bool is_gelu = + attrs.at("op_type").find("cutlass.dense_bias_gelu") != std::string::npos; // fp32 or fp16 + if (attrs.at("op_type") == "cutlass.dense_bias" || + attrs.at("op_type") == "cutlass.dense_bias_relu" || is_gelu) { + has_bias = true; + } + std::ostringstream gemm_decl; + AppendPrologue(gemm_decl, attrs, func_args, "Gemm", has_bias, is_gelu, 0, 0, 1); + + CutlassPrint(gemm_decl, " {static_cast(ptr_a), " + attrs.at("lda") + "},\n"); + CutlassPrint(gemm_decl, " {static_cast(ptr_b), " + attrs.at("ldb") + "},\n"); + if (has_bias) { + CutlassPrint(gemm_decl, " {static_cast(ptr_c_bias), 0},\n"); + } else { + CutlassPrint(gemm_decl, " {static_cast(ptr_out), " + attrs.at("ldc") + "},\n"); + } + CutlassPrint(gemm_decl, " {static_cast(ptr_out), " + attrs.at("ldc") + "},\n"); + if (has_bias && !is_gelu) { + CutlassPrint(gemm_decl, " {alpha},\n"); + } else { + // For GeLU, we explicitly specify the scale. + CutlassPrint(gemm_decl, " {alpha, beta},\n"); + } + CutlassPrint(gemm_decl, " 1};\n"); // split_k_slices + + AppendGemmExecute(gemm_decl, "Gemm"); + return gemm_decl.str(); +} + +std::string BatchMatmulOp(std::string id, const Str2StrMap& attrs, + const std::vector& func_args) { + std::ostringstream gemm_decl; + AppendPrologue(gemm_decl, attrs, func_args, "BatchedGemm", false, false, 1, 1, 2); + + auto get_batch_stride = [&attrs, &func_args](const std::string& name, int arg0_idx, int arg1_idx, + int arg0_axis_idx, int arg1_axis_idx) { + if (attrs.at(name) == kAnyDim) { + return func_args[arg0_idx] + "->shape[" + std::to_string(arg0_axis_idx) + "] * " + + func_args[arg1_idx] + "->shape[" + std::to_string(arg1_axis_idx) + "]"; + } else { + return attrs.at(name); + } + }; + + CutlassPrint(gemm_decl, " {static_cast(ptr_a), " + attrs.at("lda") + "},\n"); + CutlassPrint(gemm_decl, get_batch_stride("batch_stride_A", 0, 0, 1, 2) + ",\n"); + CutlassPrint(gemm_decl, " {static_cast(ptr_b), " + attrs.at("ldb") + "},\n"); + CutlassPrint(gemm_decl, get_batch_stride("batch_stride_B", 1, 1, 1, 2) + ",\n"); + CutlassPrint(gemm_decl, " {static_cast(ptr_out), " + attrs.at("ldc") + "},\n"); + CutlassPrint(gemm_decl, get_batch_stride("batch_stride_C", 0, 1, 1, 1) + ",\n"); + CutlassPrint(gemm_decl, " {static_cast(ptr_out), " + attrs.at("ldc") + "},\n"); + CutlassPrint(gemm_decl, get_batch_stride("batch_stride_C", 0, 1, 1, 1) + ",\n"); + CutlassPrint(gemm_decl, " {alpha, beta},\n"); + + if (attrs.at("batch") == kAnyDim) { + CutlassPrint(gemm_decl, func_args[0] + "->shape[0]" + "};\n"); + } else { + CutlassPrint(gemm_decl, attrs.at("batch") + "};\n"); + } + + AppendGemmExecute(gemm_decl, "BatchedGemm"); return gemm_decl.str(); } @@ -279,6 +342,11 @@ class CodegenCutlass : public MemoizedExprTranslator>, publi {"nn.dense", add_or_bias_add, "multiply", "erf", "multiply", "add", "multiply"}); return GenerateBody(dense_call, "cutlass_dense_bias_gelu", GetArgumentNames(caller), DenseArgs(std::ref(attrs_))); + } else if (pattern_name == "cutlass.batch_matmul") { + const auto* batch_matmul_call = + GetRootCall(callee->body.as(), 0, {"nn.batch_matmul"}); + return GenerateBody(batch_matmul_call, "cutlass_batch_matmul", GetArgumentNames(caller), + BatchMatmulArgs(std::ref(attrs_))); } LOG(FATAL) << "Unknown composite function: " << pattern_name; return {}; @@ -322,6 +390,8 @@ class CodegenCutlass : public MemoizedExprTranslator>, publi if (func_name == "cutlass_dense" || func_name == "cutlass_dense_bias" || func_name == "cutlass_dense_bias_relu" || func_name == "cutlass_dense_bias_gelu") { ret.decl = DenseOp(ext_func_id_, attribute_args, func_args); + } else if (func_name == "cutlass_batch_matmul") { + ret.decl = BatchMatmulOp(ext_func_id_, attribute_args, func_args); } return ret; } @@ -374,6 +444,7 @@ class CutlassModuleCodegen : public CSourceModuleCodegenBase { code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; + code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index 0927c41981bd..5a1ff8b2c17d 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -56,14 +56,16 @@ def get_ref_vm(mod, params, target="cuda"): return VirtualMachine(vm_exec, dev), dev -def get_output(rt_mod, x): - rt_mod.set_input("data", x) +def get_output(rt_mod, names, inputs): + for name, inp in zip(names, inputs): + rt_mod.set_input(name, inp) rt_mod.run() return rt_mod.get_output(0).asnumpy() -def get_output_vm(vm, x): - return vm.invoke("main", data=x).numpy() +def get_output_vm(vm, names, inputs): + params = dict(zip(names, inputs)) + return vm.invoke("main", **params).numpy() def get_dense_with_shape(data_shape, weight_shape, out_dtype="float16"): @@ -98,6 +100,16 @@ def get_dense_bias_gelu(M, N, K, out_dtype="float16"): return add * bias_add +def get_batch_matmul_with_shape(x_shape, y_shape, out_dtype="float16"): + x = relay.var("x", shape=x_shape, dtype="float16") + y = relay.var("y", shape=y_shape, dtype="float16") + return relay.nn.batch_matmul(x, y, out_dtype=out_dtype) + + +def get_batch_matmul(batch, M, N, K, out_dtype="float16"): + return get_batch_matmul_with_shape((batch, M, K), (batch, N, K), out_dtype="float16") + + def profile_and_build(mod, params, sm, tmp_dir="./tmp", lib_path="compile.so"): mod = partition_for_cutlass(mod) mod, num_cutlass_partition = tune_cutlass_kernels( @@ -123,7 +135,9 @@ def profile_and_build_vm( return VirtualMachine(vm_exec, dev), dev, num_cutlass_partition -def verify(func, M, N, K, ref_target="cuda", sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False): +def verify_dense( + func, M, N, K, ref_target="cuda", sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False +): if not has_cutlass(): return mod = tvm.IRModule.from_expr(func) @@ -151,14 +165,14 @@ def verify(func, M, N, K, ref_target="cuda", sm=80, atol=1e-5, rtol=1e-5, run_be rt_mod_ref, dev = get_ref_vm(mod, params, target=ref_target) x = tvm.nd.array(np_data, device=dev) - out = get_output_vm(rt_mod, x) - ref_out = get_output_vm(rt_mod_ref, x) + out = get_output_vm(rt_mod, ["data"], [x]) + ref_out = get_output_vm(rt_mod_ref, ["data"], [x]) else: rt_mod_ref, dev = get_ref_rt_mod(mod, params, target=ref_target) rt_mod, dev, num_partition = profile_and_build(mod, params, sm) x = tvm.nd.array(np_data, device=dev) - out = get_output(rt_mod, x) - ref_out = get_output(rt_mod_ref, x) + out = get_output(rt_mod, ["data"], [x]) + ref_out = get_output(rt_mod_ref, ["data"], [x]) assert num_partition > 0 np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol) @@ -168,29 +182,65 @@ def verify(func, M, N, K, ref_target="cuda", sm=80, atol=1e-5, rtol=1e-5, run_be print("TVM with target %s:" % ref_target, rt_mod_ref.benchmark(dev, number=1, repeat=600)) +def verify_batch_matmul( + func, batch, M, N, K, ref_target="cuda", sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False +): + if not has_cutlass(): + return + mod = tvm.IRModule.from_expr(func) + typ = relay.transform.InferType()(mod)["main"].body.checked_type + use_vm = any(isinstance(s, tvm.tir.Any) for s in typ.shape) + x_np = np.random.uniform(-1, 1, (batch, M, K)).astype("float16") + y_np = np.random.uniform(-1, 1, (batch, N, K)).astype("float16") + + if use_vm: + rt_mod, dev, num_partition = profile_and_build_vm(mod, {}, sm) + rt_mod_ref, dev = get_ref_vm(mod, {}, target=ref_target) + assert num_partition > 0 + x = tvm.nd.array(x_np, device=dev) + y = tvm.nd.array(y_np, device=dev) + out = get_output_vm(rt_mod, ["x", "y"], [x, y]) + ref_out = get_output_vm(rt_mod_ref, ["x", "y"], [x, y]) + else: + rt_mod, dev, num_partition = profile_and_build(mod, {}, sm) + rt_mod_ref, dev = get_ref_rt_mod(mod, {}) + assert num_partition > 0 + + x = tvm.nd.array(x_np, device=dev) + y = tvm.nd.array(y_np, device=dev) + out = get_output(rt_mod, ["x", "y"], [x, y]) + ref_out = get_output(rt_mod_ref, ["x", "y"], [x, y]) + + np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol) + + if True: + print("CUTLASS:", rt_mod.benchmark(dev, number=1, repeat=600)) + print("TVM Tensorcore (no tuning):", rt_mod_ref.benchmark(dev, number=1, repeat=600)) + + M = 1820 N = 768 K = 768 def test_dense(): - verify(get_dense(M, N, K), M, N, K) - verify(get_dense(M, N, K, out_dtype="float32"), M, N, K) + verify_dense(get_dense(M, N, K), M, N, K) + verify_dense(get_dense(M, N, K, out_dtype="float32"), M, N, K) def test_dense_bias(): - verify(get_dense_bias(M, N, K), M, N, K) - verify(get_dense_bias(M, N, K, out_dtype="float32"), M, N, K) + verify_dense(get_dense_bias(M, N, K), M, N, K) + verify_dense(get_dense_bias(M, N, K, out_dtype="float32"), M, N, K) def test_dense_bias_relu(): - verify(get_dense_bias_relu(M, N, K), M, N, K) - verify(get_dense_bias_relu(M, N, K, out_dtype="float32"), M, N, K) + verify_dense(get_dense_bias_relu(M, N, K), M, N, K) + verify_dense(get_dense_bias_relu(M, N, K, out_dtype="float32"), M, N, K) def test_dense_bias_gelu(): - verify(get_dense_bias_gelu(M, N, K), M, N, K, atol=1e-3, rtol=1e-3) - verify(get_dense_bias_gelu(M, N, K, out_dtype="float32"), M, N, K, atol=1e-3, rtol=1e-3) + verify_dense(get_dense_bias_gelu(M, N, K), M, N, K, atol=1e-3, rtol=1e-3) + verify_dense(get_dense_bias_gelu(M, N, K, out_dtype="float32"), M, N, K, atol=1e-3, rtol=1e-3) def test_dense_dynamic(): @@ -200,7 +250,7 @@ def test_dense_dynamic(): if has_cublas(): # TVM native fp16 dense (without tensorcore), using fp16 accum, seems to have accuracy issues # Use cublas as a reference - verify( + verify_dense( get_dense_with_shape(data_shape, weight_shape), M, N, @@ -208,7 +258,7 @@ def test_dense_dynamic(): ref_target="cuda -libs=cublas", ) - verify( + verify_dense( get_dense_with_shape(data_shape, weight_shape, out_dtype="float32"), M, N, @@ -218,5 +268,26 @@ def test_dense_dynamic(): ) +def test_batch_matmul(): + batch = 8 + verify_batch_matmul(get_batch_matmul(batch, M, N, K), batch, M, N, K) + verify_batch_matmul(get_batch_matmul(batch, M, N, K, out_dtype="float32"), batch, M, N, K) + + if has_cublas(): + # Test dynamic shape batch_matmul + # AutoTVM does not seem to support it + x_shape = (relay.Any(), relay.Any(), K) + y_shape = (relay.Any(), relay.Any(), K) + + verify_batch_matmul( + get_batch_matmul_with_shape(x_shape, y_shape), + batch, + M, + N, + K, + ref_target="cuda -libs=cublas", + ) + + if __name__ == "__main__": pytest.main([__file__])