Skip to content

Commit efb7bdb

Browse files
committed
metal : add im2col F32 dst support (#5132)
1 parent 1560630 commit efb7bdb

File tree

2 files changed

+39
-7
lines changed

2 files changed

+39
-7
lines changed

ggml-metal.m

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@
135135
GGML_METAL_KERNEL_TYPE_ROPE_F16,
136136
GGML_METAL_KERNEL_TYPE_ALIBI_F32,
137137
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
138+
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
138139
GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
139140
GGML_METAL_KERNEL_TYPE_PAD_F32,
140141
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
@@ -506,6 +507,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
506507
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
507508
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
508509
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
510+
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
509511
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
510512
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
511513
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
@@ -630,6 +632,10 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
630632
case GGML_OP_ALIBI:
631633
case GGML_OP_ROPE:
632634
case GGML_OP_IM2COL:
635+
return true;
636+
case GGML_OP_POOL_1D:
637+
case GGML_OP_POOL_2D:
638+
return false;
633639
case GGML_OP_UPSCALE:
634640
case GGML_OP_PAD:
635641
case GGML_OP_ARGSORT:
@@ -2015,14 +2021,15 @@ static bool ggml_metal_graph_compute(
20152021
{
20162022
GGML_ASSERT(src0->type == GGML_TYPE_F16);
20172023
GGML_ASSERT(src1->type == GGML_TYPE_F32);
2018-
GGML_ASSERT( dst->type == GGML_TYPE_F16);
2024+
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
20192025

20202026
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
20212027
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
20222028
const int32_t p0 = ((const int32_t *)(dst->op_params))[2];
20232029
const int32_t p1 = ((const int32_t *)(dst->op_params))[3];
20242030
const int32_t d0 = ((const int32_t *)(dst->op_params))[4];
20252031
const int32_t d1 = ((const int32_t *)(dst->op_params))[5];
2032+
20262033
const bool is_2D = ((const int32_t *)(dst->op_params))[6] == 1;
20272034

20282035
const int32_t N = src1->ne[is_2D ? 3 : 2];
@@ -2043,8 +2050,8 @@ static bool ggml_metal_graph_compute(
20432050

20442051
id<MTLComputePipelineState> pipeline = nil;
20452052

2046-
switch (src0->type) {
2047-
case GGML_TYPE_F32: GGML_ASSERT(false && "not implemented"); break;
2053+
switch (dst->type) {
2054+
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
20482055
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
20492056
default: GGML_ASSERT(false);
20502057
};

ggml-metal.metal

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1775,9 +1775,29 @@ kernel void kernel_rope(
17751775
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>;
17761776
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>;
17771777

1778-
kernel void kernel_im2col_f16(
1778+
typedef void (im2col_t)(
17791779
device const float * x,
1780-
device half * dst,
1780+
device char * dst,
1781+
constant int32_t & ofs0,
1782+
constant int32_t & ofs1,
1783+
constant int32_t & IW,
1784+
constant int32_t & IH,
1785+
constant int32_t & CHW,
1786+
constant int32_t & s0,
1787+
constant int32_t & s1,
1788+
constant int32_t & p0,
1789+
constant int32_t & p1,
1790+
constant int32_t & d0,
1791+
constant int32_t & d1,
1792+
uint3 tgpig[[threadgroup_position_in_grid]],
1793+
uint3 tgpg[[threadgroups_per_grid]],
1794+
uint3 tpitg[[thread_position_in_threadgroup]],
1795+
uint3 ntg[[threads_per_threadgroup]]);
1796+
1797+
template <typename T>
1798+
kernel void kernel_im2col(
1799+
device const float * x,
1800+
device char * dst,
17811801
constant int32_t & ofs0,
17821802
constant int32_t & ofs1,
17831803
constant int32_t & IW,
@@ -1800,14 +1820,19 @@ kernel void kernel_im2col_f16(
18001820
(tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
18011821
(tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
18021822

1823+
device T * pdst = (device T *) (dst);
1824+
18031825
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
1804-
dst[offset_dst] = 0.0f;
1826+
pdst[offset_dst] = 0.0f;
18051827
} else {
18061828
const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
1807-
dst[offset_dst] = x[offset_src + iih * IW + iiw];
1829+
pdst[offset_dst] = x[offset_src + iih * IW + iiw];
18081830
}
18091831
}
18101832

1833+
template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col<float>;
1834+
template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col<half>;
1835+
18111836
kernel void kernel_upscale_f32(
18121837
device const char * src0,
18131838
device char * dst,

0 commit comments

Comments
 (0)