Skip to content

Commit

Permalink
partitioning looks good
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Nov 3, 2021
1 parent 59112fd commit a5740bc
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 8 deletions.
10 changes: 7 additions & 3 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,18 +111,22 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
NN,
KK,
annotator.signature["ret_dtype"],
profile_all,
use_multiprocessing,
batched=True,
profile_all=profile_all,
use_multiprocessing=use_multiprocessing,
)
new_attrs["batch"] = arg0_shape[0]
new_attrs["batch_stride_A"] = arg0_shape[1] * arg0_shape[2]
new_attrs["batch_stride_B"] = arg1_shape[1] * arg1_shape[2]
new_attrs["batch_stride_C"] = arg0_shape[1] * arg1_shape[2]
else:
MM = arg0_shape[0]
KK = arg0_shape[1]
NN = arg1_shape[0]
out = cutlass_profiler.profile(
MM, NN, KK, annotator.signature["ret_dtype"], profile_all, use_multiprocessing
)
if new_attrs["op_type"] == "cutlass.dense":
if new_attrs["op_type"] in ["cutlass.dense", "cutlass.batch_matmul"]:
new_attrs["cutlass_op_def"] = out["opdef"]
elif new_attrs["op_type"] == "cutlass.dense_bias":
new_attrs["cutlass_op_def"] = out["opdef_bias"]
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/contrib/cutlass/gen_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def generate_tensor_op_common(math_instructions, alignment_constraints, get_tile
math_inst.element_accumulator,
]

out = create_gemm_operator(layouts, tile_descriptions, data_type, alignment_constraints, batched)
out = create_gemm_operator(layouts, tile_descriptions, data_type, alignment_constraints, batched=batched)

ops.extend(out)

Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relay/op/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def make_gemm_pattern(with_bias=True, with_act=None, out_dtype="float16"):
return make_gelu_pattern(gemm_out, out_dtype)


def make_batched_matmul_pattern():
return is_op("nn.batched_matmul")(wildcard(), wildcard()
def make_batch_matmul_pattern():
return is_op("nn.batch_matmul")(wildcard(), wildcard())


def partition_for_cutlass(mod):
Expand All @@ -71,7 +71,7 @@ def partition_for_cutlass(mod):
dense_bias_relu_pat,
dense_bias_pat,
dense_pat,
make_batched_matmul_pattern()
("cutlass.batch_matmul", make_batch_matmul_pattern())
]
mod = transform.MergeComposite(cutlass_patterns)(mod)
mod = transform.AnnotateTarget(["cutlass"])(mod)
Expand Down
6 changes: 5 additions & 1 deletion tests/python/contrib/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ def profile_and_build(mod, params, sm, tmp_dir="./tmp", lib_path="compile.so"):
mod, num_cutlass_partition = tune_cutlass_kernels(
mod, sm, profile_all=False, use_multiprocessing=False, tmp_dir=tmp_dir
)
print(mod)
return None, None, None

with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target="cuda", params=params)
lib = build_cutlass_kernels(lib, sm, tmp_dir, lib_path)
Expand Down Expand Up @@ -120,8 +123,9 @@ def verify_batch_matmul(func, batch, M, N, K, sm=80, atol=1e-5, rtol=1e-5, run_b
x_np = np.random.uniform(-1, 1, (batch, M, K)).astype("float16")
y_np = np.random.uniform(-1, 1, (batch, N, K)).astype("float16")

rt_mod_ref, dev = get_ref_rt_mod(mod, {})
rt_mod, dev, num_partition = profile_and_build(mod, {}, sm)
return
rt_mod_ref, dev = get_ref_rt_mod(mod, {})
assert num_partition > 0

x = tvm.nd.array(x_np, device=dev)
Expand Down

0 comments on commit a5740bc

Please sign in to comment.