Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN] Add the TileBroadcastTactic for NCHW broadcast #70092

Merged
merged 1 commit into from
Dec 16, 2024

Conversation

lshpku
Copy link
Contributor

@lshpku lshpku commented Dec 10, 2024

PR Category

CINN

PR Types

Improvements

Description

Implement the tiling template for NCHW broadcasts.

Introduction

This tactic performs tiling for NCHW broadcasts that have the form:

   [1, C, 1, 1] => [N, C, H, W].

More generally, if we classify the axis into two types:

  • B - broadcast axis, i.e. N, H, and W
  • P - preserved axis, i.e. C

then this tactic can handle (in the sense of axis fusion):

   [1, P, 1] => [B, P, B],

as long as the last axis is B.

Performance Impact

This tactic primarily addresses the issue of redundant computation in broadcasts. Without this tactic, the general tiling tactic doesn't treat the P axis as a special axis, and fuses all axis into one before tiling. As a result, P's index is often interleaved with the inner loop, such as:

   ... = exp(var[k * 256 + threadIdx.x])

thus for every loop iteration k, we needs to load and compute a new value.

On the contrary, this tactic assigns a dedicated axis (blockIdx.x or blockIdx.y) for P, such as:

   ... = exp(var[blockIdx.x])

so that each thread only need to load and compute it once. When dealing with complex ops (e.g. div, exp, rsqrt), this tactic can save much computation bandwidth and bring up to 30% speedup.

Limitations

The implementation of this tactic has been tailored for various layouts. However, we have not yet observed consistent performance improvements with layouts other than NCHW. Therefore, it is exclusive for NCHW now.

To avoid unexpected performance degradation, this tactic also imposes constraints on dim sizes. See Init for details.

Example

  for (i, 0, 128):       # N
    for (j, 0, 256):     # C
      for (k, 0, 32):    # H
        for (a, 0, 32):  # W
          ScheduleBlock(var_1):
            var_1[i, j, k, a] = exp(var[j])
=>
  for (blockIdx.y, 0, 128):         # N
    for (blockIdx.x, 0, 256):       # C
      for (k, 0, 4):                # HW / 256
        for (threadIdx.x, 0, 256):  # HW % 256
          ScheduleBlock(var_1):
             var_1[blockIdx.y, blockIdx.x, ...] = exp(var[blockIdx.x])

Note: there are 3 ways of axis binding in this tactic for different dim sizes. See Apply for details.

Experiment Results

tested on A100, unit: ips

Model OFF ON Speedup
PP-LCNet_x1_0_bs1024_fp16 3010 3088 2.59%
MobileNetV3_small_x1_0_bs1024_fp16 3047 3101 1.77%
ResNet50_bs256_fp16 1437 1459 1.53%
PP-HGNetV2-B0_bs64_fp16 1687 1706 1.13%

Pcard-85711

Copy link

paddle-bot bot commented Dec 10, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@lshpku lshpku force-pushed the tile-broadcast-tactic branch from 1ba4452 to 6014a56 Compare December 13, 2024 02:41
@lshpku lshpku changed the title [CINN] Implement the TileBroadcastTactic [CINN] Add the TileBroadcastTactic for NCHW broadcast Dec 13, 2024
@lshpku lshpku force-pushed the tile-broadcast-tactic branch 2 times, most recently from 6061701 to a926811 Compare December 13, 2024 13:35
@lshpku lshpku force-pushed the tile-broadcast-tactic branch from a926811 to b56a3ee Compare December 15, 2024 10:21
virtual void Init(ScheduleContext* context) = 0;
// Attribute key to record which tile tactic has been applied on a graph.
// Exactly one tile tactic is applied on a graph during scheduling.
static constexpr char* kTileMethod = "tile_method";
Copy link
Contributor Author

@lshpku lshpku Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ScheduleBlock(root)下使用 attrs 的形式记录它所应用的 tile tactic,每个 graph 能且只能应用一个 tile tactic,这样便于管理多个 tile 模板,以后如果增加 tile transpose tactic 等也可以用这个机制区分

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果每个graph只能用一个tile tactic,会不会导致某些特殊子图应用到了不合适的tile tactic上?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我这样写相当于是一个fallback的tile tactic流程,只对graph应用第一个match的tile tactic,最后用一个tile general tactic兜底;这个是学tvm的,tvm也是用一个flag标记是用了哪个tile;确实有可能match错,所以需要match条件尽可能精确,而且tactic实现也需要鲁棒一些,即使match错了顶多丢一些性能,不至于算错

Comment on lines +98 to +101
virtual void Init(ScheduleContext* context) {
PADDLE_THROW(::common::errors::Unimplemented(
"ScheduleTactic subclass must implement one of the Init method."));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个Init函数直接 = 0是不就行?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

会编译报错的,除非每个子类都重载了Init(context)方法,不然只重载Init(context, sch)会编译报错,这个我查了没找到很好的处理方法,要写很复杂的模板类,反而让代码不清晰了

virtual void Init(ScheduleContext* context) = 0;
// Attribute key to record which tile tactic has been applied on a graph.
// Exactly one tile tactic is applied on a graph during scheduling.
static constexpr char* kTileMethod = "tile_method";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果每个graph只能用一个tile tactic,会不会导致某些特殊子图应用到了不合适的tile tactic上?

Comment on lines +297 to +304
const auto MulDimSize = [](int64_t a, int64_t b) {
return (a == -1 || b == -1) ? -1 : a * b;
};

broadcast_size_ = 1;
for (int axis : broadcast_axis_) {
broadcast_size_ = MulDimSize(broadcast_size_, loop_ranges[axis]);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loop_ranges可以保证是int值吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loop_ranges是个int64的数组,动态shape的维度是-1;我这里相当于是不处理NCHW中HW为动态shape的情况,因为如果算出来是-1,外面那个low_broadcast_size_ % 32 != 0判断就会为false;但是N和C是可以为动态shape的,我测了也有一点提升

Comment on lines +176 to +177
std::unordered_set<ir::Var> vars_in_load =
CollectIterVars(load_node->indices);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于k % S0这样的下标提取出来的Var是 k 还是k % S0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是k,因为我这里只提取Var变量本身,不提取运算过程,也不提取符号变量(S0之类);我这样相当于是更谨慎的broadcast判定,比如var_1[k] = var_0[k % 16]其实是一个broadcast,但是我判定成了不是broadcast

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种情况判定成不是broadcast会不会引起后续处理的问题?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

判断成不是broadcast就不走这个tactic了,而是走默认的tile general tactic,其实我在单测里面也没遇到这种情况,只是保险起见,因为这个tile broadcast tactic本身是一个比较特化的模板,最好只处理那些条件清晰的情况

@lshpku lshpku merged commit bf88408 into PaddlePaddle:develop Dec 16, 2024
28 checks passed
lshpku added a commit to lshpku/Paddle that referenced this pull request Dec 17, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants