22#include " convert.cuh"
33
44struct conv_params {
5- const int64_t IW, IH;
6- const int64_t OW, OH;
7- const int64_t KW, KH;
8- const int64_t ST_X, ST_Y;
9- const int64_t PD_X, PD_Y;
10- const int64_t DL_X, DL_Y;
11- const int64_t IC, OC;
12- const int64_t B;
5+ const int IW, IH;
6+ const int OW, OH;
7+ const int KW, KH;
8+ const int ST_X, ST_Y;
9+ const int PD_X, PD_Y;
10+ const int DL_X, DL_Y;
11+ const int IC, OC;
12+ const int B;
1313 const int64_t TOTAL;
1414 // helpers
15- const int64_t IC_KH_KW, N_OH_OW;
15+ const int IC_KH_KW, N_OH_OW;
1616};
1717
18- __device__ __forceinline__ static int calculate_input_coord (int64_t out_coord,
19- int64_t kern_coord,
20- int64_t stride,
21- int64_t dilation,
22- int64_t padding) {
18+ __device__ __forceinline__ static int calculate_input_coord (int out_coord,
19+ int kern_coord,
20+ int stride,
21+ int dilation,
22+ int padding) {
2323 return out_coord * stride + kern_coord * dilation - padding;
2424}
2525
2626struct whcn_layout {
27- __device__ __forceinline__ static int64_t input_index (int64_t n,
28- int64_t c,
29- int64_t y,
30- int64_t x,
31- const conv_params & P) {
27+ __device__ __forceinline__ static int64_t input_index (int n, int c, int y, int x, const conv_params & P) {
3228 return n * (P.IC * P.IW * P.IH ) + c * P.IW * P.IH + y * P.IW + x;
3329 }
3430
35- __device__ __forceinline__ static int64_t kernel_index (int64_t c_out,
36- int64_t c_in,
37- int64_t ky,
38- int64_t kx,
39- const conv_params & P) {
31+ __device__ __forceinline__ static int64_t kernel_index (int c_out, int c_in, int ky, int kx, const conv_params & P) {
4032 return c_out * (P.IC * P.KH * P.KW ) + c_in * (P.KH * P.KW ) + ky * P.KW + kx;
4133 }
4234
43- __device__ __forceinline__ static int64_t output_index (int64_t n,
44- int64_t c,
45- int64_t y,
46- int64_t x,
47- const conv_params & P) {
35+ __device__ __forceinline__ static int64_t output_index (int n, int c, int y, int x, const conv_params & P) {
4836 return n * (P.OC * P.OW * P.OH ) + c * P.OW * P.OH + y * P.OW + x;
4937 }
5038
5139 __device__ __forceinline__ static void unpack_ickhkw (int64_t idx,
52- int64_t & ic,
53- int64_t & kh,
54- int64_t & kw,
40+ int & ic,
41+ int & kh,
42+ int & kw,
5543 const conv_params & P) {
5644 ic = idx / (P.KW * P.KH );
5745 int64_t r = idx - ic * (P.KW * P.KH );
@@ -60,9 +48,9 @@ struct whcn_layout {
6048 }
6149
6250 __device__ __forceinline__ static void unpack_nohow (int64_t idx,
63- int64_t & n,
64- int64_t & oh,
65- int64_t & ow,
51+ int & n,
52+ int & oh,
53+ int & ow,
6654 const conv_params & P) {
6755 n = idx / (P.OH * P.OW );
6856 int64_t r = idx - n * (P.OH * P.OW );
@@ -111,8 +99,8 @@ template <typename layout> class float_mma {
11199 }
112100 }
113101
114- __device__ __forceinline__ void store_result (const int64_t OC_BASE,
115- const int64_t NOHOW_BASE,
102+ __device__ __forceinline__ void store_result (const int OC_BASE,
103+ const int NOHOW_BASE,
116104 float * __restrict__ OUT,
117105 const conv_params & P) const {
118106 const int lane_id = threadIdx .x % WARP_SIZE;
@@ -122,14 +110,13 @@ template <typename layout> class float_mma {
122110 const int m = e / WMMA_N;
123111 const int n = e % WMMA_N;
124112
125- const int64_t oc = OC_BASE + m;
126- const int64_t nohow = NOHOW_BASE + n;
113+ const int oc = OC_BASE + m;
114+ const int nohow = NOHOW_BASE + n;
127115
128116 if (oc < P.OC && nohow < P.N_OH_OW ) {
129- int64_t n_, oh, ow;
117+ int n_, oh, ow;
130118 layout::unpack_nohow (nohow, n_, oh, ow, P);
131- const int64_t out_idx = layout::output_index (n_, oc, oh, ow, P);
132- OUT[out_idx] = acc[i];
119+ OUT[layout::output_index (n_, oc, oh, ow, P)] = acc[i];
133120 }
134121 }
135122 }
@@ -158,27 +145,30 @@ template <typename layout> class half_mma {
158145 }
159146 }
160147
161- __device__ __forceinline__ void mma (const half * A_sh, const half * B_sh, const int strideA, const int strideB) {
148+ __device__ __forceinline__ void mma (const half * __restrict__ A_sh,
149+ const half * __restrict__ B_sh,
150+ const int strideA,
151+ const int strideB) {
162152 ggml_cuda_mma::load_ldmatrix (a_frag, (const half2 *) A_sh, strideA / 2 );
163153 ggml_cuda_mma::load_ldmatrix_trans (b_frag, (const half2 *) B_sh, strideB / 2 );
164154 ggml_cuda_mma::mma (c_frag, a_frag, b_frag);
165155 }
166156
167- __device__ __forceinline__ void store_result (const int64_t OC_BASE,
168- const int64_t NOHOW_BASE,
169- float * OUT,
157+ __device__ __forceinline__ void store_result (const int OC_BASE,
158+ const int NOHOW_BASE,
159+ float * __restrict__ OUT,
170160 const conv_params & P) const {
171161# pragma unroll
172162 for (int l = 0 ; l < tile_acc::ne; ++l) {
173- const int64_t e = tile_acc::get_i (l) * WMMA_N + tile_acc::get_j (l);
174- const int m = e / WMMA_N;
175- const int n = e % WMMA_N;
163+ const int e = tile_acc::get_i (l) * WMMA_N + tile_acc::get_j (l);
164+ const int m = e / WMMA_N;
165+ const int n = e % WMMA_N;
176166
177- const int64_t oc = OC_BASE + m;
178- const int64_t nohow = NOHOW_BASE + n;
167+ const int oc = OC_BASE + m;
168+ const int nohow = NOHOW_BASE + n;
179169
180170 if (oc < P.OC && nohow < (P.N_OH_OW )) {
181- int64_t n, oh, ow;
171+ int n, oh, ow;
182172 layout::unpack_nohow (nohow, n, oh, ow, P);
183173 OUT[layout::output_index (n, oc, oh, ow, P)] = c_frag.x [l];
184174 }
@@ -228,8 +218,8 @@ template <typename layout> class half_mma {
228218 }
229219 }
230220
231- __device__ __forceinline__ void store_result (const int64_t OC_BASE,
232- const int64_t NOHOW_BASE,
221+ __device__ __forceinline__ void store_result (const int OC_BASE,
222+ const int NOHOW_BASE,
233223 float * __restrict__ OUT,
234224 const conv_params & P) const {
235225 const int lane_id = threadIdx .x % WARP_SIZE;
@@ -239,14 +229,13 @@ template <typename layout> class half_mma {
239229 const int m = e / WMMA_N;
240230 const int n = e % WMMA_N;
241231
242- const int64_t oc = OC_BASE + m;
243- const int64_t nohow = NOHOW_BASE + n;
232+ const int oc = OC_BASE + m;
233+ const int nohow = NOHOW_BASE + n;
244234
245235 if (oc < P.OC && nohow < P.N_OH_OW ) {
246- int64_t n_, oh, ow;
236+ int n_, oh, ow;
247237 layout::unpack_nohow (nohow, n_, oh, ow, P);
248- const int64_t out_idx = layout::output_index (n_, oc, oh, ow, P);
249- OUT[out_idx] = acc[i];
238+ OUT[layout::output_index (n_, oc, oh, ow, P)] = acc[i];
250239 }
251240 }
252241 }
@@ -258,26 +247,26 @@ template <typename T, typename layout, typename mma, int num_warps>
258247__global__ void conv2d_kernel (const float * IN, const T * IK, float * Out, const conv_params P) {
259248 extern __shared__ unsigned char smem_raw[];
260249
261- const int64_t NUM_IC_TILES = (P.IC_KH_KW + BS_ICKHKW - 1 ) / BS_ICKHKW;
262- const int64_t warpId = threadIdx .y ;
250+ const int NUM_IC_TILES = (P.IC_KH_KW + BS_ICKHKW - 1 ) / BS_ICKHKW;
251+ const int warpId = threadIdx .y ;
263252
264- const int64_t WARPS_PER_NOHOW = max (1 , BS_NOHOW / WMMA_N);
265- const int64_t total_warps_need = (((BS_OC * BS_NOHOW) + (WMMA_M * WMMA_N) - 1 ) / (WMMA_M * WMMA_N));
266- const int64_t num_work_per_warps = (total_warps_need + num_warps - 1 ) / num_warps;
253+ const int WARPS_PER_NOHOW = max (1 , BS_NOHOW / WMMA_N);
254+ const int total_warps_need = (((BS_OC * BS_NOHOW) + (WMMA_M * WMMA_N) - 1 ) / (WMMA_M * WMMA_N));
255+ const int num_work_per_warps = (total_warps_need + num_warps - 1 ) / num_warps;
267256
268257 mma acc[num_work_per_warps];
269258
270- const int64_t num_block_nohow = (P.N_OH_OW + BS_NOHOW - 1 ) / BS_NOHOW;
271- const int64_t BL_IDX_OC = blockIdx .x / num_block_nohow;
272- const int64_t BL_IDX_NOHOW = blockIdx .x % num_block_nohow;
259+ const int num_block_nohow = (P.N_OH_OW + BS_NOHOW - 1 ) / BS_NOHOW;
260+ const int BL_IDX_OC = blockIdx .x / num_block_nohow;
261+ const int BL_IDX_NOHOW = blockIdx .x % num_block_nohow;
273262
274- const int64_t BLOCK_OC_BASE = BL_IDX_OC * BS_OC;
275- const int64_t BLOCK_NOHOW_BASE = BL_IDX_NOHOW * BS_NOHOW;
263+ const int BLOCK_OC_BASE = BL_IDX_OC * BS_OC;
264+ const int BLOCK_NOHOW_BASE = BL_IDX_NOHOW * BS_NOHOW;
276265
277266 unsigned char * ptr = smem_raw;
278267
279- const int64_t A_total = BS_OC * BS_ICKHKW;
280- const int64_t B_total = BS_ICKHKW * BS_NOHOW;
268+ const int A_total = BS_OC * BS_ICKHKW;
269+ const int B_total = BS_ICKHKW * BS_NOHOW;
281270
282271 size_t offsetA = (size_t ) A_total * sizeof (T);
283272 T * A_sh = reinterpret_cast <T *>(ptr);
@@ -287,33 +276,33 @@ __global__ void conv2d_kernel(const float * IN, const T * IK, float * Out, const
287276 T * B_sh = reinterpret_cast <T *>(ptr);
288277 ptr += offsetB;
289278
290- int64_t ic, kh, kw;
291- int64_t n, oh, ow;
292- for (int64_t t = 0 ; t < NUM_IC_TILES; ++t) {
279+ int ic, kh, kw;
280+ int n, oh, ow;
281+ for (int t = 0 ; t < NUM_IC_TILES; ++t) {
293282#pragma unroll
294- for (int64_t tid = (threadIdx .y * blockDim .x + threadIdx .x ); tid < A_total; tid += (blockDim .x * blockDim .y )) {
283+ for (int tid = (threadIdx .y * blockDim .x + threadIdx .x ); tid < A_total; tid += (blockDim .x * blockDim .y )) {
295284 const int row = tid / BS_ICKHKW;
296285 const int col = tid % BS_ICKHKW;
297286
298- int64_t shared_oc = BLOCK_OC_BASE + row;
299- int64_t shared_ickhkw = t * BS_ICKHKW + col;
287+ int shared_oc = BLOCK_OC_BASE + row;
288+ int shared_ickhkw = t * BS_ICKHKW + col;
300289
301290 T val = ggml_cuda_cast<T>(0 );
302291 if (shared_oc < P.OC && shared_ickhkw < P.IC_KH_KW ) {
303292 layout::unpack_ickhkw (shared_ickhkw, ic, kh, kw, P);
304293
305- const int64_t kidx = layout::kernel_index (shared_oc, ic, kh, kw, P);
306- val = IK[kidx];
294+ const int kidx = layout::kernel_index (shared_oc, ic, kh, kw, P);
295+ val = IK[kidx];
307296 }
308297 A_sh[row * BS_ICKHKW + col] = val;
309298 }
310299#pragma unroll
311- for (int64_t tid = (threadIdx .y * blockDim .x + threadIdx .x ); tid < B_total; tid += (blockDim .x * blockDim .y )) {
300+ for (int tid = (threadIdx .y * blockDim .x + threadIdx .x ); tid < B_total; tid += (blockDim .x * blockDim .y )) {
312301 const int brow = tid / BS_NOHOW;
313302 const int bcol = tid % BS_NOHOW;
314303
315- int64_t IC_KH_KW_IDX = t * BS_ICKHKW + brow;
316- int64_t N_OH_OW_IDX = BLOCK_NOHOW_BASE + bcol;
304+ int IC_KH_KW_IDX = t * BS_ICKHKW + brow;
305+ int N_OH_OW_IDX = BLOCK_NOHOW_BASE + bcol;
317306
318307 T val = ggml_cuda_cast<T>(0 );
319308 if (N_OH_OW_IDX < P.N_OH_OW && IC_KH_KW_IDX < P.IC_KH_KW ) {
@@ -333,10 +322,10 @@ __global__ void conv2d_kernel(const float * IN, const T * IK, float * Out, const
333322
334323#pragma unroll
335324 for (int warp = warpId, i = 0 ; warp < total_warps_need; warp += num_warps, i++) {
336- const int64_t WARP_OC = warp / WARPS_PER_NOHOW;
337- const int64_t WARP_NOHOW = warp % WARPS_PER_NOHOW;
338- const T * A_warp_base = A_sh + WARP_OC * WMMA_M * BS_ICKHKW;
339- const T * B_warp_base = B_sh + WARP_NOHOW * WMMA_N;
325+ const int WARP_OC = warp / WARPS_PER_NOHOW;
326+ const int WARP_NOHOW = warp % WARPS_PER_NOHOW;
327+ const T * A_warp_base = A_sh + WARP_OC * WMMA_M * BS_ICKHKW;
328+ const T * B_warp_base = B_sh + WARP_NOHOW * WMMA_N;
340329#pragma unroll
341330 for (int k_tile = 0 ; k_tile < BS_ICKHKW; k_tile += WMMA_K) {
342331 const T * A_k_ptr = A_warp_base + k_tile;
@@ -349,10 +338,10 @@ __global__ void conv2d_kernel(const float * IN, const T * IK, float * Out, const
349338
350339#pragma unroll
351340 for (int warp = warpId, i = 0 ; warp < total_warps_need; warp += num_warps, i++) {
352- const int64_t WARP_OC = warp / WARPS_PER_NOHOW;
353- const int64_t WARP_NOHOW = warp % WARPS_PER_NOHOW;
354- const int64_t OC_BASE = BLOCK_OC_BASE + WARP_OC * WMMA_M;
355- const int64_t NOHOW_BASE = BLOCK_NOHOW_BASE + WARP_NOHOW * WMMA_N;
341+ const int WARP_OC = warp / WARPS_PER_NOHOW;
342+ const int WARP_NOHOW = warp % WARPS_PER_NOHOW;
343+ const int OC_BASE = BLOCK_OC_BASE + WARP_OC * WMMA_M;
344+ const int NOHOW_BASE = BLOCK_NOHOW_BASE + WARP_NOHOW * WMMA_N;
356345 acc[i].store_result (OC_BASE, NOHOW_BASE, Out, P);
357346 }
358347}
@@ -454,8 +443,8 @@ void ggml_cuda_op_conv2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
454443 const int B = input->ne [3 ]; // n_batches
455444
456445 const int64_t TOTAL = B * OC * OH * OW;
457- const int64_t IC_KH_KW = IC * KH * KW;
458- const int64_t N_OH_OW = B * OH * OW;
446+ const int IC_KH_KW = IC * KH * KW;
447+ const int N_OH_OW = B * OH * OW;
459448 conv_params params = { IW, IH, OW, OH, KW, KH, ST_X, ST_Y, PD_X,
460449 PD_Y, DL_X, DL_Y, IC, OC, B, TOTAL, IC_KH_KW, N_OH_OW };
461450
0 commit comments