-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RELAY][PASS] FoldScaleAxis Forward #2020
Conversation
@jroesch @ZihengJiang @yzhliu @zhiics @FrozenGene @merrymercy @srkreddy1238 @masahi please review |
python/tvm/relay/op/transform.py
Outdated
@@ -49,17 +49,17 @@ def transpose(data, axes=None): | |||
return _make.transpose(data, list(axes)) | |||
|
|||
|
|||
def squeeze(data, axes=None): | |||
def squeeze(data, axis=None): | |||
"""Squeeze axes in the array. | |||
|
|||
Parameters | |||
---------- | |||
data : relay.Expr |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tvm.relay.Expr?
if (!rhs.defined()) return rhs; | ||
AxesSet ret; | ||
size_t i = 0, j = 0; | ||
while (i < lhs.size() && j < rhs.size()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Out of curiosity, are both lhs
and rhs
always sorted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is the requirement of axis set, Thanks for pointing this out, will add a comment block about it
src/relay/pass/fold_scale_axis.cc
Outdated
|
||
/*! | ||
* \brief The transform function, transform an old call to | ||
* new one given the new args. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a new one
src/relay/pass/fold_scale_axis.cc
Outdated
std::unordered_map<const Node*, AxesSet> message_; | ||
// Update the message stored at node. | ||
void Update(const Expr& node, const AxesSet& axes) { | ||
// We run interection of messages: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
s/interection/intersection
src/relay/pass/type_infer.cc
Outdated
Expr new_e = ExprMutator::VisitExpr_(op); | ||
if (!checked_type.same_as(new_e->checked_type_)) { | ||
// new_call and new_var's code is only going to be valid for VarNode/CallNode. | ||
// Compiler optimization will likely fold the these away for other nodes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fold these
src/relay/pass/fold_scale_axis.cc
Outdated
|
||
// Conv2D consumes the scale axis during transformation. | ||
STuple Conv2DForwardTransform(const Call& ref_call, | ||
const AxesSet& expected_axes, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
expected_axes
is unused. Should it be attached to rnode
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also indent
python/tvm/relay/op/transform.py
Outdated
Axes to remove. | ||
If axes = [] or = None, remove all axis of dimensions 1. | ||
If axes = None, remove all axis of dimensions 1. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change the comment also
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do not get what do you mean
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
axes should be changed to axis
here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
gotcha
Thanks @zhiics @ZihengJiang for helpful reviews, please check again |
src/relay/pass/pattern_util.h
Outdated
} | ||
++j; | ||
} else { | ||
if (i >= base) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
else if
src/relay/pass/fold_scale_axis.cc
Outdated
} | ||
|
||
void VisitExpr_(const TupleGetItemNode* op) { | ||
// pass, do nothing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why dont you visit inside? maybe there is opt ability inside
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch
src/relay/pass/fold_scale_axis.cc
Outdated
|
||
void VisitExpr_(const IfNode* op) { | ||
ExprVisitor::VisitExpr_(op); | ||
// do pass through condition. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove .
// AddSub | ||
Array<AxesSet> AddSubForwardPrep(const Call& call, AxesSet out_axes) { | ||
const auto* tlhs = call->args[0]->type_as<TensorTypeNode>(); | ||
const auto* trhs = call->args[1]->type_as<TensorTypeNode>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to add check here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
type_as does the check
Thanks @MarisaKirisame @ZihengJiang , I have make followup changes, please check again |
* [RELAY][PASS] FoldScaleAxis Forward * Introduce helper function type_as * Update per review comment * Fix according to comments
* [RELAY][PASS] FoldScaleAxis Forward * Introduce helper function type_as * Update per review comment * Fix according to comments
* [RELAY][PASS] FoldScaleAxis Forward * Introduce helper function type_as * Update per review comment * Fix according to comments
* [RELAY][PASS] FoldScaleAxis Forward * Introduce helper function type_as * Update per review comment * Fix according to comments
* [RELAY][PASS] FoldScaleAxis Forward * Introduce helper function type_as * Update per review comment * Fix according to comments
* [RELAY][PASS] FoldScaleAxis Forward * Introduce helper function type_as * Update per review comment * Fix according to comments
This is a first serious attempt to implement an NN related optimization pass on relay. This PR contains the following changes:
Hopefully, this can serve as an example of how optimizations can be done in NNVMv2(relay) IR, and how the infrastructure of relay makes writing optimization in a more principled fashion.
Goal
Fold the scaling of axis(usually caused by BatchNorm) into weight of conv2d in the future. For example
Old:
Transformed:
Further constant folding can fold the multiplication and we remove the scaling in the network.
The Algorithm
While so far only the forward direction is implemented. The general idea is that we transform Expr to tuple of
(value, axes, scale)
, where the final result satisfies:Then we can propagate this signal along and fold the scale if necessary. However, it is possible that certain scale may never be consumed if there is no dense/conv2d that follow multiplication.
In order to make sure all the scale we sent out can be consumed eventually, we run a backward "preparation phase", which propagates the demand of the potential axes scaling back to its input.
The new pass is more general than the FoldScaleAxis in nnvm