62
62
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS,
63
63
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS,
64
64
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S,
65
+ GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S,
65
66
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S,
66
67
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
67
68
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
87
88
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,
88
89
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,
89
90
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,
91
+ GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,
90
92
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,
91
93
GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,
92
94
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,
108
110
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,
109
111
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,
110
112
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,
113
+ GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,
111
114
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,
112
115
GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,
113
116
GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,
126
129
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,
127
130
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,
128
131
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,
132
+ GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,
129
133
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,
130
134
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
131
135
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,
144
148
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,
145
149
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,
146
150
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,
151
+ GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,
147
152
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,
148
153
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,
149
154
GGML_METAL_KERNEL_TYPE_ROPE_F32,
@@ -458,6 +463,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
458
463
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true );
459
464
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true );
460
465
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S, get_rows_iq3_s, true );
466
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S, get_rows_iq2_s, true );
461
467
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S, get_rows_iq1_s, true );
462
468
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true );
463
469
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true );
@@ -483,6 +489,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
483
489
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction );
484
490
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction );
485
491
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32, mul_mv_iq3_s_f32, ctx->support_simdgroup_reduction );
492
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32, mul_mv_iq2_s_f32, ctx->support_simdgroup_reduction );
486
493
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32, mul_mv_iq1_s_f32, ctx->support_simdgroup_reduction );
487
494
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, mul_mv_iq4_nl_f32, ctx->support_simdgroup_reduction );
488
495
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction );
@@ -504,6 +511,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
504
511
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction );
505
512
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction );
506
513
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32, mul_mv_id_iq3_s_f32, ctx->support_simdgroup_reduction );
514
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32, mul_mv_id_iq2_s_f32, ctx->support_simdgroup_reduction );
507
515
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32, mul_mv_id_iq1_s_f32, ctx->support_simdgroup_reduction );
508
516
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32, mul_mv_id_iq4_nl_f32, ctx->support_simdgroup_reduction );
509
517
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm );
@@ -522,6 +530,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
522
530
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm );
523
531
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm );
524
532
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32, mul_mm_iq3_s_f32, ctx->support_simdgroup_mm );
533
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32, mul_mm_iq2_s_f32, ctx->support_simdgroup_mm );
525
534
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32, mul_mm_iq1_s_f32, ctx->support_simdgroup_mm );
526
535
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, ctx->support_simdgroup_mm );
527
536
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm );
@@ -540,6 +549,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
540
549
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm );
541
550
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm );
542
551
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32, mul_mm_id_iq3_s_f32, ctx->support_simdgroup_mm );
552
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32, mul_mm_id_iq2_s_f32, ctx->support_simdgroup_mm );
543
553
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32, mul_mm_id_iq1_s_f32, ctx->support_simdgroup_mm );
544
554
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32, mul_mm_id_iq4_nl_f32, ctx->support_simdgroup_mm );
545
555
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true );
@@ -1358,6 +1368,7 @@ static bool ggml_metal_graph_compute(
1358
1368
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline ; break ;
1359
1369
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline ; break ;
1360
1370
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32 ].pipeline ; break ;
1371
+ case GGML_TYPE_IQ2_S: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32 ].pipeline ; break ;
1361
1372
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32 ].pipeline ; break ;
1362
1373
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline ; break ;
1363
1374
default : GGML_ASSERT (false && " MUL MAT-MAT not implemented" );
@@ -1500,6 +1511,12 @@ static bool ggml_metal_graph_compute(
1500
1511
nth1 = 16 ;
1501
1512
pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline ;
1502
1513
} break ;
1514
+ case GGML_TYPE_IQ2_S:
1515
+ {
1516
+ nth0 = 4 ;
1517
+ nth1 = 16 ;
1518
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline ;
1519
+ } break ;
1503
1520
case GGML_TYPE_IQ1_S:
1504
1521
{
1505
1522
nth0 = 4 ;
@@ -1544,9 +1561,9 @@ static bool ggml_metal_graph_compute(
1544
1561
[encoder setBytes: &r2 length: sizeof (r2) atIndex: 17 ];
1545
1562
[encoder setBytes: &r3 length: sizeof (r3) atIndex: 18 ];
1546
1563
1547
- if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1548
- src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1549
- src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S) { // || src0t == GGML_TYPE_Q4_K ) {
1564
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
1565
+ src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
1566
+ src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ2_S ) {
1550
1567
[encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 7 )/8 , ne11, ne12*ne13) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
1551
1568
}
1552
1569
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -1658,6 +1675,7 @@ static bool ggml_metal_graph_compute(
1658
1675
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32 ].pipeline ; break ;
1659
1676
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32].pipeline ; break ;
1660
1677
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32 ].pipeline ; break ;
1678
+ case GGML_TYPE_IQ2_S: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32 ].pipeline ; break ;
1661
1679
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32 ].pipeline ; break ;
1662
1680
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32 ].pipeline ; break ;
1663
1681
default : GGML_ASSERT (false && " MUL_MAT_ID not implemented" );
@@ -1803,6 +1821,12 @@ static bool ggml_metal_graph_compute(
1803
1821
nth1 = 16 ;
1804
1822
pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline ;
1805
1823
} break ;
1824
+ case GGML_TYPE_IQ2_S:
1825
+ {
1826
+ nth0 = 4 ;
1827
+ nth1 = 16 ;
1828
+ pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline ;
1829
+ } break ;
1806
1830
case GGML_TYPE_IQ1_S:
1807
1831
{
1808
1832
nth0 = 4 ;
@@ -1863,9 +1887,9 @@ static bool ggml_metal_graph_compute(
1863
1887
[encoder setBuffer: id_src_cur offset: offs_src_cur atIndex: 23 + j];
1864
1888
}
1865
1889
1866
- if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1867
- src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1868
- src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S) { // || src2t == GGML_TYPE_Q4_K ) {
1890
+ if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
1891
+ src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
1892
+ src2t == GGML_TYPE_Q2_K || src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ2_S ) {
1869
1893
[encoder dispatchThreadgroups: MTLSizeMake ((ne21 + 7 )/8 , _ne1, ne01*ne12*ne13) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
1870
1894
}
1871
1895
else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
@@ -1925,6 +1949,7 @@ static bool ggml_metal_graph_compute(
1925
1949
case GGML_TYPE_IQ2_XS: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS ].pipeline ; break ;
1926
1950
case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS].pipeline ; break ;
1927
1951
case GGML_TYPE_IQ3_S: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_S ].pipeline ; break ;
1952
+ case GGML_TYPE_IQ2_S: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_S ].pipeline ; break ;
1928
1953
case GGML_TYPE_IQ1_S: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ1_S ].pipeline ; break ;
1929
1954
case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL ].pipeline ; break ;
1930
1955
case GGML_TYPE_I32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_GET_ROWS_I32 ].pipeline ; break ;
0 commit comments