Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
lint

lint

lint

fix lint
  • Loading branch information
MarisaKirisame committed Aug 23, 2020
1 parent aae096a commit 9cea77a
Show file tree
Hide file tree
Showing 13 changed files with 139 additions and 83 deletions.
28 changes: 28 additions & 0 deletions include/tvm/relay/feature.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,34 @@ inline FeatureSet DetectFeature(const Expr& expr, const IRModule& mod) {
return DetectFeature(expr) + DetectFeature(mod);
}

/*!
* \brief Check the feature of the program.
*
* \param expr The expression.
* \param fs The feature set of the program.
*/
void CheckFeature(const RelayExpr& expr, const FeatureSet& fs);

/*!
* \brief Check the feature of the program.
*
* \param mod The module.
* \param fs The feature set of the program.
*/
void CheckFeature(const IRModule& mod, const FeatureSet& fs);

/*!
* \brief Check the feature of the program.
*
* \param expr The expression.
* \param mod The module.
* \param fs The feature set of the program.
*/
inline void CheckFeature(const RelayExpr& expr, const IRModule& mod, const FeatureSet& fs) {
CheckFeature(expr, fs);
CheckFeature(mod, fs);
}

} // namespace relay
} // namespace tvm

Expand Down
9 changes: 9 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,15 @@ TVM_DLL Pass ToBasicBlockNormalForm();
*/
TVM_DLL Pass ToANormalForm();

/*!
* \brief ToANormalForm but on incomplete graph.
*
* \param RelayExpr the graph.
*
* \return The transformed program.
*/
TVM_DLL Expr ToANormalForm(const Expr& expr);

/*!
* \brief Turn an expression into continuation passing style(CPS).
*
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""A prelude containing useful global functions and ADT definitions."""
from tvm.ir import IRModule, TypeCall
from tvm import relay

from .ty import GlobalTypeVar, TensorType, Any, scalar_type
from .expr import Var, GlobalVar, If, const
Expand Down Expand Up @@ -1237,6 +1238,7 @@ def __init__(self, mod=None):
mod = IRModule()
self.mod = mod
self.load_prelude()
self.mod = relay.transform.ToANormalForm()(self.mod)

def get_name(self, canonical, dtype):
"""Get name corresponding to the canonical name"""
Expand Down
18 changes: 14 additions & 4 deletions src/relay/analysis/feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ FeatureSet DetectFeature(const Expr& expr) {
ExprVisitor::VisitExpr(expr);
} else {
if (!IsAtomic(expr)) {
std::cout << AsText(expr) << std::endl;
fs += fGraph;
}
}
Expand Down Expand Up @@ -88,10 +89,8 @@ FeatureSet DetectFeature(const Expr& expr) {

FeatureSet DetectFeature(const IRModule& mod) {
FeatureSet fs = FeatureSet::No();
if (mod.defined()) {
for (const auto& f : mod->functions) {
fs += DetectFeature(f.second);
}
for (const auto& f : mod->functions) {
fs += DetectFeature(f.second);
}
return fs;
}
Expand All @@ -106,5 +105,16 @@ Array<Integer> PyDetectFeature(const Expr& expr, const Optional<IRModule>& mod)

TVM_REGISTER_GLOBAL("relay.analysis.detect_feature").set_body_typed(PyDetectFeature);

void CheckFeature(const Expr& expr, const FeatureSet& fs) {
CHECK(DetectFeature(expr).is_subset_of(fs)) << AsText(expr, false) << "\nhas more feature then\n"
<< fs << "supported";
}

void CheckFeature(const IRModule& mod, const FeatureSet& fs) {
for (const auto& f : mod->functions) {
CheckFeature(f.second, fs);
}
}

} // namespace relay
} // namespace tvm
22 changes: 16 additions & 6 deletions src/relay/transforms/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <tvm/ir/type_functor.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/feature.h>
#include <tvm/relay/transform.h>
#include <tvm/te/operation.h>

Expand Down Expand Up @@ -81,7 +82,7 @@ Type WithGradientType(const Type& t) {
Expr DeGlobal(const Optional<IRModule>& mod, const Expr& e) {
const auto* x = e.as<GlobalVarNode>();

if (mod.defined() && (x)) {
if (mod.defined() && x) {
BaseFunc base_func = mod.value()->Lookup(GetRef<GlobalVar>(x));
if (auto* n = base_func.as<FunctionNode>()) {
return n->body;
Expand Down Expand Up @@ -404,7 +405,7 @@ Expr GetRev(const Type& forward_type, const Expr& e, LetList* ll) {

/*! \brief ReverseType(t) -> t. Get the original value. */
Expr GetValue(const Type& forward_type, const Expr& e, LetList* ll) {
auto val = [&](const Expr& e) { return GetField(e, 0); };
auto val = [&](const Expr& e) { return ll->Push(GetField(e, 0)); };
auto val_type = [&](const Type& forward_type) { return forward_type; };
return LiftTensor(val, val_type, forward_type, e, ll);
}
Expand Down Expand Up @@ -508,16 +509,19 @@ struct ReverseAD : ExprMutator {
return Call(bpv, {});
}),
TupleType::Empty(), {});
ll->Push(RefWrite(bp, nbp));
ll->Push(RefWrite(bp, transform::ToANormalForm(nbp)));
// TODO(@M.K.): ToANF should be called on rev. Enhance ToANF for that.
return ret;
});
}
return ExprMutator::VisitExpr_(call);
}

