Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUTLASS] Support batch_matmul #9439

Merged
merged 9 commits into from
Nov 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 111 additions & 30 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,87 @@ def visit_call(self, call):
self.signature["ret_dtype"] = op.ret_type.dtype


def select_gemm_kernel(
cutlass_profiler, MM, KK, NN, out_dtype, batched, profile_all, use_multiprocessing
):
"""Run CUTLASS profiler to select the best kernel, or return the default one for dynamic
workloads."""
if any(isinstance(s, tvm.tir.Any) for s in [MM, KK, NN]):
out = cutlass_profiler.get_default(out_dtype, batched=batched)
logger.info("Picked the default kernel %s", out["name"])
else:
out = cutlass_profiler.profile(
MM,
NN,
KK,
out_dtype,
batched=batched,
profile_all=profile_all,
use_multiprocessing=use_multiprocessing,
)
if profile_all:
logger.info("The best kernel is %s", out["name"])
else:
logger.info("Picked the first kernel found %s", out["name"])
return out


def handle_batch_matmul(
cutlass_profiler, op_type, arg0_shape, arg1_shape, out_dtype, profile_all, use_multiprocessing
):
"""Profile and select a kernel for batch_matmul op workload."""
MM = arg0_shape[1]
KK = arg0_shape[2]
NN = arg1_shape[1]

out = select_gemm_kernel(
cutlass_profiler, MM, KK, NN, out_dtype, True, profile_all, use_multiprocessing
)

if op_type == "cutlass.batch_matmul":
cutlass_op_def = out["opdef"]
else:
raise ValueError("%s pattern is not implemented." % op_type)

return {
"batch": arg0_shape[0],
"batch_stride_A": arg0_shape[1] * arg0_shape[2],
"batch_stride_B": arg1_shape[1] * arg1_shape[2],
"batch_stride_C": arg0_shape[1] * arg1_shape[1],
"cutlass_op_def": cutlass_op_def,
"cutlass_op_name": out["name"],
}


def handle_dense(
cutlass_profiler, op_type, arg0_shape, arg1_shape, out_dtype, profile_all, use_multiprocessing
):
"""Profile and select a kernel for dense op workload."""
MM = arg0_shape[0]
KK = arg0_shape[1]
NN = arg1_shape[0]

out = select_gemm_kernel(
cutlass_profiler, MM, KK, NN, out_dtype, False, profile_all, use_multiprocessing
)

if op_type == "cutlass.dense":
cutlass_op_def = out["opdef"]
elif op_type == "cutlass.dense_bias":
cutlass_op_def = out["opdef_bias"]
elif op_type == "cutlass.dense_bias_relu":
cutlass_op_def = out["opdef_bias_relu"]
elif "cutlass.dense_bias_gelu" in op_type:
cutlass_op_def = out["opdef_bias_gelu"]
else:
raise ValueError("%s pattern is not implemented." % op_type)

return {
"cutlass_op_def": cutlass_op_def,
"cutlass_op_name": out["name"],
}


def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, tmp_dir="./tmp"):
"""Given a module partitioned for CUTLASS offloading, profile each workload to select which
kernels to emit.
Expand Down Expand Up @@ -123,41 +204,41 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
if "cutlass" in fun_name:
num_cutlass_partition += 1
annotator.visit(func)
# call cutlass profiler to find best settings, update attr
new_attrs = {}
out_dtype = annotator.signature["ret_dtype"]
op_type = annotator.signature["op_type"]

new_attrs = {"op_type": op_type}
new_attrs.update(annotator.signature)
for key in func.attrs.keys():
new_attrs[key] = func.attrs[key]
# call profiler
new_attrs.update(func.attrs)
arg0_shape = new_attrs["arg0_shape"]
arg1_shape = new_attrs["arg1_shape"]
MM = arg0_shape[0]
KK = arg0_shape[1]
NN = arg1_shape[0]
out_dtype = annotator.signature["ret_dtype"]
if any(isinstance(s, tvm.tir.Any) for s in [MM, KK, NN]):
out = cutlass_profiler.get_default(out_dtype)
logger.info("Picked the default kernel %s", out["name"])
else:
out = cutlass_profiler.profile(
MM, NN, KK, out_dtype, profile_all, use_multiprocessing

if "batch_matmul" in op_type:
new_attrs.update(
handle_batch_matmul(
cutlass_profiler,
op_type,
arg0_shape,
arg1_shape,
out_dtype,
profile_all,
use_multiprocessing,
)
)
elif "dense" in op_type:
new_attrs.update(
handle_dense(
cutlass_profiler,
op_type,
arg0_shape,
arg1_shape,
out_dtype,
profile_all,
use_multiprocessing,
)
)
if profile_all:
logger.info("The best kernel is %s", out["name"])
else:
logger.info("Picked the first kernel found %s", out["name"])

if new_attrs["op_type"] == "cutlass.dense":
new_attrs["cutlass_op_def"] = out["opdef"]
elif new_attrs["op_type"] == "cutlass.dense_bias":
new_attrs["cutlass_op_def"] = out["opdef_bias"]
elif new_attrs["op_type"] == "cutlass.dense_bias_relu":
new_attrs["cutlass_op_def"] = out["opdef_bias_relu"]
elif "cutlass.dense_bias_gelu" in new_attrs["op_type"]:
new_attrs["cutlass_op_def"] = out["opdef_bias_gelu"]
else:
raise ValueError("%s pattern is not implemented." % new_attrs["op_type"])
new_attrs["cutlass_op_name"] = out["name"]
raise ValueError("%s unsupported composite" % op_type)

if new_attrs["cutlass_op_name"].find("_tn_align") > 0:
new_attrs["lda"] = "K"
Expand Down
15 changes: 6 additions & 9 deletions python/tvm/contrib/cutlass/gemm_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __init__(self):
>"""
self.gemm_template = """
// Gemm operator ${operation_name}
using Operation_${operation_name} = cutlass::gemm::device::Gemm<
using Operation_${operation_name} = cutlass::gemm::device::${kernel_name}<
${element_a}, ${layout_a},
${element_b}, ${layout_b},
${element_c}, ${layout_c},
Expand All @@ -189,13 +189,12 @@ def __init__(self):
${stages},
${align_a},
${align_b},
false,
${split_k_serial}
${math_operation}
${residual}
>;
"""

