@@ -159,7 +159,7 @@ struct LayerNormDataWriter {
159159 temp_dst[j] = static_cast <T>((buffer[i * VecSize + j] - row_mean) *
160160 row_inv_var);
161161 }
162- v_dst[threadIdx .x + blockDim .x * i] = temp_dst;
162+ v_dst[threadIdx .x + static_cast < int64_t >( blockDim .x ) * i] = temp_dst;
163163 }
164164 } else {
165165 const VecScaleT *__restrict__ v_scale =
@@ -168,7 +168,7 @@ struct LayerNormDataWriter {
168168 reinterpret_cast <const VecScaleT *__restrict__ >(bias);
169169 if (valid_scale && valid_bias) {
170170 for (int i = 0 ; i < write_times; ++i) {
171- int idx = threadIdx .x + blockDim .x * i;
171+ int64_t idx = threadIdx .x + static_cast < int64_t >( blockDim .x ) * i;
172172 VecT temp_dst;
173173 VecScaleT temp_v_scale = v_scale[idx];
174174 VecScaleT temp_v_bias = v_bias[idx];
@@ -184,7 +184,7 @@ struct LayerNormDataWriter {
184184 } else {
185185 if (valid_scale) {
186186 for (int i = 0 ; i < write_times; ++i) {
187- int idx = threadIdx .x + blockDim .x * i;
187+ int64_t idx = threadIdx .x + static_cast < int64_t >( blockDim .x ) * i;
188188 VecT temp_dst;
189189 VecScaleT temp_v_scale = v_scale[idx];
190190#pragma unroll
@@ -232,27 +232,27 @@ struct LayerNormDataWriter<T, U, IsSameType, 1> {
232232 if ((!valid_scale) && (!valid_bias)) {
233233 if (threadIdx .x < last_tid_idx) {
234234 for (int i = 0 ; i < cols_this_thread; ++i) {
235- row_dst[threadIdx .x + last_tid_idx * i] =
235+ row_dst[threadIdx .x + static_cast < int64_t >( last_tid_idx) * i] =
236236 (buffer[i] - row_mean) * row_inv_var;
237237 }
238238 } else {
239239 for (int i = 0 ; i < cols_this_thread; ++i) {
240- row_dst[last_tid_idx * write_times + i] =
240+ row_dst[static_cast < int64_t >( last_tid_idx) * write_times + i] =
241241 (buffer[i] - row_mean) * row_inv_var;
242242 }
243243 }
244244 } else if (valid_scale && valid_bias) {
245245 if (threadIdx .x < last_tid_idx) {
246246 for (int i = 0 ; i < cols_this_thread; ++i) {
247- int idx = threadIdx .x + last_tid_idx * i;
247+ int64_t idx = threadIdx .x + static_cast < int64_t >( last_tid_idx) * i;
248248 row_dst[idx] =
249249 static_cast <T>(static_cast <U>(scale[idx]) *
250250 (buffer[i] - row_mean) * row_inv_var +
251251 static_cast <U>(bias[idx]));
252252 }
253253 } else {
254254 for (int i = 0 ; i < cols_this_thread; ++i) {
255- int idx = last_tid_idx * write_times + i;
255+ int64_t idx = static_cast < int64_t >( last_tid_idx) * write_times + i;
256256 row_dst[idx] =
257257 static_cast <T>(static_cast <U>(scale[idx]) *
258258 (buffer[i] - row_mean) * row_inv_var +
@@ -263,27 +263,27 @@ struct LayerNormDataWriter<T, U, IsSameType, 1> {
263263 if (valid_scale) {
264264 if (threadIdx .x < last_tid_idx) {
265265 for (int i = 0 ; i < cols_this_thread; ++i) {
266- int idx = threadIdx .x + last_tid_idx * i;
266+ int64_t idx = threadIdx .x + static_cast < int64_t >( last_tid_idx) * i;
267267 row_dst[idx] = static_cast <T>(static_cast <U>(scale[idx]) *
268268 (buffer[i] - row_mean) * row_inv_var);
269269 }
270270 } else {
271271 for (int i = 0 ; i < cols_this_thread; ++i) {
272- int idx = last_tid_idx * write_times + i;
272+ int64_t idx = static_cast < int64_t >( last_tid_idx) * write_times + i;
273273 row_dst[idx] = static_cast <T>(static_cast <U>(scale[idx]) *
274274 (buffer[i] - row_mean) * row_inv_var);
275275 }
276276 }
277277 } else {
278278 if (threadIdx .x < last_tid_idx) {
279279 for (int i = 0 ; i < cols_this_thread; ++i) {
280- int idx = threadIdx .x + last_tid_idx * i;
280+ int64_t idx = threadIdx .x + static_cast < int64_t >( last_tid_idx) * i;
281281 row_dst[idx] = static_cast <T>((buffer[i] - row_mean) * row_inv_var +
282282 static_cast <U>(bias[idx]));
283283 }
284284 } else {
285285 for (int i = 0 ; i < cols_this_thread; ++i) {
286- int idx = last_tid_idx * write_times + i;
286+ int64_t idx = static_cast < int64_t >( last_tid_idx) * write_times + i;
287287 row_dst[idx] = static_cast <T>((buffer[i] - row_mean) * row_inv_var +
288288 static_cast <U>(bias[idx]));
289289 }
0 commit comments