From 2a1da3a6ff014311ffdb44daebe6523b8634024d Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Thu, 2 Mar 2023 13:58:55 -0800 Subject: [PATCH] Importer improvements such that full flow works end to end --- python/tvm/octo/compile.py | 25 +++-- python/tvm/relax/frontend/onnx_frontend.py | 107 ++++++++++++--------- 2 files changed, 71 insertions(+), 61 deletions(-) diff --git a/python/tvm/octo/compile.py b/python/tvm/octo/compile.py index 10d0681f9f..a78b7aa856 100644 --- a/python/tvm/octo/compile.py +++ b/python/tvm/octo/compile.py @@ -21,6 +21,7 @@ from typing import Union, Optional, Dict, List import tvm from tvm import relax +from tvm.relax.backend.contrib.cutlass import partition_for_cutlass from .utils import get_cuda_target, get_llvm_target from .octo_model import OctoModel @@ -132,24 +133,22 @@ def offload_cutlass(mod: tvm.IRModule, target: tvm.target.Target) -> tvm.IRModul # Extract the sm version of the current target. assert target.arch, "Target architecture must be specified." sm = int(target.arch.split("_")[1]) + # Cutlass only has support up to sm80, future sms will work with + # earlier kernels though. + if sm > 80: + sm = 80 + + # Apply partitioning to offload patterns to cutlass. + mod = partition_for_cutlass(mod) # Construct CUTLASS codegen pass. cutlass_codegen_pass = relax.transform.RunCodegen( {"cutlass": {"sm": sm, "find_first_valid": True}} ) - # Construct pattern identification pass. - # TODO(jwfromm) rebase on cutlass pattern language - - # Run passes on input module. - seq = tvm.transform.Sequential( - [ - # relax.transform.FuseOpPattern(patterns, annotate_codegen=True), - cutlass_codegen_pass - ] - ) - - return seq(mod) + # Generate code for matched cutlass kernels. + mod = cutlass_codegen_pass(mod) + return mod def compile( @@ -200,7 +199,7 @@ def compile( # Match subgraphs that can be offloaded to cutlass and offload them. # TODO(jwfromm) Currently doesnt work, get one e2e example. - # offload_cutlass(relax_mod, target) + relax_mod = offload_cutlass(relax_mod, target) # Perform legalization to lower Relax operators. relax_mod = relax.transform.LegalizeOps()(relax_mod) diff --git a/python/tvm/relax/frontend/onnx_frontend.py b/python/tvm/relax/frontend/onnx_frontend.py index ea585ca9bc..8354fc3b6e 100644 --- a/python/tvm/relax/frontend/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx_frontend.py @@ -288,7 +288,11 @@ def _impl_v13(cls, bb, inputs, attr): if alpha is not None: A = bb.normalize(relax.op.multiply(A, relax.const(alpha, dtype=dtype))) - Y = bb.emit_te(topi.matmul, A, B, transA, transB) + if transA: + A = relax.op.permute_dims(A, [1, 0]) + if transB: + B = relax.op.permute_dims(B, [1, 0]) + Y = bb.normalize(relax.op.matmul(A, B)) if C is not None: if beta is not None: @@ -839,97 +843,104 @@ def _impl_v1(cls, bb, inputs, attr): assert past is None, "past K, V state is not currently supported" assert extra_add is None, "extra add to QxK not currently supported" - split_1 = bb.emit_te(topi.split, weight, 3, 1) + split_1 = bb.normalize(relax.op.split(weight, 3, 1)) # split weight and biases and do the matmuls - w_Q, w_K, w_V = bb.emit(split_1[0]), bb.emit(split_1[1]), bb.emit(split_1[2]) + w_Q, w_K, w_V = split_1[0], split_1[1], split_1[2] split_2 = bb.emit_te(topi.split, bias, 3, 0) - b_Q, b_K, b_V = bb.emit(split_2[0]), bb.emit(split_2[1]), bb.emit(split_2[2]) + split_2 = bb.normalize(relax.op.split(bias, 3, 0)) + b_Q, b_K, b_V = split_2[0], split_2[1], split_2[2] # need to merge batch dimensions since TVM matmul is 2D # TODO(@yuchen): check reverse_reshape, a hack here - input_emb = bb.emit_te( - topi.reshape, input_emb, (input_emb_shape[0] * input_emb_shape[1], input_emb_shape[2]) + input_emb = bb.normalize( + relax.op.reshape( + input_emb, (input_emb_shape[0] * input_emb_shape[1], input_emb_shape[2]) + ) ) - mul = bb.emit_te(topi.nn.matmul, input_emb, w_Q) + mul = bb.normalize(relax.op.matmul(input_emb, w_Q)) - Q = bb.emit_te(topi.add, mul, b_Q) + Q = bb.normalize(relax.op.add(mul, b_Q)) - mul2 = bb.emit_te(topi.nn.matmul, input_emb, w_K) - K = bb.emit_te(topi.add, mul2, b_K) + mul2 = bb.normalize(relax.op.matmul(input_emb, w_K)) + K = bb.normalize(relax.op.add(mul2, b_K)) - mul3 = bb.emit_te(topi.nn.matmul, input_emb, w_V) - V = bb.emit_te(topi.add, mul3, b_V) + mul3 = bb.normalize(relax.op.matmul(input_emb, w_V)) + V = bb.normalize(relax.op.add(mul3, b_V)) # massage tensors in preparation for batched matmul def massage(bb, tensor): - tensor = bb.emit_te(topi.reshape, tensor, (batch_size, seq_len, num_heads, head_size)) + tensor = bb.normalize( + relax.op.reshape(tensor, (batch_size, seq_len, num_heads, head_size)) + ) # (batch_size, num_heads, seq_len, head_size) - tensor = bb.emit_te(topi.transpose, tensor, [0, 2, 1, 3]) + tensor = bb.normalize(relax.op.permute_dims(tensor, [0, 2, 1, 3])) tensor_shape = [val.value for val in tensor.struct_info.shape.values] # (batch_size * num_heads, seq_len, head_size) # TODO(@yuchen): check reverse_reshape, hack here - return bb.emit_te( - topi.reshape, - tensor, - (tensor_shape[0] * tensor_shape[1], tensor_shape[2], tensor_shape[3]), + return bb.normalize( + relax.op.reshape( + tensor, (tensor_shape[0] * tensor_shape[1], tensor_shape[2], tensor_shape[3]) + ) ) Q = massage(bb, Q) K = massage(bb, K) V = massage(bb, V) - K_present = bb.emit_te(topi.reshape, K, (batch_size, num_heads, seq_len, head_size)) - V_present = bb.emit_te(topi.reshape, V, (batch_size, num_heads, seq_len, head_size)) + K_present = bb.normalize(relax.op.reshape(K, (batch_size, num_heads, seq_len, head_size))) + V_present = bb.normalize(relax.op.reshape(V, (batch_size, num_heads, seq_len, head_size))) present = bb.emit_te(topi.stack, [K_present, V_present], 0) - att_scores = bb.emit_te(topi.nn.batch_matmul, Q, K, transpose_a=False, transpose_b=True) + att_scores = bb.normalize(relax.op.matmul(Q, relax.op.permute_dims(K, [0, 2, 1]))) score_dtype = att_scores.checked_type.dtype - att_scores = bb.emit_te( - topi.multiply, - att_scores, - relax.const(1 / _np.sqrt(head_size), dtype=att_scores.checked_type.dtype), + att_scores = bb.normalize( + relax.op.multiply( + att_scores, + relax.const(1 / _np.sqrt(head_size), dtype=att_scores.checked_type.dtype), + ) + ) + att_scores = bb.normalize( + relax.op.reshape(att_scores, (batch_size, num_heads, seq_len, seq_len)) ) - att_scores = bb.emit_te(topi.reshape, att_scores, (batch_size, num_heads, seq_len, seq_len)) # build the attention mask - att_mask = bb.emit_te(topi.cast, mask_index, score_dtype) + att_mask = bb.normalize(relax.op.astype(mask_index, score_dtype)) att_mask = bb.emit_te(topi.expand_dims, att_mask, 1, num_newaxis=2) - att_mask = bb.emit_te(topi.subtract, relax.const(1, dtype=score_dtype), att_mask) - att_mask = bb.emit_te(topi.multiply, att_mask, relax.const(-10000, dtype=score_dtype)) + att_mask = relax.op.subtract(relax.const(1, dtype=score_dtype), att_mask) + att_mask = relax.op.multiply(att_mask, relax.const(-10000, dtype=score_dtype)) # apply the mask - att_scores = bb.emit_te(topi.add, att_scores, att_mask) - att_scores = bb.emit_te( - topi.reshape, att_scores, (batch_size * num_heads, seq_len, seq_len) + att_scores = relax.op.add(att_scores, att_mask) + att_scores = bb.normalize( + relax.op.reshape(att_scores, (batch_size * num_heads, seq_len, seq_len)) ) - att_probs = bb.emit_te(topi.nn.softmax, att_scores, axis=-1) + att_probs = relax.op.nn.softmax(att_scores, axis=-1) - output = bb.emit_te( - topi.nn.batch_matmul, att_probs, V, transpose_a=False, transpose_b=False - ) + output = bb.normalize(relax.op.matmul(att_probs, V)) # TODO(@yuchen): check reverse_reshape, hack here output_shape = [val.value for val in output.struct_info.shape.values] - output = bb.emit_te( - topi.reshape, - output, - ( - int(output_shape[0]) // num_heads, - num_heads, - int(output_shape[1]), - int(output_shape[2]), - ), + output = bb.normalize( + relax.op.reshape( + output, + ( + int(output_shape[0]) // num_heads, + num_heads, + int(output_shape[1]), + int(output_shape[2]), + ), + ) ) - output = bb.emit_te(topi.transpose, output, axes=[0, 2, 1, 3]) + output = bb.normalize(relax.op.permute_dims(output, axes=[0, 2, 1, 3])) output_shape = [val.value for val in output.struct_info.shape.values] - output = bb.emit_te( - topi.reshape, output, (int(output_shape[0]), int(output_shape[1]), out_hidden) + output = bb.normalize( + relax.op.reshape(output, (int(output_shape[0]), int(output_shape[1]), out_hidden)) ) return relax.Tuple([output, present])