Skip to content

Commit ac5e0c0

Browse files
committed
CUDA: conv2d convert int64_t to int
1 parent d633cee commit ac5e0c0

File tree

1 file changed

+81
-92
lines changed

1 file changed

+81
-92
lines changed

ggml/src/ggml-cuda/conv2d.cu

Lines changed: 81 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -2,56 +2,44 @@
22
#include "convert.cuh"
33

44
struct conv_params {
5-
const int64_t IW, IH;
6-
const int64_t OW, OH;
7-
const int64_t KW, KH;
8-
const int64_t ST_X, ST_Y;
9-
const int64_t PD_X, PD_Y;
10-
const int64_t DL_X, DL_Y;
11-
const int64_t IC, OC;
12-
const int64_t B;
5+
const int IW, IH;
6+
const int OW, OH;
7+
const int KW, KH;
8+
const int ST_X, ST_Y;
9+
const int PD_X, PD_Y;
10+
const int DL_X, DL_Y;
11+
const int IC, OC;
12+
const int B;
1313
const int64_t TOTAL;
1414
// helpers
15-
const int64_t IC_KH_KW, N_OH_OW;
15+
const int IC_KH_KW, N_OH_OW;
1616
};
1717

18-
__device__ __forceinline__ static int calculate_input_coord(int64_t out_coord,
19-
int64_t kern_coord,
20-
int64_t stride,
21-
int64_t dilation,
22-
int64_t padding) {
18+
__device__ __forceinline__ static int calculate_input_coord(int out_coord,
19+
int kern_coord,
20+
int stride,
21+
int dilation,
22+
int padding) {
2323
return out_coord * stride + kern_coord * dilation - padding;
2424
}
2525

2626
struct whcn_layout {
27-
__device__ __forceinline__ static int64_t input_index(int64_t n,
28-
int64_t c,
29-
int64_t y,
30-
int64_t x,
31-
const conv_params & P) {
27+
__device__ __forceinline__ static int64_t input_index(int n, int c, int y, int x, const conv_params & P) {
3228
return n * (P.IC * P.IW * P.IH) + c * P.IW * P.IH + y * P.IW + x;
3329
}
3430

35-
__device__ __forceinline__ static int64_t kernel_index(int64_t c_out,
36-
int64_t c_in,
37-
int64_t ky,
38-
int64_t kx,
39-
const conv_params & P) {
31+
__device__ __forceinline__ static int64_t kernel_index(int c_out, int c_in, int ky, int kx, const conv_params & P) {
4032
return c_out * (P.IC * P.KH * P.KW) + c_in * (P.KH * P.KW) + ky * P.KW + kx;
4133
}
4234

43-
__device__ __forceinline__ static int64_t output_index(int64_t n,
44-
int64_t c,
45-
int64_t y,
46-
int64_t x,
47-
const conv_params & P) {
35+
__device__ __forceinline__ static int64_t output_index(int n, int c, int y, int x, const conv_params & P) {
4836
return n * (P.OC * P.OW * P.OH) + c * P.OW * P.OH + y * P.OW + x;
4937
}
5038

5139
__device__ __forceinline__ static void unpack_ickhkw(int64_t idx,
52-
int64_t & ic,
53-
int64_t & kh,
54-
int64_t & kw,
40+
int & ic,
41+
int & kh,
42+
int & kw,
5543
const conv_params & P) {
5644
ic = idx / (P.KW * P.KH);
5745
int64_t r = idx - ic * (P.KW * P.KH);
@@ -60,9 +48,9 @@ struct whcn_layout {
6048
}
6149

6250
__device__ __forceinline__ static void unpack_nohow(int64_t idx,
63-
int64_t & n,
64-
int64_t & oh,
65-
int64_t & ow,
51+
int & n,
52+
int & oh,
53+
int & ow,
6654
const conv_params & P) {
6755
n = idx / (P.OH * P.OW);
6856
int64_t r = idx - n * (P.OH * P.OW);
@@ -111,8 +99,8 @@ template <typename layout> class float_mma {
11199
}
112100
}
113101

114-
__device__ __forceinline__ void store_result(const int64_t OC_BASE,
115-
const int64_t NOHOW_BASE,
102+
__device__ __forceinline__ void store_result(const int OC_BASE,
103+
const int NOHOW_BASE,
116104
float * __restrict__ OUT,
117105
const conv_params & P) const {
118106
const int lane_id = threadIdx.x % WARP_SIZE;
@@ -122,14 +110,13 @@ template <typename layout> class float_mma {
122110
const int m = e / WMMA_N;
123111
const int n = e % WMMA_N;
124112

125-
const int64_t oc = OC_BASE + m;
126-
const int64_t nohow = NOHOW_BASE + n;
113+
const int oc = OC_BASE + m;
114+
const int nohow = NOHOW_BASE + n;
127115

128116
if (oc < P.OC && nohow < P.N_OH_OW) {
129-
int64_t n_, oh, ow;
117+
int n_, oh, ow;
130118
layout::unpack_nohow(nohow, n_, oh, ow, P);
131-
const int64_t out_idx = layout::output_index(n_, oc, oh, ow, P);
132-
OUT[out_idx] = acc[i];
119+
OUT[layout::output_index(n_, oc, oh, ow, P)] = acc[i];
133120
}
134121
}
135122
}
@@ -158,27 +145,30 @@ template <typename layout> class half_mma {
158145
}
159146
}
160147

161-
__device__ __forceinline__ void mma(const half * A_sh, const half * B_sh, const int strideA, const int strideB) {
148+
__device__ __forceinline__ void mma(const half * __restrict__ A_sh,
149+
const half * __restrict__ B_sh,
150+
const int strideA,
151+
const int strideB) {
162152
ggml_cuda_mma::load_ldmatrix(a_frag, (const half2 *) A_sh, strideA / 2);
163153
ggml_cuda_mma::load_ldmatrix_trans(b_frag, (const half2 *) B_sh, strideB / 2);
164154
ggml_cuda_mma::mma(c_frag, a_frag, b_frag);
165155
}
166156

167-
__device__ __forceinline__ void store_result(const int64_t OC_BASE,
168-
const int64_t NOHOW_BASE,
169-
float * OUT,
157+
__device__ __forceinline__ void store_result(const int OC_BASE,
158+
const int NOHOW_BASE,
159+
float * __restrict__ OUT,
170160
const conv_params & P) const {
171161
# pragma unroll
172162
for (int l = 0; l < tile_acc::ne; ++l) {
173-
const int64_t e = tile_acc::get_i(l) * WMMA_N + tile_acc::get_j(l);
174-
const int m = e / WMMA_N;
175-
const int n = e % WMMA_N;
163+
const int e = tile_acc::get_i(l) * WMMA_N + tile_acc::get_j(l);
164+
const int m = e / WMMA_N;
165+
const int n = e % WMMA_N;
176166

177-
const int64_t oc = OC_BASE + m;
178-
const int64_t nohow = NOHOW_BASE + n;
167+
const int oc = OC_BASE + m;
168+
const int nohow = NOHOW_BASE + n;
179169

180170
if (oc < P.OC && nohow < (P.N_OH_OW)) {
181-
int64_t n, oh, ow;
171+
int n, oh, ow;
182172
layout::unpack_nohow(nohow, n, oh, ow, P);
183173
OUT[layout::output_index(n, oc, oh, ow, P)] = c_frag.x[l];
184174
}
@@ -228,8 +218,8 @@ template <typename layout> class half_mma {
228218
}
229219
}
230220

231-
__device__ __forceinline__ void store_result(const int64_t OC_BASE,
232-
const int64_t NOHOW_BASE,
221+
__device__ __forceinline__ void store_result(const int OC_BASE,
222+
const int NOHOW_BASE,
233223
float * __restrict__ OUT,
234224
const conv_params & P) const {
235225
const int lane_id = threadIdx.x % WARP_SIZE;
@@ -239,14 +229,13 @@ template <typename layout> class half_mma {
239229
const int m = e / WMMA_N;
240230
const int n = e % WMMA_N;
241231

242-
const int64_t oc = OC_BASE + m;
243-
const int64_t nohow = NOHOW_BASE + n;
232+
const int oc = OC_BASE + m;
233+
const int nohow = NOHOW_BASE + n;
244234

245235
if (oc < P.OC && nohow < P.N_OH_OW) {
246-
int64_t n_, oh, ow;
236+
int n_, oh, ow;
247237
layout::unpack_nohow(nohow, n_, oh, ow, P);
248-
const int64_t out_idx = layout::output_index(n_, oc, oh, ow, P);
249-
OUT[out_idx] = acc[i];
238+
OUT[layout::output_index(n_, oc, oh, ow, P)] = acc[i];
250239
}
251240
}
252241
}
@@ -258,26 +247,26 @@ template <typename T, typename layout, typename mma, int num_warps>
258247
__global__ void conv2d_kernel(const float * IN, const T * IK, float * Out, const conv_params P) {
259248
extern __shared__ unsigned char smem_raw[];
260249

261-
const int64_t NUM_IC_TILES = (P.IC_KH_KW + BS_ICKHKW - 1) / BS_ICKHKW;
262-
const int64_t warpId = threadIdx.y;
250+
const int NUM_IC_TILES = (P.IC_KH_KW + BS_ICKHKW - 1) / BS_ICKHKW;
251+
const int warpId = threadIdx.y;
263252

264-
const int64_t WARPS_PER_NOHOW = max(1, BS_NOHOW / WMMA_N);
265-
const int64_t total_warps_need = (((BS_OC * BS_NOHOW) + (WMMA_M * WMMA_N) - 1) / (WMMA_M * WMMA_N));
266-
const int64_t num_work_per_warps = (total_warps_need + num_warps - 1) / num_warps;
253+
const int WARPS_PER_NOHOW = max(1, BS_NOHOW / WMMA_N);
254+
const int total_warps_need = (((BS_OC * BS_NOHOW) + (WMMA_M * WMMA_N) - 1) / (WMMA_M * WMMA_N));
255+
const int num_work_per_warps = (total_warps_need + num_warps - 1) / num_warps;
267256

268257
mma acc[num_work_per_warps];
269258

270-
const int64_t num_block_nohow = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW;
271-
const int64_t BL_IDX_OC = blockIdx.x / num_block_nohow;
272-
const int64_t BL_IDX_NOHOW = blockIdx.x % num_block_nohow;
259+
const int num_block_nohow = (P.N_OH_OW + BS_NOHOW - 1) / BS_NOHOW;
260+
const int BL_IDX_OC = blockIdx.x / num_block_nohow;
261+
const int BL_IDX_NOHOW = blockIdx.x % num_block_nohow;
273262

274-
const int64_t BLOCK_OC_BASE = BL_IDX_OC * BS_OC;
275-
const int64_t BLOCK_NOHOW_BASE = BL_IDX_NOHOW * BS_NOHOW;
263+
const int BLOCK_OC_BASE = BL_IDX_OC * BS_OC;
264+
const int BLOCK_NOHOW_BASE = BL_IDX_NOHOW * BS_NOHOW;
276265

277266
unsigned char * ptr = smem_raw;
278267

279-
const int64_t A_total = BS_OC * BS_ICKHKW;
280-
const int64_t B_total = BS_ICKHKW * BS_NOHOW;
268+
const int A_total = BS_OC * BS_ICKHKW;
269+
const int B_total = BS_ICKHKW * BS_NOHOW;
281270

282271
size_t offsetA = (size_t) A_total * sizeof(T);
283272
T * A_sh = reinterpret_cast<T *>(ptr);
@@ -287,33 +276,33 @@ __global__ void conv2d_kernel(const float * IN, const T * IK, float * Out, const
287276
T * B_sh = reinterpret_cast<T *>(ptr);
288277
ptr += offsetB;
289278

290-
int64_t ic, kh, kw;
291-
int64_t n, oh, ow;
292-
for (int64_t t = 0; t < NUM_IC_TILES; ++t) {
279+
int ic, kh, kw;
280+
int n, oh, ow;
281+
for (int t = 0; t < NUM_IC_TILES; ++t) {
293282
#pragma unroll
294-
for (int64_t tid = (threadIdx.y * blockDim.x + threadIdx.x); tid < A_total; tid += (blockDim.x * blockDim.y)) {
283+
for (int tid = (threadIdx.y * blockDim.x + threadIdx.x); tid < A_total; tid += (blockDim.x * blockDim.y)) {
295284
const int row = tid / BS_ICKHKW;
296285
const int col = tid % BS_ICKHKW;
297286

298-
int64_t shared_oc = BLOCK_OC_BASE + row;
299-
int64_t shared_ickhkw = t * BS_ICKHKW + col;
287+
int shared_oc = BLOCK_OC_BASE + row;
288+
int shared_ickhkw = t * BS_ICKHKW + col;
300289

301290
T val = ggml_cuda_cast<T>(0);
302291
if (shared_oc < P.OC && shared_ickhkw < P.IC_KH_KW) {
303292
layout::unpack_ickhkw(shared_ickhkw, ic, kh, kw, P);
304293

305-
const int64_t kidx = layout::kernel_index(shared_oc, ic, kh, kw, P);
306-
val = IK[kidx];
294+
const int kidx = layout::kernel_index(shared_oc, ic, kh, kw, P);
295+
val = IK[kidx];
307296
}
308297
A_sh[row * BS_ICKHKW + col] = val;
309298
}
310299
#pragma unroll
311-
for (int64_t tid = (threadIdx.y * blockDim.x + threadIdx.x); tid < B_total; tid += (blockDim.x * blockDim.y)) {
300+
for (int tid = (threadIdx.y * blockDim.x + threadIdx.x); tid < B_total; tid += (blockDim.x * blockDim.y)) {
312301
const int brow = tid / BS_NOHOW;
313302
const int bcol = tid % BS_NOHOW;
314303

315-
int64_t IC_KH_KW_IDX = t * BS_ICKHKW + brow;
316-
int64_t N_OH_OW_IDX = BLOCK_NOHOW_BASE + bcol;
304+
int IC_KH_KW_IDX = t * BS_ICKHKW + brow;
305+
int N_OH_OW_IDX = BLOCK_NOHOW_BASE + bcol;
317306

318307
T val = ggml_cuda_cast<T>(0);
319308
if (N_OH_OW_IDX < P.N_OH_OW && IC_KH_KW_IDX < P.IC_KH_KW) {
@@ -333,10 +322,10 @@ __global__ void conv2d_kernel(const float * IN, const T * IK, float * Out, const
333322

334323
#pragma unroll
335324
for (int warp = warpId, i = 0; warp < total_warps_need; warp += num_warps, i++) {
336-
const int64_t WARP_OC = warp / WARPS_PER_NOHOW;
337-
const int64_t WARP_NOHOW = warp % WARPS_PER_NOHOW;
338-
const T * A_warp_base = A_sh + WARP_OC * WMMA_M * BS_ICKHKW;
339-
const T * B_warp_base = B_sh + WARP_NOHOW * WMMA_N;
325+
const int WARP_OC = warp / WARPS_PER_NOHOW;
326+
const int WARP_NOHOW = warp % WARPS_PER_NOHOW;
327+
const T * A_warp_base = A_sh + WARP_OC * WMMA_M * BS_ICKHKW;
328+
const T * B_warp_base = B_sh + WARP_NOHOW * WMMA_N;
340329
#pragma unroll
341330
for (int k_tile = 0; k_tile < BS_ICKHKW; k_tile += WMMA_K) {
342331
const T * A_k_ptr = A_warp_base + k_tile;
@@ -349,10 +338,10 @@ __global__ void conv2d_kernel(const float * IN, const T * IK, float * Out, const
349338

350339
#pragma unroll
351340
for (int warp = warpId, i = 0; warp < total_warps_need; warp += num_warps, i++) {
352-
const int64_t WARP_OC = warp / WARPS_PER_NOHOW;
353-
const int64_t WARP_NOHOW = warp % WARPS_PER_NOHOW;
354-
const int64_t OC_BASE = BLOCK_OC_BASE + WARP_OC * WMMA_M;
355-
const int64_t NOHOW_BASE = BLOCK_NOHOW_BASE + WARP_NOHOW * WMMA_N;
341+
const int WARP_OC = warp / WARPS_PER_NOHOW;
342+
const int WARP_NOHOW = warp % WARPS_PER_NOHOW;
343+
const int OC_BASE = BLOCK_OC_BASE + WARP_OC * WMMA_M;
344+
const int NOHOW_BASE = BLOCK_NOHOW_BASE + WARP_NOHOW * WMMA_N;
356345
acc[i].store_result(OC_BASE, NOHOW_BASE, Out, P);
357346
}
358347
}
@@ -454,8 +443,8 @@ void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
454443
const int B = input->ne[3]; // n_batches
455444

456445
const int64_t TOTAL = B * OC * OH * OW;
457-
const int64_t IC_KH_KW = IC * KH * KW;
458-
const int64_t N_OH_OW = B * OH * OW;
446+
const int IC_KH_KW = IC * KH * KW;
447+
const int N_OH_OW = B * OH * OW;
459448
conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X,
460449
PD_Y, DL_X, DL_Y, IC, OC, B, TOTAL, IC_KH_KW, N_OH_OW };
461450

0 commit comments

Comments
 (0)