def emit(self, operation, no_beta_scaling=False):
def emit(self, operation, no_beta_scaling=False, batched=False):
"""Instantiate a GEMM kernel from given `operation`."""
warp_shape = [
operation.tile_description.threadblock_shape[idx]
Expand All @@ -206,8 +205,6 @@ def emit(self, operation, no_beta_scaling=False):
min(operation.C.alignment * DataTypeSize[operation.C.element], 128)
// DataTypeSize[operation.C.element]
)
residual = ""
complex_transform_tag = "cutlass::ComplexTransform::kNone"
values = {
"operation_name": operation.procedural_name(),
"element_a": DataTypeTag[operation.A.element],
Expand Down Expand Up @@ -243,14 +240,14 @@ def emit(self, operation, no_beta_scaling=False):
"stages": str(operation.tile_description.stages),
"align_a": str(operation.A.alignment),
"align_b": str(operation.B.alignment),
"transform_a": complex_transform_tag,
"transform_b": complex_transform_tag,
"math_operation": MathOperationTag[
operation.tile_description.math_instruction.math_operation
],
"residual": residual,
}

values["kernel_name"] = "GemmBatched" if batched else "Gemm"
values["split_k_serial"] = "" if batched else "false,"

gemm_template = substitute_template(
self.gemm_template,
{
Expand Down
42 changes: 27 additions & 15 deletions python/tvm/contrib/cutlass/gen_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def create_gemm_operator(
alignment_constraints,
epilogue_functor=EpilogueFunctor.LinearCombination,
swizzling_functor=SwizzlingFunctor.Identity8,
batched=False,
):
"""Exhaustively instantiate all kernels from a given configuration."""
ret = []
Expand All @@ -55,6 +56,9 @@ def create_gemm_operator(

element_a, element_b, element_c, element_epilogue = data_type

if batched:
swizzling_functor = SwizzlingFunctor.Batched

for layout in layouts:
for tile_description in tile_descriptions:
for alignment in alignment_constraints:
Expand Down Expand Up @@ -109,15 +113,17 @@ def create_gemm_operator(
kernel_emitter = EmitGemmInstance()
op_entry["op"] = op
op_entry["name"] = op.procedural_name()
op_entry["opdef"] = kernel_emitter.emit(op)
op_entry["opdef_bias"] = kernel_emitter.emit(op_bias, no_beta_scaling=True)
op_entry["opdef"] = kernel_emitter.emit(op, batched=batched)
op_entry["opdef_bias"] = kernel_emitter.emit(
op_bias, no_beta_scaling=True, batched=batched
)
op_entry["opdef_bias_relu"] = kernel_emitter.emit(
op_bias_relu, no_beta_scaling=True
op_bias_relu, no_beta_scaling=True, batched=batched
)
op_entry["opdef_bias_gelu"] = kernel_emitter.emit(op_bias_gelu)
op_entry["opdef_bias_gelu"] = kernel_emitter.emit(op_bias_gelu, batched=batched)
op_entry["src"] = profiler_emitter.emit(
op.procedural_name(),
op_entry["opdef"],
kernel_emitter.emit(op, batched=False),
DataTypeTag[element_a],
DataTypeTag[element_b],
DataTypeTag[element_c],
Expand All @@ -128,7 +134,9 @@ def create_gemm_operator(
return ret


def generate_tensor_op_common(math_instructions, alignment_constraints, get_tile_descriptions):
def generate_tensor_op_common(
math_instructions, alignment_constraints, get_tile_descriptions, batched=False
):
"""Common kernel generator to be used by archtecture specific generators."""
ops = []
layouts = [
Expand All @@ -143,14 +151,16 @@ def generate_tensor_op_common(math_instructions, alignment_constraints, get_tile
math_inst.element_accumulator,
]

out = create_gemm_operator(layouts, tile_descriptions, data_type, alignment_constraints)
out = create_gemm_operator(
layouts, tile_descriptions, data_type, alignment_constraints, batched=batched
)

ops.extend(out)

return ops


def generate_sm75_tensor_op_1688(out_dtype):
def generate_sm75_tensor_op_1688(out_dtype, batched=False):
"""Generate GEMM kernels for Turing."""
assert out_dtype in ["float32", "float16"]
math_instructions = {
Expand Down Expand Up @@ -192,11 +202,11 @@ def get_tile_descriptions(math_inst):
]

return generate_tensor_op_common(
math_instructions, alignment_constraints, get_tile_descriptions
math_instructions, alignment_constraints, get_tile_descriptions, batched
)


def generate_sm80_tensor_op_16816(out_dtype):
def generate_sm80_tensor_op_16816(out_dtype, batched=False):
"""Generate GEMM kernels for Ampere."""
assert out_dtype in ["float32", "float16"]
math_instructions = {
Expand Down Expand Up @@ -250,7 +260,7 @@ def get_tile_descriptions(math_inst):
]

return generate_tensor_op_common(
math_instructions, alignment_constraints, get_tile_descriptions
math_instructions, alignment_constraints, get_tile_descriptions, batched
)


Expand Down Expand Up @@ -350,25 +360,27 @@ def check_align(self, op_name, M):
return False
return True

def get_default(self, out_dtype):
def get_default(self, out_dtype, batched=False):
"""Return the default kernel for the requested architecture.
For now, the default kernel was picked arbitrary.
"""
ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype)
ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, batched)
default_kernel_name = DEFAULT_KERNELS[self.sm][out_dtype]
filtered = list(filter(lambda op: op["name"] == default_kernel_name, ops))
assert len(filtered) == 1
return filtered[0]

def profile(self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=False):
def profile(
self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=False, batched=False
):
"""Profile and select the best kernel from candidate kernels.
If profile_all is False, return immediately after the first applicable kernel is found.
If use_multiprocessing is True, compile all profiler executables in parallel.
"""
if (M, N, K) in self.cache:
return self.cache[(M, N, K)]

ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype)
ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, batched)
ops = list(filter(lambda op: self.check_align(op["name"], M), ops))

for op in ops:
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/contrib/cutlass/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,15 @@ class SwizzlingFunctor(enum.Enum):
Identity2 = enum_auto()
Identity4 = enum_auto()
Identity8 = enum_auto()
Batched = enum_auto()


SwizzlingFunctorTag = {
SwizzlingFunctor.Identity1: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>",
SwizzlingFunctor.Identity2: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<2>",
SwizzlingFunctor.Identity4: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>",
SwizzlingFunctor.Identity8: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>",
SwizzlingFunctor.Batched: "cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle",
}


Expand Down
11 changes: 8 additions & 3 deletions python/tvm/relay/op/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@


def make_gelu_pattern(bias_out, out_dtype="float16"):
mul = is_op("multiply")(bias_out, is_constant())
mul = is_op("multiply")(bias_out, is_constant() | wildcard())
if out_dtype == "float16":
erf = is_op("cast")(is_op("erf")(is_op("cast")(mul)))
else:
erf = is_op("erf")(mul)
mul_half = is_op("multiply")(erf, is_constant())
add = is_op("add")(mul_half, is_constant())
mul_half = is_op("multiply")(erf, is_constant() | wildcard())
add = is_op("add")(mul_half, is_constant() | wildcard())
return is_op("multiply")(add, bias_out)


Expand All @@ -51,6 +51,10 @@ def make_gemm_pattern(with_bias=True, with_act=None, out_dtype="float16"):
return make_gelu_pattern(gemm_out, out_dtype)


def make_batch_matmul_pattern():
return is_op("nn.batch_matmul")(wildcard(), wildcard())


def partition_for_cutlass(mod):
"""Partition the input module into CUTLASS-supported subgraphs."""
dense_pat = ("cutlass.dense", make_gemm_pattern(False, None))
Expand All @@ -67,6 +71,7 @@ def partition_for_cutlass(mod):
dense_bias_relu_pat,
dense_bias_pat,
dense_pat,
("cutlass.batch_matmul", make_batch_matmul_pattern()),
]
mod = transform.MergeComposite(cutlass_patterns)(mod)
mod = transform.AnnotateTarget(["cutlass"])(mod)
Expand Down
Loading