@@ -4489,6 +4489,13 @@ static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
4489
4489
*dsti = __float2half (*xi);
4490
4490
}
4491
4491
4492
+ static __device__ void cpy_1_f16_f16 (const char * cxi, char * cdsti) {
4493
+ const half * xi = (const half *) cxi;
4494
+ half * dsti = (half *) cdsti;
4495
+
4496
+ *dsti = *xi;
4497
+ }
4498
+
4492
4499
template <cpy_kernel_t cpy_1>
4493
4500
static __global__ void cpy_f32_f16 (const char * cx, char * cdst, const int ne,
4494
4501
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
@@ -4742,6 +4749,25 @@ static __global__ void clamp_f32(const float * x, float * dst, const float min,
4742
4749
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
4743
4750
}
4744
4751
4752
+ static __global__ void im2col_f32_f16 (
4753
+ const float * x, half * dst,
4754
+ int ofs0, int ofs1, int IW, int IH, int CHW,
4755
+ int s0, int s1, int p0, int p1, int d0, int d1) {
4756
+ const int iiw = blockIdx .z * s0 + threadIdx .z * d0 - p0;
4757
+ const int iih = blockIdx .y * s1 + threadIdx .y * d1 - p1;
4758
+
4759
+ const int offset_dst =
4760
+ (threadIdx .x * gridDim .y * gridDim .z + blockIdx .y * gridDim .z + blockIdx .z ) * CHW +
4761
+ (blockIdx .x * (blockDim .y * blockDim .z ) + threadIdx .y * blockDim .z + threadIdx .z );
4762
+
4763
+ if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
4764
+ dst[offset_dst] = __float2half (0 .0f );
4765
+ } else {
4766
+ const int offset_src = threadIdx .x * ofs0 + blockIdx .x * ofs1;
4767
+ dst[offset_dst] = __float2half (x[offset_src + iih * IW + iiw]);
4768
+ }
4769
+ }
4770
+
4745
4771
template <int qk, int qr, dequantize_kernel_t dq>
4746
4772
static void get_rows_cuda (const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
4747
4773
const dim3 block_dims (CUDA_GET_ROWS_BLOCK_SIZE, 1 , 1 );
@@ -5642,6 +5668,16 @@ static void ggml_cpy_f32_f16_cuda(
5642
5668
(cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
5643
5669
}
5644
5670
5671
+ static void ggml_cpy_f16_f16_cuda (
5672
+ const char * cx, char * cdst, const int ne,
5673
+ const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
5674
+ const int ne10, const int ne11, const int nb10, const int nb11, const int nb12, cudaStream_t stream) {
5675
+
5676
+ const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1 ) / CUDA_CPY_BLOCK_SIZE;
5677
+ cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0 , stream>>>
5678
+ (cx, cdst, ne, ne00, ne01, nb00, nb01, nb02, ne10, ne11, nb10, nb11, nb12);
5679
+ }
5680
+
5645
5681
static void scale_f32_cuda (const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
5646
5682
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1 ) / CUDA_SCALE_BLOCK_SIZE;
5647
5683
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0 , stream>>> (x, dst, scale, k);
@@ -5725,6 +5761,15 @@ static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, c
5725
5761
soft_max_f32<<<block_nums, block_dims, 0 , stream>>> (x, dst, ncols_x);
5726
5762
}
5727
5763
5764
+ static void im2col_f32_f16_cuda (const float * x, half * dst,
5765
+ int OH, int IW, int IH, int OW, int IC,
5766
+ int KH, int KW, int N, int ofs0, int ofs1,
5767
+ int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) {
5768
+ dim3 block_nums (IC, OH, OW);
5769
+ dim3 block_dims (N, KH, KW);
5770
+ im2col_f32_f16<<<block_nums, block_dims, 0 , stream>>> (x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
5771
+ }
5772
+
5728
5773
// buffer pool for cuda
5729
5774
#define MAX_CUDA_BUFFERS 256
5730
5775
@@ -6522,8 +6567,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
6522
6567
src1_as_f16 = (half *) ggml_cuda_pool_malloc (ne * sizeof (half), &src1_as);
6523
6568
to_fp16_cuda (src1_ddf_i, src1_as_f16, ne, stream);
6524
6569
}
6525
- const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
6526
-
6570
+ const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16;
6527
6571
size_t dst_as = 0 ;
6528
6572
half * dst_f16 = (half *) ggml_cuda_pool_malloc (row_diff*src1_ncols * sizeof (half), &dst_as);
6529
6573
@@ -6698,6 +6742,45 @@ inline void ggml_cuda_op_alibi(
6698
6742
(void ) src1_dd;
6699
6743
}
6700
6744
6745
+ inline void ggml_cuda_op_im2col (
6746
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6747
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
6748
+
6749
+ GGML_ASSERT (src0->type == GGML_TYPE_F16);
6750
+ GGML_ASSERT (src1->type == GGML_TYPE_F32);
6751
+ GGML_ASSERT ( dst->type == GGML_TYPE_F16);
6752
+
6753
+ const int32_t s0 = ((const int32_t *)(dst->op_params ))[0 ];
6754
+ const int32_t s1 = ((const int32_t *)(dst->op_params ))[1 ];
6755
+ const int32_t p0 = ((const int32_t *)(dst->op_params ))[2 ];
6756
+ const int32_t p1 = ((const int32_t *)(dst->op_params ))[3 ];
6757
+ const int32_t d0 = ((const int32_t *)(dst->op_params ))[4 ];
6758
+ const int32_t d1 = ((const int32_t *)(dst->op_params ))[5 ];
6759
+
6760
+ const bool is_2D = ((const int32_t *)(dst->op_params ))[6 ] == 1 ;
6761
+
6762
+ const int64_t N = src1->ne [is_2D ? 3 : 2 ];
6763
+ const int64_t IC = src1->ne [is_2D ? 2 : 1 ];
6764
+ const int64_t IH = is_2D ? src1->ne [1 ] : 1 ;
6765
+ const int64_t IW = src1->ne [0 ];
6766
+
6767
+ const int64_t KH = is_2D ? src0->ne [1 ] : 1 ;
6768
+ const int64_t KW = src0->ne [0 ];
6769
+
6770
+ const int64_t OH = is_2D ? dst->ne [2 ] : 1 ;
6771
+ const int64_t OW = dst->ne [1 ];
6772
+
6773
+ const size_t ofs0 = src1->nb [is_2D ? 3 : 2 ] / 4 ; // nb is byte offset, src is type float32
6774
+ const size_t ofs1 = src1->nb [is_2D ? 2 : 1 ] / 4 ; // nb is byte offset, src is type float32
6775
+
6776
+ im2col_f32_f16_cuda (src1_dd, (half*) dst_dd,
6777
+ OH, IW, IH, OW, IC, KH, KW, N,
6778
+ ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream);
6779
+
6780
+ (void ) src0;
6781
+ (void ) src0_dd;
6782
+ }
6783
+
6701
6784
inline void ggml_cuda_op_diag_mask_inf (
6702
6785
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
6703
6786
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
@@ -7610,6 +7693,9 @@ static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, gg
7610
7693
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
7611
7694
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
7612
7695
ne10, ne11, nb10, nb11, nb12, main_stream);
7696
+ } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
7697
+ ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, nb00, nb01, nb02,
7698
+ ne10, ne11, nb10, nb11, nb12, main_stream);
7613
7699
} else {
7614
7700
fprintf (stderr, " %s: unsupported type combination (%s to %s)\n " , __func__,
7615
7701
ggml_type_name (src0->type ), ggml_type_name (src1->type ));
@@ -7641,6 +7727,10 @@ static void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1,
7641
7727
ggml_cuda_op_flatten (src0, src1, dst, ggml_cuda_op_alibi);
7642
7728
}
7643
7729
7730
+ void ggml_cuda_im2col (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7731
+ ggml_cuda_op_flatten (src0, src1, dst, ggml_cuda_op_im2col);
7732
+ }
7733
+
7644
7734
static void ggml_cuda_nop (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
7645
7735
(void ) src0;
7646
7736
(void ) src1;
@@ -7934,6 +8024,15 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
7934
8024
return false ;
7935
8025
}
7936
8026
8027
+ if (tensor->op == GGML_OP_MUL_MAT) {
8028
+ if (tensor->src [0 ]->ne [3 ] != tensor->src [1 ]->ne [3 ]) {
8029
+ #ifndef NDEBUG
8030
+ fprintf (stderr, " %s: cannot compute %s: src0->ne[3] = %d, src1->ne[3] = %d - fallback to CPU\n " , __func__, tensor->name , tensor->src [0 ]->ne [3 ], tensor->src [1 ]->ne [3 ]);
8031
+ #endif
8032
+ return false ;
8033
+ }
8034
+ }
8035
+
7937
8036
switch (tensor->op ) {
7938
8037
case GGML_OP_REPEAT:
7939
8038
func = ggml_cuda_repeat;
@@ -8012,6 +8111,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
8012
8111
case GGML_OP_ALIBI:
8013
8112
func = ggml_cuda_alibi;
8014
8113
break ;
8114
+ case GGML_OP_IM2COL:
8115
+ func = ggml_cuda_im2col;
8116
+ break ;
8015
8117
default :
8016
8118
return false ;
8017
8119
}
0 commit comments