Skip to content

Commit

Permalink
Unittests and Refactor (tlc-pack#10)
Browse files Browse the repository at this point in the history
* ut except bind params

* all complete
  • Loading branch information
jinhongyii authored and MasterJH5574 committed Apr 24, 2022
1 parent e3da9b9 commit 7c14d07
Show file tree
Hide file tree
Showing 11 changed files with 376 additions and 218 deletions.
9 changes: 8 additions & 1 deletion include/tvm/relax/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ TVM_DLL Pass FoldConstant();
*
* \return The Pass.
*/
TVM_DLL Pass AnnotateOpKind();
TVM_DLL Pass AnnotateTIROpPattern();

/*!
* \brief Layout Rewrite
Expand All @@ -156,6 +156,13 @@ TVM_DLL Pass FoldConstant();
*/
TVM_DLL Pass FuseOps(int fuse_opt_level = -1);

/*!
* \brief Bind params of main function of the module to constant tensors.
*
* \return The Pass.
*/
TVM_DLL Pass BindParams(Map<String, runtime::NDArray> params);

} // namespace transform
} // namespace relax
} // namespace tvm
Expand Down
19 changes: 17 additions & 2 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,14 +165,14 @@ def FoldConstant() -> tvm.ir.transform.Pass:
return _ffi_api.FoldConstant()


def AnnotateOpKind() -> tvm.ir.transform.Pass:
def AnnotateTIROpPattern() -> tvm.ir.transform.Pass:
"""Annotate Op Pattern Kind for TIR functions
Returns
-------
ret: tvm.ir.transform.Pass
"""
return _ffi_api.AnnotateOpKind()
return _ffi_api.AnnotateTIROpPattern()


def LayoutRewrite() -> tvm.ir.transform.Pass:
Expand Down Expand Up @@ -212,6 +212,21 @@ def FuseOps(fuse_opt_level=-1) -> tvm.ir.transform.Pass:
return _ffi_api.FuseOps(fuse_opt_level)


def BindParams(params) -> tvm.ir.transform.Pass:
"""Bind params of main function of the module to constant tensors.
Parameters
----------
params : dict from str to ndarray
The map from param name to constant tensors.
Returns
-------
ret: tvm.ir.transform.Pass
"""
return _ffi_api.BindParams(params)


