Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
lint

lint

lint

fix lint

lint

update

lint

save

save

save

lint

format

format

save

save

fix

use a form more suitable for numeric check

save
  • Loading branch information
MarisaKirisame committed Aug 24, 2020
1 parent aae096a commit 6201c5c
Show file tree
Hide file tree
Showing 16 changed files with 266 additions and 161 deletions.
36 changes: 36 additions & 0 deletions include/tvm/relay/feature.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/relay/expr.h>

#include <bitset>
#include <string>

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -124,6 +125,13 @@ class FeatureSet {
*/
bool is_subset_of(const FeatureSet& rhs) const { return ((*this) - rhs).bs_.none(); }

/*!
* \brief Pretty Print the FeatureSet.
*
* \return a string representation.
*/
std::string Print() const;

private:
std::bitset<feature_count> bs_;
FeatureSet() = default;
Expand Down Expand Up @@ -160,6 +168,34 @@ inline FeatureSet DetectFeature(const Expr& expr, const IRModule& mod) {
return DetectFeature(expr) + DetectFeature(mod);
}

/*!
* \brief Check the feature of the program.
*
* \param expr The expression.
* \param fs The feature set of the program.
*/
void CheckFeature(const RelayExpr& expr, const FeatureSet& fs);

/*!
* \brief Check the feature of the program.
*
* \param mod The module.
* \param fs The feature set of the program.
*/
void CheckFeature(const IRModule& mod, const FeatureSet& fs);

/*!
* \brief Check the feature of the program.
*
* \param expr The expression.
* \param mod The module.
* \param fs The feature set of the program.
*/
inline void CheckFeature(const RelayExpr& expr, const IRModule& mod, const FeatureSet& fs) {
CheckFeature(expr, fs);
CheckFeature(mod, fs);
}

} // namespace relay
} // namespace tvm

Expand Down
9 changes: 9 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,15 @@ TVM_DLL Pass ToBasicBlockNormalForm();
*/
TVM_DLL Pass ToANormalForm();

/*!
* \brief ToANormalForm but on incomplete graph.
*
* \param expr the graph.
*
* \return The transformed program.
*/
TVM_DLL Expr ToANormalForm(const Expr& expr);

