Skip to content

Commit b749326

Browse files
committed
Change the way the new fast/precise flags work
1 parent e6532f7 commit b749326

File tree

7 files changed

+37
-27
lines changed

7 files changed

+37
-27
lines changed

llama.cpp/common.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,12 +291,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
291291
}
292292
if (arg == "--fast") {
293293
FLAG_precise = false;
294-
FLAG_precision_specified = true;
295294
return true;
296295
}
297296
if (arg == "--precise") {
298297
FLAG_precise = true;
299-
FLAG_precision_specified = true;
300298
return true;
301299
}
302300
if (arg == "--trap") {

llama.cpp/server/server.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2525,12 +2525,10 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
25252525
else if (arg == "--fast")
25262526
{
25272527
FLAG_precise = false;
2528-
FLAG_precision_specified = true;
25292528
}
25302529
else if (arg == "--precise")
25312530
{
25322531
FLAG_precise = true;
2533-
FLAG_precision_specified = true;
25342532
}
25352533
else if (arg == "--trap")
25362534
{

llamafile/flags.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,4 @@
1717

1818
#include "llamafile.h"
1919

20-
bool FLAG_precise = true;
21-
bool FLAG_precision_specified;
20+
bool FLAG_precise;

llamafile/llamafile.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ void llamafile_launch_browser(const char *);
3232
extern bool FLAG_trap;
3333
extern bool FLAG_precise;
3434
extern bool FLAG_unsecure;
35-
extern bool FLAG_precision_specified;
3635

3736
#define LLAMAFILE_GPU_ERROR -2
3837
#define LLAMAFILE_GPU_DISABLE -1

llamafile/numba.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,27 +23,31 @@ inline float float01(unsigned x) { // (0,1)
2323
return 1.f / 8388608 * ((x >> 9) + .5f);
2424
}
2525

26-
inline float numba(void) { // (-1,1)
27-
return float01(rand32()) * 2 - 1;
26+
inline float numba(void) { // (-10,10)
27+
return float01(rand32()) * 2.f - 1.f;
2828
}
2929

30-
template <typename T> void randomize(T *A, int n) {
30+
template <typename T>
31+
void randomize(T *A, int n) {
3132
for (int i = 0; i < n; ++i)
3233
A[i] = numba();
3334
}
3435

35-
template <typename T> void randomize(int m, int n, T *A, int lda) {
36+
template <typename T>
37+
void randomize(int m, int n, T *A, int lda) {
3638
for (int j = 0; j < n; ++j)
3739
for (int i = 0; i < m; ++i)
3840
A[lda * j + i] = numba();
3941
}
4042

41-
template <typename T, typename U> void broadcast(T *A, int n, U x) {
43+
template <typename T, typename U>
44+
void broadcast(T *A, int n, U x) {
4245
for (int i = 0; i < n; ++i)
4346
A[i] = x;
4447
}
4548

46-
template <typename T, typename U> void broadcast(int m, int n, T *A, int lda, U x) {
49+
template <typename T, typename U>
50+
void broadcast(int m, int n, T *A, int lda, U x) {
4751
for (int j = 0; j < n; ++j)
4852
for (int i = 0; i < m; ++i)
4953
A[lda * j + i] = x;

llamafile/tinyblas_cpu.h

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -622,7 +622,9 @@ class tinyBLAS {
622622
D Cv[RN][RM] = {};
623623
D Ce[RN][RM] = {};
624624
for (long l = 0; l < k; l += KN)
625+
#pragma GCC unroll 100
625626
for (int j = 0; j < RN; ++j)
627+
#pragma GCC unroll 100
626628
for (int i = 0; i < RM; ++i)
627629
if (PRECISE)
628630
Cv[j][i] = madder(load<V>(INDEX(A, lda, ii + i, l)), //
@@ -632,7 +634,9 @@ class tinyBLAS {
632634
Cv[j][i] = madd(load<V>(INDEX(A, lda, ii + i, l)), //
633635
load<V>(INDEX(B, ldb, jj + j, l)), //
634636
Cv[j][i]);
637+
#pragma GCC unroll 100
635638
for (int j = 0; j < RN; ++j)
639+
#pragma GCC unroll 100
636640
for (int i = 0; i < RM; ++i)
637641
store(INDEX(C, ldc, jj + j, ii + i), hsum(Cv[j][i]));
638642
}
@@ -670,7 +674,7 @@ class tinyBLAS_Q0_ARM {
670674
NOINLINE void mnpack(long m0, long m, long n0, long n) {
671675
long mc, nc, mp, np;
672676

673-
if (!FLAG_precise || (!FLAG_precision_specified && sizeof(TB) == sizeof(block_q4_0))) {
677+
if (!FLAG_precise) {
674678
switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3)) {
675679
case 0x33:
676680
mc = 3;
@@ -762,7 +766,9 @@ class tinyBLAS_Q0_ARM {
762766
float32x4_t Cv[RN][RM] = {};
763767
float32x4_t Ce[RN][RM] = {};
764768
for (int l = 0; l < k; ++l)
769+
#pragma GCC unroll 100
765770
for (int j = 0; j < RN; ++j)
771+
#pragma GCC unroll 100
766772
for (int i = 0; i < RM; ++i) {
767773
float32x4_t a = vcvtq_f32_s32(vdotq_s32(
768774
vdotq_s32(vdupq_n_s32(0), load_lo(INDEX(A, lda, ii + i, l)),
@@ -775,7 +781,9 @@ class tinyBLAS_Q0_ARM {
775781
else
776782
Cv[j][i] = vmlaq_n_f32(Cv[j][i], a, b);
777783
}
784+
#pragma GCC unroll 100
778785
for (int j = 0; j < RN; ++j)
786+
#pragma GCC unroll 100
779787
for (int i = 0; i < RM; ++i)
780788
store(INDEX(C, ldc, jj + j, ii + i), hsum(Cv[j][i]));
781789
}
@@ -829,7 +837,7 @@ class tinyBLAS_Q0_AVX2 {
829837
long mc, nc, mp, np;
830838

831839
#if VECTOR_REGISTERS == 32
832-
if (!FLAG_precise || (!FLAG_precision_specified && sizeof(TB) == sizeof(block_q4_0))) {
840+
if (!FLAG_precise) {
833841
switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3)) {
834842
case 0x33:
835843
mc = 3;
@@ -901,7 +909,7 @@ class tinyBLAS_Q0_AVX2 {
901909
#endif
902910

903911
#if VECTOR_REGISTERS == 16
904-
if (!FLAG_precise || (!FLAG_precision_specified && sizeof(TB) == sizeof(block_q4_0))) {
912+
if (!FLAG_precise) {
905913
switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 2)) {
906914
case 0x32:
907915
mc = 3;
@@ -982,7 +990,9 @@ class tinyBLAS_Q0_AVX2 {
982990
__m256 Cv[RN][RM] = {};
983991
__m256 Ce[RN][RM] = {};
984992
for (long l = 0; l < k; ++l)
993+
#pragma GCC unroll 100
985994
for (int j = 0; j < RN; ++j)
995+
#pragma GCC unroll 100
986996
for (int i = 0; i < RM; ++i) {
987997
__m256 a = _mm256_set1_ps(unhalf(INDEX(A, lda, ii + i, l)->d) *
988998
unhalf(INDEX(B, ldb, jj + j, l)->d));
@@ -995,7 +1005,9 @@ class tinyBLAS_Q0_AVX2 {
9951005
else
9961006
Cv[j][i] = madd(a, b, Cv[j][i]);
9971007
}
1008+
#pragma GCC unroll 100
9981009
for (int j = 0; j < RN; ++j)
1010+
#pragma GCC unroll 100
9991011
for (int i = 0; i < RM; ++i)
10001012
store(INDEX(C, ldc, jj + j, ii + i), hsum(Cv[j][i]));
10011013
}

llamafile/tinyblas_mnpack.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,19 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
# # tinyBLAS
16-
# MAX_M = 5
17-
# MAX_N = 5
18-
# EDGE_M = 2
19-
# EDGE_N = 2
20-
# OVERHEAD = 1
21-
22-
# tinyBLAS_Q0
23-
MAX_M = 3
24-
MAX_N = 3
15+
# tinyBLAS
16+
MAX_M = 5
17+
MAX_N = 5
2518
EDGE_M = 2
2619
EDGE_N = 2
27-
OVERHEAD = 8
20+
OVERHEAD = 1
21+
22+
# # tinyBLAS_Q0
23+
# MAX_M = 3
24+
# MAX_N = 3
25+
# EDGE_M = 2
26+
# EDGE_N = 2
27+
# OVERHEAD = 8
2828

2929
def doit(VECTOR_REGISTERS, PRECISE):
3030
# choose tile size that exploits all vector registers

0 commit comments

Comments
 (0)