@@ -42,7 +42,8 @@ struct AllReduce {
4242 if constexpr (offset == scale) {
4343 return x;
4444 } else {
45- return AllReduce<Reducer, offset, scale>::run (x, red_buf);
45+ return AllReduce<Reducer, offset, scale, thread_offset, all_threads>::run (
46+ x, red_buf);
4647 }
4748 }
4849
@@ -51,7 +52,7 @@ struct AllReduce {
5152 constexpr int offset = threads / 2 ;
5253 if constexpr (offset >= 32 ) {
5354 asm volatile (" bar.sync %0, %1;" : : " r" (1 ), " r" (all_threads));
54- red_buf[threadIdx.x ] = x;
55+ red_buf[threadIdx.x - thread_offset ] = x;
5556 // TODO(lei): maybe we can merge the two bar.sync into one?
5657 asm volatile (" bar.sync %0, %1;" : : " r" (2 ), " r" (all_threads));
5758 x = Reducer ()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]);
@@ -61,8 +62,8 @@ struct AllReduce {
6162 if constexpr (offset == scale) {
6263 return x;
6364 } else {
64- return AllReduce<Reducer, offset, scale, all_threads>:: run_hopper (
65- x, red_buf);
65+ return AllReduce<Reducer, offset, scale, thread_offset,
66+ all_threads>:: run_hopper ( x, red_buf);
6667 }
6768 }
6869};
0 commit comments