-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bugfix] [Relay] fix a bug of printing dataflow pattern #15350
Conversation
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment. Generated by tvm-bot |
btw, does the dataflow pattern support recursion (loop) at all? |
include/tvm/relay/dataflow_pattern.h
Outdated
|
||
std::unordered_map<DFPattern, std::pair<size_t, std::string>, ObjectPtrHash, ObjectPtrEqual> | ||
memo_{}; | ||
std::vector<DFPattern> recursed_patterns{}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please document what you mean by "recursed_patterns". Probably "recursive_patterns" would be more correct.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Documentation added. To be consistent with the printed text, I choose to name this variable to be auxiliary_patterns
include/tvm/relay/dataflow_pattern.h
Outdated
string_stream << "Main pattern is:" << std::endl; | ||
string_stream << printer.string_stream.str(); | ||
string_stream << std::endl; | ||
string_stream << "Auxiliary patterns are:"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove "is" and "are" above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Surprisingly, the answer is yes. Indeed, I was thinking to add the support for recursion, only to find that the existing dataflow pattern language has alreadly (silently) support it. TVM_REGISTER_GLOBAL("relay.dataflow_pattern.my_pattern")
.set_body_typed([]() {
DFPattern dense_pattern = IsOp("nn.dense")({IsWildcard(), IsWildcard()});
ObjectPtr<CallPatternNode> the_pattern_ptr = make_object<CallPatternNode>();
the_pattern_ptr->op = IsOp("cast");
the_pattern_ptr->args.clear();
CallPattern the_pattern = CallPattern(the_pattern_ptr);
AltPattern or_pattern{the_pattern, dense_pattern};
the_pattern_ptr->args.push_back(or_pattern);
//LOG(INFO) << PrettyPrint(the_pattern); // TODO: BUG!
return the_pattern;
}); This simple pattern matches a nn.dense followed by an arbitrary number of cast. You can test this pattern via the following python code: class TheRewrite(DFPatternCallback):
def __init__(self):
super(TheRewrite, self).__init__(rewrite_once = True)
pattern = tvm.get_global_func("relay.dataflow_pattern.my_pattern")()
self.pattern = pattern
def callback(self, pre, post, node_map):
return relay.nn.relu(post)
mod = create_model() # define a model
the_rewrite = TheRewrite()
out = rewrite(the_rewrite, mod["main"]) Another application of recursion is PR #15362, which I do not know how to achieve without recursion. That PR is useful, and can really improve the computational graph for some quantized models. I would like to examine the pattern matching code further in the following days. |
@tvm-bot rerun |
When recursion of dataflow patterns is used, the pattern graph may not be a DAG. If recursion is encountered, ReprPrint of dataflow pattern may fall in a dead loop.
This PR solves the bug of dataflow pattern printing, and is the first PR for the pre-RFC.