-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CINN] Add the TileBroadcastTactic for NCHW broadcast #70092
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
1ba4452
to
6014a56
Compare
6061701
to
a926811
Compare
a926811
to
b56a3ee
Compare
virtual void Init(ScheduleContext* context) = 0; | ||
// Attribute key to record which tile tactic has been applied on a graph. | ||
// Exactly one tile tactic is applied on a graph during scheduling. | ||
static constexpr char* kTileMethod = "tile_method"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在ScheduleBlock(root)
下使用 attrs 的形式记录它所应用的 tile tactic,每个 graph 能且只能应用一个 tile tactic,这样便于管理多个 tile 模板,以后如果增加 tile transpose tactic 等也可以用这个机制区分
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果每个graph只能用一个tile tactic,会不会导致某些特殊子图应用到了不合适的tile tactic上?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我这样写相当于是一个fallback的tile tactic流程,只对graph应用第一个match的tile tactic,最后用一个tile general tactic兜底;这个是学tvm的,tvm也是用一个flag标记是用了哪个tile;确实有可能match错,所以需要match条件尽可能精确,而且tactic实现也需要鲁棒一些,即使match错了顶多丢一些性能,不至于算错
virtual void Init(ScheduleContext* context) { | ||
PADDLE_THROW(::common::errors::Unimplemented( | ||
"ScheduleTactic subclass must implement one of the Init method.")); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个Init函数直接 = 0是不就行?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
会编译报错的,除非每个子类都重载了Init(context)
方法,不然只重载Init(context, sch)
会编译报错,这个我查了没找到很好的处理方法,要写很复杂的模板类,反而让代码不清晰了
virtual void Init(ScheduleContext* context) = 0; | ||
// Attribute key to record which tile tactic has been applied on a graph. | ||
// Exactly one tile tactic is applied on a graph during scheduling. | ||
static constexpr char* kTileMethod = "tile_method"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果每个graph只能用一个tile tactic,会不会导致某些特殊子图应用到了不合适的tile tactic上?
const auto MulDimSize = [](int64_t a, int64_t b) { | ||
return (a == -1 || b == -1) ? -1 : a * b; | ||
}; | ||
|
||
broadcast_size_ = 1; | ||
for (int axis : broadcast_axis_) { | ||
broadcast_size_ = MulDimSize(broadcast_size_, loop_ranges[axis]); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
loop_ranges可以保证是int值吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
loop_ranges
是个int64的数组,动态shape的维度是-1;我这里相当于是不处理NCHW中HW为动态shape的情况,因为如果算出来是-1,外面那个low_broadcast_size_ % 32 != 0
判断就会为false;但是N和C是可以为动态shape的,我测了也有一点提升
std::unordered_set<ir::Var> vars_in_load = | ||
CollectIterVars(load_node->indices); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对于k % S0
这样的下标提取出来的Var是 k
还是k % S0
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是k,因为我这里只提取Var变量本身,不提取运算过程,也不提取符号变量(S0之类);我这样相当于是更谨慎的broadcast判定,比如var_1[k] = var_0[k % 16]
其实是一个broadcast,但是我判定成了不是broadcast
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这种情况判定成不是broadcast会不会引起后续处理的问题?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
判断成不是broadcast就不走这个tactic了,而是走默认的tile general tactic,其实我在单测里面也没遇到这种情况,只是保险起见,因为这个tile broadcast tactic本身是一个比较特化的模板,最好只处理那些条件清晰的情况
…Paddle#70092)" This reverts commit bf88408.
PR Category
CINN
PR Types
Improvements
Description
Implement the tiling template for NCHW broadcasts.
Introduction
This tactic performs tiling for NCHW broadcasts that have the form:
More generally, if we classify the axis into two types:
then this tactic can handle (in the sense of axis fusion):
as long as the last axis is B.
Performance Impact
This tactic primarily addresses the issue of redundant computation in broadcasts. Without this tactic, the general tiling tactic doesn't treat the P axis as a special axis, and fuses all axis into one before tiling. As a result, P's index is often interleaved with the inner loop, such as:
thus for every loop iteration
k
, we needs to load and compute a new value.On the contrary, this tactic assigns a dedicated axis (blockIdx.x or blockIdx.y) for P, such as:
so that each thread only need to load and compute it once. When dealing with complex ops (e.g. div, exp, rsqrt), this tactic can save much computation bandwidth and bring up to 30% speedup.
Limitations
The implementation of this tactic has been tailored for various layouts. However, we have not yet observed consistent performance improvements with layouts other than NCHW. Therefore, it is exclusive for NCHW now.
To avoid unexpected performance degradation, this tactic also imposes constraints on dim sizes. See
Init
for details.Example
Note: there are 3 ways of axis binding in this tactic for different dim sizes. See
Apply
for details.Experiment Results
tested on A100, unit: ips
Pcard-85711