def _wrap_class_function_pass(pass_cls, pass_info):
"""Wrap a python class as function pass"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
* under the License.
*/
/*!
* \file src/relax/transform/annotate_opkind.cc
* \brief Annotate OpKind for TIR functions
* \file src/relax/transform/annotate_tir_op_pattern.cc
* \brief Annotate Op Pattern for TIR functions
*/
#include <tvm/relax/attrs/memory.h>
#include <tvm/relax/transform.h>
Expand Down Expand Up @@ -51,13 +51,13 @@ IRModule Annotate(IRModule m) {

namespace transform {

Pass AnnotateOpKind() {
Pass AnnotateTIROpPattern() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule mod, PassContext pc) { return Annotate(mod); };
return CreateModulePass(pass_func, 0, "VMShapeLower", {});
}

TVM_REGISTER_GLOBAL("relax.transform.AnnotateOpKind").set_body_typed(AnnotateOpKind);
TVM_REGISTER_GLOBAL("relax.transform.AnnotateTIROpPattern").set_body_typed(AnnotateTIROpPattern);

} // namespace transform

Expand Down
19 changes: 15 additions & 4 deletions src/relax/transform/layout_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,18 @@ struct LayoutRewriteInfo {
DataType dtype;
};

Array<PrimExpr> NormalizeShapeForRelax(const Array<PrimExpr>& shape){
Array<PrimExpr> res;
for (const auto& e : shape) {
res.push_back(IntImm(DataType::Int(64),e.as<IntImmNode>()->value));
}
return res;
}


class LayoutRewriteInserter : public ExprMutator {
public:
LayoutRewriteInserter(IRModule module) : module_(module) {
LayoutRewriteInserter(IRModule module) : module_(GetRef<IRModule>(module.CopyOnWrite())) {
InitializeIndexMaps();
}

Expand All @@ -63,7 +72,7 @@ class LayoutRewriteInserter : public ExprMutator {
tir::Block block = sch->Get(block_rv);
if (Optional<ObjectRef> ann = block->annotations.Get("layout_free_placeholders")) {
auto layout_free_buffers = Downcast<Array<tir::Buffer>>(ann.value());

sch->Unannotate(block_rv, "layout_free_placeholders");
Optional<Buffer> buffer;
int buffer_index = -1;
int var_index = -1;
Expand Down Expand Up @@ -183,7 +192,8 @@ class LayoutRewriteInserter : public ExprMutator {
GlobalVar layout_rewrite_func = CreateFuncFromIndexMap(pr.second);

Var new_var = builder_->Emit(Call(
call_tir_op, {layout_rewrite_func, args[pr.first], ShapeExpr(pr.second.tgt_shape)},
call_tir_op, {layout_rewrite_func, args[pr.first], ShapeExpr(NormalizeShapeForRelax(pr.second
.tgt_shape))},
{}, {DynTensorType(pr.second.tgt_shape.size(), pr.second.dtype)}));
args.Set(pr.first, new_var);
}
Expand All @@ -196,7 +206,8 @@ class LayoutRewriteInserter : public ExprMutator {
GlobalVar layout_rewrite_func = CreateFuncFromIndexMap(info);

Var new_var = builder_->Emit(
Call(call_tir_op, {layout_rewrite_func, arg, ShapeExpr(info.tgt_shape)}, {},
Call(call_tir_op, {layout_rewrite_func, arg, ShapeExpr(NormalizeShapeForRelax(info
.tgt_shape))}, {},
{DynTensorType(info.tgt_shape.size(), info.dtype)}));
return Call(call_tir_op, {call->args[0], new_var, call->args[2], call->args[3]}, {},
call->type_args);
Expand Down
57 changes: 53 additions & 4 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2086,6 +2086,47 @@ bool CheckInjectivePattern(const Array<PrimExpr>& indices_l, const Array<PrimExp
return true;
}

bool CheckAllowReusePattern(const Array<PrimExpr>& indices_l, const Array<PrimExpr>& indices_r) {
std::unordered_set<const VarNode*> vars;
for (int i = 0; i < static_cast<int>(indices_l.size()); i++) {
if (const auto* v = indices_l[i].as<VarNode>()) {
vars.insert(v);
} else {
return false;
}
}
for (const PrimExpr& e : indices_r) {
PreOrderVisit(e, [&](const ObjectRef& node){
if (const auto* v = node.as<VarNode>()) {
if(vars.count(v)) {
vars.erase(v);
}
}
return true;
});
}
return !vars.empty();
}

bool CheckFMA(Stmt body) {
if (const auto* store = body.as<BufferStoreNode>()) {
if (const auto* add = store->value.as<AddNode>()) {
if (const auto* l = add->a.as<BufferLoadNode>()) {
if(const auto* r = add->b.as<MulNode>()) {
bool incremental = store->buffer.same_as(l->buffer) && CheckSameArray(store->indices,
l->indices);
const auto* l_operand = r->a.as<BufferLoadNode>();
const auto* r_operand = r->b.as<BufferLoadNode>();
if (incremental && l_operand && r_operand) {
return CheckAllowReusePattern(store->indices, l_operand->indices) &&
CheckAllowReusePattern(store->indices, r_operand->indices);
}
}
}
}
}
return false;
}
class PatternKindAnalyzer: public StmtExprVisitor {
void VisitStmt_(const BufferStoreNode* op) final {
store_indices_ = op->indices;
Expand Down Expand Up @@ -2127,15 +2168,23 @@ class PatternKindAnalyzer: public StmtExprVisitor {
kind_ = std::max(kind_, index_pair_pattern);
return;
}

bool has_reduction = false;
for (IterVar it : op->iter_vars) {
if (it->iter_type == kCommReduce) {
has_reduction =true;
break;
}
}
if (has_reduction) {
if (CheckFMA(op->body)) {
kind_ = std::max(kind_, relay::kOutEWiseFusable);
} else {
kind_ = std::max(kind_, relay::kCommReduce);
return;
}
} else {
kind_ = relay::kOpaque;
}

kind_ = relay::kOpaque;

}

Array<PrimExpr> store_indices_;
Expand Down
Loading

0 comments on commit 7c14d07

Please sign in to comment.