-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RELAY] Basic block normal form #6152
Conversation
Thanks @eric-haibin-lin , it would also be useful to have a verification analysis util that verifies the IR is in the basic block form |
@eric-haibin-lin : Thanks for the PR! I believe this feature will be really helpful for certain detection & mutation process. |
cc @MarisaKirisame @mbrookhart @electriclilies @jroesch @junrushao1994 would be great if you can help to take a look at the PR |
src/relay/transforms/pass_util.h
Outdated
* \param lifted_exprs the output set of expressions whose scope is lifted due to dependency | ||
*/ | ||
void CalcScope(const DependencyGraph& dg, | ||
std::unordered_map<DependencyGraph::Node*, Scope>* expr_scope, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why use pointer instead of returning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@MarisaKirisame thanks for the review! For BBlock I want to get a set of expressions whose scope should be lifted and rewritten with let bindings. I found CalcScope
already performs finding the LCA for all nodes in the dependency graph, so i just added a few lines of code so that the set of corresponding expressions are also returned by this function. However, setting the return type to std::pair<std::unordered_map<node, scope> std::unordered_set<expr>>
causes cpplint to fails due to the long length. So I instead used pointers for both outputs. Please let me know the preferred way for returning both outputs. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should disable cppint and use tuple, or just do a few typedef to shorten stuff.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I'll update the code
} | ||
""" | ||
print(body) | ||
assert not check_basic_block_normal_form(body) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_basic_block_normal_form test if something is or is not in bbnf.
check_basic_block_normal_form should assert it, and perhaps give more useful error message
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 Agreed.
/* Fill expressions based on each scope's let list. Different from FillANF, | ||
* only expressions with lifted scope will be pushed to the let list. | ||
*/ | ||
class FillBasicBlock : ExprFunctor<Expr(const Expr&, const Var&)> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you reuse the Fill in to_anf? the code is almost identically the same, and duplicating it will make maintenance significantly harder.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly this looks great, thank you!
I few nitpicks, a bunch of print statements that need to be removed, a question on naming, and one OOP design question.
Feel free to argue against my suggestions, nothing but removing the prints is really required.
src/relay/transforms/pass_util.h
Outdated
// the parent scope | ||
Scope parent; | ||
// the corresponding let list which holds all let bindings in the scope | ||
std::shared_ptr<LetList> ll = std::make_shared<LetList>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we use a more expressive variable name, maybe let_list
?
namespace tvm { | ||
namespace relay { | ||
|
||
Expr ToBasicBlockNormalFormAux(const Expr& e) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why Basic Block normal form? Is there a more advanced version coming? Or should we just call it BlockNormalForm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it's because of https://en.wikipedia.org/wiki/Basic_block ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:) I missed this naming influence in the conversations @tqchen and I had a few months ago. Thanks! That makes things a lot clearer for me.
body = relay.If(v2, true_branch, false_branch) | ||
body = relay.Let(v1, relay.add(x, one), body) | ||
func = relay.Function([x], body) | ||
print(func) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove print
} | ||
""" | ||
print(body) | ||
assert not check_basic_block_normal_form(body) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 Agreed.
y = op.add(x, x) | ||
z = op.add(y, y) | ||
f = relay.Function([], op.add(z, z)) | ||
print(f) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove print
print(top) | ||
print('bblock=') | ||
print(bblock) | ||
assert check_basic_block_normal_form(bblock) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd love an expected graph on this test, but I don't think it's critical
add(%shared, %shared) | ||
} | ||
""" | ||
print(body) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove print
add(%shared, %shared) | ||
} | ||
""" | ||
print(body) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove print
static Expr ToANormalForm(const Expr& e, const DependencyGraph& dg, NodeScopeMap* node_scope); | ||
|
||
// For basic block normal form, bind expressions only if the original expression's | ||
// scope should be lifted | ||
static Expr ToBasicBlockNormalForm(const Expr& e, const DependencyGraph& dg, | ||
NodeScopeMap* node_scope, ExprSet* lifted); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really like that you're reusing the infrastructure from A-Normal form to do this, there's a lot of overlap. I don't love that this class violates the Single Responsiblity Principle and that you have to introduce branches in a couple of the recursive functions to handle the two different cases. Given that this class has a private constructor, I'm not sure it 100% matters, but there might be a cleaner solution?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for bringing this up. I was also thinking about how to best leverage the existing code without too much ad-hoc code or duplication. I think adding an inclusion/exclusion argument to Fill
is still acceptable, and the Fill constructor is private anyway
args.push_back(VisitExpr(a)); | ||
} | ||
return Compound(e, Call(VisitExpr(c->op), args, c->attrs, c->type_args), v); | ||
Expr Fill::VisitExpr_(const CallNode* c, const Var& v) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why remove final
from all of these functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
final
is still used in the header at declaration https://en.cppreference.com/w/cpp/language/final#Syntax
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @mbrookhart for the review. I have updated some code as suggested
namespace tvm { | ||
namespace relay { | ||
|
||
Expr ToBasicBlockNormalFormAux(const Expr& e) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it's because of https://en.wikipedia.org/wiki/Basic_block ?
args.push_back(VisitExpr(a)); | ||
} | ||
return Compound(e, Call(VisitExpr(c->op), args, c->attrs, c->type_args), v); | ||
Expr Fill::VisitExpr_(const CallNode* c, const Var& v) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
final
is still used in the header at declaration https://en.cppreference.com/w/cpp/language/final#Syntax
static Expr ToANormalForm(const Expr& e, const DependencyGraph& dg, NodeScopeMap* node_scope); | ||
|
||
// For basic block normal form, bind expressions only if the original expression's | ||
// scope should be lifted | ||
static Expr ToBasicBlockNormalForm(const Expr& e, const DependencyGraph& dg, | ||
NodeScopeMap* node_scope, ExprSet* lifted); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for bringing this up. I was also thinking about how to best leverage the existing code without too much ad-hoc code or duplication. I think adding an inclusion/exclusion argument to Fill
is still acceptable, and the Fill constructor is private anyway
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
I dont think it break single responsibility - the code is doing conversion and a single configuration param denote which conversion it is. Another way to think about it is that that visitor is declaring a 'scoped mutator' and toanf/tobbnf is two subclass of it. |
That is exactly breaking the SRP, it's a class that does two mutually-exclusive conversions :) The SRP-compliant way to do this would be exactly what you said, a base class and two subclasses. That being said, since we're not allowing anyone to inherit from this class, and the constructor is private, it's probably not the end of the world, I'm cool with leaving it the way it is. |
We can allow it if needed! |
@eric-haibin-lin please at me when it is done. I will merge it then. |
* initial commit * refactor utils * add util * revert anf test * update test * fix logging * fix scope bug * complete tests * remove logging * revert refactoring * add one more test case * fix missing var binding * fix test * fix lint * fix lint * fix clang-format * fix lint * fix lint * commit missing code * add analysis api * fix lint * fix lint * lint * add test for func * address CR * fix typo * fix return type * fix lint * refactor classes * fix lint * remove prints * address comments Co-authored-by: Ubuntu <ubuntu@ip-172-31-42-138.ec2.internal>
* initial commit * refactor utils * add util * revert anf test * update test * fix logging * fix scope bug * complete tests * remove logging * revert refactoring * add one more test case * fix missing var binding * fix test * fix lint * fix lint * fix clang-format * fix lint * fix lint * commit missing code * add analysis api * fix lint * fix lint * lint * add test for func * address CR * fix typo * fix return type * fix lint * refactor classes * fix lint * remove prints * address comments Co-authored-by: Ubuntu <ubuntu@ip-172-31-42-138.ec2.internal>
* initial commit * refactor utils * add util * revert anf test * update test * fix logging * fix scope bug * complete tests * remove logging * revert refactoring * add one more test case * fix missing var binding * fix test * fix lint * fix lint * fix clang-format * fix lint * fix lint * commit missing code * add analysis api * fix lint * fix lint * lint * add test for func * address CR * fix typo * fix return type * fix lint * refactor classes * fix lint * remove prints * address comments Co-authored-by: Ubuntu <ubuntu@ip-172-31-42-138.ec2.internal>
* initial commit * refactor utils * add util * revert anf test * update test * fix logging * fix scope bug * complete tests * remove logging * revert refactoring * add one more test case * fix missing var binding * fix test * fix lint * fix lint * fix clang-format * fix lint * fix lint * commit missing code * add analysis api * fix lint * fix lint * lint * add test for func * address CR * fix typo * fix return type * fix lint * refactor classes * fix lint * remove prints * address comments Co-authored-by: Ubuntu <ubuntu@ip-172-31-42-138.ec2.internal>
* initial commit * refactor utils * add util * revert anf test * update test * fix logging * fix scope bug * complete tests * remove logging * revert refactoring * add one more test case * fix missing var binding * fix test * fix lint * fix lint * fix clang-format * fix lint * fix lint * commit missing code * add analysis api * fix lint * fix lint * lint * add test for func * address CR * fix typo * fix return type * fix lint * refactor classes * fix lint * remove prints * address comments Co-authored-by: Ubuntu <ubuntu@ip-172-31-42-138.ec2.internal>
Thanks for contributing to TVM! Please refer to guideline https://tvm.apache.org/docs/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from Reviewers by @ them in the pull request thread.
As mentioned in https://discuss.tvm.ai/t/basic-block-normal-form/5908/5, the requirements for basic block normal form are:
We rely on
analysis::DependencyGraph
to compute the dependency among nodes, and use LCA to compute the scope of each node. As we traverse the dependency graph, we record the expression with a "lifted scope" if its original scope is different from the result of LCA of nodes depending on it. For instance, a variable shared by the true_branch and false_branch in anIf
statement. If any expression with lifted scope is found, the function violates the basic block normal form.The expressions with lifted scope need to be bound by let binding. It is done with a pass almost the same as the ANF pass. Same as ANF we use a let list per scope to help construct let bindings of expressions. The main difference with ANF pass is that we only push expressions to the let list if the expression's scope is lifted, otherwise we return the original expression. When we come across If/function node, we get the expression generated from the let list of the corresponding scope.
@MarisaKirisame would you suggest we merge the two passes into a single class?