Skip to content
This repository has been archived by the owner on Oct 25, 2023. It is now read-only.

Commit

Permalink
Importer improvements such that full flow works end to end
Browse files Browse the repository at this point in the history
  • Loading branch information
Josh Fromm committed Mar 2, 2023
1 parent 3a5627d commit 2a1da3a
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 61 deletions.
25 changes: 12 additions & 13 deletions python/tvm/octo/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Union, Optional, Dict, List
import tvm
from tvm import relax
from tvm.relax.backend.contrib.cutlass import partition_for_cutlass
from .utils import get_cuda_target, get_llvm_target
from .octo_model import OctoModel

Expand Down Expand Up @@ -132,24 +133,22 @@ def offload_cutlass(mod: tvm.IRModule, target: tvm.target.Target) -> tvm.IRModul
# Extract the sm version of the current target.
assert target.arch, "Target architecture must be specified."
sm = int(target.arch.split("_")[1])
# Cutlass only has support up to sm80, future sms will work with
# earlier kernels though.
if sm > 80:
sm = 80

# Apply partitioning to offload patterns to cutlass.
mod = partition_for_cutlass(mod)

# Construct CUTLASS codegen pass.
cutlass_codegen_pass = relax.transform.RunCodegen(
{"cutlass": {"sm": sm, "find_first_valid": True}}
)

# Construct pattern identification pass.
# TODO(jwfromm) rebase on cutlass pattern language

# Run passes on input module.
seq = tvm.transform.Sequential(
[
# relax.transform.FuseOpPattern(patterns, annotate_codegen=True),
cutlass_codegen_pass
]
)

return seq(mod)
# Generate code for matched cutlass kernels.
mod = cutlass_codegen_pass(mod)
return mod


