-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay] Alter Op Layout #2150
[Relay] Alter Op Layout #2150
Conversation
16d6294
to
18d138a
Compare
src/relay/pass/simplify_bias_add.cc
Outdated
namespace tvm { | ||
namespace relay { | ||
|
||
class BiasAddSimplifier : public ExprMutator { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we fold this into simplify inference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. Because it is not only for inference.
49d22b3
to
1747a21
Compare
will review tonight |
include/tvm/relay/op_attr_types.h
Outdated
* operator with other expressions. | ||
* | ||
* \param attrs The attribute of the node. | ||
* \param inputs The arguments of this operator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
args & tinfos
// if the new layout changes width or height dimension, | ||
// fallback to old layout; | ||
input = raw_layout; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since you change the param instead, then these checker & fallback should be no use.
but I guess we need to modify pool compute to let it know how to deal with w & h.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, there is a small mistake here. I fixed that. In fallback case, we don't change the param.
src/relay/op/tensor/transform.cc
Outdated
// relay.layout_transform | ||
std::pair<Layout, Layout> RemoveLeadingReduandantDimensions( | ||
const Layout &src_layout, const Layout &dst_layout, size_t keep_size) { | ||
// For a broadcast operator, if the input is a 3-dimensional tensor (64, 1, 1), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggest to make it more clear, e.g., when broadcast (1, 64, 16, 16) with (64, 1, 1), we can still apply NCHW -> NCHW16c to the right tensor, by deleting N
and applying normal CHW -> CHW16c layout transform. sth like that.
src/relay/op/tensor/transform.cc
Outdated
// we can still apply 4-dimensional rule NCHW -> NCHW16c to it. | ||
// In this case, we will delete leading redundant dimensions. | ||
CHECK_GE(src_layout.ndim(), keep_size) | ||
<< "Can only apply layout transform rule for smaller tensor dimensions"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this error msg is not clear. suggest to print out src_layout and keep_size
src/relay/pass/alter_op_layout.cc
Outdated
|
||
// old_in, old_out = op.infer(old_in) | ||
bool success = false; | ||
std::tie(old_in, old_out, success) = CallInfer(ref_call, old_in); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
try to learn, how does batch_flatten
's FInferCorrectLayout
gets registered?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not defined. When we meet ops without FInferCorrectLayout such as batch_flatten. We will transform its inputs back to old layout.
https://github.com/merrymercy/tvm/blob/ad256f9d2fddf59c6c4224ab78be18437a13e6b2/src/relay/pass/alter_op_layout.cc#L99
src/relay/pass/alter_op_layout.h
Outdated
lhs = rhs; | ||
} | ||
|
||
return Array<Array<Layout> > {{lhs, lhs}, {lhs}}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if len(lhs) < len(rhs) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that more checks and good fallback should be considered here.
Will let you know when it's ready for review.
src/relay/op/tensor/transform.cc
Outdated
<< "cannot convert from " << param->src_layout << " to " << param->dst_layout; | ||
|
||
std::tie(src_layout, dst_layout) = RemoveLeadingReduandantDimensions( | ||
src_layout, dst_layout, inputs[0]->shape.size()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
RemoveLeadingReduandantDimensions has various checker, while it should still compile even though these restriction are not satisfied. e.g., broadcast(N16nCHW, CHW) -> broadcast(NCHW, CHW) will fail here, but actually we should be able to make it compile.
Also I'm a little worried about that, the approach allows layout and shape mismatch, and the mismatch is handled in layout_transform operator. In some sense, it's like other operators need to know how layout_transform works, then decide whether to infer a mismatch-shape layout. I would say, from the point view of system design, it's not quite good.
Also if later on there's another pass deal with layout, then it needs to carefully handle the mismatch stuff, and understand where does the mismatch come from.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good points. I will clean this mismatch inside this pass.
python/tvm/relay/op/transform.py
Outdated
|
||
|
||
def layout_transform(data, src_layout, dst_layout): | ||
"""Transform the layout of an tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be Transform the layout of a tensor.
@merrymercy can you rebase the current PR, given the interest of this feature, it would be great to see if we can recover some of the end to end example benchmarks. #2163 should enable most of common networks already |
@tqchen Some operators (winograd, NCHWc) are still missing. Will do it in follow up prs. |
ede6608
to
13338b9
Compare
@yzhliu comments are addressed, please review again. |
|
||
# concatenate | ||
@_reg.register_compute("concatenate") | ||
def concatenate_compute(attrs, inputs, output_type, target): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we also need a correct layout function for concatenate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. For concatenate, all its inputs will be transformed to old layouts. This is the default fallback case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible for concat to propagate layout?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. We should add it
src/relay/op/tensor/transform.cc
Outdated
const Array<Layout> &in_layouts, | ||
const Array<Array<IndexExpr>> &in_shapes) { | ||
const ConcatenateAttrs* param = attrs.as<ConcatenateAttrs>(); | ||
CHECK_EQ(in_layouts.size(), 2); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should allow more than two inputs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is wrong. Not ready for review..
81934f9
to
eb299ac
Compare
@kevinthesun @tqchen please review again. |
15abe1b
to
51a20ba
Compare
python/tvm/relay/ir_pass.py
Outdated
@@ -191,6 +191,23 @@ def simplify_inference(expr): | |||
return _ir_pass.simplify_inference(expr) | |||
|
|||
|
|||
def simplify_bias_add(expr): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let us rename this to canonicalize_ops since bias_add lowering can be viewed as canonicalization step for ops, we might add more op canonicalizations later, (e.g. mean)..
TransformMemorizer() {} | ||
explicit TransformMemorizer(NodePtr<Node> n) : NodeRef(n) {} | ||
|
||
TransformMemorizerNode* operator->() { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can just mark TransformMemorizerNode* operator->() const
, to allow it to be accessed by the const
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But I need to modify the content?
src/relay/pass/alter_op_layout.cc
Outdated
|
||
// Make a transform CallNode | ||
Expr TransformLayout(Expr raw, Layout src_layout, Layout dst_layout) { | ||
if (src_layout.Equals(dst_layout)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
either keep the return in the single line, or enclose with { }
} | ||
|
||
// Memorize layout transform so we can reuse internal transformed nodes | ||
class TransformMemorizerNode : public Node { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
expand the comment a bit. what is the key of the map, what is the content
src/relay/pass/alter_op_layout.cc
Outdated
using TransformKey = std::tuple<const Node*, std::string, std::string>; | ||
struct key_hash : public std::unary_function<TransformKey , std::size_t> { | ||
std::size_t operator()(const TransformKey& k) const { | ||
return std::hash<const Node*>()(std::get<0>(k)) ^ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use dmlc::CombineHash
return static_cast<TransformMemorizerNode*>(node_.get()); | ||
} | ||
|
||
// Transform layout with memorizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
document all the arguments and return value
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the rule is functions in header files should be documented while functions in src files are not required to doc thoroughly, is it true?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that is true, but it is still helpful to document some of them to make code more readable
src/relay/pass/alter_op_layout.cc
Outdated
// Transform layout with memorizer | ||
Expr Transform(Expr raw, const Layout& src_layout, const Layout& dst_layout) { | ||
if (src_layout.Equals(dst_layout)) | ||
return raw; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keep raw in the same line as if
|
||
// Call FInferCorrectLayout of an op. | ||
// Return inferred_input_layout, inferred_output_layout, success | ||
std::tuple<Array<Layout>, Array<Layout>, bool> CallInfer( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
document the fields of the return value and the arguments
src/relay/pass/alter_op_layout.cc
Outdated
const NodeRef& ctx) { | ||
std::vector<LayoutAlternatedExpr> inputs; | ||
std::vector<Expr> normal_new_args; | ||
Array<Array<IndexExpr>> input_shapes; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
space between > for compatibility to old compilers
// NOTE: discard the "const" qualifier | ||
TransformMemorizer memorizer = Downcast<TransformMemorizer>(ctx); | ||
|
||
// fill incomplete state and expand tuple |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explain that we always expect LayoutAlternatedExpr, and this is used to convert the arguments into LayoutAlternatedExpr.
@yzhliu @jroesch @kevinthesun @ajtulloch can you please take a look at this PR if you have time? |
45f0754
to
8f9e396
Compare
comments are addressed! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good to me. ping @jroesch
@merrymercy please rebase against the latest master |
f2af1ef
to
e9cf2bb
Compare
Not sure if I am the only one see this issue in local build.
|
Ignore above issue. looks like submodule update didn't sync all changes. |
* [RELAY] Finish alter op pass * [RELAY] AlterOpLayout Pass * fix broadcast operators * fix broadcast operators * fix broadcast operators * Support concatenate * address comments * address comments * add comments * rebase
* [RELAY] Finish alter op pass * [RELAY] AlterOpLayout Pass * fix broadcast operators * fix broadcast operators * fix broadcast operators * Support concatenate * address comments * address comments * add comments * rebase
* [RELAY] Finish alter op pass * [RELAY] AlterOpLayout Pass * fix broadcast operators * fix broadcast operators * fix broadcast operators * Support concatenate * address comments * address comments * add comments * rebase
Port alter_op_layout pass to relay. The new pass does forward rewrite in a single pass.
The major improvements over the old pass :
This pr also did some cleaning to operators.
cc @yzhliu @kevinthesun @tqchen