Skip to content

Commit f849dcd

Browse files
Chuanwei Yimeta-codesync[bot]
authored andcommitted
fix integer overflow problem (#5051)
Summary: Pull Request resolved: #5051 X-link: https://github.com/facebookresearch/FBGEMM/pull/2055 When token size is very big, the integer multiplication will lead to integer overflow therefore lead to a cuda memory access error. Reviewed By: royren622 Differential Revision: D85330676 fbshipit-source-id: 51a5f61ed52f43358b61f4e689b33b95d051f987
1 parent 3c4ae35 commit f849dcd

File tree

1 file changed

+3
-3
lines changed
  • fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device

1 file changed

+3
-3
lines changed

fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha_device_bwd.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ class Sm100FmhaBwd {
293293
ElementAccumulator* scaled_lse = reinterpret_cast<ElementAccumulator*>(workspace_scaled_lse);
294294
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_dQ);
295295
params_.dQ_acc = dQ_acc;
296-
params_.dQ_acc_size = B*H*Q*D * sizeof(ElementAccumulator);
296+
params_.dQ_acc_size = static_cast<size_t>(B)*H*Q*D * sizeof(ElementAccumulator);
297297
auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO, scaled_lse);
298298
auto args_convert = to_convert_arguments(args, dQ_acc);
299299
params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream);
@@ -320,9 +320,9 @@ class Sm100FmhaBwd {
320320
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
321321
char* workspace_chr = reinterpret_cast<char*>(workspace);
322322
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_chr);
323-
workspace_chr += B*H*Q * sizeof(ElementAccumulator);
323+
workspace_chr += static_cast<size_t>(B)*H*Q * sizeof(ElementAccumulator);
324324
ElementAccumulator* scaled_lse = reinterpret_cast<ElementAccumulator*>(workspace_chr);
325-
workspace_chr += B*H*Q * sizeof(ElementAccumulator);
325+
workspace_chr += static_cast<size_t>(B)*H*Q * sizeof(ElementAccumulator);
326326
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_chr);
327327
return initialize_split(args, dQ_acc, sum_OdO, scaled_lse, stream);
328328
}

0 commit comments

Comments
 (0)