Skip to content

Commit

Permalink
[Unity][Pass] LambdaLift pass (#14012)
Browse files Browse the repository at this point in the history
  • Loading branch information
yongwww authored Feb 16, 2023
1 parent 85b8a41 commit ab43ba8
Show file tree
Hide file tree
Showing 6 changed files with 855 additions and 0 deletions.
57 changes: 57 additions & 0 deletions include/tvm/relax/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,63 @@ TVM_DLL bool IsBaseOf(const StructInfo& base, const StructInfo& derived,
TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, const StructInfo& rhs,
arith::Analyzer* ana = nullptr);

//-----------------------------------
// General IR analysis
//-----------------------------------
/*!
* \brief Get all bound variables from expression expr.
*
* Bound variables are all variables that are declared in the expr.
* They only have meaning inside that expr, and can only be used in it.
*
* \param expr the expression.
*
* \return List of bound vars, in the PostDFS order in the expression.
*/
TVM_DLL tvm::Array<Var> BoundVars(const Expr& expr);

/*!
* \brief Get free type parameters from expression expr.
*
* Free variables are variables that are not bound by a
* varbinding or a function parameter in the context.
*
* \param expr the expression.
*
* \return List of free vars, in the PostDFS order in the expression.
*/
TVM_DLL tvm::Array<Var> FreeVars(const Expr& expr);

/*!
* \brief Get all variables from expression expr.
*
* \param expr the expression.
*
* \return List of all vars, in the PostDFS order in the expression.
*/
TVM_DLL tvm::Array<Var> AllVars(const Expr& expr);

/*!
* \brief Get all global variables used in calls in expression expr.
*
* \param expr the expression.
*
* \return List of all global variables called in expr.
*/
TVM_DLL tvm::Array<GlobalVar> CalledGlobalVars(const Expr& expr);

/*!
* \brief Get all global variables from expression expr.
*
* AllVars is a superset of BoundVars and FreeVars.
* The union of BoundVars and FreeVars is Allvars.
*
* \param expr the expression.
*
* \return List of all global variables, in the PostDFS order in the expression.
*/
TVM_DLL tvm::Array<GlobalVar> AllGlobalVars(const Expr& expr);

/*!
* \brief Annotate Op Pattern Kind for PrimFunc, which is used in relax FuseOps.
*
Expand Down
10 changes: 10 additions & 0 deletions python/tvm/relax/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ def ToNonDataflow() -> tvm.ir.transform.Pass:
return _ffi_api.ToNonDataflow() # type: ignore


def LambdaLift():
"""A pass that lifts local functions into global.
Returns
-------
ret : tvm.ir.transform.Pass
"""
return _ffi_api.LambdaLift()


def CallTIRRewrite() -> tvm.ir.transform.Pass:
"""Perform explicit tensor allocation for call_tir.
Expand Down
173 changes: 173 additions & 0 deletions src/relax/analysis/analysis.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
*
* \file analysis.cc
*
* \brief Analysis functions for Relax.
*/

#include <tvm/relax/analysis.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/tir/expr_functor.h>

namespace tvm {
namespace relax {

template <typename T>
struct InsertionSet {
std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual> set;
std::vector<T> data;
void Insert(const T& t) {
if (set.count(t) == 0) {
set.insert(t);
data.push_back(t);
}
}
};

class VarVisitor : protected ExprVisitor {
public:
Array<Var> Free(const Expr& expr) {
this->VisitExpr(expr);
Array<Var> ret;
for (const auto& v : vars_.data) {
if (bound_vars_.set.count(v) == 0) {
ret.push_back(v);
}
}
return ret;
}

Array<Var> Collect() {
Array<Var> ret;
for (const auto& v : bound_vars_.data) {
ret.push_back(v);
}
return ret;
}

Array<Var> Bound(const Expr& expr) {
this->VisitExpr(expr);
return Collect();
}

Array<Var> All(const Expr& expr) {
this->VisitExpr(expr);
Array<Var> ret;
for (const auto& v : vars_.data) {
ret.push_back(v);
}
return ret;
}

Array<GlobalVar> AllGlobalVars(const Expr& expr) {
this->VisitExpr(expr);
Array<GlobalVar> ret;
for (const auto& v : global_vars_.data) {
ret.push_back(v);
}
return ret;
}

Array<GlobalVar> CalledGlobalVars(const Expr& expr) {
this->VisitExpr(expr);
Array<GlobalVar> ret;
for (const auto& v : called_global_vars_.data) {
ret.push_back(v);
}
return ret;
}

void MarkBounded(const Var& v) {
bound_vars_.Insert(v);
vars_.Insert(v);
}

void VisitExpr_(const VarNode* var) final { vars_.Insert(GetRef<Var>(var)); }

void VisitExpr_(const FunctionNode* op) final {
for (const auto& param : op->params) {
MarkBounded(param);
}
VisitExpr(op->body);
}

void VisitExpr_(const GlobalVarNode* op) final { global_vars_.Insert(GetRef<GlobalVar>(op)); }

void VisitExpr_(const CallNode* call_node) final {
VisitSpan(call_node->span);
VisitExpr(call_node->op);

for (StructInfo sinfo_arg : call_node->sinfo_args) {
VisitExprDepStructInfoField(sinfo_arg);
}

for (Expr arg : call_node->args) {
VisitExpr(arg);
}

if (const GlobalVarNode* global_var_node = call_node->op.as<GlobalVarNode>()) {
called_global_vars_.Insert(GetRef<GlobalVar>(global_var_node));
}
}

void VisitBinding_(const VarBindingNode* binding) final {
MarkBounded(binding->var);
VisitExpr(binding->value);
VisitVarDef(binding->var);
}

void VisitBinding_(const MatchCastNode* binding) final {
MarkBounded(binding->var);
ExprVisitor::VisitBinding_(binding);
}

private:
InsertionSet<Var> vars_;
InsertionSet<Var> bound_vars_;
InsertionSet<GlobalVar> global_vars_;
InsertionSet<GlobalVar> called_global_vars_;
};

tvm::Array<Var> FreeVars(const Expr& expr) { return VarVisitor().Free(expr); }

tvm::Array<Var> BoundVars(const Expr& expr) { return VarVisitor().Bound(expr); }

tvm::Array<Var> AllVars(const Expr& expr) { return VarVisitor().All(expr); }

tvm::Array<GlobalVar> AllGlobalVars(const Expr& expr) { return VarVisitor().AllGlobalVars(expr); }

tvm::Array<GlobalVar> CalledGlobalVars(const Expr& expr) {
return VarVisitor().CalledGlobalVars(expr);
}

TVM_REGISTER_GLOBAL("relax.analysis.free_vars").set_body_typed(FreeVars);

TVM_REGISTER_GLOBAL("relax.analysis.bound_vars").set_body_typed(BoundVars);

TVM_REGISTER_GLOBAL("relax.analysis.all_vars").set_body_typed(AllVars);

TVM_REGISTER_GLOBAL("relax.analysis.all_global_vars").set_body_typed(AllGlobalVars);

TVM_REGISTER_GLOBAL("relax.analysis.called_global_vars").set_body_typed(CalledGlobalVars);

} // namespace relax
} // namespace tvm
Loading

0 comments on commit ab43ba8

Please sign in to comment.