-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][Training] Make AutoDiff thread through global function. #6336
Conversation
@@ -85,7 +85,7 @@ Expr DeGlobal(const Optional<IRModule>& mod, const Expr& e) { | |||
if (mod.defined() && x) { | |||
BaseFunc base_func = mod.value()->Lookup(GetRef<GlobalVar>(x)); | |||
if (auto* n = base_func.as<FunctionNode>()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need if else here? doesn't every GlobalVar map to a Function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It could point to a PrimFunction
src/relay/transforms/gradient.cc
Outdated
if (ad_gvars->count(orig_gv) == 0) { | ||
GlobalVar gv(op->name_hint + "_grad"); | ||
(*ad_gvars)[orig_gv] = gv; | ||
Function orig_f = Downcast<Function>(mod.value()->Lookup(GetRef<GlobalVar>(op))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mod.value()->Lookup(orig_gv)
src/relay/transforms/gradient.cc
Outdated
if (ad_gvars->count(orig_gv) == 0) { | ||
GlobalVar gv(op->name_hint + "_grad"); | ||
(*ad_gvars)[orig_gv] = gv; | ||
Function orig_f = Downcast<Function>(mod.value()->Lookup(GetRef<GlobalVar>(op))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it be a good idea to DeDup here?
m[q] = relay.Function([y], d(d(y))) | ||
g = GlobalVar('grad') | ||
m[g] = tvm.relay.transform.gradient(q, m) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add type and value check
src/relay/transforms/gradient.cc
Outdated
@@ -438,12 +449,17 @@ Expr BPEmpty() { | |||
|
|||
struct ReverseAD : ExprMutator { | |||
using ADVarMap = std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual>; | |||
|
|||
using ADGVarMap = std::unordered_map<GlobalVar, GlobalVar, ObjectPtrHash, ObjectPtrEqual>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better name here
@@ -438,12 +449,17 @@ Expr BPEmpty() { | |||
|
|||
struct ReverseAD : ExprMutator { | |||
using ADVarMap = std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual>; | |||
|
|||
using ADGVarMap = std::unordered_map<GlobalVar, GlobalVar, ObjectPtrHash, ObjectPtrEqual>; | |||
Optional<IRModule> mod; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this really optional?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes - sometime it is call on naked expr. i could refactor to require it to always be here, if this is what you want
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you just add a note about this to your on-going doc on things we might want to refactor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you mean the list from ~2 years ago? that is very out of dated. I think we can just grep -r for TODO(@M.K.) in the source.
} | ||
return Call(bpv, {}); | ||
}); | ||
Expr nbp = Function({}, nbp_body, TupleType::Empty(), {}); | ||
ll->Push(RefWrite(bp, transform::ToANormalForm(nbp))); | ||
// TODO(@M.K.): ToANF should be called on rev. Enhance ToANF for that. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this comment still valid? we should add to the TODO list if so
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, i just added it yesterday, what list
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
small nits
…he#6336) * save * lint * lint * fix warning * fix test * save
…he#6336) * save * lint * lint * fix warning * fix test * save
…he#6336) * save * lint * lint * fix warning * fix test * save
@junrushao1994 @jroesch @vinx13 @altanh @icemelon9 @hypercubestart @t-vi can you guys help review?