|
23 | 23 | #include <infiniband/mlx5dv.h> |
24 | 24 | #include <non_abi/device/threadgroup/nvshmemi_common_device_defines.cuh> |
25 | 25 | #include <device_host_transport/nvshmem_common_ibgda.h> |
| 26 | +#ifdef __NVCC__ |
| 27 | +#include <cub/cub.cuh> |
| 28 | +#endif |
26 | 29 | // clang-format on |
27 | | - |
28 | 30 | #include "paddle/fluid/distributed/collective/deep_ep/kernels/configs.cuh" |
29 | 31 | #include "paddle/fluid/distributed/collective/deep_ep/kernels/exception.cuh" |
30 | 32 | #include "paddle/fluid/distributed/collective/deep_ep/kernels/ibgda_device.cuh" |
31 | 33 | #include "paddle/fluid/distributed/collective/deep_ep/kernels/launch.cuh" |
32 | 34 | #include "paddle/phi/kernels/funcs/aligned_vector.h" |
33 | | - |
34 | 35 | namespace deep_ep { |
35 | 36 |
|
36 | 37 | namespace internode_ll { |
@@ -189,7 +190,32 @@ __global__ __launch_bounds__( |
189 | 190 | // Note(zkk) |
190 | 191 | // create a run_deepep_loop, so I need not modify Deepep's code any more. |
191 | 192 | int run_deepep_loop = 1; |
192 | | - if (use_expertwise_scale) { |
| 193 | + if (use_expertwise_scale && kUseFP8) { // w4afp8 |
| 194 | + run_deepep_loop = 0; |
| 195 | + for (int ii = 0; ii < num_topk; ii++) { |
| 196 | + int tmp_id = topk_idx[ii + token_idx * num_topk]; |
| 197 | + float scale = expertwise_scale[tmp_id]; |
| 198 | + for (int i = thread_id; i < hidden_bf16_int4; i += num_threads) { |
| 199 | + auto int4_value = __ldg(x_int4 + i); |
| 200 | + auto bf16_values = reinterpret_cast<nv_bfloat16*>(&int4_value); |
| 201 | + int2 int2_value; |
| 202 | + phi::AlignedVector<phi::dtype::float8_e4m3fn, 8> res_vec; |
| 203 | + const float max_bound = 448.f; |
| 204 | + const float min_bound = -448.f; |
| 205 | + for (int j = 0; j < 8; j++) { |
| 206 | + float quant_value = |
| 207 | + max_bound * scale * static_cast<float>(bf16_values[j]); |
| 208 | + quant_value = quant_value > max_bound ? max_bound : quant_value; |
| 209 | + quant_value = quant_value < min_bound ? min_bound : quant_value; |
| 210 | + res_vec[j] = static_cast<phi::dtype::float8_e4m3fn>(quant_value); |
| 211 | + } |
| 212 | + phi::Store(res_vec, |
| 213 | + reinterpret_cast<phi::dtype::float8_e4m3fn*>(rdma_x) + |
| 214 | + (ii + token_idx * num_topk) * num_bytes_per_msg + |
| 215 | + sizeof(int4) + i * sizeof(res_vec)); |
| 216 | + } |
| 217 | + } |
| 218 | + } else if (use_expertwise_scale) { // w4aint8 |
193 | 219 | run_deepep_loop = 0; |
194 | 220 | for (int ii = 0; ii < num_topk; ii++) { |
195 | 221 | int tmp_id = topk_idx[ii + token_idx * num_topk]; |
@@ -224,7 +250,7 @@ __global__ __launch_bounds__( |
224 | 250 | // Read |
225 | 251 | auto int4_value = __ldg(x_int4 + i); |
226 | 252 |
|
227 | | - if (kUseFP8) { |
| 253 | + if (kUseFP8 && !use_expertwise_scale) { |
228 | 254 | // Calculate local amax |
229 | 255 | auto bf16_values = reinterpret_cast<nv_bfloat16*>(&int4_value); |
230 | 256 | float fp32_values[kNumElemsPerRead]; |
@@ -502,7 +528,7 @@ LOW_LATENCY_DISPATCH_RECV: |
502 | 528 | st_na_global); |
503 | 529 |
|
504 | 530 | // Copy scales |
505 | | - if (kUseFP8) { |
| 531 | + if (kUseFP8 && !use_expertwise_scale) { |
506 | 532 | const auto src_scales = reinterpret_cast<float*>( |
507 | 533 | reinterpret_cast<uint8_t*>(src_data) + hidden_bytes); |
508 | 534 | const auto dst_scales = |
|
0 commit comments