Skip to content

Commit 3a40815

Browse files
authored
[Bugfix] Added missing thread offsets and other information to reduce. (#646)
1 parent b060c9f commit 3a40815

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

src/tl_templates/cuda/reduce.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ struct AllReduce {
4242
if constexpr (offset == scale) {
4343
return x;
4444
} else {
45-
return AllReduce<Reducer, offset, scale>::run(x, red_buf);
45+
return AllReduce<Reducer, offset, scale, thread_offset, all_threads>::run(
46+
x, red_buf);
4647
}
4748
}
4849

@@ -51,7 +52,7 @@ struct AllReduce {
5152
constexpr int offset = threads / 2;
5253
if constexpr (offset >= 32) {
5354
asm volatile("bar.sync %0, %1;" : : "r"(1), "r"(all_threads));
54-
red_buf[threadIdx.x] = x;
55+
red_buf[threadIdx.x - thread_offset] = x;
5556
// TODO(lei): maybe we can merge the two bar.sync into one?
5657
asm volatile("bar.sync %0, %1;" : : "r"(2), "r"(all_threads));
5758
x = Reducer()(x, red_buf[(threadIdx.x - thread_offset) ^ offset]);
@@ -61,8 +62,8 @@ struct AllReduce {
6162
if constexpr (offset == scale) {
6263
return x;
6364
} else {
64-
return AllReduce<Reducer, offset, scale, all_threads>::run_hopper(
65-
x, red_buf);
65+
return AllReduce<Reducer, offset, scale, thread_offset,
66+
all_threads>::run_hopper(x, red_buf);
6667
}
6768
}
6869
};

0 commit comments

Comments
 (0)