135135 GGML_METAL_KERNEL_TYPE_ROPE_F16,
136136 GGML_METAL_KERNEL_TYPE_ALIBI_F32,
137137 GGML_METAL_KERNEL_TYPE_IM2COL_F16,
138+ GGML_METAL_KERNEL_TYPE_IM2COL_F32,
138139 GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
139140 GGML_METAL_KERNEL_TYPE_PAD_F32,
140141 GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
@@ -506,6 +507,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
506507 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true );
507508 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true );
508509 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true );
510+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true );
509511 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true );
510512 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true );
511513 GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true );
@@ -630,6 +632,10 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
630632 case GGML_OP_ALIBI:
631633 case GGML_OP_ROPE:
632634 case GGML_OP_IM2COL:
635+ return true ;
636+ case GGML_OP_POOL_1D:
637+ case GGML_OP_POOL_2D:
638+ return false ;
633639 case GGML_OP_UPSCALE:
634640 case GGML_OP_PAD:
635641 case GGML_OP_ARGSORT:
@@ -2015,14 +2021,15 @@ static bool ggml_metal_graph_compute(
20152021 {
20162022 GGML_ASSERT (src0->type == GGML_TYPE_F16);
20172023 GGML_ASSERT (src1->type == GGML_TYPE_F32);
2018- GGML_ASSERT ( dst->type == GGML_TYPE_F16);
2024+ GGML_ASSERT ( dst->type == GGML_TYPE_F16 || dst-> type == GGML_TYPE_F32 );
20192025
20202026 const int32_t s0 = ((const int32_t *)(dst->op_params ))[0 ];
20212027 const int32_t s1 = ((const int32_t *)(dst->op_params ))[1 ];
20222028 const int32_t p0 = ((const int32_t *)(dst->op_params ))[2 ];
20232029 const int32_t p1 = ((const int32_t *)(dst->op_params ))[3 ];
20242030 const int32_t d0 = ((const int32_t *)(dst->op_params ))[4 ];
20252031 const int32_t d1 = ((const int32_t *)(dst->op_params ))[5 ];
2032+
20262033 const bool is_2D = ((const int32_t *)(dst->op_params ))[6 ] == 1 ;
20272034
20282035 const int32_t N = src1->ne [is_2D ? 3 : 2 ];
@@ -2043,8 +2050,8 @@ static bool ggml_metal_graph_compute(
20432050
20442051 id <MTLComputePipelineState > pipeline = nil ;
20452052
2046- switch (src0 ->type ) {
2047- case GGML_TYPE_F32: GGML_ASSERT ( false && " not implemented " ) ; break ;
2053+ switch (dst ->type ) {
2054+ case GGML_TYPE_F32: pipeline = ctx-> kernels [GGML_METAL_KERNEL_TYPE_IM2COL_F32]. pipeline ; break ;
20482055 case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline ; break ;
20492056 default : GGML_ASSERT (false );
20502057 };
0 commit comments