Skip to content

Commit

Permalink
volatile test
Browse files Browse the repository at this point in the history
  • Loading branch information
hanzhi713 committed Sep 19, 2024
1 parent 7874c89 commit 4ba295f
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions csrc/custom_all_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg,
&sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank];
auto self_counter_ptr =
&self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x];
if constexpr (need_fence) {
if constexpr (false) {
st_flag_release(peer_counter_ptr, val);
while (ld_flag_acquire(self_counter_ptr) != val);
} else {
Expand Down Expand Up @@ -240,6 +240,17 @@ __global__ void __launch_bounds__(512, 1)
tmps[i] = get_tmp_buf<P>(sg.signals[target]);
}
auto tmp_out = tmps[0];
// Uncomment to populate cache of peer values, which are currently in a indeterminate state
// for (int idx = tid; idx < largest_part; idx += stride) {
// #pragma unroll
// for (int i = 0; i < ngpus; i++) {
// int gather_from_rank = ((rank + i) % ngpus);
// if (gather_from_rank == ngpus - 1 || idx < part) {
// int dst_idx = gather_from_rank * part + idx;
// ((P*)result)[dst_idx] = tmps[i][idx];
// }
// }
// }
multi_gpu_barrier<ngpus, true>(sg, self_sg, rank);
// stage 1: reduce scatter
for (int idx = start + tid; idx < end; idx += stride) {
Expand All @@ -258,7 +269,18 @@ __global__ void __launch_bounds__(512, 1)
int gather_from_rank = ((rank + i) % ngpus);
if (gather_from_rank == ngpus - 1 || idx < part) {
int dst_idx = gather_from_rank * part + idx;
((P*)result)[dst_idx] = tmps[i][idx];
uint4 tmp;
// note: relaxed.sys also works
asm volatile (
"ld.volatile.global.v4.u32 {%0, %1, %2, %3}, [%4];\n"
: "=r"(tmp.x), "=r"(tmp.y), "=r"(tmp.z), "=r"(tmp.w)
: "l"(&tmps[i][idx])
);
// asm volatile (
// "st.global.v4.u32 [%4], {%0, %1, %2, %3};\n"
// :: "r"(tmp.x), "r"(tmp.y), "r"(tmp.z), "r"(tmp.w), "l"(&((P*)result)[dst_idx])
// );
((P*)result)[dst_idx] = *reinterpret_cast<P*>(&tmp);
}
}
}
Expand Down

0 comments on commit 4ba295f

Please sign in to comment.