Skip to content

Commit

Permalink
[Unity][Pass] Operator Fusion Passes (#14001)
Browse files Browse the repository at this point in the history
[Unity][Pass] Operator fusion passes

This PR introduces three passes for operator fusion:
1. AnnotateTIROpPattern: analysis the operator kind from PrimFunc.
2. FuseOps: fuse operators for Relax functions, which adds a new fused
relax primitive function.
3. FuseTIR: fuse corresponding TIR PrimFuncs for the fused relax.
  • Loading branch information
Hzfengsy authored and tqchen committed Feb 22, 2023
1 parent a7c5975 commit 59dfc38
Show file tree
Hide file tree
Showing 10 changed files with 3,441 additions and 2 deletions.
11 changes: 11 additions & 0 deletions include/tvm/relax/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,17 @@ TVM_DLL bool IsBaseOf(const StructInfo& base, const StructInfo& derived,
TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs,
arith::Analyzer* ana = nullptr);

/*!
* \brief Annotate Op Pattern Kind for PrimFunc, which is used in relax FuseOps.
*
* \param func The PrimFunc to be analyzed.
* \return The Op Pattern Kind.
*
* \note This analysis applies on TIR function but is primarily used by relax passes.
* As a result we place it under the relax namespace.
*/
TVM_DLL relay::OpPatternKind AnalyzeOpPatternKind(const tir::PrimFunc& func);

/*!
* \brief Check if the given PrimFunc is essentially doing a reshape operation.
* The reshape operation also includes expand_dims, squeeze, flatten, etc.
Expand Down
14 changes: 13 additions & 1 deletion include/tvm/tir/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@
namespace tvm {
namespace tir {

#ifndef TVM_INDEX_DEFAULT_I64
#define TVM_INDEX_DEFAULT_I64 1
#endif
/*! \brief if TVM_INDEX_DEFAULT_I64 is set, return int64, otherwise return int32 */
inline DataType DefaultIndexType() {
#if TVM_INDEX_DEFAULT_I64
return DataType::Int(64);
#else
return DataType::Int(32);
#endif
}

// forward declare Stmt
class Stmt;

Expand Down Expand Up @@ -135,7 +147,7 @@ class BufferNode : public Object {

/*! \return preferred index type for this buffer node */
DataType DefaultIndexType() const {
return shape.size() != 0 ? shape[0].dtype() : DataType::Int(32);
return shape.size() != 0 ? shape[0].dtype() : tvm::tir::DefaultIndexType();
}

/*! \brief Determine the offset in the buffer of the given index.
Expand Down
43 changes: 43 additions & 0 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,49 @@ def AttachGlobalSymbol() -> tvm.ir.transform.Pass:
return _ffi_api.AttachGlobalSymbol() # type: ignore


def AnnotateTIROpPattern() -> tvm.ir.transform.Pass:
"""Annotate Op Pattern Kind for TIR functions
Returns
-------
ret: tvm.ir.transform.Pass
"""
return _ffi_api.AnnotateTIROpPattern() # type: ignore


def FuseOps(fuse_opt_level=-1) -> tvm.ir.transform.Pass:
"""This pass groups bindings in a dataflow block of Relax functions and generate a new grouped
Relax function for each group, according to the fusion algorithm described in the pass
implementation. By grouping bindings into new Relax functions, we substitute the bindings in
the function being manipulated into function calls to the new grouped function.
A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each grouped function.
Parameters
----------
fuse_opt_level : int
The level of fuse optimization. -1 indicates that the level will be
inferred from pass context.
Returns
-------
ret : tvm.transform.Pass
The registered pass for operator fusion.
"""
return _ffi_api.FuseOps(fuse_opt_level) # type: ignore


def FuseTIR() -> tvm.ir.transform.Pass:
"""Fuse primitive relax function into a larger TIR function if possible
Returns
-------
ret : tvm.transform.Pass
The registered pass for tir fusion.
"""
return _ffi_api.FuseTIR() # type: ignore


def _wrap_class_function_pass(pass_cls, pass_info):
"""Wrap a python class as function pass."""

Expand Down
55 changes: 55 additions & 0 deletions src/relax/transform/annotate_tir_op_pattern.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relax/transform/annotate_tir_op_pattern.cc
* \brief Annotate Op Pattern for TIR functions. It is a pass works on TIR PrimFuncs,
* but they are needed for relax fusion. So we put them in the relax namespace.
*/
#include <tvm/relax/analysis.h>
#include <tvm/relax/transform.h>
#include <tvm/tir/transform.h>

namespace tvm {
namespace relax {

tir::PrimFunc AnnotateOpPattern(tir::PrimFunc f) {
if (f->HasNonzeroAttr("op_pattern")) {
return f;
} else {
relay::OpPatternKind kind = AnalyzeOpPatternKind(f);
return WithAttr(std::move(f), "op_pattern", Integer(static_cast<int>(kind)));
}
}

namespace transform {

Pass AnnotateTIROpPattern() {
auto pass_func = [=](tir::PrimFunc f, IRModule m, PassContext ctx) {
return AnnotateOpPattern(std::move(f));
};
return tir::transform::CreatePrimFuncPass(pass_func, 0, "AnnotateTIROpPattern", {});
}

TVM_REGISTER_GLOBAL("relax.transform.AnnotateTIROpPattern").set_body_typed(AnnotateTIROpPattern);

} // namespace transform

} // namespace relax
} // namespace tvm
Loading

0 comments on commit 59dfc38

Please sign in to comment.