Skip to content

Commit cf32a9b

Browse files
authoredNov 17, 2024··
metal : refactor kernel args into structs (#10238)
* metal : add kernel arg structs (wip) * metal : fattn args ggml-ci * metal : cont + avoid potential int overflow [no ci] * metal : mul mat struct (wip) * cont : mul mat vec * cont : pass by reference * cont : args is first argument * cont : use char ptr * cont : shmem style * cont : thread counters style * cont : mul mm id ggml-ci * cont : int safety + register optimizations ggml-ci * metal : GGML_OP_CONCAT ggml-ci * metal : GGML_OP_ADD, GGML_OP_SUB, GGML_OP_MUL, GGML_OP_DIV * metal : GGML_OP_REPEAT * metal : GGML_OP_CPY * metal : GGML_OP_RMS_NORM * metal : GGML_OP_NORM * metal : add TODOs for rest of ops * ggml : add ggml-metal-impl.h ggml-ci
1 parent a431782 commit cf32a9b

File tree

5 files changed

+1949
-2596
lines changed

5 files changed

+1949
-2596
lines changed
 

‎Makefile

+4-1
Original file line numberDiff line numberDiff line change
@@ -906,16 +906,19 @@ endif # GGML_METAL
906906
ifdef GGML_METAL
907907
ggml/src/ggml-metal/ggml-metal.o: \
908908
ggml/src/ggml-metal/ggml-metal.m \
909+
ggml/src/ggml-metal/ggml-metal-impl.h \
909910
ggml/include/ggml-metal.h \
910911
ggml/include/ggml.h
911912
$(CC) $(CFLAGS) -c $< -o $@
912913

913914
ifdef GGML_METAL_EMBED_LIBRARY
914915
ggml/src/ggml-metal-embed.o: \
915916
ggml/src/ggml-metal/ggml-metal.metal \
917+
ggml/src/ggml-metal/ggml-metal-impl.h \
916918
ggml/src/ggml-common.h
917919
@echo "Embedding Metal library"
918-
@sed -e '/__embed_ggml-common.h__/r ggml/src/ggml-common.h' -e '/__embed_ggml-common.h__/d' < ggml/src/ggml-metal/ggml-metal.metal > ggml/src/ggml-metal/ggml-metal-embed.metal
920+
@sed -e '/__embed_ggml-common.h__/r ggml/src/ggml-common.h' -e '/__embed_ggml-common.h__/d' < ggml/src/ggml-metal/ggml-metal.metal > ggml/src/ggml-metal/ggml-metal-embed.metal.tmp
921+
@sed -e '/#include "ggml-metal-impl.h"/r ggml/src/ggml-metal/ggml-metal-impl.h' -e '/#include "ggml-metal-impl.h"/d' < ggml/src/ggml-metal/ggml-metal-embed.metal.tmp > ggml/src/ggml-metal/ggml-metal-embed.metal
919922
$(eval TEMP_ASSEMBLY=$(shell mktemp -d))
920923
@echo ".section __DATA, __ggml_metallib" > $(TEMP_ASSEMBLY)/ggml-metal-embed.s
921924
@echo ".globl _ggml_metallib_start" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s

‎ggml/src/ggml-metal/CMakeLists.txt

+11-7
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,10 @@ if (GGML_METAL_USE_BF16)
2525
add_compile_definitions(GGML_METAL_USE_BF16)
2626
endif()
2727

28-
# copy ggml-common.h and ggml-metal.metal to bin directory
29-
configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY)
30-
configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
28+
# copy metal files to bin directory
29+
configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY)
30+
configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY)
31+
configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY)
3132

3233
if (GGML_METAL_EMBED_LIBRARY)
3334
enable_language(ASM)
@@ -36,24 +37,27 @@ if (GGML_METAL_EMBED_LIBRARY)
3637

3738
set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h")
3839
set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal")
40+
set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h")
3941

4042
file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/autogenerated")
4143