/*!
* \brief Turn an expression into continuation passing style(CPS).
*
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/prelude.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""A prelude containing useful global functions and ADT definitions."""
from tvm.ir import IRModule, TypeCall
from tvm import relay

from .ty import GlobalTypeVar, TensorType, Any, scalar_type
from .expr import Var, GlobalVar, If, const
Expand Down Expand Up @@ -1237,6 +1238,7 @@ def __init__(self, mod=None):
mod = IRModule()
self.mod = mod
self.load_prelude()
self.mod = relay.transform.ToANormalForm()(self.mod)

def get_name(self, canonical, dtype):
"""Get name corresponding to the canonical name"""
Expand Down
51 changes: 47 additions & 4 deletions src/relay/analysis/feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,43 @@ FeatureSet DetectFeature(const Expr& expr) {
return fd.fs;
}

std::string FeatureSet::Print() const {
std::string ret;
ret += "[";
size_t detected = 0;
#define DETECT_FEATURE(FEATURE_NAME) \
++detected; \
if (bs_[FEATURE_NAME]) { \
ret += #FEATURE_NAME; \
ret += ", "; \
}
DETECT_FEATURE(fVar);
DETECT_FEATURE(fGlobalVar);
DETECT_FEATURE(fConstant);
DETECT_FEATURE(fTuple);
DETECT_FEATURE(fTupleGetItem);
DETECT_FEATURE(fFunction);
DETECT_FEATURE(fOp);
DETECT_FEATURE(fCall);
DETECT_FEATURE(fLet);
DETECT_FEATURE(fIf);
DETECT_FEATURE(fRefCreate);
DETECT_FEATURE(fRefRead);
DETECT_FEATURE(fRefWrite);
DETECT_FEATURE(fConstructor);
DETECT_FEATURE(fMatch);
DETECT_FEATURE(fGraph);
DETECT_FEATURE(fLetRec);
#undef DETECT_FEATURE
CHECK(detected == feature_count) << "some feature not printed";
ret += "]";
return ret;
}

FeatureSet DetectFeature(const IRModule& mod) {
FeatureSet fs = FeatureSet::No();
if (mod.defined()) {
for (const auto& f : mod->functions) {
fs += DetectFeature(f.second);
}
for (const auto& f : mod->functions) {
fs += DetectFeature(f.second);
}
return fs;
}
Expand All @@ -106,5 +137,17 @@ Array<Integer> PyDetectFeature(const Expr& expr, const Optional<IRModule>& mod)

TVM_REGISTER_GLOBAL("relay.analysis.detect_feature").set_body_typed(PyDetectFeature);

void CheckFeature(const Expr& expr, const FeatureSet& fs) {
auto dfs = DetectFeature(expr);
CHECK(dfs.is_subset_of(fs)) << AsText(expr, false)
<< "\nhas unsupported feature: " << (dfs - fs).Print();
}

void CheckFeature(const IRModule& mod, const FeatureSet& fs) {
for (const auto& f : mod->functions) {
CheckFeature(f.second, fs);
}
}

} // namespace relay
} // namespace tvm
55 changes: 42 additions & 13 deletions src/relay/transforms/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <tvm/ir/type_functor.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/feature.h>
#include <tvm/relay/transform.h>
#include <tvm/te/operation.h>

Expand Down Expand Up @@ -81,7 +82,7 @@ Type WithGradientType(const Type& t) {
Expr DeGlobal(const Optional<IRModule>& mod, const Expr& e) {
const auto* x = e.as<GlobalVarNode>();

if (mod.defined() && (x)) {
if (mod.defined() && x) {
BaseFunc base_func = mod.value()->Lookup(GetRef<GlobalVar>(x));
if (auto* n = base_func.as<FunctionNode>()) {
return n->body;
Expand Down Expand Up @@ -354,7 +355,7 @@ Expr LiftTensor(const std::function<Expr(const Expr& t)>& f,
LetList* ll) {
CHECK(IsAtomic(e)) << e;
if (forward_type.as<TensorTypeNode>()) {
auto ret = f(e);
auto ret = ll->Push(f(e));
ret->checked_type_ = tf(forward_type);
return ret;
} else if (auto* tt = forward_type.as<TupleTypeNode>()) {
Expand All @@ -365,9 +366,9 @@ Expr LiftTensor(const std::function<Expr(const Expr& t)>& f,
fields.push_back(field);
types.push_back(field->checked_type_);
}
auto ret = Tuple(fields);
auto ret = ll->Push(Tuple(fields));
ret->checked_type_ = TupleType(types);
return std::move(ret);
return ret;
} else {
LOG(FATAL) << "unsupported input/output type: " << tt;
throw;
Expand Down Expand Up @@ -395,9 +396,10 @@ void TransferGrads(const Type& forward_type, const Expr& from, const Expr& to, L
}
}

// TODO(@M.K.): why take Expr?
/*! \brief t -> ReverseType(t). Transform to Reverse Mode Value. */
Expr GetRev(const Type& forward_type, const Expr& e, LetList* ll) {
auto rev = [&](const Expr& e) { return Pair(e, ll->Push(RefCreate(ZerosLike(e)))); };
auto rev = [&](const Expr& e) { return Pair(e, RefCreate(ZerosLike(e))); };
auto rev_type = [&](const Type& forward_type) { return ReverseType(forward_type); };
return LiftTensor(rev, rev_type, forward_type, e, ll);
}
Expand All @@ -411,14 +413,14 @@ Expr GetValue(const Type& forward_type, const Expr& e, LetList* ll) {

/*! \brief ReverseType(t) -> t. Get the gradient. */
Expr GetGrad(const Type& forward_type, const Expr& e, LetList* ll) {
auto grad = [&](const Expr& e) { return ll->Push(RefRead(GetField(e, 1))); };
auto grad = [&](const Expr& e) { return RefRead(GetField(e, 1)); };
auto grad_type = [&](const Type& forward_type) { return forward_type; };
return LiftTensor(grad, grad_type, forward_type, e, ll);
}

void UpdateGrad(const Type& t, const Expr& arg, const Expr& grad, LetList* ll) {
if (t.as<TensorTypeNode>()) {
ll->Push(RefWrite(GetField(arg, 1), Add(ll->Push(RefRead(GetField(arg, 1))), grad)));
ll->Push(RefWrite(GetField(arg, 1), Add(RefRead(GetField(arg, 1)), grad)));
} else if (auto* tt = t.as<TupleTypeNode>()) {
for (size_t i = 0; i < tt->fields.size(); ++i) {
UpdateGrad(tt->fields[i], ll->Push(GetField(arg, i)), ll->Push(GetField(grad, i)), ll);
Expand Down Expand Up @@ -448,14 +450,32 @@ struct ReverseAD : ExprMutator {
throw;
}

Expr Remap(const Expr& e) {
struct Remapper : ExprMutator {
std::shared_ptr<ADVarMap> ad_vars;
LetList* ll;
Remapper(const std::shared_ptr<ADVarMap>& ad_vars, LetList* ll) : ad_vars(ad_vars), ll(ll) {}
Expr VisitExpr_(const VarNode* var) final {
// memoize Var -> ADVar so we don't end up with free Vars when checkpointing
auto var_ref = GetRef<Var>(var);
if (ad_vars->count(var_ref) == 0) {
return var_ref;
} else {
return GetValue(var_ref->checked_type(), ad_vars->at(var_ref), ll);
}
}
};
return LetList::With([&](LetList* ll) { return Remapper(ad_vars, ll)(e); });
}

Expr VisitCheckpoint(const CallNode* call) {
const OpNode* op_node = call->op.as<OpNode>();
CHECK(op_node) << "expected op in call";
Op op_ref = GetRef<Op>(op_node);
CHECK(op_ref->name == "annotation.checkpoint") << "expected checkpoint annotation";
auto x = call->args[0];
return LetList::With([&](LetList* ll) {
auto x_var = ll->Push(x);
auto x_var = ll->Push(Remap(x));
auto ret = ll->Push(GetRev(call->checked_type(), x_var, ll));
auto bpv = ll->Push(RefRead(bp));
Expr nbp = Function({}, LetList::With([&](LetList* ll) {
Expand Down Expand Up @@ -508,16 +528,19 @@ struct ReverseAD : ExprMutator {
return Call(bpv, {});
}),
TupleType::Empty(), {});
ll->Push(RefWrite(bp, nbp));
ll->Push(RefWrite(bp, transform::ToANormalForm(nbp)));
// TODO(@M.K.): ToANF should be called on rev. Enhance ToANF for that.
return ret;
});
}
return ExprMutator::VisitExpr_(call);
}

Expr VisitExpr_(const ConstantNode* op) final {
Expr e = GetRef<Expr>(op);
return Pair(e, RefCreate(ZerosLike(e)));
return LetList::With([&](LetList* ll) {
Expr e = ll->Push(GetRef<Expr>(op));
return Pair(e, RefCreate(ZerosLike(e)));
});
}

Expr VisitExpr_(const IfNode* op) final {
Expand All @@ -528,7 +551,7 @@ struct ReverseAD : ExprMutator {
Expr VisitExpr_(const VarNode* var) final {
// memoize Var -> ADVar so we don't end up with free Vars when checkpointing
auto var_ref = GetRef<Var>(var);
if (!ad_vars->count(var_ref)) {
if (ad_vars->count(var_ref) == 0) {
auto res = Downcast<Var>(ExprMutator::VisitExpr_(var));
(*ad_vars)[var_ref] = res;
}
Expand Down Expand Up @@ -568,6 +591,10 @@ bool MissingGrad(const Expr& e) {
}

Expr Gradient(const Expr& re, const Optional<IRModule>& mod) {
CheckFeature(re, FeatureSet::All() - fGraph);
if (mod.defined()) {
CheckFeature(mod.value(), FeatureSet::All() - fGraph);
}
auto e = DeGlobal(mod, re);
auto f = e.as<FunctionNode>();
CHECK(f) << "input need to be a function";
Expand Down Expand Up @@ -619,7 +646,9 @@ Expr Gradient(const Expr& re, const Optional<IRModule>& mod) {
};
return Pair(get_final_result(c, f->body->checked_type()), Tuple(ret));
});
return Function(f->params, body, GradRetType(GetRef<Function>(f)), {});
auto ret = Function(f->params, body, GradRetType(GetRef<Function>(f)), {});
CheckFeature(ret, FeatureSet::All() - fGraph);
return ret;
}

TVM_REGISTER_GLOBAL("relay._transform.gradient").set_body_typed(Gradient);
Expand Down
Loading

0 comments on commit 6201c5c

Please sign in to comment.