diff --git a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha_device_bwd.hpp b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha_device_bwd.hpp index 765f3b4f78..512c7e0a98 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha_device_bwd.hpp +++ b/fbgemm_gpu/experimental/gen_ai/src/attention/cuda/cutlass_blackwell_fmha/device/fmha_device_bwd.hpp @@ -293,7 +293,7 @@ class Sm100FmhaBwd { ElementAccumulator* scaled_lse = reinterpret_cast(workspace_scaled_lse); ElementAccumulator* dQ_acc = reinterpret_cast(workspace_dQ); params_.dQ_acc = dQ_acc; - params_.dQ_acc_size = B*H*Q*D * sizeof(ElementAccumulator); + params_.dQ_acc_size = static_cast(B)*H*Q*D * sizeof(ElementAccumulator); auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO, scaled_lse); auto args_convert = to_convert_arguments(args, dQ_acc); params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream); @@ -320,9 +320,9 @@ class Sm100FmhaBwd { int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment char* workspace_chr = reinterpret_cast(workspace); ElementAccumulator* sum_OdO = reinterpret_cast(workspace_chr); - workspace_chr += B*H*Q * sizeof(ElementAccumulator); + workspace_chr += static_cast(B)*H*Q * sizeof(ElementAccumulator); ElementAccumulator* scaled_lse = reinterpret_cast(workspace_chr); - workspace_chr += B*H*Q * sizeof(ElementAccumulator); + workspace_chr += static_cast(B)*H*Q * sizeof(ElementAccumulator); ElementAccumulator* dQ_acc = reinterpret_cast(workspace_chr); return initialize_split(args, dQ_acc, sum_OdO, scaled_lse, stream); }