Skip to content

Commit

Permalink
fix: fix custom_ar_kernel threshold
Browse files Browse the repository at this point in the history
  • Loading branch information
byshiue committed Mar 6, 2023
1 parent fba7567 commit f7a4418
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/fastertransformer/kernels/custom_ar_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ void invokeOneOrTwoShotAllReduceKernel(AllReduceParams<T>& param, cudaStream_t s
size_t elts_total = param.elts_total;
int blocks_per_grid = 1, threads_per_block = DEFAULT_BLOCK_SIZE;
int kernel_algo = 1;
if (elts_total <= DEFALUT_ALGO_AR_SIZE_THRESHOLD) {
if (elts_total * sizeof(T) <= DEFALUT_ALGO_AR_SIZE_THRESHOLD) {
kernel_algo = 0;
}

Expand Down
2 changes: 1 addition & 1 deletion src/fastertransformer/kernels/custom_ar_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
#define RANKS_PER_NODE 8
#define WARP_SIZE 32
#define DEFAULT_BLOCK_SIZE 1024
#define DEFALUT_ALGO_AR_SIZE_THRESHOLD 196608
#define DEFALUT_ALGO_AR_SIZE_THRESHOLD 393216

namespace fastertransformer {

Expand Down

0 comments on commit f7a4418

Please sign in to comment.