Expr VisitExpr_(const ConstantNode* op) final {
Expr e = GetRef<Expr>(op);
return Pair(e, RefCreate(ZerosLike(e)));
return LetList::With([&](LetList* ll) {
Expr e = ll->Push(GetRef<Expr>(op));
return Pair(e, RefCreate(ZerosLike(e)));
});
}

Expr VisitExpr_(const IfNode* op) final {
Expand Down Expand Up @@ -568,6 +572,10 @@ bool MissingGrad(const Expr& e) {
}

Expr Gradient(const Expr& re, const Optional<IRModule>& mod) {
CheckFeature(re, FeatureSet::All() - fGraph);
if (mod.defined()) {
CheckFeature(mod.value(), FeatureSet::All() - fGraph);
}
auto e = DeGlobal(mod, re);
auto f = e.as<FunctionNode>();
CHECK(f) << "input need to be a function";
Expand Down Expand Up @@ -619,7 +627,9 @@ Expr Gradient(const Expr& re, const Optional<IRModule>& mod) {
};
return Pair(get_final_result(c, f->body->checked_type()), Tuple(ret));
});
return Function(f->params, body, GradRetType(GetRef<Function>(f)), {});
auto ret = Function(f->params, body, GradRetType(GetRef<Function>(f)), {});
CheckFeature(ret, FeatureSet::All() - fGraph);
return ret;
}

TVM_REGISTER_GLOBAL("relay._transform.gradient").set_body_typed(Gradient);
Expand Down
8 changes: 5 additions & 3 deletions src/relay/transforms/lazy_gradient_init.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
#include <tvm/node/structural_equal.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/feature.h>
#include <tvm/relay/transform.h>

