Skip to content

Commit

Permalink
[CINN] Adjust tile config parameters based on model benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
lshpku committed Nov 14, 2024
1 parent 63d63dc commit 127b515
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions paddle/cinn/ir/group_schedule/config/group_tile_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ TileConfigMap BuildPureStaticShapeConfig(
if (last_dim == "R") {
rd_thread_num = 32;
int64_t remain_reduce_numel = CeilDiv(reduce_numel, 32);
if (remain_reduce_numel <= 8 && spatial_numel > 1) {
if ((remain_reduce_numel <= 8 && spatial_numel > 1) ||
(spatial_numel > remain_reduce_numel * 128)) {
sp_thread_num = Trim(spatial_numel, 1, 8);
reduce_method = WarpReduceMethod();
} else {
Expand Down Expand Up @@ -247,7 +248,7 @@ TileConfigMap BuildPureStaticShapeConfig(
return 1;
}
int64_t expected = spatial_numel / (sm_count * 4);
return CeilPow2(Trim(expected, 1, 32));
return CeilPow2(Trim(expected, 1, 4));
}();

int64_t sp_upper_bound = base_info->spatial_numel > 1 ? kMaxNumel : 1;
Expand Down

0 comments on commit 127b515

Please sign in to comment.