Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Unity][Pass] Operator Fusion Passes (#14001)
Browse files Browse the repository at this point in the history
[Unity][Pass] Operator fusion passes

This PR introduces three passes for operator fusion:
1. AnnotateTIROpPattern: analysis the operator kind from PrimFunc.
2. FuseOps: fuse operators for Relax functions, which adds a new fused
relax primitive function.
3. FuseTIR: fuse corresponding TIR PrimFuncs for the fused relax.
Hzfengsy authored and tqchen committed Mar 20, 2023

Unverified

No user is associated with the committer email.
1 parent 62daae4 commit 20adb37
Showing 10 changed files with 3,441 additions and 2 deletions.
11 changes: 11 additions & 0 deletions include/tvm/relax/analysis.h
Original file line number Diff line number Diff line change
@@ -260,6 +260,17 @@ TVM_DLL bool IsBaseOf(const StructInfo& base, const StructInfo& derived,
TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs,
arith::Analyzer* ana = nullptr);

/*!
* \brief Annotate Op Pattern Kind for PrimFunc, which is used in relax FuseOps.
*
* \param func The PrimFunc to be analyzed.
* \return The Op Pattern Kind.
*
* \note This analysis applies on TIR function but is primarily used by relax passes.
* As a result we place it under the relax namespace.
*/
TVM_DLL relay::OpPatternKind AnalyzeOpPatternKind(const tir::PrimFunc& func);

/*!
* \brief Check if the given PrimFunc is essentially doing a reshape operation.
* The reshape operation also includes expand_dims, squeeze, flatten, etc.
14 changes: 13 additions & 1 deletion include/tvm/tir/buffer.h
Original file line number Diff line number Diff line change
@@ -34,6 +34,18 @@
namespace tvm {
namespace tir {

#ifndef TVM_INDEX_DEFAULT_I64
#define TVM_INDEX_DEFAULT_I64 1
#endif
/*! \brief if TVM_INDEX_DEFAULT_I64 is set, return int64, otherwise return int32 */
inline DataType DefaultIndexType() {
#if TVM_INDEX_DEFAULT_I64
return DataType::Int(64);
#else
return DataType::Int(32);
#endif
}

// forward declare Stmt
class Stmt;

@@ -135,7 +147,7 @@ class BufferNode : public Object {

/*! \return preferred index type for this buffer node */
DataType DefaultIndexType() const {
return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32);
return shape.size() != 0 ? shape[0].dtype() : tvm::tir::DefaultIndexType();
}

/*! \brief Determine the offset in the buffer of the given index.
43 changes: 43 additions & 0 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
@@ -105,6 +105,49 @@ def AttachGlobalSymbol() -> tvm.ir.transform.Pass:
return _ffi_api.AttachGlobalSymbol() # type: ignore


def AnnotateTIROpPattern() -> tvm.ir.transform.Pass:
"""Annotate Op Pattern Kind for TIR functions
Returns
-------
ret: tvm.ir.transform.Pass
"""
return _ffi_api.AnnotateTIROpPattern() # type: ignore


def FuseOps(fuse_opt_level=-1) -> tvm.ir.transform.Pass:
"""This pass groups bindings in a dataflow block of Relax functions and generate a new grouped
Relax function for each group, according to the fusion algorithm described in the pass
implementation. By grouping bindings into new Relax functions, we substitute the bindings in
the function being manipulated into function calls to the new grouped function.
A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each grouped function.
Parameters
----------
fuse_opt_level : int
The level of fuse optimization. -1 indicates that the level will be
inferred from pass context.
Returns
-------
ret : tvm.transform.Pass
The registered pass for operator fusion.
"""
return _ffi_api.FuseOps(fuse_opt_level) # type: ignore


def FuseTIR() -> tvm.ir.transform.Pass:
"""Fuse primitive relax function into a larger TIR function if possible
Returns
-------
ret : tvm.transform.Pass
The registered pass for tir fusion.
"""
return _ffi_api.FuseTIR() # type: ignore


def _wrap_class_function_pass(pass_cls, pass_info):
"""Wrap a python class as function pass."""

55 changes: 55 additions & 0 deletions src/relax/transform/annotate_tir_op_pattern.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relax/transform/annotate_tir_op_pattern.cc
* \brief Annotate Op Pattern for TIR functions. It is a pass works on TIR PrimFuncs,
* but they are needed for relax fusion. So we put them in the relax namespace.
*/
#include <tvm/relax/analysis.h>
#include <tvm/relax/transform.h>
#include <tvm/tir/transform.h>

namespace tvm {
namespace relax {

tir::PrimFunc AnnotateOpPattern(tir::PrimFunc f) {
if (f->HasNonzeroAttr("op_pattern")) {
return f;
} else {
relay::OpPatternKind kind = AnalyzeOpPatternKind(f);
return WithAttr(std::move(f), "op_pattern", Integer(static_cast<int>(kind)));
}
}

namespace transform {

Pass AnnotateTIROpPattern() {
auto pass_func = [=](tir::PrimFunc f, IRModule m, PassContext ctx) {
return AnnotateOpPattern(std::move(f));
};
return tir::transform::CreatePrimFuncPass(pass_func, 0, "AnnotateTIROpPattern", {});
}

TVM_REGISTER_GLOBAL("relax.transform.AnnotateTIROpPattern").set_body_typed(AnnotateTIROpPattern);

} // namespace transform

} // namespace relax
} // namespace tvm
909 changes: 909 additions & 0 deletions src/relax/transform/fuse_ops.cc

Large diffs are not rendered by default.

728 changes: 728 additions & 0 deletions src/relax/transform/fuse_tir.cc

Large diffs are not rendered by default.

360 changes: 360 additions & 0 deletions tests/python/relax/test_transform_annotate_tir_op_pattern.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,360 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import enum

import tvm
import tvm.script
import tvm.testing
from tvm import relax
from tvm.script import tir as T


class OpPatternKind(enum.IntEnum):
kElemWise = 0
kBroadcast = 1
kInjective = 2
kCommReduce = 3
kOutEWiseFusable = 4
kTuple = 7
kOpaque = 8


def test_annotate_opkind_outewisefusable():
@tvm.script.ir_module
class InputModule:
@T.prim_func
def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
T.func_attr({"global_symbol": "tir_matmul"})
m = T.var("int32")
n = T.var("int32")
k = T.var("int32")
A = T.match_buffer(x, (m, n))
B = T.match_buffer(y, (n, k))
C = T.match_buffer(z, (m, k))

for i, j, k in T.grid(m, k, n):
with T.block("matmul"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

mod = InputModule
new_mod = relax.transform.AnnotateTIROpPattern()(mod)
assert new_mod["tir_matmul"].attrs["op_pattern"] == OpPatternKind.kOutEWiseFusable


def test_annotate_opkind_outewisefusable_int_var_signature():
@tvm.script.ir_module
class InputModule:
@T.prim_func
def tir_matmul(x: T.handle, y: T.handle, z: T.handle, m: T.int64, n: T.int64, k: T.int64):
T.func_attr({"global_symbol": "tir_matmul"})
A = T.match_buffer(x, (m, n))
B = T.match_buffer(y, (n, k))
C = T.match_buffer(z, (m, k))

for i, j, k in T.grid(m, k, n):
with T.block("matmul"):
vi, vj, vk = T.axis.remap("SSR", [i, j, k])
with T.init():
C[vi, vj] = T.float32(0)
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]

mod = InputModule
new_mod = relax.transform.AnnotateTIROpPattern()(mod)
assert new_mod["tir_matmul"].attrs["op_pattern"] == OpPatternKind.kOutEWiseFusable


def test_annotate_opkind_reduce():
@tvm.script.ir_module
class InputModule:
@T.prim_func
def sum(x: T.handle, y: T.handle) -> None:
T.func_attr({"global_symbol": "elemwise"})
A = T.match_buffer(x, (16, 16))
B = T.match_buffer(y, (16,))

for i, j in T.grid(16, 16):
with T.block("matmul"):
vi, vj = T.axis.remap("SR", [i, j])
with T.init():
B[vi] = 0.0
B[vi] += A[vi, vj]

mod = InputModule
new_mod = relax.transform.AnnotateTIROpPattern()(mod)
assert new_mod["sum"].attrs["op_pattern"] == OpPatternKind.kCommReduce


def test_annotate_opkind_ewise():
@tvm.script.ir_module
class InputModule:
@T.prim_func
def elemwise(x: T.handle, y: T.handle) -> None:
T.func_attr({"global_symbol": "elemwise"})
A = T.match_buffer(x, (16, 16))
B = T.match_buffer(y, (16, 16))

for i, j in T.grid(16, 16):
with T.block("matmul"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] + 1.0

mod = InputModule
new_mod = relax.transform.AnnotateTIROpPattern()(mod)
assert new_mod["elemwise"].attrs["op_pattern"] == OpPatternKind.kElemWise


def test_annotate_opkind_broadcast():
@tvm.script.ir_module
class InputModule:
@T.prim_func
def broadcast(x: T.handle, y: T.handle) -> None:
T.func_attr({"global_symbol": "elemwise"})
A = T.match_buffer(x, (16, 16))
B = T.match_buffer(y, (16, 16, 16, 16))

for i0, j0, i1, j1 in T.grid(16, 16, 16, 16):
with T.block("matmul"):
vi0, vj0, vi1, vj1 = T.axis.remap("SSSS", [i0, j0, i1, j1])
B[vi0, vj0, vi1, vj1] = A[vj0, vj1]

mod = InputModule
new_mod = relax.transform.AnnotateTIROpPattern()(mod)
assert new_mod["broadcast"].attrs["op_pattern"] == OpPatternKind.kBroadcast


def test_annotate_opkind_injective():
@tvm.script.ir_module
class InputModule:
@T.prim_func
def injective(x: T.handle, y: T.handle) -> None:
T.func_attr({"global_symbol": "elemwise"})
A = T.match_buffer(x, (4, 4, 4, 4))
B = T.match_buffer(y, (16, 16))

for i, j in T.grid(16, 16):
with T.block("matmul"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi // 4, vj // 4, vi % 4, vj % 4]

mod = InputModule
new_mod = relax.transform.AnnotateTIROpPattern()(mod)
assert new_mod["injective"].attrs["op_pattern"] == OpPatternKind.kInjective


def test_annotate_opkind_bias_add():
@tvm.script.ir_module
class InputModule:
@T.prim_func
def tir_bias_add(
A: T.Buffer((1, 1000), "float32"),
B: T.Buffer((1000,), "float32"),
C: T.Buffer((1, 1000), "float32"),
) -> None:
# function attr dict
T.func_attr({"global_symbol": "tir_bias_add", "tir.noalias": True})
# body
# with T.block("root")
for i0, i1 in T.grid(1, 1000):
with T.block("T_add"):
ax0, ax1 = T.axis.remap("SS", [i0, i1])
T.reads(A[ax0, ax1], B[ax1])
T.writes(C[ax0, ax1])
C[ax0, ax1] = A[ax0, ax1] + B[ax1]

mod = InputModule
new_mod = relax.transform.AnnotateTIROpPattern()(mod)
assert new_mod["tir_bias_add"].attrs["op_pattern"] == OpPatternKind.kElemWise


def test_annotate_opkind_add_broadcast_with_unit_shape():
@tvm.script.ir_module
class InputModule:
@T.prim_func
def add_with_unit_dim_len_broadcast(
A: T.Buffer((1, 64, 112, 112), "float32"),
B: T.Buffer((64, 1, 1), "float32"),
C: T.Buffer((1, 64, 112, 112), "float32"),
) -> None:
T.func_attr({"global_symbol": "add5", "tir.noalias": True})
for i0, i1, i2, i3 in T.grid(1, 64, 112, 112):
with T.block("T_add"):
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(A[ax0, ax1, ax2, ax3], B[ax1, 0, 0])
T.writes(C[ax0, ax1, ax2, ax3])
C[ax0, ax1, ax2, ax3] = A[ax0, ax1, ax2, ax3] + B[ax1, 0, 0]

mod = InputModule
new_mod = relax.transform.AnnotateTIROpPattern()(mod)
assert new_mod["add_with_unit_dim_len_broadcast"].attrs["op_pattern"] == OpPatternKind.kElemWise


def test_annotate_opkind_add_zero_dim_element_wise():
@tvm.script.ir_module
class InputModule:
@T.prim_func
def add_zero_dim(
A: T.Buffer((128,), "float32"),
B: T.Buffer((), "float32"),
C: T.Buffer((128,), "float32"),
) -> None:
T.func_attr({"global_symbol": "add8", "tir.noalias": True})
for i0 in T.serial(128):
with T.block("T_add"):
ax0 = T.axis.spatial(128, i0)
T.reads(A[ax0], B[()])
T.writes(C[ax0])
C[ax0] = A[ax0] + B[()]

mod = InputModule
new_mod = relax.transform.AnnotateTIROpPattern()(mod)
assert new_mod["add_zero_dim"].attrs["op_pattern"] == OpPatternKind.kElemWise


def test_annotate_opkind_pooling():
@tvm.script.ir_module
class InputModule:
@T.prim_func
def max_pool2d(
rxplaceholder_1: T.Buffer((1, 64, 112, 112), "float32"),
tensor_1: T.Buffer((1, 64, 56, 56), "float32"),
) -> None:
# function attr dict
T.func_attr({"global_symbol": "max_pool2d", "T.noalias": True})
# body
# with T.block("root")
pad_temp_1 = T.alloc_buffer([1, 64, 114, 114], dtype="float32")
for i0, i1, i2, i3 in T.grid(1, 64, 114, 114):
with T.block("pad_temp"):
ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
T.reads(rxplaceholder_1[ax0, ax1, ax2 - 1, ax3 - 1])
T.writes(pad_temp_1[ax0, ax1, ax2, ax3])
pad_temp_1[ax0, ax1, ax2, ax3] = T.if_then_else(
1 <= ax2 and ax2 < 113 and 1 <= ax3 and ax3 < 113,
rxplaceholder_1[ax0, ax1, ax2 - 1, ax3 - 1],
T.float32(-3.4028234663852886e38),
dtype="float32",
)
for i0, i1, i2, i3, i4, i5 in T.grid(1, 64, 56, 56, 3, 3):
with T.block("tensor"):
ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5])
T.reads(
tensor_1[ax0, ax1, ax2, ax3],
pad_temp_1[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1],
)
T.writes(tensor_1[ax0, ax1, ax2, ax3])
with T.init():
tensor_1[ax0, ax1, ax2, ax3] = T.float32(-3.4028234663852886e38)
tensor_1[ax0, ax1, ax2, ax3] = T.max(
tensor_1[ax0, ax1, ax2, ax3],
pad_temp_1[ax0, ax1, ax2 * 2 + rv0, ax3 * 2 + rv1],
)

mod = InputModule
new_mod = relax.transform.AnnotateTIROpPattern()(mod)
assert new_mod["max_pool2d"].attrs["op_pattern"] == OpPatternKind.kOutEWiseFusable


def test_annotate_opkind_softmax():
@tvm.script.ir_module
class InputModule:
@T.prim_func
def softmax(
rxplaceholder_1: T.Buffer((16, 16), "float32"),
T_softmax_norm_1: T.Buffer((16, 16), "float32"),
) -> None:
# function attr dict
T.func_attr({"global_symbol": "softmax", "T.noalias": True})
# body
# with T.block("root")
T_softmax_maxelem_1 = T.alloc_buffer([16], dtype="float32")
T_softmax_exp_1 = T.alloc_buffer([16, 16], dtype="float32")
T_softmax_expsum_1 = T.alloc_buffer([16], dtype="float32")
for i0_7, i1_3 in T.grid(16, 16):
with T.block("T_softmax_maxelem"):
i0_8, k = T.axis.remap("SR", [i0_7, i1_3])
T.reads(T_softmax_maxelem_1[i0_8], rxplaceholder_1[i0_8, k])
T.writes(T_softmax_maxelem_1[i0_8])
with T.init():
T_softmax_maxelem_1[i0_8] = T.float32(-3.4028234663852886e38)
T_softmax_maxelem_1[i0_8] = T.max(
T_softmax_maxelem_1[i0_8], rxplaceholder_1[i0_8, k]
)
for i0_9, i1_4 in T.grid(16, 16):
with T.block("T_softmax_exp"):
i0_10, i1_5 = T.axis.remap("SS", [i0_9, i1_4])
T.reads(rxplaceholder_1[i0_10, i1_5], T_softmax_maxelem_1[i0_10])
T.writes(T_softmax_exp_1[i0_10, i1_5])
T_softmax_exp_1[i0_10, i1_5] = T.exp(
rxplaceholder_1[i0_10, i1_5] - T_softmax_maxelem_1[i0_10], dtype="float32"
)
for i0_11, i1_6 in T.grid(16, 16):
with T.block("T_softmax_expsum"):
i0_12, k = T.axis.remap("SR", [i0_11, i1_6])
T.reads(T_softmax_expsum_1[i0_12], T_softmax_exp_1[i0_12, k])
T.writes(T_softmax_expsum_1[i0_12])
with T.init():
T_softmax_expsum_1[i0_12] = T.float32(0)
T_softmax_expsum_1[i0_12] = (
T_softmax_expsum_1[i0_12] + T_softmax_exp_1[i0_12, k]
)
for i0_13, i1_7 in T.grid(16, 16):
with T.block("T_softmax_norm"):
i0_14, i1_8 = T.axis.remap("SS", [i0_13, i1_7])
T.reads(T_softmax_exp_1[i0_14, i1_8], T_softmax_expsum_1[i0_14])
T.writes(T_softmax_norm_1[i0_14, i1_8])
T.block_attr({"axis": 1})
T_softmax_norm_1[i0_14, i1_8] = (
T_softmax_exp_1[i0_14, i1_8] / T_softmax_expsum_1[i0_14]
)

