Skip to content

Commit

Permalink
[CUTLASS] Support batch_matmul (apache#9439)
Browse files Browse the repository at this point in the history
* Import batched gemm change

commit cfacfa2
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon Nov 1 15:57:49 2021 +0900

    change is_constant pattern to wildcard in gelu pattern

commit 84da943
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Mon Nov 1 05:41:11 2021 +0900

    fixed batch stride C

commit 66e5779
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun Oct 31 20:47:16 2021 +0900

    refactoring codegen

commit 561daea
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun Oct 31 20:05:20 2021 +0900

    generated kernel compiled and result match

commit a5740bc
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun Oct 31 19:36:53 2021 +0900

    partitioning looks good

commit 59112fd
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun Oct 31 19:01:47 2021 +0900

    [WIP] cutlass batch matmul support

* fixed test

* refactoring

* gelu test fixed

* more refactor

* batch_matmul fp32 accum working

* dynamic batch matmul working

* black

* remove doc TODO
  • Loading branch information
masahi authored and ylc committed Jan 13, 2022
1 parent 099ae14 commit fc04ea8
Show file tree
Hide file tree
Showing 7 changed files with 363 additions and 124 deletions.
141 changes: 111 additions & 30 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,87 @@ def visit_call(self, call):
self.signature["ret_dtype"] = op.ret_type.dtype


def select_gemm_kernel(
cutlass_profiler, MM, KK, NN, out_dtype, batched, profile_all, use_multiprocessing
):
"""Run CUTLASS profiler to select the best kernel, or return the default one for dynamic
workloads."""
if any(isinstance(s, tvm.tir.Any) for s in [MM, KK, NN]):
out = cutlass_profiler.get_default(out_dtype, batched=batched)
logger.info("Picked the default kernel %s", out["name"])
else:
out = cutlass_profiler.profile(
MM,
NN,
KK,
out_dtype,
batched=batched,
profile_all=profile_all,
use_multiprocessing=use_multiprocessing,
)
if profile_all:
logger.info("The best kernel is %s", out["name"])
else:
logger.info("Picked the first kernel found %s", out["name"])
return out


def handle_batch_matmul(
cutlass_profiler, op_type, arg0_shape, arg1_shape, out_dtype, profile_all, use_multiprocessing
):
"""Profile and select a kernel for batch_matmul op workload."""
MM = arg0_shape[1]
KK = arg0_shape[2]
NN = arg1_shape[1]

out = select_gemm_kernel(
cutlass_profiler, MM, KK, NN, out_dtype, True, profile_all, use_multiprocessing
)

if op_type == "cutlass.batch_matmul":
cutlass_op_def = out["opdef"]
else:
raise ValueError("%s pattern is not implemented." % op_type)

return {
"batch": arg0_shape[0],
"batch_stride_A": arg0_shape[1] * arg0_shape[2],
"batch_stride_B": arg1_shape[1] * arg1_shape[2],
"batch_stride_C": arg0_shape[1] * arg1_shape[1],
"cutlass_op_def": cutlass_op_def,
"cutlass_op_name": out["name"],
}


def handle_dense(
cutlass_profiler, op_type, arg0_shape, arg1_shape, out_dtype, profile_all, use_multiprocessing
):
"""Profile and select a kernel for dense op workload."""
MM = arg0_shape[0]
KK = arg0_shape[1]
NN = arg1_shape[0]

out = select_gemm_kernel(
cutlass_profiler, MM, KK, NN, out_dtype, False, profile_all, use_multiprocessing
)

if op_type == "cutlass.dense":
cutlass_op_def = out["opdef"]
elif op_type == "cutlass.dense_bias":
cutlass_op_def = out["opdef_bias"]
elif op_type == "cutlass.dense_bias_relu":
cutlass_op_def = out["opdef_bias_relu"]
elif "cutlass.dense_bias_gelu" in op_type:
cutlass_op_def = out["opdef_bias_gelu"]
else:
raise ValueError("%s pattern is not implemented." % op_type)

return {
"cutlass_op_def": cutlass_op_def,
"cutlass_op_name": out["name"],
}


def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, tmp_dir="./tmp"):
"""Given a module partitioned for CUTLASS offloading, profile each workload to select which
kernels to emit.
Expand Down Expand Up @@ -123,41 +204,41 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
if "cutlass" in fun_name:
num_cutlass_partition += 1
annotator.visit(func)
# call cutlass profiler to find best settings, update attr
new_attrs = {}
out_dtype = annotator.signature["ret_dtype"]
op_type = annotator.signature["op_type"]

new_attrs = {"op_type": op_type}
new_attrs.update(annotator.signature)
for key in func.attrs.keys():
new_attrs[key] = func.attrs[key]
# call profiler
new_attrs.update(func.attrs)
arg0_shape = new_attrs["arg0_shape"]
arg1_shape = new_attrs["arg1_shape"]
MM = arg0_shape[0]
KK = arg0_shape[1]
NN = arg1_shape[0]
out_dtype = annotator.signature["ret_dtype"]
if any(isinstance(s, tvm.tir.Any) for s in [MM, KK, NN]):
out = cutlass_profiler.get_default(out_dtype)
logger.info("Picked the default kernel %s", out["name"])
else:
out = cutlass_profiler.profile(
MM, NN, KK, out_dtype, profile_all, use_multiprocessing

if "batch_matmul" in op_type:
new_attrs.update(
handle_batch_matmul(
cutlass_profiler,
op_type,
arg0_shape,
arg1_shape,
out_dtype,
profile_all,
use_multiprocessing,
)
)
elif "dense" in op_type:
new_attrs.update(
handle_dense(
cutlass_profiler,
op_type,
arg0_shape,
arg1_shape,
out_dtype,
profile_all,
use_multiprocessing,
)
)
if profile_all:
logger.info("The best kernel is %s", out["name"])
else:
logger.info("Picked the first kernel found %s", out["name"])

if new_attrs["op_type"] == "cutlass.dense":
new_attrs["cutlass_op_def"] = out["opdef"]
elif new_attrs["op_type"] == "cutlass.dense_bias":
new_attrs["cutlass_op_def"] = out["opdef_bias"]
elif new_attrs["op_type"] == "cutlass.dense_bias_relu":
new_attrs["cutlass_op_def"] = out["opdef_bias_relu"]
elif "cutlass.dense_bias_gelu" in new_attrs["op_type"]:
new_attrs["cutlass_op_def"] = out["opdef_bias_gelu"]
else:
raise ValueError("%s pattern is not implemented." % new_attrs["op_type"])
new_attrs["cutlass_op_name"] = out["name"]
raise ValueError("%s unsupported composite" % op_type)

if new_attrs["cutlass_op_name"].find("_tn_align") > 0:
new_attrs["lda"] = "K"
Expand Down
15 changes: 6 additions & 9 deletions python/tvm/contrib/cutlass/gemm_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __init__(self):
>"""
self.gemm_template = """
// Gemm operator ${operation_name}
using Operation_${operation_name} = cutlass::gemm::device::Gemm<
using Operation_${operation_name} = cutlass::gemm::device::${kernel_name}<
${element_a}, ${layout_a},
${element_b}, ${layout_b},
${element_c}, ${layout_c},
Expand All @@ -189,13 +189,12 @@ def __init__(self):
${stages},
${align_a},
${align_b},
false,
${split_k_serial}
${math_operation}
${residual}
>;
"""

