Skip to content

Commit f484fd3

Browse files
authored
fix: negative zero by type trait --> binary value (#1136)
<!-- .github/pull_request_template.md --> ## πŸ“Œ Description We are going to fix the has_neg_zero. ## πŸ” Related Issues <!-- Link any related issues here --> ## πŸš€ Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### βœ… Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## πŸ§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
1 parent 2f01a9a commit f484fd3

File tree

2 files changed

+42
-43
lines changed

2 files changed

+42
-43
lines changed

β€Žinclude/flashinfer/comm/trtllm_allreduce.cuhβ€Ž

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -188,11 +188,40 @@ struct neg_zero<nv_bfloat16> {
188188
template <typename T>
189189
__device__ static constexpr T neg_zero_v = neg_zero<T>::value;
190190

191+
template <typename T>
192+
__device__ bool is_negative_zero(T) {
193+
return false;
194+
}
195+
196+
// float specialization
197+
template <>
198+
__device__ bool is_negative_zero<float>(float x) {
199+
return (__float_as_int(x) == 0x80000000);
200+
}
201+
202+
// double specialization
203+
template <>
204+
__device__ bool is_negative_zero<double>(double x) {
205+
return (__double_as_longlong(x) == 0x8000000000000000ULL);
206+
}
207+
208+
// __half specialization
209+
template <>
210+
__device__ bool is_negative_zero<__half>(__half x) {
211+
return (__half_as_ushort(x) == 0x8000);
212+
}
213+
214+
// __nv_bfloat16 specialization
215+
template <>
216+
__device__ bool is_negative_zero<__nv_bfloat16>(__nv_bfloat16 x) {
217+
return (__bfloat16_as_ushort(x) == 0x8000);
218+
}
219+
191220
template <typename T, uint32_t VEC_SIZE>
192221
__device__ __forceinline__ bool has_neg_zero(const vec_t<T, VEC_SIZE>& vec) {
193222
#pragma unroll
194223
for (int i = 0; i < VEC_SIZE; ++i) {
195-
if (vec[i] == neg_zero_v<T>) {
224+
if (is_negative_zero(vec[i])) {
196225
return true;
197226
}
198227
}
@@ -203,7 +232,7 @@ template <typename T, uint32_t VEC_SIZE>
203232
__device__ __forceinline__ void remove_neg_zero(vec_t<T, VEC_SIZE>& vec) {
204233
#pragma unroll
205234
for (int i = 0; i < VEC_SIZE; ++i) {
206-
vec[i] = (vec[i] == neg_zero_v<T>) ? static_cast<T>(0.f) : vec[i];
235+
vec[i] = (is_negative_zero(vec[i])) ? static_cast<T>(0.f) : vec[i];
207236
}
208237
}
209238

@@ -1694,10 +1723,8 @@ cudaError_t lamportInitializeAll(void* buffer_0, void* buffer_1, void* buffer_2,
16941723
status = lamportInitialize<T>(buffer_2, size / sizeof(T), stream);
16951724
FLASHINFER_CHECK(status == cudaSuccess, "lamportInitialize failed with error code " +
16961725
std::string(cudaGetErrorString(status)));
1697-
1726+
cudaDeviceSynchronize();
16981727
return cudaSuccess;
1699-
// todo(zihao): we can skip sycn with stream as below?
1700-
// cudaDeviceSynchronize();
17011728
}
17021729

17031730
} // namespace trtllm_allreduce

β€Žtests/test_trtllm_allreduce.pyβ€Ž

Lines changed: 10 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def _run_correctness_worker(world_size, rank, dtype, distributed_init_port):
5050

5151
# below are the recommended hidden sizes for custom all-reduce in trtllm test
5252
# hidden_size should be in range [256, 8192], and maxHiddenSize should be 8192
53-
hidden_sizes = [1024, 2048, 4096]
53+
hidden_sizes = [1024, 4096]
5454
config_codes = [
5555
0,
5656
comm.AllReduceStrategyConfig.USE_MEMCPY,
@@ -79,7 +79,7 @@ def _run_correctness_worker(world_size, rank, dtype, distributed_init_port):
7979
group=group,
8080
)
8181

82-
test_loop = 1 # could be any number
82+
test_loop = 2 # could be any number
8383

8484
# NOTE: the barrier flag should be initialized to 1, and incremented by 1 for each AR
8585
flag_value = 1
@@ -165,20 +165,12 @@ def _run_correctness_worker(world_size, rank, dtype, distributed_init_port):
165165
)
166166
dist.all_reduce(inp1_ref, group=group)
167167

168-
tolerance = 1e-2 if dtype == torch.float16 else 5e-2
168+
tolerance = 1e-2 if dtype == torch.float16 else 8e-2
169169

170170
if fusion_op_code == comm.AllReduceFusionOp.NONE:
171-
if not torch.allclose(
171+
torch.testing.assert_close(
172172
out1, inp1_ref, atol=tolerance, rtol=3e-2
173-
):
174-
print(
175-
f"test RANK {rank}: {world_size}-{dtype}-{strategy_code}-{config_code}-{fusion_op_code}-{launch_with_pdl}-{hidden_size} failed"
176-
)
177-
print(f"out1: {out1}")
178-
print(f"inp1_ref: {inp1_ref}")
179-
print(f"tolerance: {tolerance}")
180-
print(f"rtol: {3e-2}")
181-
pass_flag = False
173+
)
182174
elif (
183175
fusion_op_code
184176
== comm.AllReduceFusionOp.RESIDUAL_RMS_NORM
@@ -198,21 +190,12 @@ def _run_correctness_worker(world_size, rank, dtype, distributed_init_port):
198190
+ bias_float[i % hidden_size]
199191
)
200192
ref_half = ref_float.to(dtype)
201-
202-
if not torch.allclose(
193+
torch.testing.assert_close(
203194
inter_buffer,
204195
ref_half,
205196
atol=tolerance,
206197
rtol=3e-2,
207-
):
208-
print(
209-
f"test RANK {rank}: {world_size}-{dtype}-{strategy_code}-{config_code}-{fusion_op_code}-{launch_with_pdl}-{hidden_size} failed"
210-
)
211-
print(f"inter_buffer: {inter_buffer}")
212-
print(f"ref_half: {ref_half}")
213-
print(f"tolerance: {tolerance}")
214-
print(f"rtol: {3e-2}")
215-
pass_flag = False
198+
)
216199

217200
# RMSNorm over hidden size
218201
ref_float = ref_float.view(
@@ -229,23 +212,12 @@ def _run_correctness_worker(world_size, rank, dtype, distributed_init_port):
229212
torch.float32
230213
)
231214
normed_half = normed_float.to(dtype)
232-
233-
if not torch.allclose(
215+
torch.testing.assert_close(
234216
out1,
235217
normed_half.view(-1),
236218
atol=tolerance,
237219
rtol=3e-2,
238-
):
239-
print(
240-
f"test RANK {rank}: {world_size}-{dtype}-{strategy_code}-{config_code}-{fusion_op_code}-{launch_with_pdl}-{hidden_size} failed"
241-
)
242-
print(f"out1: {out1}")
243-
print(
244-
f"normed_half.view(-1): {normed_half.view(-1)}"
245-
)
246-
print(f"tolerance: {tolerance}")
247-
print(f"rtol: {3e-2}")
248-
pass_flag = False
220+
)
249221

250222
elif (
251223
fusion_op_code
@@ -259,7 +231,7 @@ def _run_correctness_worker(world_size, rank, dtype, distributed_init_port):
259231
print(
260232
f"test RANK {rank}: {world_size}-{dtype}-{strategy_code}-{config_code}-{fusion_op_code}-{launch_with_pdl}-{hidden_size} passed"
261233
)
262-
# dist.barrier(group=group)
234+
# torch.cuda.synchronize()
263235
# # you might want to enable this barrier for a better log output, but it's not mandatory across allReduce calls
264236
finally:
265237
dist.barrier(group=group)

0 commit comments

Comments
Β (0)