@@ -640,7 +640,7 @@ @implementation ggml_metal_heap_ptr
640640@end 
641641
642642// 
643- //  ggml_metal_mem_pool
643+ //  ggml_metal_mem_pool [TAG_MEM_POOL_REMOVE] 
644644// 
645645
646646struct  ggml_metal_mem_pool {
@@ -4112,6 +4112,14 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
41124112                        default : break ;
41134113                    }
41144114
4115+                     //  TODO: using mem pool allocations with enabled concurrency is not safe because the mem pool
4116+                     //  reuses buffers. this can result in 2 concurrent MUL_MAT_ID ops using the same mem pool buffer.
4117+                     //  so we add this extra barrier to prevent the race.
4118+                     //  the correct solution is to remove mem pools and then remove this barrier [TAG_MEM_POOL_REMOVE]
4119+                     if  (ctx_dev->use_concurrency ) {
4120+                         ggml_metal_encode_mem_ranges_reset (ctx_enc);
4121+                     }
4122+ 
41154123                    //  tokens per expert
41164124                    const  size_t  s_tpe = ggml_type_size (GGML_TYPE_I32)*ne02;
41174125                    id <MTLBuffer > h_tpe = ggml_metal_mem_pool_alloc (mem_pool, s_tpe);
@@ -4172,6 +4180,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
41724180                        [encoder dispatchThreadgroups: MTLSizeMake (1 , 1 , 1 ) threadsPerThreadgroup: MTLSizeMake (ne02, 1 , 1 )];
41734181                    }
41744182
4183+                     //  this barrier is always needed because the next kernel has to wait for the id maps to be computed
41754184                    if  (ctx_dev->use_concurrency ) {
41764185                        ggml_metal_encode_mem_ranges_reset (ctx_enc);
41774186                    }
@@ -5561,6 +5570,12 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
55615570                        GGML_ASSERT (ne01*ne02*ne03 == ne1*ne2*ne3);
55625571                        GGML_ASSERT (ne1*ne2*ne3 <= (1u  << 31 ));
55635572
5573+                         //  using mem pool allocations with enabled concurrency is not safe [TAG_MEM_POOL_REMOVE]
5574+                         //  still, we assume that concurrent FA won't happen before we do the refactor
5575+                         // if (ctx_dev->use_concurrency) {
5576+                         //     ggml_metal_encode_mem_ranges_reset(ctx_enc);
5577+                         // }
5578+ 
55645579                        const  int32_t  nrows = ne1*ne2*ne3;
55655580
55665581                        //  temp buffer for writing the results from each workgroup
@@ -5939,6 +5954,7 @@ static enum ggml_status ggml_metal_graph_compute(
59395954            //  cannot use commandBufferWithUnretainedReferences because the buffers from the memory pool can get destroyed
59405955            //  TODO: when the memory pools are removed, we can again use commandBufferWithUnretainedReferences
59415956            //        https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2334215009
5957+             //  [TAG_MEM_POOL_REMOVE]
59425958            // id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
59435959            id <MTLCommandBuffer > cmd_buf = [ctx->queue commandBuffer ];
59445960            [cmd_buf retain ];
0 commit comments