@@ -964,9 +964,9 @@ void ggml_metal_graph_compute(
964964 const int64_t nb = ne00;
965965
966966 [encoder setComputePipelineState: ctx->pipeline_concat];
967- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
968- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
969- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
967+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
968+ if (id_src1) [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
969+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
970970 [encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 3 ];
971971 [encoder setBytes: &ne01 length: sizeof (ne01) atIndex: 4 ];
972972 [encoder setBytes: &ne02 length: sizeof (ne02) atIndex: 5 ];
@@ -1029,9 +1029,9 @@ void ggml_metal_graph_compute(
10291029 default : GGML_ASSERT (false );
10301030 }
10311031 }
1032- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1033- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1034- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
1032+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1033+ if (id_src1) [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1034+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
10351035 [encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 3 ];
10361036 [encoder setBytes: &ne01 length: sizeof (ne01) atIndex: 4 ];
10371037 [encoder setBytes: &ne02 length: sizeof (ne02) atIndex: 5 ];
@@ -1083,8 +1083,8 @@ void ggml_metal_graph_compute(
10831083 [encoder setComputePipelineState: ctx->pipeline_scale];
10841084 }
10851085
1086- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1087- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1086+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1087+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
10881088 [encoder setBytes: &scale length: sizeof (scale) atIndex: 2 ];
10891089
10901090 [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
@@ -1094,8 +1094,8 @@ void ggml_metal_graph_compute(
10941094 case GGML_UNARY_OP_SILU:
10951095 {
10961096 [encoder setComputePipelineState: ctx->pipeline_silu];
1097- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1098- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1097+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1098+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
10991099
11001100 const int64_t n = ggml_nelements (dst);
11011101 GGML_ASSERT (n % 4 == 0 );
@@ -1105,8 +1105,8 @@ void ggml_metal_graph_compute(
11051105 case GGML_UNARY_OP_RELU:
11061106 {
11071107 [encoder setComputePipelineState: ctx->pipeline_relu];
1108- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1109- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1108+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1109+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
11101110
11111111 const int64_t n = ggml_nelements (dst);
11121112
@@ -1115,8 +1115,8 @@ void ggml_metal_graph_compute(
11151115 case GGML_UNARY_OP_GELU:
11161116 {
11171117 [encoder setComputePipelineState: ctx->pipeline_gelu];
1118- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1119- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1118+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1119+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
11201120
11211121 const int64_t n = ggml_nelements (dst);
11221122 GGML_ASSERT (n % 4 == 0 );
@@ -1134,8 +1134,8 @@ void ggml_metal_graph_compute(
11341134 GGML_ASSERT (ggml_is_contiguous (src0));
11351135
11361136 [encoder setComputePipelineState: ctx->pipeline_sqr];
1137- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1138- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1137+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1138+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
11391139
11401140 const int64_t n = ggml_nelements (dst);
11411141 [encoder dispatchThreadgroups: MTLSizeMake (n, 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (1 , 1 , 1 )];
@@ -1145,8 +1145,8 @@ void ggml_metal_graph_compute(
11451145 GGML_ASSERT (src0->nb [0 ] == ggml_type_size (src0->type ));
11461146
11471147 [encoder setComputePipelineState: ctx->pipeline_sum_rows];
1148- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1149- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1148+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1149+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
11501150 [encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 2 ];
11511151 [encoder setBytes: &ne01 length: sizeof (ne01) atIndex: 3 ];
11521152 [encoder setBytes: &ne02 length: sizeof (ne02) atIndex: 4 ];
@@ -1192,9 +1192,9 @@ void ggml_metal_graph_compute(
11921192
11931193 const float scale = ((float *) dst->op_params )[0 ];
11941194
1195- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1196- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1197- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
1195+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1196+ if (id_src1) [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1197+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
11981198 [encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 3 ];
11991199 [encoder setBytes: &ne01 length: sizeof (ne01) atIndex: 4 ];
12001200 [encoder setBytes: &ne02 length: sizeof (ne02) atIndex: 5 ];
@@ -1212,8 +1212,8 @@ void ggml_metal_graph_compute(
12121212 } else {
12131213 [encoder setComputePipelineState: ctx->pipeline_diag_mask_inf];
12141214 }
1215- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1216- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1215+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1216+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
12171217 [encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 2 ];
12181218 [encoder setBytes: &ne01 length: sizeof (ne01) atIndex: 3 ];
12191219 [encoder setBytes: &n_past length: sizeof (int ) atIndex: 4 ];
@@ -1286,9 +1286,9 @@ void ggml_metal_graph_compute(
12861286 case GGML_TYPE_Q6_K: [encoder setComputePipelineState: ctx->pipeline_mul_mm_q6_K_f32]; break ;
12871287 default : GGML_ASSERT (false && " MUL MAT-MAT not implemented" );
12881288 }
1289- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1290- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1291- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
1289+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1290+ if (id_src1) [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1291+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
12921292 [encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 3 ];
12931293 [encoder setBytes: &ne02 length: sizeof (ne02) atIndex: 4 ];
12941294 [encoder setBytes: &nb01 length: sizeof (nb01) atIndex: 5 ];
@@ -1403,9 +1403,9 @@ void ggml_metal_graph_compute(
14031403 }
14041404 };
14051405
1406- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1407- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1408- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
1406+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1407+ if (id_src1) [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1408+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
14091409 [encoder setBytes: &ne00 length: sizeof (ne00) atIndex: 3 ];
14101410 [encoder setBytes: &ne01 length: sizeof (ne01) atIndex: 4 ];
14111411 [encoder setBytes: &ne02 length: sizeof (ne02) atIndex: 5 ];
@@ -1511,9 +1511,9 @@ void ggml_metal_graph_compute(
15111511 case GGML_TYPE_Q6_K: [encoder setComputePipelineState: ctx->pipeline_mul_mm_id_q6_K_f32]; break ;
15121512 default : GGML_ASSERT (false && " MUL_MAT_ID not implemented" );
15131513 }
1514- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1515- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1516- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
1514+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1515+ if (id_src1) [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1516+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
15171517 [encoder setBytes: &ne20 length: sizeof (ne20) atIndex: 3 ];
15181518 [encoder setBytes: &ne22 length: sizeof (ne22) atIndex: 4 ];
15191519 [encoder setBytes: &nb21 length: sizeof (nb21) atIndex: 5 ];
@@ -1559,9 +1559,9 @@ void ggml_metal_graph_compute(
15591559 default : GGML_ASSERT (false && " not implemented" );
15601560 }
15611561
1562- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1563- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1564- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
1562+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1563+ if (id_src1) [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1564+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
15651565 [encoder setBytes: &ne00 length: sizeof ( int64_t ) atIndex: 3 ];
15661566 [encoder setBytes: &nb01 length: sizeof (uint64_t ) atIndex: 4 ];
15671567 [encoder setBytes: &nb1 length: sizeof (uint64_t ) atIndex: 5 ];
@@ -1584,8 +1584,8 @@ void ggml_metal_graph_compute(
15841584 }
15851585
15861586 [encoder setComputePipelineState: ctx->pipeline_rms_norm];
1587- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1588- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1587+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1588+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
15891589 [encoder setBytes: &ne00 length: sizeof ( int64_t ) atIndex: 2 ];
15901590 [encoder setBytes: &nb01 length: sizeof (uint64_t ) atIndex: 3 ];
15911591 [encoder setBytes: &eps length: sizeof ( float ) atIndex: 4 ];
@@ -1603,8 +1603,8 @@ void ggml_metal_graph_compute(
16031603 const int nth = MIN (256 , ne00);
16041604
16051605 [encoder setComputePipelineState: ctx->pipeline_norm];
1606- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1607- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1606+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1607+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
16081608 [encoder setBytes: &ne00 length: sizeof ( int64_t ) atIndex: 2 ];
16091609 [encoder setBytes: &nb01 length: sizeof (uint64_t ) atIndex: 3 ];
16101610 [encoder setBytes: &eps length: sizeof ( float ) atIndex: 4 ];
@@ -1630,8 +1630,8 @@ void ggml_metal_graph_compute(
16301630 const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_heads_log2_floor);
16311631
16321632 [encoder setComputePipelineState: ctx->pipeline_alibi_f32];
1633- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1634- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1633+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1634+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
16351635 [encoder setBytes: &ne00 length: sizeof ( int64_t ) atIndex: 2 ];
16361636 [encoder setBytes: &ne01 length: sizeof ( int64_t ) atIndex: 3 ];
16371637 [encoder setBytes: &ne02 length: sizeof ( int64_t ) atIndex: 4 ];
@@ -1680,9 +1680,9 @@ void ggml_metal_graph_compute(
16801680 default : GGML_ASSERT (false );
16811681 };
16821682
1683- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1684- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1685- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
1683+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1684+ if (id_src1) [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
1685+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
16861686 [encoder setBytes: &ne00 length: sizeof ( int64_t ) atIndex: 3 ];
16871687 [encoder setBytes: &ne01 length: sizeof ( int64_t ) atIndex: 4 ];
16881688 [encoder setBytes: &ne02 length: sizeof ( int64_t ) atIndex: 5 ];
@@ -1748,8 +1748,8 @@ void ggml_metal_graph_compute(
17481748 default : GGML_ASSERT (false );
17491749 };
17501750
1751- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 0 ];
1752- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1751+ if (id_src1) [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 0 ];
1752+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
17531753 [encoder setBytes: &ofs0 length: sizeof ( int32_t ) atIndex: 2 ];
17541754 [encoder setBytes: &ofs1 length: sizeof ( int32_t ) atIndex: 3 ];
17551755 [encoder setBytes: &IW length: sizeof ( int32_t ) atIndex: 4 ];
@@ -1779,8 +1779,8 @@ void ggml_metal_graph_compute(
17791779 default : GGML_ASSERT (false );
17801780 };
17811781
1782- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1783- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1782+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1783+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
17841784 [encoder setBytes: &ne00 length: sizeof ( int64_t ) atIndex: 2 ];
17851785
17861786 [encoder dispatchThreadgroups: MTLSizeMake (1 , nrows, 1 ) threadsPerThreadgroup: MTLSizeMake (ne00, 1 , 1 )];
@@ -1820,8 +1820,8 @@ void ggml_metal_graph_compute(
18201820 default : GGML_ASSERT (false && " not implemented" );
18211821 }
18221822
1823- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1824- [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
1823+ if (id_src0) [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
1824+ if (id_dst) [encoder setBuffer: id_dst offset: offs_dst atIndex: 1 ];
18251825 [encoder setBytes: &ne00 length: sizeof ( int64_t ) atIndex: 2 ];
18261826 [encoder setBytes: &ne01 length: sizeof ( int64_t ) atIndex: 3 ];
18271827 [encoder setBytes: &ne02 length: sizeof ( int64_t ) atIndex: 4 ];
0 commit comments