mod = InputModule
new_mod = relax.transform.AnnotateTIROpPattern()(mod)
assert new_mod["softmax"].attrs["op_pattern"] == OpPatternKind.kOutEWiseFusable


def test_multiple_bufer_stores_fallback():
@tvm.script.ir_module
class CumsumModule:
@T.prim_func
def cumsum(var_rxplaceholder: T.handle, out_buf: T.Buffer(160, "float32")):
rxplaceholder = T.match_buffer(
var_rxplaceholder, [10, 16], dtype="float32", offset_factor=1
)
with T.block("cumsum_generic"):
T.reads(rxplaceholder[0:10, 0:16])
T.writes(out_buf[0:160])
for fused in T.parallel(1):
out_buf[fused * 160] = rxplaceholder[fused * 160 // 16, fused * 160 % 16]
for v_k in T.serial(159):
out_buf[fused * 160 + (v_k + 1)] = (
out_buf[fused * 160 + (v_k + 1 - 1)]
+ rxplaceholder[
(fused * 160 + (v_k + 1)) // 16,
(fused * 160 + (v_k + 1)) % 16,
]
)

mod = CumsumModule
new_mod = relax.transform.AnnotateTIROpPattern()(mod)
assert new_mod["cumsum"].attrs["op_pattern"] == OpPatternKind.kOpaque


if __name__ == "__main__":
tvm.testing.main()
759 changes: 759 additions & 0 deletions tests/python/relax/test_transform_fuse_ops.py

Large diffs are not rendered by default.

563 changes: 563 additions & 0 deletions tests/python/relax/test_transform_fuse_tir.py

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion tests/python/relax/test_tvmscript_parser.py
Original file line number Diff line number Diff line change
@@ -1073,5 +1073,4 @@ def mul_add(x: R.Tensor) -> R.Tensor:


if __name__ == "__main__":
test_cross_function_call()
tvm.testing.main()

0 comments on commit 20adb37

Please sign in to comment.