Skip to content

Commit 99f3f15

Browse files
authored
cuda : fix im2col_f32_f16 (#658)
1 parent 6b846cb commit 99f3f15

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

src/ggml-cuda.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5259,17 +5259,17 @@ static __global__ void im2col_f32_f16(
52595259
const int ky = (i - kd) / OW;
52605260
const int ix = i % OW;
52615261

5262-
const int iiw = ix * s0 + kx * d0 - p0;
5263-
const int iih = blockIdx.y * s1 + ky * d1 - p1;
5262+
const int64_t iiw = ix * s0 + kx * d0 - p0;
5263+
const int64_t iih = blockIdx.y * s1 + ky * d1 - p1;
52645264

5265-
const int offset_dst =
5265+
const int64_t offset_dst =
52665266
(blockIdx.y * OW + ix) * CHW +
52675267
(blockIdx.z * (KW * KH) + ky * KW + kx);
52685268

52695269
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
52705270
dst[offset_dst] = __float2half(0.0f);
52715271
} else {
5272-
const int offset_src = blockIdx.z * offset_delta;
5272+
const int64_t offset_src = blockIdx.z * offset_delta;
52735273
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
52745274
}
52755275
}

0 commit comments

Comments
 (0)