#include "let_list.h"
Expand All @@ -82,7 +83,6 @@ class InputVisitor : public ExprFunctor<Expr(const Expr&, const Type&)> {
explicit InputVisitor(IRModule module) : module_(module) {}

Expr VisitExpr_(const VarNode* op, const Type& t) final {
std::cout << op->type_annotation << std::endl;
return WrapExpr(GetRef<Var>(op), op->type_annotation);
}

Expand All @@ -93,7 +93,7 @@ class InputVisitor : public ExprFunctor<Expr(const Expr&, const Type&)> {
private:
IRModule module_;

Expr WrapExpr(const Expr expr, const Type& type) {
Expr WrapExpr(const Expr expr, const Type& type, LetList* ll) {
if (type.as<TensorTypeNode>()) {
return Call(module_->GetConstructor("GradCell", "Raw"), {expr}, Attrs(), {type});
} else if (auto* type_anno = type.as<TupleTypeNode>()) {
Expand Down Expand Up @@ -293,7 +293,9 @@ class LazyGradientInitializer : public ExprMutator, public TypeMutator {
};

Expr LazyGradientInit(const Expr& e, IRModule mod) {
return LazyGradientInitializer(mod).Transform(e);
auto ret = LazyGradientInitializer(mod).Transform(e);
CheckFeature(e, mod, FeatureSet::All() - fGraph);
return ret;
}

namespace transform {
Expand Down
10 changes: 7 additions & 3 deletions src/relay/transforms/partial_eval.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
#include <tvm/ir/type_functor.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/feature.h>
#include <tvm/relay/interpreter.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/transform.h>
Expand Down Expand Up @@ -1195,9 +1196,12 @@ IRModule PartialEval(const IRModule& m) {
namespace transform {

Pass PartialEval() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule m, PassContext pc) { return relay::PartialEval(m); };
return CreateModulePass(pass_func, 1, "PartialEvaluate", {});
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule m,
PassContext pc) {
CheckFeature(m, FeatureSet::All() - fGraph);
return relay::PartialEval(m);
};
return CreateModulePass(pass_func, 1, "PartialEval", {});
}

TVM_REGISTER_GLOBAL("relay._transform.PartialEvaluate").set_body_typed(PartialEval);
Expand Down
58 changes: 30 additions & 28 deletions src/relay/transforms/to_a_normal_form.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,32 +252,6 @@ Expr Fill::VisitExpr_(const MatchNode* m, const Var& v) {
return Compound(e, Match(data, clauses, m->complete), v);
}

Expr ToANormalFormAux(const Expr& e) {
/* When you lift a lambda, what is inside is also being lift.
*
* So we must determine the scope of the lambda before determining the scope of it's body.
*
* To make this more principled,
* we always determine the scope of parent before determining the scope of children.
*
* So we calculate all the dependency between nodes.
*/
support::Arena arena;
DependencyGraph dg = DependencyGraph::Create(&arena, e);
/* In order to model new subscopes created by lambda, if else and pattern matching,
* we also assign scope to edge as well.
* The scope of an edge is either the parent's scope, or a new subscope of the parent's scope.
*
* So, the scope of the whole expr is global.
* The scope of any subexpr, is the lowest common ancestor of all incoming edge.
*
* Every scope additionally contain a LetList which collect all value of that scope.
* We do an additional pass to fill all the LetList and we are done.
*/
std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg);
return Fill::ToANormalForm(e, dg, &scopes.first);
}

IRModule ToANormalForm(const IRModule& m) {
DLOG(INFO) << "ToANF:" << std::endl << m;

Expand All @@ -288,7 +262,7 @@ IRModule ToANormalForm(const IRModule& m) {
if (const auto* n = it.second.as<FunctionNode>()) {
if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
}
Expr ret = TransformF([&](const Expr& e) { return ToANormalFormAux(e); }, it.second);
Expr ret = TransformF([&](const Expr& e) { return transform::ToANormalForm(e); }, it.second);
CHECK_EQ(FreeVars(ret).size(), 0)
<< AsText(ret) << "should not has free vars: " << FreeVars(ret);
updates.Set(it.first, Downcast<Function>(ret));
Expand All @@ -305,13 +279,41 @@ IRModule ToANormalForm(const IRModule& m) {

namespace transform {

Expr ToANormalForm(const Expr& e) {
/* When you lift a lambda, what is inside is also being lift.
*
* So we must determine the scope of the lambda before determining the scope of it's body.
*
* To make this more principled,
* we always determine the scope of parent before determining the scope of children.
*
* So we calculate all the dependency between nodes.
*/
support::Arena arena;
DependencyGraph dg = DependencyGraph::Create(&arena, e);
/* In order to model new subscopes created by lambda, if else and pattern matching,
* we also assign scope to edge as well.
* The scope of an edge is either the parent's scope, or a new subscope of the parent's scope.
*
* So, the scope of the whole expr is global.
* The scope of any subexpr, is the lowest common ancestor of all incoming edge.
*
* Every scope additionally contain a LetList which collect all value of that scope.
* We do an additional pass to fill all the LetList and we are done.
*/
std::pair<NodeScopeMap, ExprSet> scopes = CalcScope(dg);
return Fill::ToANormalForm(e, dg, &scopes.first);
}

Pass ToANormalForm() {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule m, PassContext pc) { return relay::ToANormalForm(m); };
return CreateModulePass(pass_func, 1, "ToANormalForm", {});
}

TVM_REGISTER_GLOBAL("relay._transform.ToANormalForm").set_body_typed(ToANormalForm);
TVM_REGISTER_GLOBAL("relay._transform.ToANormalForm").set_body_typed([]() {
return ToANormalForm();
});

} // namespace transform

Expand Down
3 changes: 3 additions & 0 deletions src/relay/transforms/to_cps.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
*/
#include <tvm/ir/type_functor.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/feature.h>
#include <tvm/relay/pattern_functor.h>
#include <tvm/relay/transform.h>

Expand Down Expand Up @@ -301,11 +302,13 @@ Function ToCPS(const Function& f, const IRModule& m, CPSMap* cm) {
}

Function ToCPS(const Function& f, const IRModule& m) {
CheckFeature(f, m, FeatureSet::All() - fGraph);
CPSMap cps;
return ToCPS(f, m, &cps);
}

Function UnCPS(const Function& f) {
CheckFeature(f, FeatureSet::All() - fGraph);
CHECK_GT(f->params.size(), 0);
std::vector<Var> new_params;
for (const auto& p : f->params) {
Expand Down
2 changes: 0 additions & 2 deletions tests/python/relay/test_analysis_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def test_prelude():
Feature.fIf,
Feature.fConstructor,
Feature.fMatch,
Feature.fGraph
])


Expand All @@ -65,7 +64,6 @@ def test_ad():
Feature.fRefCreate,
Feature.fRefRead,
Feature.fRefWrite,
Feature.fGraph
])


Expand Down
Loading

0 comments on commit 9cea77a

Please sign in to comment.