4244
# merge ggml-common.h and ggml-metal.metal into a single file
43-
set(METALLIB_EMBED_ASM "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s")
44-
set(METALLIB_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal")
45+
set(METALLIB_EMBED_ASM "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s")
46+
set(METALLIB_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal")
47+
set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp")
4548

4649
add_custom_command(
4750
OUTPUT ${METALLIB_EMBED_ASM}
4851
COMMAND echo "Embedding Metal library"
49-
COMMAND sed -e '/__embed_ggml-common.h__/r ${METALLIB_COMMON}' -e '/__embed_ggml-common.h__/d' < ${METALLIB_SOURCE} > ${METALLIB_SOURCE_EMBED}
52+
COMMAND sed -e '/__embed_ggml-common.h__/r ${METALLIB_COMMON}' -e '/__embed_ggml-common.h__/d' < ${METALLIB_SOURCE} > ${METALLIB_SOURCE_EMBED_TMP}
53+
COMMAND sed -e '/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}' -e '/\#include \"ggml-metal-impl.h\"/d' < ${METALLIB_SOURCE_EMBED_TMP} > ${METALLIB_SOURCE_EMBED}
5054
COMMAND echo ".section __DATA,__ggml_metallib" > ${METALLIB_EMBED_ASM}
5155
COMMAND echo ".globl _ggml_metallib_start" >> ${METALLIB_EMBED_ASM}
5256
COMMAND echo "_ggml_metallib_start:" >> ${METALLIB_EMBED_ASM}
5357
COMMAND echo ".incbin \\\"${METALLIB_SOURCE_EMBED}\\\"" >> ${METALLIB_EMBED_ASM}
5458
COMMAND echo ".globl _ggml_metallib_end" >> ${METALLIB_EMBED_ASM}
5559
COMMAND echo "_ggml_metallib_end:" >> ${METALLIB_EMBED_ASM}
56-
DEPENDS ggml-metal.metal ../ggml-common.h
60+
DEPENDS ../ggml-common.h ggml-metal.metal ggml-metal-impl.h
5761
COMMENT "Generate assembly for embedded Metal library"
5862
)
5963

‎ggml/src/ggml-metal/ggml-metal-impl.h

+249
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,249 @@
1+
#ifndef GGML_METAL_IMPL
2+
#define GGML_METAL_IMPL
3+
4+
// kernel argument structs
5+
//
6+
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
7+
// however, be careful from int overflows when using those in the kernel implementation
8+
//
9+
// - strides (e.g. nb00) use uint64_t
10+
11+
typedef struct {
12+
int32_t ne00;
13+
int32_t ne01;
14+
int32_t ne02;
15+
int32_t ne03;
16+
uint64_t nb00;
17+
uint64_t nb01;
18+
uint64_t nb02;
19+
uint64_t nb03;
20+
int32_t ne10;
21+
int32_t ne11;
22+
int32_t ne12;
23+
int32_t ne13;
24+
uint64_t nb10;
25+
uint64_t nb11;
26+
uint64_t nb12;
27+
uint64_t nb13;
28+
int32_t ne0;
29+
int32_t ne1;
30+
int32_t ne2;
31+
int32_t ne3;
32+
uint64_t nb0;
33+
uint64_t nb1;
34+
uint64_t nb2;
35+
uint64_t nb3;
36+
int32_t dim;
37+
} ggml_metal_kargs_concat;
38+
39+
typedef struct {
40+
int32_t ne00;
41+
int32_t ne01;
42+
int32_t ne02;
43+
int32_t ne03;
44+
uint64_t nb00;
45+
uint64_t nb01;
46+
uint64_t nb02;
47+
uint64_t nb03;
48+
int32_t ne10;
49+
int32_t ne11;
50+
int32_t ne12;
51+
int32_t ne13;
52+
uint64_t nb10;
53+
uint64_t nb11;
54+
uint64_t nb12;
55+
uint64_t nb13;
56+
int32_t ne0;
57+
int32_t ne1;
58+
int32_t ne2;
59+
int32_t ne3;
60+
uint64_t nb0;
61+
uint64_t nb1;
62+
uint64_t nb2;
63+
uint64_t nb3;
64+
uint64_t offs;
65+
} ggml_metal_kargs_bin;
66+
67+
typedef struct {
68+
int32_t ne00;
69+
int32_t ne01;
70+
int32_t ne02;
71+
int32_t ne03;
72+
uint64_t nb00;
73+
uint64_t nb01;
74+
uint64_t nb02;
75+
uint64_t nb03;
76+
int32_t ne0;
77+
int32_t ne1;
78+
int32_t ne2;
79+
int32_t ne3;
80+
uint64_t nb0;
81+
uint64_t nb1;
82+
uint64_t nb2;
83+
uint64_t nb3;
84+
} ggml_metal_kargs_repeat;
85+
86+
typedef struct {
87+
int64_t ne00;
88+
int64_t ne01;
89+
int64_t ne02;
90+
int64_t ne03;
91+
uint64_t nb00;
92+
uint64_t nb01;
93+
uint64_t nb02;
94+
uint64_t nb03;
95+
int64_t ne0;
96+
int64_t ne1;
97+
int64_t ne2;
98+
int64_t ne3;
99+
uint64_t nb0;
100+
uint64_t nb1;
101+
uint64_t nb2;
102+
uint64_t nb3;
103+
} ggml_metal_kargs_cpy;
104+
105+
typedef struct {
106+
int32_t ne00;
107+
int32_t ne01;
108+
int32_t ne02;
109+
int32_t ne03;
110+
uint64_t nb00;
111+
uint64_t nb01;
112+
uint64_t nb02;
113+
uint64_t nb03;
114+
int32_t ne0;
115+
int32_t ne1;
116+
int32_t ne2;
117+
int32_t ne3;
118+
uint64_t nb0;
119+
uint64_t nb1;
120+
uint64_t nb2;
121+
uint64_t nb3;
122+
int32_t n_past;
123+
int32_t n_dims;
124+
int32_t n_ctx_orig;
125+
float freq_base;
126+
float freq_scale;
127+
float ext_factor;
128+
float attn_factor;
129+
float beta_fast;
130+
float beta_slow;
131+
} ggml_metal_kargs_rope;
132+
133+
typedef struct {
134+
int32_t ne01;
135+
int32_t ne02;
136+
int32_t ne03;
137+
uint64_t nb01;
138+
uint64_t nb02;
139+
uint64_t nb03;
140+
int32_t ne11;
141+
int32_t ne_12_2; // assume K and V are same shape
142+
int32_t ne_12_3;
143+
uint64_t nb_12_1;
144+
uint64_t nb_12_2;
145+
uint64_t nb_12_3;
146+
uint64_t nb31;
147+
int32_t ne1;
148+
int32_t ne2;
149+
float scale;
150+
float max_bias;
151+
float m0;
152+
float m1;
153+
uint16_t n_head_log2;
154+
float logit_softcap;
155+
} ggml_metal_kargs_flash_attn_ext;
156+
157+
typedef struct {
158+
int32_t ne00;
159+
int32_t ne02;
160+
uint64_t nb01;
161+
uint64_t nb02;
162+
uint64_t nb03;
163+
int32_t ne12;
164+
uint64_t nb10;
165+
uint64_t nb11;
166+
uint64_t nb12;
167+
uint64_t nb13;
168+
int32_t ne0;
169+
int32_t ne1;
170+
int16_t r2;
171+
int16_t r3;
172+
} ggml_metal_kargs_mul_mm;
173+
174+
typedef struct {
175+
int32_t ne00;
176+
int32_t ne01;
177+
int32_t ne02;
178+
uint64_t nb00;
179+
uint64_t nb01;
180+
uint64_t nb02;
181+
uint64_t nb03;
182+
int32_t ne10;
183+
int32_t ne11;
184+
int32_t ne12;
185+
uint64_t nb10;
186+
uint64_t nb11;
187+
uint64_t nb12;
188+
uint64_t nb13;
189+
int32_t ne0;
190+
int32_t ne1;
191+
int16_t r2;
192+
int16_t r3;
193+
} ggml_metal_kargs_mul_mv;
194+
195+
typedef struct {
196+
int32_t nei0;
197+
int32_t nei1;
198+
uint64_t nbi1;
199+
int32_t ne00;
200+
int32_t ne02;
201+
uint64_t nb01;
202+
uint64_t nb02;
203+
int32_t ne11;
204+
int32_t ne12;
205+
int32_t ne13;
206+
uint64_t nb10;
207+
uint64_t nb11;
208+
uint64_t nb12;
209+
int32_t ne0;
210+
int32_t ne1;
211+
} ggml_metal_kargs_mul_mm_id;
212+
213+
typedef struct {
214+
int32_t nei0;
215+
int32_t nei1;
216+
uint64_t nbi1;
217+
int32_t ne00;
218+
int32_t ne01;
219+
int32_t ne02;
220+
uint64_t nb00;
221+
uint64_t nb01;
222+
uint64_t nb02;
223+
int32_t ne10;
224+
int32_t ne11;
225+
int32_t ne12;
226+
int32_t ne13;
227+
uint64_t nb10;
228+
uint64_t nb11;
229+
uint64_t nb12;
230+
int32_t ne0;
231+
int32_t ne1;
232+
uint64_t nb1;
233+
} ggml_metal_kargs_mul_mv_id;
234+
235+
typedef struct {
236+
int32_t ne00;
237+
int32_t ne00_4;
238+
uint64_t nb01;
239+
float eps;
240+
} ggml_metal_kargs_norm;
241+
242+
typedef struct {
243+
int32_t ne00;
244+
int32_t ne00_4;
245+
uint64_t nb01;
246+
float eps;
247+
} ggml_metal_kargs_rms_norm;
248+
249+
#endif // GGML_METAL_IMPL

‎ggml/src/ggml-metal/ggml-metal.m

+382-299
Large diffs are not rendered by default.

‎ggml/src/ggml-metal/ggml-metal.metal

+1,303-2,289
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)
Please sign in to comment.