def compile(
Expand Down Expand Up @@ -200,7 +199,7 @@ def compile(

# Match subgraphs that can be offloaded to cutlass and offload them.
# TODO(jwfromm) Currently doesnt work, get one e2e example.
# offload_cutlass(relax_mod, target)
relax_mod = offload_cutlass(relax_mod, target)

# Perform legalization to lower Relax operators.
relax_mod = relax.transform.LegalizeOps()(relax_mod)
Expand Down
107 changes: 59 additions & 48 deletions python/tvm/relax/frontend/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,11 @@ def _impl_v13(cls, bb, inputs, attr):
if alpha is not None:
A = bb.normalize(relax.op.multiply(A, relax.const(alpha, dtype=dtype)))

Y = bb.emit_te(topi.matmul, A, B, transA, transB)
if transA:
A = relax.op.permute_dims(A, [1, 0])
if transB:
B = relax.op.permute_dims(B, [1, 0])
Y = bb.normalize(relax.op.matmul(A, B))

if C is not None:
if beta is not None:
Expand Down Expand Up @@ -839,97 +843,104 @@ def _impl_v1(cls, bb, inputs, attr):
assert past is None, "past K, V state is not currently supported"
assert extra_add is None, "extra add to QxK not currently supported"

split_1 = bb.emit_te(topi.split, weight, 3, 1)
split_1 = bb.normalize(relax.op.split(weight, 3, 1))
# split weight and biases and do the matmuls
w_Q, w_K, w_V = bb.emit(split_1[0]), bb.emit(split_1[1]), bb.emit(split_1[2])
w_Q, w_K, w_V = split_1[0], split_1[1], split_1[2]

split_2 = bb.emit_te(topi.split, bias, 3, 0)
b_Q, b_K, b_V = bb.emit(split_2[0]), bb.emit(split_2[1]), bb.emit(split_2[2])
split_2 = bb.normalize(relax.op.split(bias, 3, 0))
b_Q, b_K, b_V = split_2[0], split_2[1], split_2[2]
# need to merge batch dimensions since TVM matmul is 2D

# TODO(@yuchen): check reverse_reshape, a hack here
input_emb = bb.emit_te(
topi.reshape, input_emb, (input_emb_shape[0] * input_emb_shape[1], input_emb_shape[2])
input_emb = bb.normalize(
relax.op.reshape(
input_emb, (input_emb_shape[0] * input_emb_shape[1], input_emb_shape[2])
)
)

mul = bb.emit_te(topi.nn.matmul, input_emb, w_Q)
mul = bb.normalize(relax.op.matmul(input_emb, w_Q))

Q = bb.emit_te(topi.add, mul, b_Q)
Q = bb.normalize(relax.op.add(mul, b_Q))

mul2 = bb.emit_te(topi.nn.matmul, input_emb, w_K)
K = bb.emit_te(topi.add, mul2, b_K)
mul2 = bb.normalize(relax.op.matmul(input_emb, w_K))
K = bb.normalize(relax.op.add(mul2, b_K))

mul3 = bb.emit_te(topi.nn.matmul, input_emb, w_V)
V = bb.emit_te(topi.add, mul3, b_V)
mul3 = bb.normalize(relax.op.matmul(input_emb, w_V))
V = bb.normalize(relax.op.add(mul3, b_V))

# massage tensors in preparation for batched matmul
def massage(bb, tensor):
tensor = bb.emit_te(topi.reshape, tensor, (batch_size, seq_len, num_heads, head_size))
tensor = bb.normalize(
relax.op.reshape(tensor, (batch_size, seq_len, num_heads, head_size))
)

# (batch_size, num_heads, seq_len, head_size)
tensor = bb.emit_te(topi.transpose, tensor, [0, 2, 1, 3])
tensor = bb.normalize(relax.op.permute_dims(tensor, [0, 2, 1, 3]))
tensor_shape = [val.value for val in tensor.struct_info.shape.values]

# (batch_size * num_heads, seq_len, head_size)
# TODO(@yuchen): check reverse_reshape, hack here
return bb.emit_te(
topi.reshape,
tensor,
(tensor_shape[0] * tensor_shape[1], tensor_shape[2], tensor_shape[3]),
return bb.normalize(
relax.op.reshape(
tensor, (tensor_shape[0] * tensor_shape[1], tensor_shape[2], tensor_shape[3])
)
)

Q = massage(bb, Q)
K = massage(bb, K)
V = massage(bb, V)

K_present = bb.emit_te(topi.reshape, K, (batch_size, num_heads, seq_len, head_size))
V_present = bb.emit_te(topi.reshape, V, (batch_size, num_heads, seq_len, head_size))
K_present = bb.normalize(relax.op.reshape(K, (batch_size, num_heads, seq_len, head_size)))
V_present = bb.normalize(relax.op.reshape(V, (batch_size, num_heads, seq_len, head_size)))
present = bb.emit_te(topi.stack, [K_present, V_present], 0)

att_scores = bb.emit_te(topi.nn.batch_matmul, Q, K, transpose_a=False, transpose_b=True)
att_scores = bb.normalize(relax.op.matmul(Q, relax.op.permute_dims(K, [0, 2, 1])))
score_dtype = att_scores.checked_type.dtype
att_scores = bb.emit_te(
topi.multiply,
att_scores,
relax.const(1 / _np.sqrt(head_size), dtype=att_scores.checked_type.dtype),
att_scores = bb.normalize(
relax.op.multiply(
att_scores,
relax.const(1 / _np.sqrt(head_size), dtype=att_scores.checked_type.dtype),
)
)
att_scores = bb.normalize(
relax.op.reshape(att_scores, (batch_size, num_heads, seq_len, seq_len))
)
att_scores = bb.emit_te(topi.reshape, att_scores, (batch_size, num_heads, seq_len, seq_len))

# build the attention mask
att_mask = bb.emit_te(topi.cast, mask_index, score_dtype)
att_mask = bb.normalize(relax.op.astype(mask_index, score_dtype))
att_mask = bb.emit_te(topi.expand_dims, att_mask, 1, num_newaxis=2)
att_mask = bb.emit_te(topi.subtract, relax.const(1, dtype=score_dtype), att_mask)
att_mask = bb.emit_te(topi.multiply, att_mask, relax.const(-10000, dtype=score_dtype))
att_mask = relax.op.subtract(relax.const(1, dtype=score_dtype), att_mask)
att_mask = relax.op.multiply(att_mask, relax.const(-10000, dtype=score_dtype))

# apply the mask
att_scores = bb.emit_te(topi.add, att_scores, att_mask)
att_scores = bb.emit_te(
topi.reshape, att_scores, (batch_size * num_heads, seq_len, seq_len)
att_scores = relax.op.add(att_scores, att_mask)
att_scores = bb.normalize(
relax.op.reshape(att_scores, (batch_size * num_heads, seq_len, seq_len))
)

att_probs = bb.emit_te(topi.nn.softmax, att_scores, axis=-1)
att_probs = relax.op.nn.softmax(att_scores, axis=-1)

output = bb.emit_te(
topi.nn.batch_matmul, att_probs, V, transpose_a=False, transpose_b=False
)
output = bb.normalize(relax.op.matmul(att_probs, V))

# TODO(@yuchen): check reverse_reshape, hack here
output_shape = [val.value for val in output.struct_info.shape.values]
output = bb.emit_te(
topi.reshape,
output,
(
int(output_shape[0]) // num_heads,
num_heads,
int(output_shape[1]),
int(output_shape[2]),
),
output = bb.normalize(
relax.op.reshape(
output,
(
int(output_shape[0]) // num_heads,
num_heads,
int(output_shape[1]),
int(output_shape[2]),
),
)
)

output = bb.emit_te(topi.transpose, output, axes=[0, 2, 1, 3])
output = bb.normalize(relax.op.permute_dims(output, axes=[0, 2, 1, 3]))
output_shape = [val.value for val in output.struct_info.shape.values]
output = bb.emit_te(
topi.reshape, output, (int(output_shape[0]), int(output_shape[1]), out_hidden)
output = bb.normalize(
relax.op.reshape(output, (int(output_shape[0]), int(output_shape[1]), out_hidden))
)
return relax.Tuple([output, present])

Expand Down

0 comments on commit 2a1da3a

Please sign in to comment.