Skip to content

Commit

Permalink
[Relay][Training] Make AutoDiff thread through global function. (apac…
Browse files Browse the repository at this point in the history
…he#6336)

* save

* lint

* lint

* fix warning

* fix test

* save
  • Loading branch information
MarisaKirisame authored and kevinthesun committed Sep 18, 2020
1 parent bab6148 commit 984666c
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 25 deletions.
2 changes: 1 addition & 1 deletion src/printer/doc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ Doc Doc::Indent(int indent, Doc doc) {
}

Doc Doc::StrLiteral(const std::string& value, std::string quote) {
// TODO(M.K.): add escape.
// TODO(@M.K.): add escape.
Doc doc;
return doc << quote << value << quote;
}
Expand Down
106 changes: 83 additions & 23 deletions src/relay/transforms/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ Type WithGradientType(const Type&);
Expr FirstOrderGradient(const Expr& e, const Optional<IRModule>& mod);

Type WithGradientType(const Type& t) {
// TODO(M.K.): stricter checking
// TODO(@M.K.): stricter checking
auto ty = t.as<FuncTypeNode>();
CHECK(ty) << "input should be a function";
return FuncType(ty->arg_types, TupleType({ty->ret_type, TupleType(ty->arg_types)}), {}, {});
Expand All @@ -85,7 +85,7 @@ Expr DeGlobal(const Optional<IRModule>& mod, const Expr& e) {
if (mod.defined() && x) {
BaseFunc base_func = mod.value()->Lookup(GetRef<GlobalVar>(x));
if (auto* n = base_func.as<FunctionNode>()) {
return n->body;
return GetRef<Function>(n);
} else {
return e;
}
Expand Down Expand Up @@ -338,11 +338,22 @@ Expr FirstOrderGradient(const Expr& re, const Optional<IRModule>& mod) {

TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient").set_body_typed(FirstOrderGradient);

Type bpt = RelayRefType(FuncType({}, TupleType(Array<Type>()), {}, {}));

struct ReverseADType : TypeMutator {
Type VisitType_(const TensorTypeNode* ttn) final {
Type t = GetRef<Type>(ttn);
return TupleType({t, RelayRefType(t)});
}

Type VisitType_(const FuncTypeNode* ftn) final {
std::vector<Type> arg_types;
for (const auto& t : ftn->arg_types) {
arg_types.push_back(VisitType(t));
}
arg_types.push_back(bpt);
return FuncType(arg_types, ftn->ret_type, ftn->type_params, ftn->type_constraints);
}
};

Type ReverseType(const Type& t) { return ReverseADType()(t); }
Expand Down Expand Up @@ -438,12 +449,18 @@ Expr BPEmpty() {

struct ReverseAD : ExprMutator {
using ADVarMap = std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual>;

using ADGlobalVarMap = std::unordered_map<GlobalVar, GlobalVar, ObjectPtrHash, ObjectPtrEqual>;
Optional<IRModule> mod;
// TODO(@M.K.) refactor AD to always use mod.
Var bp;
std::shared_ptr<ADVarMap> ad_vars;
std::shared_ptr<ADGlobalVarMap> ad_gvars;
const OpAttrMap<FPrimalGradient> rev_map = Op::GetAttrMap<FPrimalGradient>("FPrimalGradient");

explicit ReverseAD(const Var& bp, std::shared_ptr<ADVarMap> ad_vars) : bp(bp), ad_vars(ad_vars) {}
explicit ReverseAD(const Optional<IRModule>& mod, const Var& bp,
const std::shared_ptr<ADVarMap>& ad_vars,
const std::shared_ptr<ADGlobalVarMap>& ad_gvars)
: mod(mod), bp(bp), ad_vars(ad_vars), ad_gvars(ad_gvars) {}

Expr VisitExpr_(const OpNode* op) final {
LOG(FATAL) << "op should only be inside call";
Expand Down Expand Up @@ -481,9 +498,8 @@ struct ReverseAD : ExprMutator {
Expr nbp = Function({}, LetList::With([&](LetList* ll) {
// we need a new ReverseAD visitor to avoid clobbering the bp local var
auto dup_bp = ll->Push(BPEmpty());
ReverseAD dup_diff(dup_bp, ad_vars);
auto dup_ad = ll->Push(dup_diff.VisitExpr(DeDup(x)));

auto dup_ad =
ll->Push(ReverseAD(mod, dup_bp, ad_vars, ad_gvars)(DeDup(x)));
TransferGrads(call->checked_type(), ret, dup_ad, ll);
ll->Push(Call(RefRead(dup_bp), {}));
return Call(bpv, {});
Expand Down Expand Up @@ -518,22 +534,29 @@ struct ReverseAD : ExprMutator {
orig_var->checked_type_ = call->checked_type();
auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll));
auto bpv = ll->Push(RefRead(bp));
Expr nbp = Function({}, LetList::With([&](LetList* ll) {
tvm::Array<Expr> rev =
rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll));
CHECK(args.size() == rev.size());
for (size_t i = 0; i < args.size(); ++i) {
UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll);
}
return Call(bpv, {});
}),
TupleType::Empty(), {});
Expr nbp_body = LetList::With([&](LetList* ll) {
tvm::Array<Expr> rev = rev_map[op_ref](orig, GetGrad(call->checked_type(), ret, ll));
CHECK(args.size() == rev.size());
for (size_t i = 0; i < args.size(); ++i) {
UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll);
}
return Call(bpv, {});
});
Expr nbp = Function({}, nbp_body, TupleType::Empty(), {});
ll->Push(RefWrite(bp, transform::ToANormalForm(nbp)));
// TODO(@M.K.): ToANF should be called on rev. Enhance ToANF for that.
return ret;
});
} else if (call->op.as<ConstructorNode>()) {
return ExprMutator::VisitExpr_(call);
} else {
std::vector<Expr> args;
for (const auto& arg : call->args) {
args.push_back(VisitExpr(arg));
}
args.push_back(bp);
return Call(VisitExpr(call->op), args);
}
return ExprMutator::VisitExpr_(call);
}

Expr VisitExpr_(const ConstantNode* op) final {
Expand All @@ -559,6 +582,39 @@ struct ReverseAD : ExprMutator {
return ad_vars->at(var_ref);
}

Expr VisitExpr_(const GlobalVarNode* op) final {
// todo: concatenating string to add attribute seems like a brittle hack.
// maybe get module indexed by a rose tree of string?
CHECK(mod.defined());
auto orig_gv = GetRef<GlobalVar>(op);
if (ad_gvars->count(orig_gv) == 0) {
GlobalVar gv(op->name_hint + "_grad");
(*ad_gvars)[orig_gv] = gv;
Function orig_f = Downcast<Function>(DeDup(mod.value()->Lookup(orig_gv)));
std::vector<Var> params;
for (const auto& p : orig_f->params) {
params.push_back(Downcast<Var>(VisitExpr(p)));
}
params.push_back(bp);
Expr body = VisitExpr(orig_f->body);
Function f(params, body, VisitType(orig_f->ret_type), orig_f->type_params, orig_f->attrs);
std::cout << "gv " << op->name_hint << ": " << AsText(f, false) << std::endl;
mod.value()->Add(gv, f);
}
return ad_gvars->at(orig_gv);
}

Expr VisitExpr_(const FunctionNode* op) final {
std::vector<Var> params;
for (const auto& var : op->params) {
params.push_back(Downcast<Var>(VisitExpr(var)));
}
auto new_bp = Var("bp", bpt);
params.push_back(new_bp);
return Function(params, ReverseAD(mod, new_bp, ad_vars, ad_gvars)(op->body),
VisitType(op->ret_type), op->type_params, op->attrs);
}

Type VisitType(const Type& t) final { return t.defined() ? ReverseType(t) : t; }
};

Expand Down Expand Up @@ -604,12 +660,16 @@ Expr Gradient(const Expr& re, const Optional<IRModule>& mod) {
}
CHECK(!MissingGrad(e)) << "input has operators with missing gradients";
Expr body = LetList::With([&](LetList* ll) {
Var bp = ll->Push(BPEmpty());
Expr rev = ReverseAD(bp, std::make_shared<ReverseAD::ADVarMap>())(e);
std::vector<Expr> args;
Var bp = ll->Push(BPEmpty(), bpt);
Expr rev = ReverseAD(mod, bp, std::make_shared<ReverseAD::ADVarMap>(),
std::make_shared<ReverseAD::ADGlobalVarMap>())(e);
std::vector<Expr> normal_args, args;
for (const auto& p : f->params) {
args.push_back(ll->Push(Pair(p, RefCreate(ZerosLike(p)))));
auto x = ll->Push(Pair(p, RefCreate(ZerosLike(p))));
normal_args.push_back(x);
args.push_back(x);
}
args.push_back(bp);
auto c = ll->Push(Call(rev, args));
std::function<void(const Expr&, const Type&)> init_grad;
init_grad = [&](const Expr& e, const Type& t) {
Expand All @@ -626,7 +686,7 @@ Expr Gradient(const Expr& re, const Optional<IRModule>& mod) {
init_grad(c, f->body->checked_type());
ll->Push(Call(RefRead(bp), {}));
std::vector<Expr> ret;
for (const auto& a : args) {
for (const auto& a : normal_args) {
ret.push_back(RefRead(GetField(a, 1)));
}
std::function<Expr(const Expr&, const Type&)> get_final_result;
Expand Down
41 changes: 40 additions & 1 deletion tests/python/relay/test_pass_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import tvm
from tvm import te
from tvm import relay
from tvm.relay import GlobalVar
from tvm.relay.analysis import free_vars, free_type_vars
from tvm.relay import create_executor, transform
from tvm.relay.transform import gradient
Expand All @@ -29,7 +30,7 @@
import tvm.relay.op as op


def test_id():
def test_fo_id():
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
Expand All @@ -44,6 +45,21 @@ def test_id():
tvm.testing.assert_allclose(forward.asnumpy(), x.asnumpy())
tvm.testing.assert_allclose(grad.asnumpy(), np.ones_like(x.asnumpy()))

def test_id():
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.var("x", t)
func = relay.Function([x], x)
func = run_infer_type(func)
back_func = run_infer_type(gradient(func))
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
ex = create_executor()
x = rand(dtype, *shape)
forward, (grad,) = ex.evaluate(back_func)(x)
tvm.testing.assert_allclose(forward.asnumpy(), x.asnumpy())
tvm.testing.assert_allclose(grad.asnumpy(), np.ones_like(x.asnumpy()))


def test_relu():
shape = (10, 10)
Expand Down Expand Up @@ -341,5 +357,28 @@ def test_no_duplication():
counts = count_ops(gr)
assert counts['nn.dense'] == 3, "We expect 3 dense (1 forward, two backward)"


def test_global_function():
m = tvm.IRModule()
shape = (10, 10)
dtype = 'float32'
t = relay.TensorType(shape, dtype)
x = relay.Var('x', t)
d = GlobalVar('double')
m[d] = relay.Function([x], x + x)
y = relay.Var('y', t)
q = GlobalVar('q')
m[q] = relay.Function([y], d(d(y)))
g = GlobalVar('grad')
m[g] = tvm.relay.transform.gradient(q, m)
back_func = m[g]
assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, relay.TupleType([t])]))
ex = create_executor(mod=m)
x = rand(dtype, *shape)
forward, (grad,) = ex.evaluate(back_func)(x)
tvm.testing.assert_allclose(forward.asnumpy(), 4 * x.asnumpy())
tvm.testing.assert_allclose(grad.asnumpy(), 4 * np.ones_like(x.asnumpy()))


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 984666c

Please sign in to comment.