def emit(self, operation, no_beta_scaling=False):
def emit(self, operation, no_beta_scaling=False, batched=False):
"""Instantiate a GEMM kernel from given `operation`."""
warp_shape = [
operation.tile_description.threadblock_shape[idx]
Expand All @@ -206,8 +205,6 @@ def emit(self, operation, no_beta_scaling=False):
min(operation.C.alignment * DataTypeSize[operation.C.element], 128)
// DataTypeSize[operation.C.element]
)
residual = ""
complex_transform_tag = "cutlass::ComplexTransform::kNone"
values = {
"operation_name": operation.procedural_name(),
"element_a": DataTypeTag[operation.A.element],
Expand Down Expand Up @@ -243,14 +240,14 @@ def emit(self, operation, no_beta_scaling=False):
"stages": str(operation.tile_description.stages),
"align_a": str(operation.A.alignment),
"align_b": str(operation.B.alignment),
"transform_a": complex_transform_tag,
"transform_b": complex_transform_tag,
"math_operation": MathOperationTag[
operation.tile_description.math_instruction.math_operation
],
"residual": residual,
}

values["kernel_name"] = "GemmBatched" if batched else "Gemm"
values["split_k_serial"] = "" if batched else "false,"

gemm_template = substitute_template(
self.gemm_template,
{
Expand Down
42 changes: 27 additions & 15 deletions python/tvm/contrib/cutlass/gen_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def create_gemm_operator(
alignment_constraints,
epilogue_functor=EpilogueFunctor.LinearCombination,
swizzling_functor=SwizzlingFunctor.Identity8,
batched=False,
):
"""Exhaustively instantiate all kernels from a given configuration."""
ret = []
Expand All @@ -55,6 +56,9 @@ def create_gemm_operator(

element_a, element_b, element_c, element_epilogue = data_type

if batched:
swizzling_functor = SwizzlingFunctor.Batched

for layout in layouts:
for tile_description in tile_descriptions:
for alignment in alignment_constraints:
Expand Down Expand Up @@ -109,15 +113,17 @@ def create_gemm_operator(
kernel_emitter = EmitGemmInstance()
op_entry["op"] = op
op_entry["name"] = op.procedural_name()
op_entry["opdef"] = kernel_emitter.emit(op)
op_entry["opdef_bias"] = kernel_emitter.emit(op_bias, no_beta_scaling=True)
op_entry["opdef"] = kernel_emitter.emit(op, batched=batched)
op_entry["opdef_bias"] = kernel_emitter.emit(
op_bias, no_beta_scaling=True, batched=batched
)
op_entry["opdef_bias_relu"] = kernel_emitter.emit(
op_bias_relu, no_beta_scaling=True
op_bias_relu, no_beta_scaling=True, batched=batched
)
op_entry["opdef_bias_gelu"] = kernel_emitter.emit(op_bias_gelu)
op_entry["opdef_bias_gelu"] = kernel_emitter.emit(op_bias_gelu, batched=batched)
op_entry["src"] = profiler_emitter.emit(
op.procedural_name(),
op_entry["opdef"],
kernel_emitter.emit(op, batched=False),
DataTypeTag[element_a],
DataTypeTag[element_b],
DataTypeTag[element_c],
Expand All @@ -128,7 +134,9 @@ def create_gemm_operator(
return ret


def generate_tensor_op_common(math_instructions, alignment_constraints, get_tile_descriptions):
def generate_tensor_op_common(
math_instructions, alignment_constraints, get_tile_descriptions, batched=False
):
"""Common kernel generator to be used by archtecture specific generators."""
ops = []
layouts = [
Expand All @@ -143,14 +151,16 @@ def generate_tensor_op_common(math_instructions, alignment_constraints, get_tile
math_inst.element_accumulator,
]

out = create_gemm_operator(layouts, tile_descriptions, data_type, alignment_constraints)
out = create_gemm_operator(
layouts, tile_descriptions, data_type, alignment_constraints, batched=batched
)

ops.extend(out)

return ops


def generate_sm75_tensor_op_1688(out_dtype):
def generate_sm75_tensor_op_1688(out_dtype, batched=False):
"""Generate GEMM kernels for Turing."""
assert out_dtype in ["float32", "float16"]
math_instructions = {
Expand Down Expand Up @@ -192,11 +202,11 @@ def get_tile_descriptions(math_inst):
]

return generate_tensor_op_common(
math_instructions, alignment_constraints, get_tile_descriptions
math_instructions, alignment_constraints, get_tile_descriptions, batched
)


def generate_sm80_tensor_op_16816(out_dtype):
def generate_sm80_tensor_op_16816(out_dtype, batched=False):
"""Generate GEMM kernels for Ampere."""
assert out_dtype in ["float32", "float16"]
math_instructions = {
Expand Down Expand Up @@ -250,7 +260,7 @@ def get_tile_descriptions(math_inst):
]

return generate_tensor_op_common(
math_instructions, alignment_constraints, get_tile_descriptions
math_instructions, alignment_constraints, get_tile_descriptions, batched
)


Expand Down Expand Up @@ -350,25 +360,27 @@ def check_align(self, op_name, M):
return False
return True

def get_default(self, out_dtype):
def get_default(self, out_dtype, batched=False):
"""Return the default kernel for the requested architecture.
For now, the default kernel was picked arbitrary.
"""
ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype)
ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, batched)
default_kernel_name = DEFAULT_KERNELS[self.sm][out_dtype]
filtered = list(filter(lambda op: op["name"] == default_kernel_name, ops))
assert len(filtered) == 1
return filtered[0]

def profile(self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=False):
def profile(
self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=False, batched=False
):
"""Profile and select the best kernel from candidate kernels.
If profile_all is False, return immediately after the first applicable kernel is found.
If use_multiprocessing is True, compile all profiler executables in parallel.
"""
if (M, N, K) in self.cache:
return self.cache[(M, N, K)]

ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype)
ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, batched)
ops = list(filter(lambda op: self.check_align(op["name"], M), ops))

for op in ops:
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/contrib/cutlass/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,15 @@ class SwizzlingFunctor(enum.Enum):
Identity2 = enum_auto()
Identity4 = enum_auto()
Identity8 = enum_auto()
Batched = enum_auto()


SwizzlingFunctorTag = {
SwizzlingFunctor.Identity1: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>",
SwizzlingFunctor.Identity2: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>",
SwizzlingFunctor.Identity4: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>",
SwizzlingFunctor.Identity8: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>",
SwizzlingFunctor.Batched: "cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle",
}


Expand Down
11 changes: 8 additions & 3 deletions python/tvm/relay/op/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@


def make_gelu_pattern(bias_out, out_dtype="float16"):
mul = is_op("multiply")(bias_out, is_constant())
mul = is_op("multiply")(bias_out, is_constant() | wildcard())
if out_dtype == "float16":
erf = is_op("cast")(is_op("erf")(is_op("cast")(mul)))
else:
erf = is_op("erf")(mul)
mul_half = is_op("multiply")(erf, is_constant())
add = is_op("add")(mul_half, is_constant())
mul_half = is_op("multiply")(erf, is_constant() | wildcard())
add = is_op("add")(mul_half, is_constant() | wildcard())
return is_op("multiply")(add, bias_out)


Expand All @@ -51,6 +51,10 @@ def make_gemm_pattern(with_bias=True, with_act=None, out_dtype="float16"):
return make_gelu_pattern(gemm_out, out_dtype)


def make_batch_matmul_pattern():
return is_op("nn.batch_matmul")(wildcard(), wildcard())


def partition_for_cutlass(mod):
"""Partition the input module into CUTLASS-supported subgraphs."""
dense_pat = ("cutlass.dense", make_gemm_pattern(False, None))
Expand All @@ -67,6 +71,7 @@ def partition_for_cutlass(mod):
dense_bias_relu_pat,
dense_bias_pat,
dense_pat,
("cutlass.batch_matmul", make_batch_matmul_pattern()),
]
mod = transform.MergeComposite(cutlass_patterns)(mod)
mod = transform.AnnotateTarget(["cutlass"])(mod)
Expand Down
Loading

0 comments on commit fc04ea8

Please sign in to comment.