@@ -42,7 +42,8 @@ __device__ void BlockLoad(const InT* input,
4242 const uint32_t local_off_M = threadIdx .y + i * 16 ;
4343 const uint32_t off_m = blockIdx .x * 128 + local_off_M;
4444 const uint32_t off_k = blockIdx .y * 128 + threadIdx .x * VecSize;
45- const size_t offset = off_m * K + off_k;
45+ const size_t offset =
46+ static_cast <size_t >(off_m) * static_cast <size_t >(K) + off_k;
4647
4748 float scale;
4849 if constexpr (need_dequant) {
@@ -53,15 +54,17 @@ __device__ void BlockLoad(const InT* input,
5354
5455#pragma unroll
5556 for (uint32_t j = 0 ; j < 4 ; j += VecSize) {
56- const size_t idx = offset + j * 32 ;
57- using LoadT = VecType<InT, VecSize>;
58- LoadT data = *reinterpret_cast <const LoadT*>(input + idx);
57+ if (off_k + j * 32 < K) {
58+ const size_t idx = offset + j * 32 ;
59+ using LoadT = VecType<InT, VecSize>;
60+ LoadT data = *reinterpret_cast <const LoadT*>(input + idx);
5961#pragma unroll
60- for (uint32_t k = 0 ; k < VecSize; k++) {
61- if constexpr (need_dequant) {
62- x[i][j + k] = __float2bfloat16 (static_cast <float >(data[k]) * scale);
63- } else {
64- x[i][j + k] = (*reinterpret_cast <__nv_bfloat16*>(&data[k]));
62+ for (uint32_t k = 0 ; k < VecSize; k++) {
63+ if constexpr (need_dequant) {
64+ x[i][j + k] = __float2bfloat16 (static_cast <float >(data[k]) * scale);
65+ } else {
66+ x[i][j + k] = (*reinterpret_cast <__nv_bfloat16*>(&data[k]));
67+ }
6568 }
6669 }
6770 }
0 commit comments