@@ -8204,12 +8204,12 @@ kernel void kernel_mul_mm(
82048204 mc[i] = make_filled_simdgroup_matrix<float , 8 >(0 .f );
82058205 }
82068206#else
8207- auto tA = tensor<threadgroup S0, dextents<int32_t , 2 >, tensor_inline>(sa, dextents<int32_t , 2 >(NK, NR0));
8208- auto tB = tensor<threadgroup S1, dextents<int32_t , 2 >, tensor_inline>(sb, dextents<int32_t , 2 >(NR1, NK ));
8207+ auto tA = tensor<threadgroup S0, dextents<int32_t , 2 >, tensor_inline>(sa, dextents<int32_t , 2 >(NK, NR0));
8208+ auto tB = tensor<threadgroup S1, dextents<int32_t , 2 >, tensor_inline>(sb, dextents<int32_t , 2 >(NR1, NK ));
82098209
8210- constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor (NR1, NR0, NK, false , true , false , mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate);
8211-
8212- mpp::tensor_ops::matmul2d<desc, execution_simdgroups<4 >> mm;
8210+ mpp::tensor_ops::matmul2d<
8211+ mpp::tensor_ops::matmul2d_descriptor (NR1, NR0, NK, false , true , false , mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
8212+ execution_simdgroups<4 >> mm;
82138213
82148214 auto cT = mm.get_destination_cooperative_tensor <decltype (tA), decltype (tB), float >();
82158215#endif
@@ -8522,72 +8522,169 @@ kernel void kernel_mul_mm_id(
85228522 ushort tiitg[[thread_index_in_threadgroup]],
85238523 ushort tiisg[[thread_index_in_simdgroup]],
85248524 ushort sgitg[[simdgroup_index_in_threadgroup]]) {
8525-
85268525 threadgroup S0 * sa = (threadgroup S0 *)(shmem);
85278526 threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096 );
85288527
8529- const int r0 = tgpig.y ;
8530- const int r1 = tgpig.x ;
8528+ threadgroup float * sc = (threadgroup float *)(shmem);
8529+
8530+ constexpr int NR0 = 64 ;
8531+ constexpr int NR1 = 32 ;
8532+
8533+ constexpr int NK = 32 ;
8534+ constexpr int NL0 = NK/16 ;
8535+ constexpr int NL1 = NK/8 ;
8536+
85318537 const int im = tgpig.z ; // expert
8538+ const int r0 = tgpig.y *NR0;
8539+ const int r1 = tgpig.x *NR1;
85328540
85338541 device const uint32_t * tpe_u32 = (device const uint32_t *) (htpe);
85348542 device const int32_t * ids_i32 = (device const int32_t *) (hids);
85358543
85368544 const int32_t neh1 = tpe_u32[im];
85378545
8538- if (r1*BLOCK_SIZE_N >= neh1) {
8546+ if (r1 >= neh1) {
85398547 return ;
85408548 }
85418549
85428550 // if this block is of 64x32 shape or smaller
8543- const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M ) ? (args.ne0 - r0*BLOCK_SIZE_M ) : BLOCK_SIZE_M ;
8544- const short n_cols = ( neh1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N ) ? ( neh1 - r1*BLOCK_SIZE_N ) : BLOCK_SIZE_N ;
8551+ const short nr0 = (args.ne0 - r0 < NR0 ) ? (args.ne0 - r0) : NR0 ;
8552+ const short nr1 = ( neh1 - r1 < NR1 ) ? ( neh1 - r1) : NR1 ;
85458553
85468554 // a thread shouldn't load data outside of the matrix
8547- const short thread_row = ((short )tiitg/THREAD_PER_ROW) < n_rows ? ((short )tiitg/THREAD_PER_ROW) : n_rows - 1 ;
8548- const short thread_col = ((short )tiitg/THREAD_PER_COL) < n_cols ? ((short )tiitg/THREAD_PER_COL) : n_cols - 1 ;
8549-
8550- S0_8x8 ma[4 ];
8551- S1_8x8 mb[2 ];
8552-
8553- simdgroup_float8x8 mc[8 ];
8555+ const short lr0 = ((short )tiitg/NL0) < nr0 ? ((short )tiitg/NL0) : nr0 - 1 ; // 0 .. 63
8556+ const short lr1 = ((short )tiitg/NL1) < nr1 ? ((short )tiitg/NL1) : nr1 - 1 ; // 0 .. 31
85548557
8555- for (short i = 0 ; i < 8 ; i++){
8556- mc[i] = make_filled_simdgroup_matrix<float , 8 >(0 .f );
8557- }
8558+ const short il0 = (tiitg % NL0);
85588559
8559- short il = (tiitg % THREAD_PER_ROW) ;
8560+ short il = il0 ;
85608561
8561- const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + thread_col ];
8562+ const int id = ids_i32[im*args.ne21 + r1 + lr1 ];
85628563
85638564 const short i11 = (id % args.ne20 ) % args.ne11 ;
85648565 const short i12 = (id / args.ne20 );
85658566 const short i13 = 0 ;
85668567
85678568 const uint64_t offset0 = im*args.nb02 + i13*args.nb03 ;
8568- const short offset1 = il /nl;
8569+ const short offset1 = il0 /nl;
85698570
8570- device const block_q * x = (device const block_q *)(src0
8571- + args.nb01 *(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
8571+ device const block_q * x = (device const block_q *)(src0 + args.nb01 *(r0 + lr0) + offset0) + offset1;
85728572
8573- const short iy = (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) );
8573+ const short iy = 8 * (tiitg % NL1 );
85748574
85758575 device const T1 * y = (device const T1 *)(src1
85768576 + args.nb13 *i13
85778577 + args.nb12 *i12
85788578 + args.nb11 *i11
85798579 + args.nb10 *iy);
85808580
8581- for (int loop_k = 0 ; loop_k < args.ne00 ; loop_k += BLOCK_SIZE_K) {
8581+ #ifndef GGML_METAL_HAS_TENSOR
8582+ S0_8x8 ma[4 ];
8583+ S1_8x8 mb[2 ];
8584+
8585+ simdgroup_float8x8 mc[8 ];
8586+
8587+ for (short i = 0 ; i < 8 ; i++){
8588+ mc[i] = make_filled_simdgroup_matrix<float , 8 >(0 .f );
8589+ }
8590+ #else
8591+ auto tA = tensor<threadgroup S0, dextents<int32_t , 2 >, tensor_inline>(sa, dextents<int32_t , 2 >(NK, NR0));
8592+ auto tB = tensor<threadgroup S1, dextents<int32_t , 2 >, tensor_inline>(sb, dextents<int32_t , 2 >(NR1, NK ));
8593+
8594+ mpp::tensor_ops::matmul2d<
8595+ mpp::tensor_ops::matmul2d_descriptor (NR1, NR0, NK, false , true , false , mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate),
8596+ execution_simdgroups<4 >> mm;
8597+
8598+ auto cT = mm.get_destination_cooperative_tensor <decltype (tA), decltype (tB), float >();
8599+ #endif
8600+
8601+ for (int loop_k = 0 ; loop_k < args.ne00 ; loop_k += NK) {
8602+ #ifndef GGML_METAL_HAS_TENSOR
8603+ // load data and store to threadgroup memory
8604+ if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
8605+ threadgroup_barrier (mem_flags::mem_threadgroup);
8606+
8607+ // no need for dequantization
8608+ for (short i = 0 ; i < 16 ; i++) {
8609+ const short sx = 2 *il0 + i/8 ;
8610+ const short sy = (tiitg/NL0)/8 ;
8611+
8612+ // const short lx = i%8;
8613+ // const short ly = (tiitg/NL0)%8;
8614+ const short lx = (tiitg/NL0)%8 ;
8615+ const short ly = i%8 ;
8616+
8617+ const short ib = 8 *sx + sy;
8618+
8619+ *(sa + 64 *ib + 8 *ly + lx) = loop_k + 16 *il + i < args.ne00 ? *((device T0 *) x + i) : 0 ;
8620+ }
8621+ } else {
8622+ S0_4x4 temp_a;
8623+ dequantize_func (x, il, temp_a);
8624+
8625+ threadgroup_barrier (mem_flags::mem_threadgroup);
8626+
8627+ FOR_UNROLL (short i = 0 ; i < 16 ; i++) {
8628+ const short sx = 2 *il0 + i/8 ;
8629+ const short sy = (tiitg/NL0)/8 ;
8630+
8631+ // const short lx = i%8;
8632+ // const short ly = (tiitg/NL0)%8;
8633+ const short lx = (tiitg/NL0)%8 ;
8634+ const short ly = i%8 ;
8635+
8636+ const short ib = 8 *sx + sy;
8637+
8638+ // NOTE: this is massively slower.. WTF?
8639+ // sa[64*ib + 8*ly + lx] = temp_a[i/4][i%4];
8640+
8641+ *(sa + 64 *ib + 8 *ly + lx) = temp_a[i/4 ][i%4 ];
8642+ }
8643+ }
8644+
8645+ if (FC_mul_mm_bc_inp) {
8646+ for (short i = 0 ; i < 8 ; ++i) {
8647+ const short sx = (tiitg%NL1);
8648+ const short sy = (tiitg/NL1)/8 ;
8649+
8650+ const short lx = i;
8651+ const short ly = (tiitg/NL1)%8 ;
8652+ // const short lx = (tiitg/NL1)%8;
8653+ // const short ly = i;
8654+
8655+ const short ib = 4 *sx + sy;
8656+
8657+ *(sb + 64 *ib + 8 *ly + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0 ;
8658+ }
8659+ } else {
8660+ const short sx = (tiitg%NL1);
8661+ const short sy = (tiitg/NL1)/8 ;
8662+
8663+ const short dx = sx;
8664+ const short dy = sy;
8665+
8666+ const short ly = (tiitg/NL1)%8 ;
8667+
8668+ const short ib = 4 *sx + sy;
8669+
8670+ *(threadgroup S1_2x4 *)(sb + 64 *ib + 8 *ly) = (S1_2x4)(*((device T1_2x4 *) y));
8671+ }
8672+ #else
85828673 // load data and store to threadgroup memory
85838674 if (is_same<T0_4x4, block_q>::value && FC_mul_mm_bc_inp) {
85848675 threadgroup_barrier (mem_flags::mem_threadgroup);
85858676
85868677 // no need for dequantization
85878678 for (short i = 0 ; i < 16 ; i++) {
8588- *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8 ) \
8589- + (tiitg%THREAD_PER_ROW)*16 + (i/8 )*8 ) \
8590- + (tiitg/THREAD_PER_ROW)%8 + (i&7 )*8 ) = loop_k + 16 *il + i < args.ne00 ? ((device T0 *) x)[i] : 0 ;
8679+ const short sx = 2 *il0 + i/8 ;
8680+ const short sy = (tiitg/NL0)/8 ;
8681+
8682+ const short lx = i%8 ;
8683+ const short ly = (tiitg/NL0)%8 ;
8684+ // const short lx = (tiitg/NL0)%8;
8685+ // const short ly = i%8;
8686+
8687+ *(sa + NK*(8 *sy + ly) + 8 *sx + lx) = loop_k + 16 *il + i < args.ne00 ? *((device T0 *) x + i) : 0 ;
85918688 }
85928689 } else {
85938690 S0_4x4 temp_a;
@@ -8596,85 +8693,120 @@ kernel void kernel_mul_mm_id(
85968693 threadgroup_barrier (mem_flags::mem_threadgroup);
85978694
85988695 FOR_UNROLL (short i = 0 ; i < 16 ; i++) {
8599- *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8 ) \
8600- + (tiitg%THREAD_PER_ROW)*16 + (i/8 )*8 ) \
8601- + (tiitg/THREAD_PER_ROW)%8 + (i&7 )*8 ) = temp_a[i/4 ][i%4 ];
8696+ const short sx = 2 *il0 + i/8 ;
8697+ const short sy = (tiitg/NL0)/8 ;
8698+
8699+ const short lx = i%8 ;
8700+ const short ly = (tiitg/NL0)%8 ;
8701+ // const short lx = (tiitg/NL0)%8;
8702+ // const short ly = i%8;
8703+
8704+ *(sa + NK*(8 *sy + ly) + 8 *sx + lx) = temp_a[i/4 ][i%4 ];
86028705 }
86038706 }
86048707
86058708 if (FC_mul_mm_bc_inp) {
86068709 for (short i = 0 ; i < 8 ; ++i) {
8607- sb[32 *8 *(tiitg%THREAD_PER_COL) + 8 *(tiitg/THREAD_PER_COL) + i] = loop_k + iy + i < args.ne00 ? (S1) ((device T1 *) y)[i] : 0 ;
8710+ const short sx = (tiitg%NL1);
8711+ const short sy = (tiitg/NL1)/8 ;
8712+
8713+ const short lx = i;
8714+ const short ly = (tiitg/NL1)%8 ;
8715+ // const short lx = (tiitg/NL1)%8;
8716+ // const short ly = i;
8717+
8718+ *(sb + NK*(8 *sy + ly) + 8 *sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0 ;
86088719 }
86098720 } else {
8610- *(threadgroup S1_2x4 *)(sb + 32 *8 *(tiitg%THREAD_PER_COL) + 8 *(tiitg/THREAD_PER_COL)) = (S1_2x4)(*((device T1_2x4 *) y));
8721+ const short sx = (tiitg%NL1);
8722+ const short sy = (tiitg/NL1)/8 ;
8723+
8724+ // const short lx = i;
8725+ const short ly = (tiitg/NL1)%8 ;
8726+ // const short lx = (tiitg/NL1)%8;
8727+ // const short ly = i;
8728+
8729+ *(threadgroup S1_2x4 *)(sb + NK*(8 *sy + ly) + 8 *sx) = (S1_2x4)(*((device T1_2x4 *) y));
86118730 }
8731+ #endif
86128732
86138733 il = (il + 2 < nl) ? il + 2 : il % 2 ;
86148734 x = (il < 2 ) ? x + (2 + nl - 1 )/nl : x;
8615- y += BLOCK_SIZE_K;
8735+
8736+ y += NK;
86168737
86178738 threadgroup_barrier (mem_flags::mem_threadgroup);
86188739
8740+ #ifndef GGML_METAL_HAS_TENSOR
86198741 // load matrices from threadgroup memory and conduct outer products
8620- threadgroup const S0 * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE *(sgitg%2 ));
8621- threadgroup const S1 * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE *(sgitg/2 ));
8622-
8623- # pragma unroll(4)
8624- for ( short ik = 0 ; ik < BLOCK_SIZE_K/ 8 ; ik++) {
8625- # pragma unroll(4)
8626- for (short i = 0 ; i < 4 ; i++) {
8627- simdgroup_load (ma[i], lsma + SG_MAT_SIZE * i );
8742+ threadgroup const S0 * lsma = (sa + 4 * 64 *(sgitg%2 ));
8743+ threadgroup const S1 * lsmb = (sb + 2 * 64 *(sgitg/2 ));
8744+
8745+ FOR_UNROLL ( short ik = 0 ; ik < NK/ 8 ; ik++) {
8746+ simdgroup_barrier (mem_flags::mem_none);
8747+
8748+ FOR_UNROLL (short i = 0 ; i < 4 ; i++) {
8749+ simdgroup_load (ma[i], lsma + 64 *i, 8 , 0 , false );
86288750 }
86298751
86308752 simdgroup_barrier (mem_flags::mem_none);
86318753
8632- #pragma unroll(2)
8633- for (short i = 0 ; i < 2 ; i++) {
8634- simdgroup_load (mb[i], lsmb + SG_MAT_SIZE * i);
8754+ FOR_UNROLL (short i = 0 ; i < 2 ; i++) {
8755+ simdgroup_load (mb[i], lsmb + 64 *i, 8 , 0 , false );
86358756 }
86368757
8637- #pragma unroll(8)
8638- for (short i = 0 ; i < 8 ; i++){
8758+ simdgroup_barrier (mem_flags::mem_none);
8759+
8760+ FOR_UNROLL (short i = 0 ; i < 8 ; i++){
86398761 simdgroup_multiply_accumulate (mc[i], mb[i/4 ], ma[i%4 ], mc[i]);
86408762 }
86418763
8642- lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE ;
8643- lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE ;
8764+ lsma += 8 * 64 ;
8765+ lsmb += 4 * 64 ;
86448766 }
8767+ #else
8768+ auto sA = tA.slice (0 , 0 );
8769+ auto sB = tB.slice (0 , 0 );
8770+
8771+ mm.run (sB , sA , cT);
8772+ #endif
86458773 }
86468774
8775+ // block is smaller than 64x32, we should avoid writing data outside of the matrix
86478776 threadgroup_barrier (mem_flags::mem_threadgroup);
86488777
8649- threadgroup float * temp_str = ((threadgroup float *) shmem) \
8650- + 32 *(sgitg&1 ) + (16 *(sgitg >> 1 ))*BLOCK_SIZE_M;
8778+ #ifdef GGML_METAL_HAS_TENSOR
8779+ auto tC = tensor<threadgroup float , dextents<int32_t , 2 >, tensor_inline>(sc, dextents<int32_t , 2 >(NR0, NR1));
8780+ cT.store (tC);
8781+ #else
8782+ threadgroup float * temp_str = ((threadgroup float *) shmem) + 32 *(sgitg&1 ) + (16 *(sgitg >> 1 ))*NR0;
86518783
8652- #pragma unroll(8)
86538784 for (short i = 0 ; i < 8 ; i++) {
8654- simdgroup_store (mc[i], temp_str + 8 *(i%4 ) + 8 *BLOCK_SIZE_M *(i/4 ), BLOCK_SIZE_M );
8785+ simdgroup_store (mc[i], temp_str + 8 *(i%4 ) + 8 *NR0 *(i/4 ), NR0, 0 , false );
86558786 }
8787+ #endif
86568788
86578789 threadgroup_barrier (mem_flags::mem_threadgroup);
86588790
8659- for (short j = sgitg; j < n_cols ; j += 4 ) {
8660- const int id = ids_i32[im*args.ne21 + r1*BLOCK_SIZE_N + j];
8791+ for (short j = sgitg; j < nr1 ; j += 4 ) {
8792+ const int id = ids_i32[im*args.ne21 + r1 + j];
86618793
86628794 const short ide = id % args.ne20 ;
86638795 const short idt = id / args.ne20 ;
86648796
8665- device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + ide*args.ne0 + idt*args.ne1 *args.ne0 ;
8797+ device float * D = (device float *) dst + r0 + ide*args.ne0 + idt*args.ne1 *args.ne0 ;
86668798 device float4 * D4 = (device float4 *) D;
86678799
8668- threadgroup float * C = (threadgroup float *) shmem + (j*BLOCK_SIZE_M) ;
8800+ threadgroup float * C = (threadgroup float *) shmem + j*NR0 ;
86698801 threadgroup float4 * C4 = (threadgroup float4 *) C;
86708802
86718803 int i = tiisg;
8672- for (; i < n_rows /4 ; i += 32 ) {
8804+ for (; i < nr0 /4 ; i += 32 ) {
86738805 *(D4 + i) = *(C4 + i);
86748806 }
86758807
8676- i = (4 *(n_rows /4 )) + tiisg;
8677- for (; i < n_rows ; i += 32 ) {
8808+ i = (4 *(nr0 /4 )) + tiisg;
8809+ for (; i < nr0 ; i += 32 ) {
86788810 *(D + i) = *(C + i);
86798811 }
86808812 }
0 commit comments