@@ -622,7 +622,9 @@ class tinyBLAS {
622622 D Cv[RN][RM] = {};
623623 D Ce[RN][RM] = {};
624624 for (long l = 0 ; l < k; l += KN)
625+ #pragma GCC unroll 100
625626 for (int j = 0 ; j < RN; ++j)
627+ #pragma GCC unroll 100
626628 for (int i = 0 ; i < RM; ++i)
627629 if (PRECISE)
628630 Cv[j][i] = madder (load<V>(INDEX (A, lda, ii + i, l)), //
@@ -632,7 +634,9 @@ class tinyBLAS {
632634 Cv[j][i] = madd (load<V>(INDEX (A, lda, ii + i, l)), //
633635 load<V>(INDEX (B, ldb, jj + j, l)), //
634636 Cv[j][i]);
637+ #pragma GCC unroll 100
635638 for (int j = 0 ; j < RN; ++j)
639+ #pragma GCC unroll 100
636640 for (int i = 0 ; i < RM; ++i)
637641 store (INDEX (C, ldc, jj + j, ii + i), hsum (Cv[j][i]));
638642 }
@@ -670,7 +674,7 @@ class tinyBLAS_Q0_ARM {
670674 NOINLINE void mnpack (long m0, long m, long n0, long n) {
671675 long mc, nc, mp, np;
672676
673- if (!FLAG_precise || (!FLAG_precision_specified && sizeof (TB) == sizeof (block_q4_0)) ) {
677+ if (!FLAG_precise) {
674678 switch ((MIN (m - m0, 3 ) << 4 ) | MIN (n - n0, 3 )) {
675679 case 0x33 :
676680 mc = 3 ;
@@ -762,7 +766,9 @@ class tinyBLAS_Q0_ARM {
762766 float32x4_t Cv[RN][RM] = {};
763767 float32x4_t Ce[RN][RM] = {};
764768 for (int l = 0 ; l < k; ++l)
769+ #pragma GCC unroll 100
765770 for (int j = 0 ; j < RN; ++j)
771+ #pragma GCC unroll 100
766772 for (int i = 0 ; i < RM; ++i) {
767773 float32x4_t a = vcvtq_f32_s32 (vdotq_s32 (
768774 vdotq_s32 (vdupq_n_s32 (0 ), load_lo (INDEX (A, lda, ii + i, l)),
@@ -775,7 +781,9 @@ class tinyBLAS_Q0_ARM {
775781 else
776782 Cv[j][i] = vmlaq_n_f32 (Cv[j][i], a, b);
777783 }
784+ #pragma GCC unroll 100
778785 for (int j = 0 ; j < RN; ++j)
786+ #pragma GCC unroll 100
779787 for (int i = 0 ; i < RM; ++i)
780788 store (INDEX (C, ldc, jj + j, ii + i), hsum (Cv[j][i]));
781789 }
@@ -829,7 +837,7 @@ class tinyBLAS_Q0_AVX2 {
829837 long mc, nc, mp, np;
830838
831839#if VECTOR_REGISTERS == 32
832- if (!FLAG_precise || (!FLAG_precision_specified && sizeof (TB) == sizeof (block_q4_0)) ) {
840+ if (!FLAG_precise) {
833841 switch ((MIN (m - m0, 3 ) << 4 ) | MIN (n - n0, 3 )) {
834842 case 0x33 :
835843 mc = 3 ;
@@ -901,7 +909,7 @@ class tinyBLAS_Q0_AVX2 {
901909#endif
902910
903911#if VECTOR_REGISTERS == 16
904- if (!FLAG_precise || (!FLAG_precision_specified && sizeof (TB) == sizeof (block_q4_0)) ) {
912+ if (!FLAG_precise) {
905913 switch ((MIN (m - m0, 3 ) << 4 ) | MIN (n - n0, 2 )) {
906914 case 0x32 :
907915 mc = 3 ;
@@ -982,7 +990,9 @@ class tinyBLAS_Q0_AVX2 {
982990 __m256 Cv[RN][RM] = {};
983991 __m256 Ce[RN][RM] = {};
984992 for (long l = 0 ; l < k; ++l)
993+ #pragma GCC unroll 100
985994 for (int j = 0 ; j < RN; ++j)
995+ #pragma GCC unroll 100
986996 for (int i = 0 ; i < RM; ++i) {
987997 __m256 a = _mm256_set1_ps (unhalf (INDEX (A, lda, ii + i, l)->d ) *
988998 unhalf (INDEX (B, ldb, jj + j, l)->d ));
@@ -995,7 +1005,9 @@ class tinyBLAS_Q0_AVX2 {
9951005 else
9961006 Cv[j][i] = madd (a, b, Cv[j][i]);
9971007 }
1008+ #pragma GCC unroll 100
9981009 for (int j = 0 ; j < RN; ++j)
1010+ #pragma GCC unroll 100
9991011 for (int i = 0 ; i < RM; ++i)
10001012 store (INDEX (C, ldc, jj + j, ii + i), hsum (Cv[j][i]));
10011013 }
0 commit comments