From 57853c1b45251391e3b5f716456a391786cd18df Mon Sep 17 00:00:00 2001 From: Chuanwei Yi Date: Fri, 24 Oct 2025 03:11:55 -0700 Subject: [PATCH] fix integer overflow problem Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2055 When token size is very big, the integer multiplication will lead to integer overflow therefore lead to a cuda memory access error. Reviewed By: royren622 Differential Revision: D85330676 --- .../cuda/cutlass_blackwell_fmha/device/fmha_device_bwd.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha_device_bwd.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha_device_bwd.hpp index 765f3b4f78..512c7e0a98 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha_device_bwd.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha_device_bwd.hpp @@ -293,7 +293,7 @@ class Sm100FmhaBwd { ElementAccumulator* scaled_lse = reinterpret_cast(workspace_scaled_lse); ElementAccumulator* dQ_acc = reinterpret_cast(workspace_dQ); params_.dQ_acc = dQ_acc; - params_.dQ_acc_size = B*H*Q*D * sizeof(ElementAccumulator); + params_.dQ_acc_size = static_cast(B)*H*Q*D * sizeof(ElementAccumulator); auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO, scaled_lse); auto args_convert = to_convert_arguments(args, dQ_acc); params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream); @@ -320,9 +320,9 @@ class Sm100FmhaBwd { int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment char* workspace_chr = reinterpret_cast(workspace); ElementAccumulator* sum_OdO = reinterpret_cast(workspace_chr); - workspace_chr += B*H*Q * sizeof(ElementAccumulator); + workspace_chr += static_cast(B)*H*Q * sizeof(ElementAccumulator); ElementAccumulator* scaled_lse = reinterpret_cast(workspace_chr); - workspace_chr += B*H*Q * sizeof(ElementAccumulator); + workspace_chr += static_cast(B)*H*Q * sizeof(ElementAccumulator); ElementAccumulator* dQ_acc = reinterpret_cast(workspace_chr); return initialize_split(args, dQ_acc, sum_OdO, scaled_lse, stream); }