From de65783ba22f8a5423684b37aa56a3df0b324f0d Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Tue, 16 May 2023 09:26:03 +0200 Subject: [PATCH 01/34] Broadcasting for ggml_mul --- ggml.c | 52 ++++++++++++++++++++++++++++++---------------------- llama.cpp | 18 +++++++++++++++--- 2 files changed, 45 insertions(+), 25 deletions(-) diff --git a/ggml.c b/ggml.c index 4311ce7cf9dbe..e937c0a33c1eb 100644 --- a/ggml.c +++ b/ggml.c @@ -4643,7 +4643,7 @@ struct ggml_tensor * ggml_mul_impl( struct ggml_tensor * a, struct ggml_tensor * b, bool inplace) { - GGML_ASSERT(ggml_are_same_shape(a, b)); + GGML_ASSERT(a->ne[0] == b->ne[0] && ggml_can_repeat(b, a)); bool is_node = false; @@ -7945,7 +7945,16 @@ static void ggml_compute_forward_mul_f32( const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + const int nr = ggml_nrows(src0); + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src0->ne[3]; + + GGML_ASSERT(ne00 == ne10 && ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -7953,11 +7962,6 @@ static void ggml_compute_forward_mul_f32( const int ith = params->ith; const int nth = params->nth; - const int nr = ggml_nrows(src0); - const int64_t ne0 = src0->ne[0]; - const int64_t ne1 = src0->ne[1]; - const int64_t ne2 = src0->ne[2]; - const size_t nb00 = src0->nb[0]; const size_t nb01 = src0->nb[1]; const size_t nb02 = src0->nb[2]; @@ -7976,12 +7980,12 @@ static void ggml_compute_forward_mul_f32( GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); - if (nb10 == sizeof(float)) { + if (nb10 == sizeof(float) && ggml_are_same_shape(src0, src1)) { for (int ir = ith; ir < nr; ir += nth) { // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); #ifdef GGML_USE_ACCELERATE @@ -7991,9 +7995,9 @@ static void ggml_compute_forward_mul_f32( (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, - ne0); + ne00); #else - ggml_vec_mul_f32(ne0, + ggml_vec_mul_f32(ne00, (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); @@ -8004,15 +8008,19 @@ static void ggml_compute_forward_mul_f32( } else { // src1 is not contiguous for (int ir = ith; ir < nr; ir += nth) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); - - float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - for (int i0 = 0; i0 < ne0; i0++) { - float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10); + // src0 and dst are same shape => same indices + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int i03 = ir/(ne02*ne01); + const int i02 = (ir - i03*ne02*ne01)/ne01; + const int i01 = (ir - i03*ne02*ne01 - i02*ne01); + const int i13 = i03 % ne13; + const int i12 = i02 % ne12; + const int i11 = i01 % ne11; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + for (int i0 = 0; i0 < ne00; i0++) { + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10); dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr); } diff --git a/llama.cpp b/llama.cpp index 98f49abd7cf48..b594789fb3135 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1154,10 +1154,14 @@ static bool llama_eval_internal( { cur = ggml_rms_norm(ctx0, inpL); - // cur = attention_norm*cur + // cur = cur*attention_norm(broadcasted) +#ifdef GGML_USE_CUBLAS + cur = ggml_mul(ctx0, cur, model.layers[il].attention_norm); +#else cur = ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].attention_norm, cur), cur); +#endif } // self-attention @@ -1264,10 +1268,14 @@ static bool llama_eval_internal( { cur = ggml_rms_norm(ctx0, inpFF); - // cur = ffn_norm*cur + // cur = cur*ffn_norm(broadcasted) +#ifdef GGML_USE_CUBLAS + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm); +#else cur = ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].ffn_norm, cur), cur); +#endif } struct ggml_tensor * tmp = ggml_mul_mat(ctx0, @@ -1304,10 +1312,14 @@ static bool llama_eval_internal( inpL = ggml_rms_norm(ctx0, inpL); - // inpL = norm*inpL + // inpL = inpL*norm(broadcasted) +#ifdef GGML_USE_CUBLAS + inpL = ggml_mul(ctx0, inpL, model.norm); +#else inpL = ggml_mul(ctx0, ggml_repeat(ctx0, model.norm, inpL), inpL); +#endif embeddings = inpL; } From 2365a2a97038aa9060e056a4f2ca470e84935004 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Tue, 16 May 2023 15:20:36 +0200 Subject: [PATCH 02/34] CUDA kernel for ggml_mul, norms in VRAM --- ggml-cuda.cu | 93 +++++++++++++++++++++++++++++++++++++++++++++++++--- ggml-cuda.h | 1 + ggml.c | 8 +++++ llama.cpp | 2 ++ 4 files changed, 100 insertions(+), 4 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index f2630ec8ee1a1..b93c3ea8c7411 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -83,9 +83,19 @@ typedef struct { } block_q8_0; static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); +#define CUDA_MUL_BLOCK_SIZE 256 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 #define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec +static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= kx) { + return; + } + dst[i] = x[i] * y[i%ky]; +} + static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){ const block_q4_0 * x = (const block_q4_0 *) vx; @@ -228,6 +238,11 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, } } +static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) { + const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE; + mul_f32<<>>(x, y, dst, kx, ky); +} + static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; dequantize_block<<>>(vx, y, k); @@ -467,6 +482,67 @@ static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor } } +static void ggml_cuda_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src1->backend == GGML_BACKEND_CUDA); + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[2]; + const int64_t ne0 = ne00 * ne01 * ne02 * ne03; + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + size_t x_size, d_size; + + float * d_X = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &x_size); // src0 + float * d_Y = (float *) src1->data; // src1 is already on device, broadcasted. + float * d_D = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &d_size); // dst + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + const int i0 = i03*ne02 + i02; + float * c_X2 = d_X + i0*ne01*ne00; + float * c_D2 = d_D + i0*ne01*ne00; + + cudaStream_t cudaStream = g_cudaStreams[i0 % GGML_CUDA_MAX_STREAMS]; + cudaStream_t cudaStream2 = g_cudaStreams2[i0 % GGML_CUDA_MAX_STREAMS]; + cudaEvent_t cudaEvent = g_cudaEvents[i0 % GGML_CUDA_MAX_EVENTS]; + + // copy src0 to device + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X2, src0, i03, i02, cudaStream2)); + CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2)); + + // wait for data + CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0)); + + for (int64_t i01 = 0; i01 < ne01; i01++) { + const int64_t i13 = i03%ne13; + const int64_t i12 = i02%ne12; + const int64_t i11 = i01%ne11; + const int i1 = i13*ne12*ne11 + i12*ne11 + i11; + + float * c_X1 = c_X2 + i01*ne00; + float * c_Y = d_Y + i1*ne10; + float * c_D1 = c_D2 + i01*ne00; + + // compute + mul_f32_cuda(c_X1, c_Y, c_D1, ne00, ne10, cudaStream); + CUDA_CHECK(cudaGetLastError()); + } + + // copy dst to host + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + CUDA_CHECK(cudaMemcpyAsync(d, c_D2, sizeof(float)*ne00*ne01, cudaMemcpyDeviceToHost, cudaStream)); + } + } + CUDA_CHECK(cudaDeviceSynchronize()); + ggml_cuda_pool_free(d_X, x_size); + ggml_cuda_pool_free(d_D, d_size); +} + static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; @@ -724,6 +800,11 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor ggml_cuda_pool_free(d_Q, q_size); } +void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); + ggml_cuda_mul_f32(src0, src1, dst); +} + bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { const int64_t ne10 = src1->ne[0]; @@ -797,14 +878,18 @@ void ggml_cuda_transform_tensor(ggml_tensor * tensor) { const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type); size_t q_size; - char * d_Q = (char *) ggml_cuda_pool_malloc(q_sz, &q_size); + char * dst = (char *) ggml_cuda_pool_malloc(q_sz, &q_size); cudaStream_t cudaStream2 = g_cudaStreams2[0]; // copy tensor to device - CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, tensor, 0, 0, cudaStream2)); - CUDA_CHECK(cudaDeviceSynchronize()); + for (int64_t i3 = 0; i3 < ne3; i3++) { + for (int64_t i2 = 0; i2 < ne2; i2++) { + int i = i3*ne2 + i2; + CUDA_CHECK(ggml_cuda_h2d_tensor_2d(dst + i*ne0*ne1, tensor, i3, i2, cudaStream2)); + } + } - tensor->data = d_Q; + tensor->data = dst; tensor->backend = GGML_BACKEND_CUDA; } diff --git a/ggml-cuda.h b/ggml-cuda.h index 4e2c24283ccf4..682a2ce0ab450 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -6,6 +6,7 @@ extern "C" { void ggml_init_cublas(void); +void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize); diff --git a/ggml.c b/ggml.c index e937c0a33c1eb..c089a815e0986 100644 --- a/ggml.c +++ b/ggml.c @@ -7961,6 +7961,14 @@ static void ggml_compute_forward_mul_f32( } const int ith = params->ith; const int nth = params->nth; +#ifdef GGML_USE_CUBLAS + if (src1->backend == GGML_BACKEND_CUDA) { + if (ith == 0) { + ggml_cuda_mul(src0, src1, dst); + } + return; + } +#endif const size_t nb00 = src0->nb[0]; const size_t nb01 = src0->nb[1]; diff --git a/llama.cpp b/llama.cpp index b594789fb3135..0b237a3d90f66 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1038,10 +1038,12 @@ static void llama_model_load_internal( for (int i = 0; i < n_gpu; ++i) { const auto & layer = model.layers[i]; + ggml_cuda_transform_tensor(layer.attention_norm); vram_total += ggml_nbytes(layer.attention_norm); ggml_cuda_transform_tensor(layer.wq); vram_total += ggml_nbytes(layer.wq); ggml_cuda_transform_tensor(layer.wk); vram_total += ggml_nbytes(layer.wk); ggml_cuda_transform_tensor(layer.wv); vram_total += ggml_nbytes(layer.wv); ggml_cuda_transform_tensor(layer.wo); vram_total += ggml_nbytes(layer.wo); + ggml_cuda_transform_tensor(layer.ffn_norm); vram_total += ggml_nbytes(layer.ffn_norm); ggml_cuda_transform_tensor(layer.w1); vram_total += ggml_nbytes(layer.w1); ggml_cuda_transform_tensor(layer.w2); vram_total += ggml_nbytes(layer.w2); ggml_cuda_transform_tensor(layer.w3); vram_total += ggml_nbytes(layer.w3); From fa1a29f36f7bf336b55d36cc1917a60373fb9971 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Wed, 17 May 2023 16:35:50 +0200 Subject: [PATCH 03/34] GPU weights not in RAM, direct loading with cuFile --- .gitignore | 9 +++ Makefile | 2 +- ggml-cuda.cu | 49 +++++++++++++++- ggml-cuda.h | 3 + llama.cpp | 160 +++++++++++++++++++++++++++++++-------------------- 5 files changed, 159 insertions(+), 64 deletions(-) diff --git a/.gitignore b/.gitignore index d231f3ff8ed36..07528b5c6cc08 100644 --- a/.gitignore +++ b/.gitignore @@ -48,3 +48,12 @@ qnt-*.txt perf-*.txt examples/jeopardy/results.txt + +/prompts +*.sh +*.log +*.py +*.txt +/wikitext-2-raw/ +*.org +/libllama.so diff --git a/Makefile b/Makefile index f9ec8797a40dc..93c149edc9593 100644 --- a/Makefile +++ b/Makefile @@ -125,7 +125,7 @@ endif ifdef LLAMA_CUBLAS CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include - LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib + LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lcufile -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib OBJS += ggml-cuda.o NVCC = nvcc NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native diff --git a/ggml-cuda.cu b/ggml-cuda.cu index b93c3ea8c7411..893c7d4a92407 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2,11 +2,13 @@ #include #include #include +#include #include #include #include #include +#include #include "ggml-cuda.h" #include "ggml.h" @@ -32,6 +34,15 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); } \ } while (0) +#define CUFILE_CHECK(status) \ + do { \ + CUfileError_t status_ = (status); \ + if (status_.err != CU_FILE_SUCCESS) { \ + fprintf(stderr, "cuFile error %d at %s:%d\n", status_.err, __FILE__, __LINE__); \ + exit(1); \ + } \ + } while (0) + typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1); typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream); typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream); @@ -372,7 +383,7 @@ struct cuda_buffer { static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS]; static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT; -static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { +void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { scoped_spin_lock lock(g_cuda_pool_lock); for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { @@ -391,7 +402,7 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { return ptr; } -static void ggml_cuda_pool_free(void * ptr, size_t size) { +void ggml_cuda_pool_free(void * ptr, size_t size) { scoped_spin_lock lock(g_cuda_pool_lock); for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { @@ -431,6 +442,9 @@ void ggml_init_cublas() { // configure logging to stdout // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr)); + + // initialize cuFile for loading model parameters directly to VRAM + CUFILE_CHECK(cuFileDriverOpen()); } } @@ -893,3 +907,34 @@ void ggml_cuda_transform_tensor(ggml_tensor * tensor) { tensor->data = dst; tensor->backend = GGML_BACKEND_CUDA; } + +bool ggml_cuda_load_data_cufile(const char * fname, struct ggml_tensor ** tensors, const int num_tensors, const size_t * offsets) { + CUfileDescr_t cf_descr; + memset((void *)&cf_descr, 0, sizeof(CUfileDescr_t)); + const int fd_cf = open(fname, O_RDONLY|O_DIRECT, 0644); + cf_descr.handle.fd = fd_cf; + cf_descr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD; + + CUfileHandle_t cf_handle; + CUfileError_t status = cuFileHandleRegister(&cf_handle, &cf_descr); + + if (status.err == CU_FILE_INTERNAL_ERROR) { + fprintf(stderr, "WARNING: cuFile experienced an internal error while loading weights from \"%s\". Using a workaround (slower). " + "This happens with weight files on Btrfs partitions. ext4 and NTFS are confirmed to work.\n", fname); + } + if (status.err != CU_FILE_SUCCESS) { + return false; + } + + for (int i = 0; i < num_tensors; ++i) { + ggml_tensor * tensor = tensors[i]; + const size_t size = ggml_nbytes(tensor); + const size_t offset = offsets[i]; + + size_t actual_size; + void * buf = ggml_cuda_pool_malloc(size, &actual_size); + cuFileRead(cf_handle, buf, size, offset, 0); + tensor->data = buf; + } + return true; +} diff --git a/ggml-cuda.h b/ggml-cuda.h index 682a2ce0ab450..7485af9f73595 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -14,8 +14,11 @@ void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens // TODO: export these with GGML_API void * ggml_cuda_host_malloc(size_t size); void ggml_cuda_host_free(void * ptr); +void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size); +void ggml_cuda_pool_free(void * ptr, size_t size); void ggml_cuda_transform_tensor(struct ggml_tensor * tensor); +bool ggml_cuda_load_data_cufile(const char * fname, struct ggml_tensor ** tensors, int num_tensors, const size_t * offsets); #ifdef __cplusplus } diff --git a/llama.cpp b/llama.cpp index 0b237a3d90f66..644b6a57b5787 100644 --- a/llama.cpp +++ b/llama.cpp @@ -10,6 +10,7 @@ #include "ggml.h" #ifdef GGML_USE_CUBLAS +#include #include "ggml-cuda.h" #endif @@ -641,7 +642,7 @@ struct llama_model_loader { } } - struct ggml_tensor * get_tensor(const std::string & name, const std::vector & ne) { + struct ggml_tensor * get_tensor(const std::string & name, const std::vector & ne, ggml_backend backend) { auto it = tensors_map.name_to_idx.find(name); if (it == tensors_map.name_to_idx.end()) { throw format("llama.cpp: tensor '%s' is missing from model", name.c_str()); @@ -652,10 +653,10 @@ struct llama_model_loader { name.c_str(), llama_format_tensor_shape(ne).c_str(), llama_format_tensor_shape(lt.ne).c_str()); } - return get_tensor_for(lt); + return get_tensor_for(lt, backend); } - struct ggml_tensor * get_tensor_for(llama_load_tensor & lt) { + struct ggml_tensor * get_tensor_for(llama_load_tensor & lt, ggml_backend backend) { struct ggml_tensor * tensor; if (lt.ne.size() == 2) { tensor = ggml_new_tensor_2d(ggml_ctx, lt.type, lt.ne.at(0), lt.ne.at(1)); @@ -665,6 +666,7 @@ struct llama_model_loader { } ggml_set_name(tensor, lt.name.c_str()); LLAMA_ASSERT(lt.ggml_tensor == NULL); // if this fails, we called get_tensor twice on the same tensor + tensor->backend = backend; lt.ggml_tensor = tensor; num_ggml_tensors_created++; return tensor; @@ -683,7 +685,7 @@ struct llama_model_loader { } if (use_mmap) { - mapping.reset(new llama_mmap(&file_loaders.at(0)->file)); + mapping.reset(new llama_mmap(&file_loaders.at(0)->file, false)); if (!lmlock) { // Don't call the callback since the actual loading will be lazy // and we can't measure it. @@ -696,6 +698,9 @@ struct llama_model_loader { size_t done_size = 0; for (llama_load_tensor & lt : tensors_map.tensors) { + if (lt.ggml_tensor->backend != GGML_BACKEND_CPU) { + continue; + } if (progress_callback) { progress_callback((float) done_size / data_size, progress_callback_user_data); } @@ -944,26 +949,6 @@ static void llama_model_load_internal( ml->calc_sizes(&ctx_size, &mmapped_size); fprintf(stderr, "%s: ggml ctx size = %6.2f KB\n", __func__, ctx_size/1024.0); - // print memory requirements - { - const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1; - - // this is the total memory required to run the inference - const size_t mem_required = - ctx_size + - mmapped_size + - MEM_REQ_SCRATCH0().at(model.type) + - MEM_REQ_SCRATCH1().at(model.type) + - MEM_REQ_EVAL().at(model.type); - - // this is the memory required by one llama_state - const size_t mem_required_state = - scale*MEM_REQ_KV_SELF().at(model.type); - - fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__, - mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); - } - // create the ggml context { lctx.model.buf.resize(ctx_size); @@ -985,6 +970,7 @@ static void llama_model_load_internal( } // prepare memory for the weights + size_t vram_total = 0; { const uint32_t n_embd = hparams.n_embd; const uint32_t n_layer = hparams.n_layer; @@ -992,33 +978,78 @@ static void llama_model_load_internal( ml->ggml_ctx = ctx; - model.tok_embeddings = ml->get_tensor("tok_embeddings.weight", {n_embd, n_vocab}); - model.norm = ml->get_tensor("norm.weight", {n_embd}); - model.output = ml->get_tensor("output.weight", {n_embd, n_vocab}); + model.tok_embeddings = ml->get_tensor("tok_embeddings.weight", {n_embd, n_vocab}, GGML_BACKEND_CPU); + model.norm = ml->get_tensor("norm.weight", {n_embd}, GGML_BACKEND_CPU); + ggml_backend backend_output; + if (n_gpu_layers > int(n_layer)) { + backend_output = GGML_BACKEND_CUDA; + } else { + backend_output = GGML_BACKEND_CPU; + } + model.output = ml->get_tensor("output.weight", {n_embd, n_vocab}, backend_output); model.layers.resize(n_layer); + const int i_gpu_start = n_layer - n_gpu_layers; for (uint32_t i = 0; i < n_layer; ++i) { auto & layer = model.layers[i]; + const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : GGML_BACKEND_CUDA; std::string layers_i = "layers." + std::to_string(i); - layer.attention_norm = ml->get_tensor(layers_i + ".attention_norm.weight", {n_embd}); + layer.attention_norm = ml->get_tensor(layers_i + ".attention_norm.weight", {n_embd}, backend); - layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}); - layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd}); - layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd}); - layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}); + layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}, backend); + layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd}, backend); + layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd}, backend); + layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}, backend); - layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}); + layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend); - layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}); - layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}); - layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}); + layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}, backend); + layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}, backend); + layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend); + if (backend == GGML_BACKEND_CUDA) { + vram_total += ggml_nbytes(layer.attention_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + + ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.attention_norm) + + ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3); + } } } ml->done_getting_tensors(); + // print memory requirements + { + const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1; + + // this is the total memory required to run the inference + const size_t mem_required = + ctx_size + + mmapped_size - vram_total + // weights in VRAM not in memory + MEM_REQ_SCRATCH0().at(model.type) + + MEM_REQ_SCRATCH1().at(model.type) + + MEM_REQ_EVAL().at(model.type); + + // this is the memory required by one llama_state + const size_t mem_required_state = + scale*MEM_REQ_KV_SELF().at(model.type); + + fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__, + mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); + +#ifdef GGML_USE_CUBLAS + const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); + + fprintf(stderr, "%s: [cublas] offloading %d layers to GPU\n", __func__, n_gpu); + if (n_gpu_layers > (int) hparams.n_layer) { + fprintf(stderr, "%s: [cublas] offloading output layer to GPU\n", __func__); + } + fprintf(stderr, "%s: [cublas] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024); +#else + (void) n_gpu_layers; +#endif + } + // populate `tensors_by_name` for (llama_load_tensor & lt : ml->tensors_map.tensors) { model.tensors_by_name.emplace_back(lt.name, lt.ggml_tensor); @@ -1026,38 +1057,44 @@ static void llama_model_load_internal( ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL); - model.mapping = std::move(ml->mapping); #ifdef GGML_USE_CUBLAS { - const int n_gpu = std::min(n_gpu_layers, int(hparams.n_layer)); + std::vector tensors; + std::vector offsets; + for (llama_load_tensor & lt : ml->tensors_map.tensors) { + if (lt.ggml_tensor->backend != GGML_BACKEND_CUDA) { + continue; + } + tensors.emplace_back(lt.ggml_tensor); + LLAMA_ASSERT(lt.shards.size() == 1); + offsets.emplace_back(lt.shards.at(0).file_off); + } + bool cufile_success = ggml_cuda_load_data_cufile(fname.c_str(), tensors.data(), tensors.size(), offsets.data()); - fprintf(stderr, "%s: [cublas] offloading %d layers to GPU\n", __func__, n_gpu); + if (!cufile_success) { + for (llama_load_tensor & lt : ml->tensors_map.tensors) { + if (lt.ggml_tensor->backend != GGML_BACKEND_CUDA) { + continue; + } + size_t actual_size; + void * buf = ggml_cuda_pool_malloc(lt.size, &actual_size); + void * buf_host = ggml_cuda_host_malloc(lt.size); - size_t vram_total = 0; + llama_file & file = ml->file_loaders.at(lt.shards.at(0).file_idx)->file; + file.seek(lt.shards.at(0).file_off, SEEK_SET); + file.read_raw(buf_host, lt.size); - for (int i = 0; i < n_gpu; ++i) { - const auto & layer = model.layers[i]; + cudaMemcpy(buf, buf_host, lt.size, cudaMemcpyHostToDevice); + cudaDeviceSynchronize(); - ggml_cuda_transform_tensor(layer.attention_norm); vram_total += ggml_nbytes(layer.attention_norm); - ggml_cuda_transform_tensor(layer.wq); vram_total += ggml_nbytes(layer.wq); - ggml_cuda_transform_tensor(layer.wk); vram_total += ggml_nbytes(layer.wk); - ggml_cuda_transform_tensor(layer.wv); vram_total += ggml_nbytes(layer.wv); - ggml_cuda_transform_tensor(layer.wo); vram_total += ggml_nbytes(layer.wo); - ggml_cuda_transform_tensor(layer.ffn_norm); vram_total += ggml_nbytes(layer.ffn_norm); - ggml_cuda_transform_tensor(layer.w1); vram_total += ggml_nbytes(layer.w1); - ggml_cuda_transform_tensor(layer.w2); vram_total += ggml_nbytes(layer.w2); - ggml_cuda_transform_tensor(layer.w3); vram_total += ggml_nbytes(layer.w3); - } - if (n_gpu_layers > (int) hparams.n_layer) { - fprintf(stderr, "%s: [cublas] offloading output layer to GPU\n", __func__); - ggml_cuda_transform_tensor(model.output); vram_total += ggml_nbytes(model.output); + lt.ggml_tensor->data = buf; + ggml_cuda_host_free(buf_host); + } } - - fprintf(stderr, "%s: [cublas] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024); } -#else - (void) n_gpu_layers; -#endif +#endif // GGML_USE_CUBLAS + + model.mapping = std::move(ml->mapping); // loading time will be recalculate after the first eval, so // we take page faults deferred by mmap() into consideration @@ -2395,7 +2432,8 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char * } size_t idx = model_loader->tensors_map.name_to_idx[base_name]; llama_load_tensor & lt = model_loader->tensors_map.tensors[idx]; - base_t = model_loader->get_tensor(base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] }); + base_t = model_loader->get_tensor( + base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] }, GGML_BACKEND_CPU); lt.data = (uint8_t *) lt.ggml_tensor->data; model_loader->load_data_for(lt); lt.ggml_tensor->data = lt.data; From 1bfe5a98867a6057396b6ca4829cedc5c3ae0a10 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Thu, 18 May 2023 23:57:13 +0200 Subject: [PATCH 04/34] fixup! GPU weights not in RAM, direct loading with cuFile --- ggml-cuda.cu | 48 ++++++++++++++++++++++++++++++++++++++---------- ggml-cuda.h | 4 +--- llama.cpp | 24 +----------------------- 3 files changed, 40 insertions(+), 36 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 893c7d4a92407..6e6e12aaac555 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -383,7 +383,7 @@ struct cuda_buffer { static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS]; static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT; -void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { +static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { scoped_spin_lock lock(g_cuda_pool_lock); for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { @@ -402,7 +402,7 @@ void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { return ptr; } -void ggml_cuda_pool_free(void * ptr, size_t size) { +static void ggml_cuda_pool_free(void * ptr, size_t size) { scoped_spin_lock lock(g_cuda_pool_lock); for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { @@ -908,7 +908,7 @@ void ggml_cuda_transform_tensor(ggml_tensor * tensor) { tensor->backend = GGML_BACKEND_CUDA; } -bool ggml_cuda_load_data_cufile(const char * fname, struct ggml_tensor ** tensors, const int num_tensors, const size_t * offsets) { +void ggml_cuda_load_data(const char * fname, struct ggml_tensor ** tensors, const int num_tensors, const size_t * offsets) { CUfileDescr_t cf_descr; memset((void *)&cf_descr, 0, sizeof(CUfileDescr_t)); const int fd_cf = open(fname, O_RDONLY|O_DIRECT, 0644); @@ -921,20 +921,48 @@ bool ggml_cuda_load_data_cufile(const char * fname, struct ggml_tensor ** tensor if (status.err == CU_FILE_INTERNAL_ERROR) { fprintf(stderr, "WARNING: cuFile experienced an internal error while loading weights from \"%s\". Using a workaround (slower). " "This happens with weight files on Btrfs partitions. ext4 and NTFS are confirmed to work.\n", fname); - } - if (status.err != CU_FILE_SUCCESS) { - return false; + } else { + for (int i = 0; i < num_tensors; ++i) { + ggml_tensor * tensor = tensors[i]; + const size_t size = ggml_nbytes(tensor); + const size_t offset = offsets[i]; + + size_t actual_size; + void * buf = ggml_cuda_pool_malloc(size, &actual_size); + cuFileRead(cf_handle, buf, size, offset, 0); + tensor->data = buf; + } + return; } + FILE * fp = fopen(fname, "rb"); + for (int i = 0; i < num_tensors; ++i) { ggml_tensor * tensor = tensors[i]; const size_t size = ggml_nbytes(tensor); const size_t offset = offsets[i]; - size_t actual_size; - void * buf = ggml_cuda_pool_malloc(size, &actual_size); - cuFileRead(cf_handle, buf, size, offset, 0); + void * buf; + CUDA_CHECK(cudaMalloc(&buf, size)); + void * buf_host = malloc(size); + +#ifdef _WIN32 + int ret = _fseeki64(fp, (__int64) offset, SEEK_SET); +#else + int ret = fseek(fp, (long) offset, SEEK_SET); +#endif + GGML_ASSERT(ret == 0); // same + + size_t ret2 = fread(buf_host, size, 1, fp); + if (ret2 != 1) { + fprintf(stderr, "unexpectedly reached end of file"); + exit(1); + } + + cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice); + cudaDeviceSynchronize(); + tensor->data = buf; + free(buf_host); } - return true; } diff --git a/ggml-cuda.h b/ggml-cuda.h index 7485af9f73595..ab2c690b082b1 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -14,11 +14,9 @@ void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens // TODO: export these with GGML_API void * ggml_cuda_host_malloc(size_t size); void ggml_cuda_host_free(void * ptr); -void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size); -void ggml_cuda_pool_free(void * ptr, size_t size); void ggml_cuda_transform_tensor(struct ggml_tensor * tensor); -bool ggml_cuda_load_data_cufile(const char * fname, struct ggml_tensor ** tensors, int num_tensors, const size_t * offsets); +void ggml_cuda_load_data(const char * fname, struct ggml_tensor ** tensors, int num_tensors, const size_t * offsets); #ifdef __cplusplus } diff --git a/llama.cpp b/llama.cpp index 644b6a57b5787..8312ffedfc8df 100644 --- a/llama.cpp +++ b/llama.cpp @@ -10,7 +10,6 @@ #include "ggml.h" #ifdef GGML_USE_CUBLAS -#include #include "ggml-cuda.h" #endif @@ -1069,28 +1068,7 @@ static void llama_model_load_internal( LLAMA_ASSERT(lt.shards.size() == 1); offsets.emplace_back(lt.shards.at(0).file_off); } - bool cufile_success = ggml_cuda_load_data_cufile(fname.c_str(), tensors.data(), tensors.size(), offsets.data()); - - if (!cufile_success) { - for (llama_load_tensor & lt : ml->tensors_map.tensors) { - if (lt.ggml_tensor->backend != GGML_BACKEND_CUDA) { - continue; - } - size_t actual_size; - void * buf = ggml_cuda_pool_malloc(lt.size, &actual_size); - void * buf_host = ggml_cuda_host_malloc(lt.size); - - llama_file & file = ml->file_loaders.at(lt.shards.at(0).file_idx)->file; - file.seek(lt.shards.at(0).file_off, SEEK_SET); - file.read_raw(buf_host, lt.size); - - cudaMemcpy(buf, buf_host, lt.size, cudaMemcpyHostToDevice); - cudaDeviceSynchronize(); - - lt.ggml_tensor->data = buf; - ggml_cuda_host_free(buf_host); - } - } + ggml_cuda_load_data(fname.c_str(), tensors.data(), tensors.size(), offsets.data()); } #endif // GGML_USE_CUBLAS From 24d5ddf67c2b0cbfa49bd2b29be4cd0bb44fa099 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Fri, 19 May 2023 10:13:20 +0200 Subject: [PATCH 05/34] fixup! GPU weights not in RAM, direct loading with cuFile --- llama-util.h | 6 +++--- llama.cpp | 8 ++++++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/llama-util.h b/llama-util.h index 88ec28dca051a..e775f64ccc475 100644 --- a/llama-util.h +++ b/llama-util.h @@ -172,7 +172,7 @@ struct llama_mmap { #ifdef _POSIX_MAPPED_FILES static constexpr bool SUPPORTED = true; - llama_mmap(struct llama_file * file, bool prefetch = true) { + llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1 /* -1 = max value */) { size = file->size; int fd = fileno(file->fp); int flags = MAP_SHARED; @@ -184,9 +184,9 @@ struct llama_mmap { throw std::runtime_error(format("mmap failed: %s", strerror(errno))); } - if (prefetch) { + if (prefetch > 0) { // Advise the kernel to preload the mapped memory - if (madvise(addr, file->size, MADV_WILLNEED)) { + if (madvise(addr, std::min(file->size, prefetch), MADV_WILLNEED)) { fprintf(stderr, "warning: madvise(.., MADV_WILLNEED) failed: %s\n", strerror(errno)); } diff --git a/llama.cpp b/llama.cpp index 8312ffedfc8df..6909db1ac241e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -679,12 +679,16 @@ struct llama_model_loader { void load_all_data(llama_progress_callback progress_callback, void * progress_callback_user_data, llama_mlock * lmlock) { size_t data_size = 0; + size_t prefetch_size = 0; for (const llama_load_tensor & lt : tensors_map.tensors) { data_size += lt.size; + if (lt.ggml_tensor->backend == GGML_BACKEND_CPU) { + prefetch_size += lt.size; + } } if (use_mmap) { - mapping.reset(new llama_mmap(&file_loaders.at(0)->file, false)); + mapping.reset(new llama_mmap(&file_loaders.at(0)->file, prefetch_size)); if (!lmlock) { // Don't call the callback since the actual loading will be lazy // and we can't measure it. @@ -2317,7 +2321,7 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char * // maybe this should in llama_model_loader if (model_loader->use_mmap) { - model_loader->mapping.reset(new llama_mmap(&model_loader->file_loaders.at(0)->file, /* prefetch */ false)); + model_loader->mapping.reset(new llama_mmap(&model_loader->file_loaders.at(0)->file, /* prefetch */ 0)); } } From 09d82511d4a5569f1551d79672400a37a94d7bfa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A1s=20Salamon?= Date: Tue, 16 May 2023 16:46:34 +0100 Subject: [PATCH 06/34] define default model path once, sync path with readme (#1366) --- examples/common.h | 2 +- examples/embedding/embedding.cpp | 1 - examples/main/main.cpp | 1 - examples/perplexity/perplexity.cpp | 1 - examples/save-load-state/save-load-state.cpp | 1 - 5 files changed, 1 insertion(+), 5 deletions(-) diff --git a/examples/common.h b/examples/common.h index 717838f06e064..f4e07a2525723 100644 --- a/examples/common.h +++ b/examples/common.h @@ -45,7 +45,7 @@ struct gpt_params { float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate - std::string model = "models/lamma-7B/ggml-model.bin"; // model path + std::string model = "models/7B/ggml-model.bin"; // model path std::string prompt = ""; std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state std::string input_prefix = ""; // string to prefix user inputs with diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index bb3fd50a9f981..c24f7f8207157 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -6,7 +6,6 @@ int main(int argc, char ** argv) { gpt_params params; - params.model = "models/llama-7B/ggml-model.bin"; if (gpt_params_parse(argc, argv, params) == false) { return 1; diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 8543414dd0fbb..fe1c847a7f490 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -50,7 +50,6 @@ void sigint_handler(int signo) { int main(int argc, char ** argv) { gpt_params params; - params.model = "models/llama-7B/ggml-model.bin"; if (gpt_params_parse(argc, argv, params) == false) { return 1; diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 9212dee5c0868..9d38626cbd4f4 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -116,7 +116,6 @@ void perplexity(llama_context * ctx, const gpt_params & params) { int main(int argc, char ** argv) { gpt_params params; - params.model = "models/llama-7B/ggml-model.bin"; params.n_batch = 512; if (gpt_params_parse(argc, argv, params) == false) { diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index ea0a984d93816..35596957974ca 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -8,7 +8,6 @@ int main(int argc, char ** argv) { gpt_params params; - params.model = "models/llama-7B/ggml-model.bin"; params.seed = 42; params.n_threads = 4; params.repeat_last_n = 64; From 230018d11ccedf5c71c4c4f1729eb7b42b752f49 Mon Sep 17 00:00:00 2001 From: Ilya Kurdyukov <59548320+ilyakurdyukov@users.noreply.github.com> Date: Wed, 17 May 2023 01:36:47 +0700 Subject: [PATCH 07/34] ~7% faster Q5_1 AVX2 code (#1477) --- ggml.c | 39 ++++++++++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/ggml.c b/ggml.c index c089a815e0986..156c0c1eef2d7 100644 --- a/ggml.c +++ b/ggml.c @@ -543,12 +543,7 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) { return _mm256_cvtepi32_ps(summed_pairs); } -// multiply int8_t, add results pairwise twice and return as float vector -static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { - // Get absolute values of x vectors - const __m256i ax = _mm256_sign_epi8(x, x); - // Sign the values of the y vectors - const __m256i sy = _mm256_sign_epi8(y, x); +static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { #if __AVXVNNI__ const __m256i zero = _mm256_setzero_si256(); const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy); @@ -560,6 +555,21 @@ static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { #endif } +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { +#if __AVXVNNIINT8__ + const __m256i zero = _mm256_setzero_si256(); + const __m256i summed_pairs = _mm256_dpbssd_epi32(zero, x, y); + return _mm256_cvtepi32_ps(summed_pairs); +#else + // Get absolute values of x vectors + const __m256i ax = _mm256_sign_epi8(x, x); + // Sign the values of the y vectors + const __m256i sy = _mm256_sign_epi8(y, x); + return mul_sum_us8_pairs_float(ax, sy); +#endif +} + static inline __m128i packNibbles( __m256i bytes ) { // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh @@ -619,6 +629,17 @@ static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) { return _mm256_cvtepi32_ps(summed_pairs); } +static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) { + const __m128i axl = _mm256_castsi256_si128(ax); + const __m128i axh = _mm256_extractf128_si256(ax, 1); + const __m128i syl = _mm256_castsi256_si128(sy); + const __m128i syh = _mm256_extractf128_si256(sy, 1); + // Perform multiplication and create 16-bit values + const __m128i dotl = _mm_maddubs_epi16(axl, syl); + const __m128i doth = _mm_maddubs_epi16(axh, syh); + return sum_i16_pairs_float(doth, dotl); +} + // multiply int8_t, add results pairwise twice and return as float vector static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { const __m128i xl = _mm256_castsi256_si128(x); @@ -2434,7 +2455,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * const __m256i bx = bytes_from_nibbles_32(x[i].qs); const __m256i by = _mm256_loadu_si256( (const __m256i *)y[i].qs ); - const __m256 xy = mul_sum_i8_pairs_float(bx, by); + const __m256 xy = mul_sum_us8_pairs_float(bx, by); // Accumulate d0*d1*x*y #if defined(__AVX2__) @@ -2906,7 +2927,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * const __m256 dy = _mm256_broadcast_ss(&y[i].d); const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - const __m256 q = mul_sum_i8_pairs_float(bx, by); + const __m256 q = mul_sum_us8_pairs_float(bx, by); acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc); } @@ -2940,7 +2961,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * const __m256 dy = _mm256_broadcast_ss(&y[i].d); const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); - const __m256 q = mul_sum_i8_pairs_float(bx, by); + const __m256 q = mul_sum_us8_pairs_float(bx, by); acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc); } From 1af2844e4694437c32afd4d78b949886ff2fabbd Mon Sep 17 00:00:00 2001 From: Tom Jobbins <784313+TheBloke@users.noreply.github.com> Date: Tue, 16 May 2023 23:04:35 +0100 Subject: [PATCH 08/34] convert.py: Support models which are stored in a single pytorch_model.bin (#1469) * Support models in a single pytorch_model.bin * Remove spurious line with typo --- convert.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/convert.py b/convert.py index 8f4f0399e1c52..ece5a02668365 100644 --- a/convert.py +++ b/convert.py @@ -121,7 +121,6 @@ def make_tensors_list() -> List[str]: f'layers.{i}.feed_forward.w1.weight', f'layers.{i}.feed_forward.w2.weight', f'layers.{i}.feed_forward.w3.weight', - f'layers.{i}.atttention_norm.weight', f'layers.{i}.ffn_norm.weight', ] return ret @@ -1055,7 +1054,7 @@ def load_some_model(path: Path) -> ModelPlus: files = list(path.glob("model-00001-of-*.safetensors")) if not files: # Try the PyTorch patterns too, with lower priority - globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt"] + globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin" ] files = [file for glob in globs for file in path.glob(glob)] if not files: # Try GGML too, but with lower priority, since if both a non-GGML From d5207bf3073f6357bc7f725f3061f86173c613de Mon Sep 17 00:00:00 2001 From: rankaiyx Date: Wed, 17 May 2023 22:47:58 +0800 Subject: [PATCH 09/34] benchmark-matmul: Print the average of the test results (#1490) --- examples/benchmark/benchmark-matmult.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/examples/benchmark/benchmark-matmult.cpp b/examples/benchmark/benchmark-matmult.cpp index 7d237be02112b..446b8e8fb5ef2 100644 --- a/examples/benchmark/benchmark-matmult.cpp +++ b/examples/benchmark/benchmark-matmult.cpp @@ -211,6 +211,7 @@ int main(int argc, char ** argv) { printf("Iteration;NThreads; SizeX; SizeY; SizeZ; Required_FLOPS; Elapsed_u_Seconds; gigaFLOPS\n"); printf("=====================================================================================\n"); + double gflops_sum = 0; for (int i=0;i Date: Wed, 17 May 2023 22:12:01 +0000 Subject: [PATCH 10/34] Remove unused n_parts parameter (#1509) --- examples/common.cpp | 8 -------- examples/common.h | 1 - examples/quantize-stats/quantize-stats.cpp | 1 - examples/save-load-state/save-load-state.cpp | 1 - llama.cpp | 1 - llama.h | 1 - 6 files changed, 13 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index 259880a7cc64f..a6abc4977bc1d 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -321,12 +321,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { invalid_param = true; break; } - } else if (arg == "--n-parts") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.n_parts = std::stoi(argv[i]); } else if (arg == "-h" || arg == "--help") { gpt_print_usage(argc, argv, default_params); exit(0); @@ -418,7 +412,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " --no-penalize-nl do not penalize newline token\n"); fprintf(stderr, " --memory-f32 use f32 instead of f16 for memory key+value\n"); fprintf(stderr, " --temp N temperature (default: %.1f)\n", (double)params.temp); - fprintf(stderr, " --n-parts N number of model parts (default: -1 = determine from dimensions)\n"); fprintf(stderr, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); fprintf(stderr, " --perplexity compute perplexity over the prompt\n"); fprintf(stderr, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep); @@ -473,7 +466,6 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) { auto lparams = llama_context_default_params(); lparams.n_ctx = params.n_ctx; - lparams.n_parts = params.n_parts; lparams.n_gpu_layers = params.n_gpu_layers; lparams.seed = params.seed; lparams.f16_kv = params.memory_f16; diff --git a/examples/common.h b/examples/common.h index f4e07a2525723..2ad20ba504e6d 100644 --- a/examples/common.h +++ b/examples/common.h @@ -24,7 +24,6 @@ struct gpt_params { int32_t seed = -1; // RNG seed int32_t n_threads = get_num_physical_cores(); int32_t n_predict = -1; // new tokens to predict - int32_t n_parts = -1; // amount of model parts (-1 = determine from model dimensions) int32_t n_ctx = 512; // context size int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) int32_t n_keep = 0; // number of tokens to keep from initial prompt diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 9a2aa7c6474fb..085fdde3caf1e 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -321,7 +321,6 @@ int main(int argc, char ** argv) { auto lparams = llama_context_default_params(); lparams.n_ctx = 256; - lparams.n_parts = 1; lparams.seed = 1; lparams.f16_kv = false; lparams.use_mlock = false; diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 35596957974ca..91f04b6c7bcb2 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -26,7 +26,6 @@ int main(int argc, char ** argv) { auto lparams = llama_context_default_params(); lparams.n_ctx = params.n_ctx; - lparams.n_parts = params.n_parts; lparams.seed = params.seed; lparams.f16_kv = params.memory_f16; lparams.use_mmap = params.use_mmap; diff --git a/llama.cpp b/llama.cpp index 6909db1ac241e..a856dd3a5ec2e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -820,7 +820,6 @@ static bool kv_cache_init( struct llama_context_params llama_context_default_params() { struct llama_context_params result = { /*.n_ctx =*/ 512, - /*.n_parts =*/ -1, /*.gpu_layers =*/ 0, /*.seed =*/ -1, /*.f16_kv =*/ false, diff --git a/llama.h b/llama.h index 21cba8cf61061..f955fa23db048 100644 --- a/llama.h +++ b/llama.h @@ -55,7 +55,6 @@ extern "C" { struct llama_context_params { int n_ctx; // text context - int n_parts; // -1 for default int n_gpu_layers; // number of layers to store in VRAM int seed; // RNG seed, -1 for random From a94b334591536319ebc77f1354aaf97b88d31e6a Mon Sep 17 00:00:00 2001 From: DannyDaemonic Date: Thu, 18 May 2023 10:30:40 -0700 Subject: [PATCH 11/34] Fixes #1511 lambda issue for w64devkit (mingw) (#1513) * Fix for w64devkit and mingw --- examples/main/main.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index fe1c847a7f490..18673ed2e00f2 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -241,7 +241,7 @@ int main(int argc, char ** argv) { sigint_action.sa_flags = 0; sigaction(SIGINT, &sigint_action, NULL); #elif defined (_WIN32) - auto console_ctrl_handler = [](DWORD ctrl_type) -> BOOL { + auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false; }; SetConsoleCtrlHandler(static_cast(console_ctrl_handler), true); From e22541a49e1ffefff49cced180ec8fe92043e91e Mon Sep 17 00:00:00 2001 From: Erik Scholz Date: Thu, 18 May 2023 19:31:01 +0200 Subject: [PATCH 12/34] make kv_f16 the default for api users (#1517) --- llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index a856dd3a5ec2e..99e34afb5bbf2 100644 --- a/llama.cpp +++ b/llama.cpp @@ -822,7 +822,7 @@ struct llama_context_params llama_context_default_params() { /*.n_ctx =*/ 512, /*.gpu_layers =*/ 0, /*.seed =*/ -1, - /*.f16_kv =*/ false, + /*.f16_kv =*/ true, /*.logits_all =*/ false, /*.vocab_only =*/ false, /*.use_mmap =*/ true, From 6b5776b0a7cdaa23f1dbfef25e0d1688b9233ac7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 May 2023 20:14:51 +0300 Subject: [PATCH 13/34] minor : fix compile warnings --- examples/common.cpp | 4 ++-- examples/common.h | 6 +++--- llama.cpp | 2 +- tests/test-sampling.cpp | 10 ++++++---- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index a6abc4977bc1d..a4fea4af48076 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -749,7 +749,7 @@ bool console_readline(console_state & con_st, std::string & line) { break; } - if (input_char == WEOF || input_char == 0x04 /* Ctrl+D*/) { + if (input_char == (char32_t) WEOF || input_char == 0x04 /* Ctrl+D*/) { end_of_stream = true; break; } @@ -764,7 +764,7 @@ bool console_readline(console_state & con_st, std::string & line) { char32_t code = getchar32(); if (code == '[' || code == 0x1B) { // Discard the rest of the escape sequence - while ((code = getchar32()) != WEOF) { + while ((code = getchar32()) != (char32_t) WEOF) { if ((code >= 'A' && code <= 'Z') || (code >= 'a' && code <= 'z') || code == '~') { break; } diff --git a/examples/common.h b/examples/common.h index 2ad20ba504e6d..2b66382a6a5e0 100644 --- a/examples/common.h +++ b/examples/common.h @@ -44,15 +44,15 @@ struct gpt_params { float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate - std::string model = "models/7B/ggml-model.bin"; // model path - std::string prompt = ""; + std::string model = "models/7B/ggml-model.bin"; // model path + std::string prompt = ""; std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state std::string input_prefix = ""; // string to prefix user inputs with std::string input_suffix = ""; // string to suffix user inputs with std::vector antiprompt; // string upon seeing which more user input is prompted std::string lora_adapter = ""; // lora adapter path - std::string lora_base = ""; // base model path for the lora adapter + std::string lora_base = ""; // base model path for the lora adapter bool memory_f16 = true; // use f16 instead of f32 for memory kv bool random_prompt = false; // do not randomize prompt if none provided diff --git a/llama.cpp b/llama.cpp index 99e34afb5bbf2..a0ea038dd7d06 100644 --- a/llama.cpp +++ b/llama.cpp @@ -949,7 +949,7 @@ static void llama_model_load_internal( size_t ctx_size; size_t mmapped_size; ml->calc_sizes(&ctx_size, &mmapped_size); - fprintf(stderr, "%s: ggml ctx size = %6.2f KB\n", __func__, ctx_size/1024.0); + fprintf(stderr, "%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/1024.0/1024.0); // create the ggml context { diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 9174c1e373915..ebfc17c1822bf 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -1,14 +1,16 @@ -#include "llama.h" #include "ggml.h" -#include -#include +#include "llama.h" + +#ifdef NDEBUG +#undef NDEBUG +#endif + #include #include #include #include #include - void dump(const llama_token_data_array * candidates) { for (size_t i = 0; i < candidates->size; i++) { printf("%d: %f (%f)\n", candidates->data[i].id, candidates->data[i].p, candidates->data[i].logit); From 75c017fc5ac6ba81fc99206cf9aa4d2acfa75100 Mon Sep 17 00:00:00 2001 From: David Kennedy Date: Fri, 19 May 2023 13:16:30 -0400 Subject: [PATCH 14/34] readme : adds WizardLM to the list of supported models (#1485) --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 1d84a5e6d0664..6a67765aa5daf 100644 --- a/README.md +++ b/README.md @@ -80,6 +80,7 @@ as the main playground for developing new features for the [ggml](https://github - [X] [Koala](https://bair.berkeley.edu/blog/2023/04/03/koala/) - [X] [OpenBuddy 🐶 (Multilingual)](https://github.com/OpenBuddy/OpenBuddy) - [X] [Pygmalion 7B / Metharme 7B](#using-pygmalion-7b--metharme-7b) +- [X] [WizardLM](https://github.com/nlpxucan/WizardLM) **Bindings:** From c51c64a8fe5a5e4a24757f8ac1daf78a2853b0bb Mon Sep 17 00:00:00 2001 From: Jason McCartney Date: Fri, 19 May 2023 10:24:59 -0700 Subject: [PATCH 15/34] main : make reverse prompt option act as a stop token in non-interactive mode (#1032) * Make reverse prompt option act as a stop token in non-interactive scenarios * Making requested review changes * Update gpt_params_parse and fix a merge error * Revert "Update gpt_params_parse and fix a merge error" This reverts commit 2bb2ff1748513591ad45b175a75ed1d8089d84c8. * Update gpt_params_parse and fix a merge error take 2 --- examples/common.cpp | 6 +++--- examples/main/main.cpp | 26 ++++++++++++++++++-------- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index a4fea4af48076..e89df537eca49 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -351,7 +351,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } if (params.prompt_cache_all && (params.interactive || params.interactive_first || - params.instruct || params.antiprompt.size())) { + params.instruct)) { fprintf(stderr, "error: --prompt-cache-all not supported in interactive mode yet\n"); gpt_print_usage(argc, argv, default_params); exit(1); @@ -373,8 +373,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n"); fprintf(stderr, " --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n"); fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n"); - fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n"); - fprintf(stderr, " specified more than once for multiple prompts).\n"); + fprintf(stderr, " halt generation at PROMPT, return control in interactive mode\n"); + fprintf(stderr, " (can be specified more than once for multiple prompts).\n"); fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n"); fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n"); fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 18673ed2e00f2..4d886f8defe56 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -208,8 +208,8 @@ int main(int argc, char ** argv) { params.antiprompt.push_back("### Instruction:\n\n"); } - // enable interactive mode if reverse prompt or interactive start is specified - if (params.antiprompt.size() != 0 || params.interactive_first) { + // enable interactive mode if interactive start is specified + if (params.interactive_first) { params.interactive = true; } @@ -305,7 +305,7 @@ int main(int argc, char ** argv) { std::vector embd; - while (n_remain != 0 || params.interactive) { + while ((n_remain != 0 && !is_antiprompt) || params.interactive) { // predict if (embd.size() > 0) { // infinite text generation via context swapping @@ -503,9 +503,8 @@ int main(int argc, char ** argv) { console_set_color(con_st, CONSOLE_COLOR_DEFAULT); } - // in interactive mode, and not currently processing queued inputs; - // check if we should prompt the user for more - if (params.interactive && (int) embd_inp.size() <= n_consumed) { + // if not currently processing queued inputs; + if ((int) embd_inp.size() <= n_consumed) { // check for reverse prompt if (params.antiprompt.size()) { @@ -516,10 +515,21 @@ int main(int argc, char ** argv) { is_antiprompt = false; // Check if each of the reverse prompts appears at the end of the output. + // If we're not running interactively, the reverse prompt might be tokenized with some following characters + // so we'll compensate for that by widening the search window a bit. for (std::string & antiprompt : params.antiprompt) { - if (last_output.find(antiprompt.c_str(), last_output.length() - antiprompt.length(), antiprompt.length()) != std::string::npos) { - is_interacting = true; + size_t extra_padding = params.interactive ? 0 : 2; + size_t search_start_pos = last_output.length() > static_cast(antiprompt.length() + extra_padding) + ? last_output.length() - static_cast(antiprompt.length() + extra_padding) + : 0; + + if (last_output.find(antiprompt.c_str(), search_start_pos) != std::string::npos) { + if (params.interactive) { + is_interacting = true; + console_set_color(con_st, CONSOLE_COLOR_USER_INPUT); + } is_antiprompt = true; + fflush(stdout); break; } } From 0226d491af156f9d40749f219eb93816f6c6cc94 Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Fri, 19 May 2023 13:39:51 -0400 Subject: [PATCH 16/34] examples : add persistent chat (#1495) * examples : add persistent chat * examples : fix whitespace --------- Co-authored-by: Georgi Gerganov --- examples/chat-persistent.sh | 151 ++++++++++++++++++++++++++++++++++++ 1 file changed, 151 insertions(+) create mode 100755 examples/chat-persistent.sh diff --git a/examples/chat-persistent.sh b/examples/chat-persistent.sh new file mode 100755 index 0000000000000..b32284b49d80b --- /dev/null +++ b/examples/chat-persistent.sh @@ -0,0 +1,151 @@ +#!/bin/bash + +set -euo pipefail + +cd "$(dirname "$0")/.." || exit + +if [[ -z "${PROMPT_CACHE_FILE+x}" || -z "${CHAT_SAVE_DIR+x}" ]]; then + echo >&2 "error: PROMPT_CACHE_FILE and CHAT_SAVE_DIR must be provided" + exit 1 +fi + +MODEL="${MODEL:-./models/13B/ggml-model-q4_0.bin}" +PROMPT_TEMPLATE="${PROMPT_TEMPLATE:-./prompts/chat.txt}" +USER_NAME="${USER_NAME:-User}" +AI_NAME="${AI_NAME:-ChatLLaMa}" +DATE_TIME="$(date +%H:%M)" +DATE_YEAR="$(date +%Y)" + +LOG="${CHAT_SAVE_DIR}/main.log" +LOG_BG="${CHAT_SAVE_DIR}/main-bg.log" +CUR_PROMPT_FILE="${CHAT_SAVE_DIR}/current-prompt.txt" +CUR_PROMPT_CACHE="${CHAT_SAVE_DIR}/current-cache.bin" +NEXT_PROMPT_FILE="${CHAT_SAVE_DIR}/next-prompt.txt" +NEXT_PROMPT_CACHE="${CHAT_SAVE_DIR}/next-cache.bin" + +SESSION_SIZE_MSG_PATTERN='main: session file matches \d+ / \d+' +SAMPLE_TIME_MSG_PATTERN='sample time =\s+\d+.\d+ ms /\s+\d+' +SED_DELETE_MESSAGES="/^(${USER_NAME}:|${AI_NAME}:|\\.\\.\\.)/,\$d" + +CTX_SIZE=2048 +CTX_ROTATE_POINT=$((CTX_SIZE * 3 / 5)) # REVIEW +OPTS=(--model "$MODEL" --ctx_size "$CTX_SIZE" --repeat_last_n 256 "$@") + +# An unbuffered `tail -c+N` +skip_bytes() { + LANG=C IFS= read -r -n "$1" -d '' c + while LANG=C IFS= read -r -n 1 -d '' c; do + printf '%s' "$c" + done +} + +mkdir -p "$CHAT_SAVE_DIR" +echo >"$LOG" +trap "tail -n100 ${LOG}" EXIT + +if [[ ! -e "$CUR_PROMPT_FILE" ]]; then + sed -e "s/\[\[USER_NAME\]\]/${USER_NAME}/g" \ + -e "s/\[\[AI_NAME\]\]/${AI_NAME}/g" \ + -e "s/\[\[DATE_TIME\]\]/${DATE_TIME}/g" \ + -e "s/\[\[DATE_YEAR\]\]/${DATE_YEAR}/g" \ + "$PROMPT_TEMPLATE" >"$CUR_PROMPT_FILE" +fi + +if [[ ! -e "$NEXT_PROMPT_FILE" ]]; then + sed -r "$SED_DELETE_MESSAGES" "$CUR_PROMPT_FILE" >"$NEXT_PROMPT_FILE" +fi + +if [[ "$(tail -c4 "$NEXT_PROMPT_FILE")" != "..." ]]; then + echo '...' >>"$NEXT_PROMPT_FILE" +fi + +if [[ ! -e "$PROMPT_CACHE_FILE" ]]; then + echo 'Prompt cache does not exist, building...' + # Default batch_size to 8 here for better user feedback during initial prompt processing + ./main 2>>"$LOG" \ + --batch_size 8 \ + "${OPTS[@]}" \ + --prompt-cache "$PROMPT_CACHE_FILE" \ + --file "$CUR_PROMPT_FILE" \ + --n_predict 1 + echo + echo 'Done!' +fi + +if [[ ! -e "$CUR_PROMPT_CACHE" ]]; then + cp "$PROMPT_CACHE_FILE" "$CUR_PROMPT_CACHE" +fi +if [[ ! -e "$NEXT_PROMPT_CACHE" ]]; then + cp "$PROMPT_CACHE_FILE" "$NEXT_PROMPT_CACHE" +fi + +printf '%s ' "$(< "$CUR_PROMPT_FILE")" +n_tokens=0 + +while read -e line; do + # Limit generation to remaining context, with a buffer and estimating 2 chars/token for input + n_predict=$((CTX_SIZE - n_tokens - ${#line} / 2 - 32)) + + # Swap prompts when we're about to run out of context + if ((n_predict <= 0)); then + wait # for background main (below) to finish with next prompt + mv "$NEXT_PROMPT_FILE" "$CUR_PROMPT_FILE" + mv "$NEXT_PROMPT_CACHE" "$CUR_PROMPT_CACHE" + + sed -r "$SED_DELETE_MESSAGES" "$CUR_PROMPT_FILE" >"$NEXT_PROMPT_FILE" + echo '...' >>"$NEXT_PROMPT_FILE" + cp "$PROMPT_CACHE_FILE" "$NEXT_PROMPT_CACHE" + + n_tokens=0 + n_predict=$((CTX_SIZE / 2)) + fi + + echo " ${line}" >>"$CUR_PROMPT_FILE" + if ((n_tokens > CTX_ROTATE_POINT)); then + echo " ${line}" >>"$NEXT_PROMPT_FILE" + fi + + n_prompt_len_pre=$(($(wc -c <"$CUR_PROMPT_FILE"))) + + printf '%s: ' "$AI_NAME" >>"$CUR_PROMPT_FILE" + + ./main 2>>"$LOG" "${OPTS[@]}" \ + --prompt-cache "$CUR_PROMPT_CACHE" \ + --prompt-cache-all \ + --file "$CUR_PROMPT_FILE" \ + --reverse-prompt "${USER_NAME}:" \ + --n_predict "$n_predict" | + skip_bytes 1 | # skip BOS token added by ./main + tee "$CUR_PROMPT_FILE.tmp" | # save prompt + generation to tmp file + skip_bytes "$n_prompt_len_pre" # print generation + + mv "$CUR_PROMPT_FILE.tmp" "$CUR_PROMPT_FILE" + + # if we hit n_predict instead of reverse-prompt, we need to add the prompt + if [[ "$(tail -n1 "$CUR_PROMPT_FILE")" != "${USER_NAME}:" ]]; then + printf '\n%s:' "$USER_NAME" + printf '\n%s:' "$USER_NAME" >> "$CUR_PROMPT_FILE" + fi + + printf ' ' + + # HACK get num tokens from debug message + # TODO get both messages in one go + if ! session_size_msg="$(tail -n30 "$LOG" | grep -oE "$SESSION_SIZE_MSG_PATTERN")" || + ! sample_time_msg="$( tail -n10 "$LOG" | grep -oE "$SAMPLE_TIME_MSG_PATTERN")"; then + echo >&2 "Couldn't get number of tokens from ./main output!" + exit 1 + fi + + n_tokens=$(($(cut -d/ -f2 <<<"$session_size_msg") + $(cut -d/ -f2 <<<"$sample_time_msg"))) + + if ((n_tokens > CTX_ROTATE_POINT)); then + tail -c+$((n_prompt_len_pre + 1)) "$CUR_PROMPT_FILE" >>"$NEXT_PROMPT_FILE" + fi + + # Update cache for next prompt in background, ideally during user input + ./main >>"$LOG_BG" 2>&1 "${OPTS[@]}" \ + --prompt-cache "$NEXT_PROMPT_CACHE" \ + --file "$NEXT_PROMPT_FILE" \ + --n_predict 1 & +done From 9fd81872150d3758d14e90d5371c747ed10d872f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 May 2023 21:17:28 +0300 Subject: [PATCH 17/34] tests : add missing header --- tests/test-sampling.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index ebfc17c1822bf..0e675127f9970 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -5,6 +5,7 @@ #undef NDEBUG #endif +#include #include #include #include From 211aa6aff0b4daee1dd23615e3bc5075f1bd8aec Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 19 May 2023 22:17:18 +0300 Subject: [PATCH 18/34] ggml : use F16 instead of F32 in Q4_0, Q4_1, Q8_0 (#1508) * ggml : use F16 instead of F32 in Q4_0, Q4_1 and Q8_0 * llama : bump LLAMA_FILE_VERSION to 3 * cuda : update Q4 and Q8 dequantize kernels * ggml : fix AVX dot products * readme : update performance table + hot topics --- README.md | 21 +++---- ggml-cuda.cu | 14 ++--- ggml.c | 154 +++++++++++++++++++++++++-------------------------- ggml.h | 2 +- llama.cpp | 18 +++++- llama.h | 2 +- 6 files changed, 109 insertions(+), 102 deletions(-) diff --git a/README.md b/README.md index 6a67765aa5daf..762f4aa0349e5 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++ **Hot topics:** +- Quantization formats `Q4` and `Q8` have changed again (19 May) - [(info)](https://github.com/ggerganov/llama.cpp/pull/1508) - Quantization formats `Q4` and `Q5` have changed - requantize any old models [(info)](https://github.com/ggerganov/llama.cpp/pull/1405) - [Roadmap May 2023](https://github.com/ggerganov/llama.cpp/discussions/1220) @@ -334,16 +335,16 @@ Several quantization methods are supported. They differ in the resulting model d | Model | Measure | F16 | Q4_0 | Q4_1 | Q5_0 | Q5_1 | Q8_0 | |------:|--------------|-------:|-------:|-------:|-------:|-------:|-------:| -| 7B | perplexity | 5.9066 | 6.1565 | 6.0910 | 5.9862 | 5.9481 | 5.9069 | -| 7B | file size | 13.0G | 4.0G | 4.8G | 4.4G | 4.8G | 7.1G | -| 7B | ms/tok @ 4th | 128 | 50 | 54 | 75 | 83 | 75 | -| 7B | ms/tok @ 8th | 123 | 44 | 52 | 53 | 58 | 72 | -| 7B | bits/weight | 16.0 | 5.0 | 6.0 | 5.5 | 6.0 | 9.0 | -| 13B | perplexity | 5.2543 | 5.3860 | 5.3607 | 5.2856 | 5.2706 | 5.2548 | -| 13B | file size | 25.0G | 7.6G | 9.1G | 8.4G | 9.1G | 14G | -| 13B | ms/tok @ 4th | 239 | 93 | 101 | 150 | 164 | 141 | -| 13B | ms/tok @ 8th | 240 | 81 | 96 | 96 | 104 | 136 | -| 13B | bits/weight | 16.0 | 5.0 | 6.0 | 5.5 | 6.0 | 9.0 | +| 7B | perplexity | 5.9066 | 6.1565 | 6.0912 | 5.9862 | 5.9481 | 5.9070 | +| 7B | file size | 13.0G | 3.5G | 3.9G | 4.3G | 4.7G | 6.7G | +| 7B | ms/tok @ 4th | 127 | 55 | 54 | 76 | 83 | 72 | +| 7B | ms/tok @ 8th | 122 | 43 | 45 | 52 | 56 | 67 | +| 7B | bits/weight | 16.0 | 4.5 | 5.0 | 5.5 | 6.0 | 8.5 | +| 13B | perplexity | 5.2543 | 5.3860 | 5.3608 | 5.2856 | 5.2706 | 5.2548 | +| 13B | file size | 25.0G | 6.8G | 7.6G | 8.3G | 9.1G | 13G | +| 13B | ms/tok @ 4th | - | 103 | 105 | 148 | 160 | 131 | +| 13B | ms/tok @ 8th | - | 73 | 82 | 98 | 105 | 128 | +| 13B | bits/weight | 16.0 | 4.5 | 5.0 | 5.5 | 6.0 | 8.5 | ### Perplexity (measuring model quality) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 6e6e12aaac555..a1fec0fd735a0 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -53,19 +53,19 @@ typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, #define QK4_0 32 #define QR4_0 2 typedef struct { - float d; // delta + half d; // delta uint8_t qs[QK4_0 / 2]; // nibbles / quants } block_q4_0; -static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding"); +static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); #define QK4_1 32 #define QR4_1 2 typedef struct { - float d; // delta - float m; // min + half d; // delta + half m; // min uint8_t qs[QK4_1 / 2]; // nibbles / quants } block_q4_1; -static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); +static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); #define QK5_0 32 #define QR5_0 2 @@ -89,10 +89,10 @@ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + #define QK8_0 32 #define QR8_0 1 typedef struct { - float d; // delta + half d; // delta int8_t qs[QK8_0]; // quants } block_q8_0; -static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); +static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); #define CUDA_MUL_BLOCK_SIZE 256 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 diff --git a/ggml.c b/ggml.c index 156c0c1eef2d7..f08c541384315 100644 --- a/ggml.c +++ b/ggml.c @@ -769,18 +769,18 @@ int32x4_t vcvtnq_s32_f32(float32x4_t v) { #define QK4_0 32 typedef struct { - float d; // delta + ggml_fp16_t d; // delta uint8_t qs[QK4_0 / 2]; // nibbles / quants } block_q4_0; -static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding"); +static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding"); #define QK4_1 32 typedef struct { - float d; // delta - float m; // min + ggml_fp16_t d; // delta + ggml_fp16_t m; // min uint8_t qs[QK4_1 / 2]; // nibbles / quants } block_q4_1; -static_assert(sizeof(block_q4_1) == 2 * sizeof(float) + QK4_1 / 2, "wrong q4_1 block size/padding"); +static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_fp16_t) + QK4_1 / 2, "wrong q4_1 block size/padding"); #define QK5_0 32 typedef struct { @@ -801,16 +801,16 @@ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + #define QK8_0 32 typedef struct { - float d; // delta - int8_t qs[QK8_0]; // quants + ggml_fp16_t d; // delta + int8_t qs[QK8_0]; // quants } block_q8_0; -static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); +static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); #define QK8_1 32 typedef struct { - float d; // delta - float s; // d * sum(qs[i]) - int8_t qs[QK8_1]; // quants + float d; // delta + float s; // d * sum(qs[i]) + int8_t qs[QK8_1]; // quants } block_q8_1; static_assert(sizeof(block_q8_1) == 2*sizeof(float) + QK8_1, "wrong q8_1 block size/padding"); @@ -837,7 +837,7 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r const float d = max / -8; const float id = d ? 1.0f/d : 0.0f; - y[i].d = d; + y[i].d = GGML_FP32_TO_FP16(d); for (int j = 0; j < qk/2; ++j) { const float x0 = x[i*qk + 0 + j]*id; @@ -877,8 +877,8 @@ static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * r const float d = (max - min) / ((1 << 4) - 1); const float id = d ? 1.0f/d : 0.0f; - y[i].d = d; - y[i].m = min; + y[i].d = GGML_FP32_TO_FP16(d); + y[i].m = GGML_FP32_TO_FP16(min); for (int j = 0; j < qk/2; ++j) { const float x0 = (x[i*qk + 0 + j] - min)*id; @@ -1009,7 +1009,7 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r const float d = amax / ((1 << 7) - 1); const float id = d ? 1.0f/d : 0.0f; - y[i].d = d; + y[i].d = GGML_FP32_TO_FP16(d); for (int j = 0; j < QK8_0; ++j) { const float x0 = x[i*QK8_0 + j]*id; @@ -1044,7 +1044,7 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int const float d = amax / ((1 << 7) - 1); const float id = d ? 1.0f/d : 0.0f; - y[i].d = d; + y[i].d = GGML_FP32_TO_FP16(d); for (int j = 0; j < 8; j++) { const float32x4_t v = vmulq_n_f32(srcv[j], id); @@ -1079,7 +1079,7 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int // Quantize these floats const float d = maxScalar / 127.f; - y[i].d = d; + y[i].d = GGML_FP32_TO_FP16(d); const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; const __m256 mul = _mm256_set1_ps( id ); @@ -1178,7 +1178,7 @@ static void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * r sum += y[i].qs[QK8_1/2 + j]; } - y[i].s = d * sum; + y[i].s = sum*d; } } @@ -1330,7 +1330,7 @@ static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict const int nb = k / qk; for (int i = 0; i < nb; i++) { - const float d = x[i].d; + const float d = GGML_FP16_TO_FP32(x[i].d); for (int j = 0; j < qk/2; ++j) { const int x0 = (x[i].qs[j] & 0x0F) - 8; @@ -1350,8 +1350,8 @@ static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict const int nb = k / qk; for (int i = 0; i < nb; i++) { - const float d = x[i].d; - const float m = x[i].m; + const float d = GGML_FP16_TO_FP32(x[i].d); + const float m = GGML_FP16_TO_FP32(x[i].m); for (int j = 0; j < qk/2; ++j) { const int x0 = (x[i].qs[j] & 0x0F); @@ -1426,7 +1426,7 @@ static void dequantize_row_q8_0(const void * restrict vx, float * restrict y, in const block_q8_0 * restrict x = vx; for (int i = 0; i < nb; i++) { - const float d = x[i].d; + const float d = GGML_FP16_TO_FP32(x[i].d); for (int j = 0; j < qk; ++j) { y[i*qk + j] = x[i].qs[j]*d; @@ -1690,8 +1690,9 @@ quantize_fns_t ggml_internal_get_quantize_fn(size_t i) { static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) { float tmp[8]; - for (int i = 0; i < 8; i++) + for (int i = 0; i < 8; i++) { tmp[i] = GGML_FP16_TO_FP32(x[i]); + } return _mm256_loadu_ps(tmp); } @@ -2111,8 +2112,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * const block_q8_0 * restrict y0 = &y[i + 0]; const block_q8_0 * restrict y1 = &y[i + 1]; - const uint8x16_t m4b = vdupq_n_u8(0x0F); - const int8x16_t s8b = vdupq_n_s8(0x8); + const uint8x16_t m4b = vdupq_n_u8(0x0F); + const int8x16_t s8b = vdupq_n_s8(0x8); const uint8x16_t v0_0 = vld1q_u8(x0->qs); const uint8x16_t v0_1 = vld1q_u8(x1->qs); @@ -2140,8 +2141,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h); const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d); + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); #else const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l)); const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l)); @@ -2158,8 +2159,8 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d); + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); #endif } @@ -2171,7 +2172,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * // Main loop for (int i = 0; i < nb; ++i) { /* Compute combined scale for the block */ - const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) ); + const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); __m256i bx = bytes_from_nibbles_32(x[i].qs); @@ -2195,7 +2196,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * // Main loop for (int i = 0; i < nb; ++i) { // Compute combined scale for the block - const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) ); + const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); const __m128i lowMask = _mm_set1_epi8(0xF); const __m128i off = _mm_set1_epi8(8); @@ -2237,7 +2238,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0); // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = _mm_mul_ps( _mm_set1_ps( x[0].d ), _mm_set1_ps( y[0].d ) ); + const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[0].d) * GGML_FP16_TO_FP32(y[0].d) ); const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs); @@ -2255,7 +2256,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0); // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = _mm_mul_ps( _mm_set1_ps( x[1].d ), _mm_set1_ps( y[1].d ) ); + const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[1].d) * GGML_FP16_TO_FP32(y[1].d) ); const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs); @@ -2288,7 +2289,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0); // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = _mm_mul_ps( _mm_set1_ps( x[i].d ), _mm_set1_ps( y[i].d ) ); + const __m128 d_0_1 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d) ); const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs); @@ -2306,7 +2307,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0); // Compute combined scale for the block 2 and 3 - const __m128 d_2_3 = _mm_mul_ps( _mm_set1_ps( x[i + 1].d ), _mm_set1_ps( y[i + 1].d ) ); + const __m128 d_2_3 = _mm_set1_ps( GGML_FP16_TO_FP32(x[i + 1].d) * GGML_FP16_TO_FP32(y[i + 1].d) ); const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs); @@ -2354,7 +2355,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); } - sumf += (x[i].d*y[i].d)*sumi; + sumf += sumi*GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d); } *s = sumf; @@ -2384,7 +2385,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * const block_q8_1 * restrict y0 = &y[i + 0]; const block_q8_1 * restrict y1 = &y[i + 1]; - summs += x0->m * y0->s + x1->m * y1->s; + summs += GGML_FP16_TO_FP32(x0->m) * y0->s + GGML_FP16_TO_FP32(x1->m) * y1->s; const uint8x16_t m4b = vdupq_n_u8(0x0F); @@ -2408,8 +2409,8 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h); const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d); + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), GGML_FP16_TO_FP32(x0->d)*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), GGML_FP16_TO_FP32(x1->d)*y1->d); #else const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l)); const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l)); @@ -2426,8 +2427,8 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d); + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d); #endif } @@ -2440,13 +2441,13 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * // Main loop for (int i = 0; i < nb; ++i) { - const float * d0 = &x[i].d; - const float * d1 = &y[i].d; + const float d0 = GGML_FP16_TO_FP32(x[i].d); + const float d1 = y[i].d; - summs += x[i].m * y[i].s; + summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s; - const __m256 d0v = _mm256_broadcast_ss( d0 ); - const __m256 d1v = _mm256_broadcast_ss( d1 ); + const __m256 d0v = _mm256_set1_ps( d0 ); + const __m256 d1v = _mm256_set1_ps( d1 ); // Compute combined scales const __m256 d0d1 = _mm256_mul_ps( d0v, d1v ); @@ -2480,7 +2481,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); } - sumf += (x[i].d*y[i].d)*sumi + x[i].m*y[i].s; + sumf += (GGML_FP16_TO_FP32(x[i]).d*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; } *s = sumf; @@ -2556,16 +2557,13 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * const int8x16_t v1_1l = vld1q_s8(y1->qs); const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - const float x0d = GGML_FP16_TO_FP32(x0->d); - const float x1d = GGML_FP16_TO_FP32(x1->d); - #if defined(__ARM_FEATURE_DOTPROD) sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), - vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), x0d*y0->d); + vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), - vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), x1d*y1->d); + vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); #else const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l)); const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l)); @@ -2582,8 +2580,8 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1d*y1->d); + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); #endif } @@ -2658,7 +2656,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * // Main loop for (int i = 0; i < nb; i++) { /* Compute combined scale for the block */ - const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d)); + const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); __m256i bx = bytes_from_nibbles_32(x[i].qs); __m256i bxhi = bytes_from_bits_32(x[i].qh); @@ -2682,7 +2680,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * // Main loop for (int i = 0; i < nb; i++) { /* Compute combined scale for the block */ - const __m256 d = _mm256_mul_ps(_mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d)), _mm256_broadcast_ss(&y[i].d)); + const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); __m256i bx = bytes_from_nibbles_32(x[i].qs); const __m256i bxhi = bytes_from_bits_32(x[i].qh); @@ -2725,7 +2723,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void * sumi += (x0 * y[i].qs[j]) + (x1 * y[i].qs[j + qk/2]); } - sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi; + sumf += (GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)) * sumi; } *s = sumf; @@ -2807,16 +2805,13 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * const int8x16_t v1_1l = vld1q_s8(y1->qs); const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - const float x0d = GGML_FP16_TO_FP32(x0->d); - const float x1d = GGML_FP16_TO_FP32(x1->d); - #if defined(__ARM_FEATURE_DOTPROD) sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l), - vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), x0d*y0->d); + vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), GGML_FP16_TO_FP32(x0->d)*y0->d); sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l), - vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), x1d*y1->d); + vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), GGML_FP16_TO_FP32(x1->d)*y1->d); #else const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l)); const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l)); @@ -2833,8 +2828,8 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1d*y1->d); + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), GGML_FP16_TO_FP32(x0->d)*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), GGML_FP16_TO_FP32(x1->d)*y1->d); #endif } @@ -2894,15 +2889,14 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * const v128_t v1hl = wasm_i16x8_extend_low_i8x16 (v1h); const v128_t v1hh = wasm_i16x8_extend_high_i8x16(v1h); - const float x0d = GGML_FP16_TO_FP32(x0->d); - // dot product - sumv = wasm_f32x4_add(sumv, wasm_f32x4_mul(wasm_f32x4_convert_i32x4( - wasm_i32x4_add( - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), - wasm_i32x4_dot_i16x8(v0lfh, v1lh)), - wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), - wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), wasm_f32x4_splat(x0d*y0->d))); + sumv = wasm_f32x4_add(sumv, + wasm_f32x4_mul(wasm_f32x4_convert_i32x4(wasm_i32x4_add( + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0lfl, v1ll), + wasm_i32x4_dot_i16x8(v0lfh, v1lh)), + wasm_i32x4_add(wasm_i32x4_dot_i16x8(v0hfl, v1hl), + wasm_i32x4_dot_i16x8(v0hfh, v1hh)))), + wasm_f32x4_splat(GGML_FP16_TO_FP32(x0->d) * y0->d)); } *s = wasm_f32x4_extract_lane(sumv, 0) + wasm_f32x4_extract_lane(sumv, 1) + @@ -2924,7 +2918,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * bxhi = _mm256_and_si256(bxhi, _mm256_set1_epi8(0x10)); bx = _mm256_or_si256(bx, bxhi); - const __m256 dy = _mm256_broadcast_ss(&y[i].d); + const __m256 dy = _mm256_set1_ps(y[i].d); const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); const __m256 q = mul_sum_us8_pairs_float(bx, by); @@ -2958,7 +2952,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void * bxh = _mm_or_si128(bxh, bxhih); bx = _mm256_set_m128i(bxh, bxl); - const __m256 dy = _mm256_broadcast_ss(&y[i].d); + const __m256 dy = _mm256_set1_ps(y[i].d); const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); const __m256 q = mul_sum_us8_pairs_float(bx, by); @@ -3028,11 +3022,11 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * #if defined(__ARM_FEATURE_DOTPROD) sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32( vdotq_s32(vdupq_n_s32(0), x0_0, y0_0), - vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), x0->d*y0->d); + vdotq_s32(vdupq_n_s32(0), x0_1, y0_1))), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32( vdotq_s32(vdupq_n_s32(0), x1_0, y1_0), - vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), x1->d*y1->d); + vdotq_s32(vdupq_n_s32(0), x1_1, y1_1))), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); #else const int16x8_t p0_0 = vmull_s8(vget_low_s8 (x0_0), vget_low_s8 (y0_0)); @@ -3050,8 +3044,8 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * const int32x4_t p2 = vaddq_s32(vpaddlq_s16(p1_0), vpaddlq_s16(p1_1)); const int32x4_t p3 = vaddq_s32(vpaddlq_s16(p1_2), vpaddlq_s16(p1_3)); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), x0->d*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), x1->d*y1->d); + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(p0, p1)), GGML_FP16_TO_FP32(x0->d)*GGML_FP16_TO_FP32(y0->d)); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(p2, p3)), GGML_FP16_TO_FP32(x1->d)*GGML_FP16_TO_FP32(y1->d)); #endif } @@ -3063,7 +3057,7 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * // Main loop for (int i = 0; i < nb; ++i) { // Compute combined scale for the block - const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) ); + const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs); __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); @@ -3089,7 +3083,7 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * sumi += x[i].qs[j]*y[i].qs[j]; } - sumf += (x[i].d*y[i].d)*sumi; + sumf += sumi*(GGML_FP16_TO_FP32(x[i].d)*GGML_FP16_TO_FP32(y[i].d)); } *s = sumf; diff --git a/ggml.h b/ggml.h index 255541d0257e3..dce5ca1e7cb61 100644 --- a/ggml.h +++ b/ggml.h @@ -190,7 +190,7 @@ #define GGML_FILE_MAGIC 0x67676d6c // "ggml" #define GGML_FILE_VERSION 1 -#define GGML_QNT_VERSION 1 // bump this on quantization format changes +#define GGML_QNT_VERSION 2 // bump this on quantization format changes #define GGML_QNT_VERSION_FACTOR 1000 // do not change this #define GGML_MAX_DIMS 4 diff --git a/llama.cpp b/llama.cpp index a0ea038dd7d06..e042e22d5a4b8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -406,6 +406,7 @@ enum llama_file_version { LLAMA_FILE_VERSION_GGMF_V1, // added version field and scores in vocab LLAMA_FILE_VERSION_GGJT_V1, // added padding LLAMA_FILE_VERSION_GGJT_V2, // changed quantization format + LLAMA_FILE_VERSION_GGJT_V3, // changed Q4 and Q8 quantization format }; struct llama_file_loader { @@ -438,6 +439,8 @@ struct llama_file_loader { file_version = LLAMA_FILE_VERSION_GGJT_V1; } else if (magic == 'ggjt' && version == 2) { file_version = LLAMA_FILE_VERSION_GGJT_V2; + } else if (magic == 'ggjt' && version == 3) { + file_version = LLAMA_FILE_VERSION_GGJT_V3; } else { throw format("unknown (magic, version) combination: %08x, %08x; is this really a GGML file?", magic, version); @@ -852,7 +855,8 @@ static const char *llama_file_version_name(llama_file_version version) { case LLAMA_FILE_VERSION_GGML: return "'ggml' (old version with low tokenizer quality and no mmap support)"; case LLAMA_FILE_VERSION_GGMF_V1: return "ggmf v1 (old version with no mmap support)"; case LLAMA_FILE_VERSION_GGJT_V1: return "ggjt v1 (pre #1405)"; - case LLAMA_FILE_VERSION_GGJT_V2: return "ggjt v2 (latest)"; + case LLAMA_FILE_VERSION_GGJT_V2: return "ggjt v2 (pre #1508)"; + case LLAMA_FILE_VERSION_GGJT_V3: return "ggjt v3 (latest)"; } return "unknown"; @@ -932,11 +936,19 @@ static void llama_model_load_internal( fprintf(stderr, "%s: model size = %s\n", __func__, llama_model_type_name(model.type)); } - if (file_version != LLAMA_FILE_VERSION_GGJT_V2) { + if (file_version < LLAMA_FILE_VERSION_GGJT_V2) { if (hparams.ftype != LLAMA_FTYPE_ALL_F32 && hparams.ftype != LLAMA_FTYPE_MOSTLY_F16 && hparams.ftype != LLAMA_FTYPE_MOSTLY_Q8_0) { - throw format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1305)"); + throw format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1405)"); + } + } + + if (file_version < LLAMA_FILE_VERSION_GGJT_V3) { + if (hparams.ftype == LLAMA_FTYPE_MOSTLY_Q4_0 || + hparams.ftype == LLAMA_FTYPE_MOSTLY_Q4_1 || + hparams.ftype == LLAMA_FTYPE_MOSTLY_Q8_0) { + throw format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1508)"); } } diff --git a/llama.h b/llama.h index f955fa23db048..fd3f21e5fa30a 100644 --- a/llama.h +++ b/llama.h @@ -19,7 +19,7 @@ # define LLAMA_API #endif -#define LLAMA_FILE_VERSION 2 +#define LLAMA_FILE_VERSION 3 #define LLAMA_FILE_MAGIC 'ggjt' #define LLAMA_FILE_MAGIC_UNVERSIONED 'ggml' #define LLAMA_SESSION_MAGIC 'ggsn' From 9a7af6c2a5137d1821526dbb3799d2bef12520de Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 20 May 2023 10:13:19 +0300 Subject: [PATCH 19/34] ggml : fix scalar implementation of Q4_1 dot --- ggml.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml.c b/ggml.c index f08c541384315..48e6c2cbc8de5 100644 --- a/ggml.c +++ b/ggml.c @@ -2481,7 +2481,7 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); } - sumf += (GGML_FP16_TO_FP32(x[i]).d*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; + sumf += (GGML_FP16_TO_FP32(x[i].d)*y[i].d)*sumi + GGML_FP16_TO_FP32(x[i].m)*y[i].s; } *s = sumf; From f14673ad56c4c53d2ef4b2fdb38f208099bd5496 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 20 May 2023 10:14:31 +0300 Subject: [PATCH 20/34] llama : fix compile warnings in llama_set_state_data() --- llama.cpp | 4 ++-- llama.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/llama.cpp b/llama.cpp index e042e22d5a4b8..ad69141d380c6 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2652,8 +2652,8 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { } // Sets the state reading from the specified source address -size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { - const uint8_t * inp = src; +size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { + uint8_t * inp = src; // set rng { diff --git a/llama.h b/llama.h index fd3f21e5fa30a..8623e08ce5a12 100644 --- a/llama.h +++ b/llama.h @@ -138,7 +138,7 @@ extern "C" { // Set the state reading from the specified address // Returns the number of bytes read - LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src); + LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src); // Save/load session file LLAMA_API bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out); From df512bbb49949e1cba4b7a5fb07f0698e0dec544 Mon Sep 17 00:00:00 2001 From: Maxime <672982+maximegmd@users.noreply.github.com> Date: Sat, 20 May 2023 09:22:37 +0200 Subject: [PATCH 21/34] llama : fix name shadowing and C4146 (#1526) * Fix name shadowing and C4146 * Fix if macros not using defined when required * Update llama-util.h Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update llama-util.h Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Code style Co-authored-by: Georgi Gerganov --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Georgi Gerganov --- ggml.c | 4 ++-- llama-util.h | 40 ++++++++++++++++++++-------------------- llama.cpp | 7 ++++--- 3 files changed, 26 insertions(+), 25 deletions(-) diff --git a/ggml.c b/ggml.c index 48e6c2cbc8de5..bcef056f455ee 100644 --- a/ggml.c +++ b/ggml.c @@ -512,7 +512,7 @@ static inline int hsum_i32_4(const __m128i a) { return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); } -#if __AVX2__ || __AVX512F__ +#if defined(__AVX2__) || defined(__AVX512F__) // spread 32 bits to 32 bytes { 0x00, 0xFF } static inline __m256i bytes_from_bits_32(const uint8_t * x) { uint32_t x32; @@ -688,7 +688,7 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 #endif // __AVX__ || __AVX2__ || __AVX512F__ #endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) -#if __ARM_NEON +#if defined(__ARM_NEON) #if !defined(__aarch64__) diff --git a/llama-util.h b/llama-util.h index e775f64ccc475..3cac9f681800b 100644 --- a/llama-util.h +++ b/llama-util.h @@ -101,12 +101,12 @@ struct llama_file { LLAMA_ASSERT(ret == 0); // same } - void read_raw(void * ptr, size_t size) { - if (size == 0) { + void read_raw(void * ptr, size_t len) const { + if (len == 0) { return; } errno = 0; - std::size_t ret = std::fread(ptr, size, 1, fp); + std::size_t ret = std::fread(ptr, len, 1, fp); if (ferror(fp)) { throw std::runtime_error(format("read error: %s", strerror(errno))); } @@ -127,12 +127,12 @@ struct llama_file { return std::string(chars.data(), len); } - void write_raw(const void * ptr, size_t size) { - if (size == 0) { + void write_raw(const void * ptr, size_t len) const { + if (len == 0) { return; } errno = 0; - size_t ret = std::fwrite(ptr, size, 1, fp); + size_t ret = std::fwrite(ptr, len, 1, fp); if (ret != 1) { throw std::runtime_error(format("write error: %s", strerror(errno))); } @@ -267,9 +267,9 @@ struct llama_mlock { } } - void init(void * addr) { - LLAMA_ASSERT(this->addr == NULL && this->size == 0); - this->addr = addr; + void init(void * ptr) { + LLAMA_ASSERT(addr == NULL && size == 0); + addr = ptr; } void grow_to(size_t target_size) { @@ -340,14 +340,14 @@ struct llama_mlock { return (size_t) si.dwPageSize; } - bool raw_lock(void * addr, size_t size) { + bool raw_lock(void * ptr, size_t len) { for (int tries = 1; ; tries++) { - if (VirtualLock(addr, size)) { + if (VirtualLock(ptr, len)) { return true; } if (tries == 2) { fprintf(stderr, "warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n", - size, this->size, llama_format_win_err(GetLastError()).c_str()); + len, size, llama_format_win_err(GetLastError()).c_str()); return false; } @@ -363,7 +363,7 @@ struct llama_mlock { // is equal to the number of pages in its minimum working set minus // a small overhead." // Hopefully a megabyte is enough overhead: - size_t increment = size + 1048576; + size_t increment = len + 1048576; // The minimum must be <= the maximum, so we need to increase both: min_ws_size += increment; max_ws_size += increment; @@ -375,8 +375,8 @@ struct llama_mlock { } } - void raw_unlock(void * addr, size_t size) { - if (!VirtualUnlock(addr, size)) { + void raw_unlock(void * ptr, size_t len) { + if (!VirtualUnlock(ptr, len)) { fprintf(stderr, "warning: failed to VirtualUnlock buffer: %s\n", llama_format_win_err(GetLastError()).c_str()); } @@ -388,12 +388,12 @@ struct llama_mlock { return (size_t) 65536; } - bool raw_lock(const void * addr, size_t size) { + bool raw_lock(const void * addr, size_t len) { fprintf(stderr, "warning: mlock not supported on this system\n"); return false; } - void raw_unlock(const void * addr, size_t size) {} + void raw_unlock(const void * addr, size_t len) {} #endif }; @@ -404,10 +404,10 @@ struct llama_buffer { llama_buffer() = default; - void resize(size_t size) { + void resize(size_t len) { delete[] addr; - addr = new uint8_t[size]; - this->size = size; + addr = new uint8_t[len]; + size = len; } ~llama_buffer() { diff --git a/llama.cpp b/llama.cpp index ad69141d380c6..5de544600d8e2 100644 --- a/llama.cpp +++ b/llama.cpp @@ -45,6 +45,7 @@ enum e_model { MODEL_65B, }; + static const size_t MB = 1024*1024; // computed for n_ctx == 2048 @@ -110,7 +111,7 @@ struct llama_hparams { enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16; bool operator!=(const llama_hparams & other) const { - return memcmp(this, &other, sizeof(llama_hparams)); + return static_cast(memcmp(this, &other, sizeof(llama_hparams))); } }; @@ -502,7 +503,7 @@ struct llama_file_loader { if (file_version >= LLAMA_FILE_VERSION_GGJT_V1) { // skip to the next multiple of 32 bytes - file.seek(-file.tell() & 31, SEEK_CUR); + file.seek(-static_cast(file.tell()) & 31, SEEK_CUR); } shard.file_idx = file_idx; shard.file_off = file.tell(); @@ -577,7 +578,7 @@ struct llama_file_saver { file.write_u32(new_type); file.write_raw(tensor.ne.data(), sizeof(tensor.ne[0]) * tensor.ne.size()); file.write_raw(tensor.name.data(), tensor.name.size()); - file.seek(-file.tell() & 31, SEEK_CUR); + file.seek(-static_cast(file.tell()) & 31, SEEK_CUR); LLAMA_ASSERT(new_size == llama_calc_tensor_size(tensor.ne, new_type)); file.write_raw(new_data, new_size); } From f401d5ffa2a6dd2865de62b7a4c02f4bc646acf1 Mon Sep 17 00:00:00 2001 From: DannyDaemonic Date: Sat, 20 May 2023 00:40:02 -0700 Subject: [PATCH 22/34] Fix for mingw (#1462) --- examples/common.cpp | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/examples/common.cpp b/examples/common.cpp index e89df537eca49..1308f84109519 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -578,6 +578,37 @@ void console_set_color(console_state & con_st, console_color_t color) { } char32_t getchar32() { +#if defined(_WIN32) + HANDLE hConsole = GetStdHandle(STD_INPUT_HANDLE); + wchar_t high_surrogate = 0; + + while (true) { + INPUT_RECORD record; + DWORD count; + if (!ReadConsoleInputW(hConsole, &record, 1, &count) || count == 0) { + return WEOF; + } + + if (record.EventType == KEY_EVENT && record.Event.KeyEvent.bKeyDown) { + wchar_t wc = record.Event.KeyEvent.uChar.UnicodeChar; + if (wc == 0) { + continue; + } + + if ((wc >= 0xD800) && (wc <= 0xDBFF)) { // Check if wc is a high surrogate + high_surrogate = wc; + continue; + } else if ((wc >= 0xDC00) && (wc <= 0xDFFF)) { // Check if wc is a low surrogate + if (high_surrogate != 0) { // Check if we have a high surrogate + return ((high_surrogate - 0xD800) << 10) + (wc - 0xDC00) + 0x10000; + } + } + + high_surrogate = 0; // Reset the high surrogate + return static_cast(wc); + } + } +#else wchar_t wc = getwchar(); if (static_cast(wc) == WEOF) { return WEOF; @@ -596,6 +627,7 @@ char32_t getchar32() { #endif return static_cast(wc); +#endif } void pop_cursor(console_state & con_st) { From 54ec8a963bc39de0278e08482f018ac2a3d26d4f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 20 May 2023 11:06:11 +0300 Subject: [PATCH 23/34] llama : add llama_init_backend() API (close #1527) --- examples/benchmark/benchmark-matmult.cpp | 3 ++- examples/embedding/embedding.cpp | 2 ++ examples/main/main.cpp | 3 +-- examples/perplexity/perplexity.cpp | 2 ++ examples/quantize/quantize.cpp | 21 ++++++---------- llama.cpp | 15 ++++++++++++ llama.h | 31 +++++++++++++++--------- 7 files changed, 48 insertions(+), 29 deletions(-) diff --git a/examples/benchmark/benchmark-matmult.cpp b/examples/benchmark/benchmark-matmult.cpp index 446b8e8fb5ef2..9f9ed9db0c588 100644 --- a/examples/benchmark/benchmark-matmult.cpp +++ b/examples/benchmark/benchmark-matmult.cpp @@ -1,6 +1,7 @@ -#include #include "ggml.h" #include "build-info.h" + +#include #include #include #include diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index c24f7f8207157..03603b10fe3f9 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -31,6 +31,8 @@ int main(int argc, char ** argv) { params.prompt = gpt_random_prompt(rng); } + llama_init_backend(); + llama_context * ctx; // load the model diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 4d886f8defe56..47b418d972bbc 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -96,8 +96,7 @@ int main(int argc, char ** argv) { params.prompt = gpt_random_prompt(rng); } -// params.prompt = R"(// this function checks if the number n is prime -//bool is_prime(int n) {)"; + llama_init_backend(); llama_context * ctx; g_ctx = &ctx; diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 9d38626cbd4f4..e19c6825f2446 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -143,6 +143,8 @@ int main(int argc, char ** argv) { params.prompt = gpt_random_prompt(rng); } + llama_init_backend(); + llama_context * ctx; // load the model and apply lora adapter, if any diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 115d8fb1ba36b..769dd36a468e3 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -1,7 +1,7 @@ -#include "ggml.h" -#include "llama.h" #include "build-info.h" +#include "llama.h" + #include #include #include @@ -42,8 +42,6 @@ bool try_parse_ftype(const std::string & ftype_str, llama_ftype & ftype, std::st // ./quantize models/llama/ggml-model.bin [models/llama/ggml-model-quant.bin] type [nthreads] // int main(int argc, char ** argv) { - ggml_time_init(); - if (argc < 3) { fprintf(stderr, "usage: %s model-f32.bin [model-quant.bin] type [nthreads]\n", argv[0]); for (auto it = LLAMA_FTYPE_MAP.begin(); it != LLAMA_FTYPE_MAP.end(); it++) { @@ -52,12 +50,7 @@ int main(int argc, char ** argv) { return 1; } - // needed to initialize f16 tables - { - struct ggml_init_params params = { 0, NULL, false }; - struct ggml_context * ctx = ggml_init(params); - ggml_free(ctx); - } + llama_init_backend(); // parse command line arguments const std::string fname_inp = argv[1]; @@ -116,25 +109,25 @@ int main(int argc, char ** argv) { } fprintf(stderr, "\n"); - const int64_t t_main_start_us = ggml_time_us(); + const int64_t t_main_start_us = llama_time_us(); int64_t t_quantize_us = 0; // load the model { - const int64_t t_start_us = ggml_time_us(); + const int64_t t_start_us = llama_time_us(); if (llama_model_quantize(fname_inp.c_str(), fname_out.c_str(), ftype, nthread)) { fprintf(stderr, "%s: failed to quantize model from '%s'\n", __func__, fname_inp.c_str()); return 1; } - t_quantize_us = ggml_time_us() - t_start_us; + t_quantize_us = llama_time_us() - t_start_us; } // report timing { - const int64_t t_main_end_us = ggml_time_us(); + const int64_t t_main_end_us = llama_time_us(); printf("\n"); printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0); diff --git a/llama.cpp b/llama.cpp index 5de544600d8e2..91653f374a6e3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -847,6 +847,21 @@ bool llama_mlock_supported() { return llama_mlock::SUPPORTED; } +void llama_init_backend() { + ggml_time_init(); + + // needed to initialize f16 tables + { + struct ggml_init_params params = { 0, NULL, false }; + struct ggml_context * ctx = ggml_init(params); + ggml_free(ctx); + } +} + +int64_t llama_time_us() { + return ggml_time_us(); +} + // // model loading // diff --git a/llama.h b/llama.h index 8623e08ce5a12..0a63d034b38ed 100644 --- a/llama.h +++ b/llama.h @@ -40,9 +40,9 @@ extern "C" { typedef int llama_token; typedef struct llama_token_data { - llama_token id; // token id - float logit; // log-odds of the token - float p; // probability of the token + llama_token id; // token id + float logit; // log-odds of the token + float p; // probability of the token } llama_token_data; typedef struct llama_token_data_array { @@ -73,16 +73,16 @@ extern "C" { // model file types enum llama_ftype { - LLAMA_FTYPE_ALL_F32 = 0, - LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors + LLAMA_FTYPE_ALL_F32 = 0, + LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 - // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed - // LLAMA_FTYPE_MOSTLY_Q4_3 (6) support has been removed - LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors + // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed + // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed + LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors }; LLAMA_API struct llama_context_params llama_context_default_params(); @@ -90,6 +90,13 @@ extern "C" { LLAMA_API bool llama_mmap_supported(); LLAMA_API bool llama_mlock_supported(); + // TODO: not great API - very likely to change + // Initialize the llama + ggml backend + // Call once at the start of the program + LLAMA_API void llama_init_backend(); + + LLAMA_API int64_t llama_time_us(); + // Various functions for loading a ggml llama model. // Allocate (almost) all memory needed for the model. // Return NULL on failure From 667c57f11a3c8bbd85c3549f93c0524ce63e01c1 Mon Sep 17 00:00:00 2001 From: Zenix Date: Sat, 20 May 2023 18:02:48 +0900 Subject: [PATCH 24/34] feature : add blis and other BLAS implementation support (#1502) * feature: add blis support * feature: allow all BLA_VENDOR to be assigned in cmake arguments. align with whisper.cpp pr 927 * fix: version detection for BLA_SIZEOF_INTEGER, recover min version of cmake * Fix typo in INTEGER Co-authored-by: Georgi Gerganov --------- Co-authored-by: Georgi Gerganov --- BLIS.md | 67 ++++++++++++++++++++++++++++++++++++++++++++++++++ CMakeLists.txt | 39 ++++++++++++----------------- Makefile | 4 +++ README.md | 19 ++++++++++++-- 4 files changed, 104 insertions(+), 25 deletions(-) create mode 100644 BLIS.md diff --git a/BLIS.md b/BLIS.md new file mode 100644 index 0000000000000..9b3c3060515db --- /dev/null +++ b/BLIS.md @@ -0,0 +1,67 @@ +BLIS Installation Manual +------------------------ + +BLIS is a portable software framework for high-performance BLAS-like dense linear algebra libraries. It has received awards and recognition, including the 2023 James H. Wilkinson Prize for Numerical Software and the 2020 SIAM Activity Group on Supercomputing Best Paper Prize. BLIS provides a new BLAS-like API and a compatibility layer for traditional BLAS routine calls. It offers features such as object-based API, typed API, BLAS and CBLAS compatibility layers. + +Project URL: https://github.com/flame/blis + +### Prepare: + +Compile BLIS: + +```bash +git clone https://github.com/flame/blis +cd blis +./configure --enable-cblas -t openmp,pthreads auto +# will install to /usr/local/ by default. +make -j +``` + +Install BLIS: + +```bash +sudo make install +``` + +We recommend using openmp since it's easier to modify the cores been used. + +### llama.cpp compilation + +Makefile: + +```bash +make LLAMA_BLIS=1 -j +# make LLAMA_BLIS=1 benchmark-matmult +``` + +CMake: + +```bash +mkdir build +cd build +cmake -DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=FLAME .. +make -j +``` + +### llama.cpp execution + +According to the BLIS documentation, we could set the following +environment variables to modify the behavior of openmp: + +``` +export GOMP_GPU_AFFINITY="0-19" +export BLIS_NUM_THREADS=14 +``` + +And then run the binaries as normal. + + +### Intel specific issue + +Some might get the error message saying that `libimf.so` cannot be found. +Please follow this [stackoverflow page](https://stackoverflow.com/questions/70687930/intel-oneapi-2022-libimf-so-no-such-file-or-directory-during-openmpi-compila). + +### Reference: + +1. https://github.com/flame/blis#getting-started +2. https://github.com/flame/blis/blob/master/docs/Multithreading.md diff --git a/CMakeLists.txt b/CMakeLists.txt index 48e3238dfa52e..0876ab90a2208 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -65,7 +65,8 @@ endif() # 3rd party libs option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) -option(LLAMA_OPENBLAS "llama: use OpenBLAS" OFF) +option(LLAMA_BLAS "llama: use BLAS" OFF) +option(LLAMA_BLAS_VENDOR "llama: BLA_VENDOR from https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" Generic) option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) option(LLAMA_CLBLAST "llama: use CLBlast" OFF) @@ -145,36 +146,28 @@ if (APPLE AND LLAMA_ACCELERATE) endif() endif() -if (LLAMA_OPENBLAS) +if (LLAMA_BLAS) if (LLAMA_STATIC) set(BLA_STATIC ON) endif() - - set(BLA_VENDOR OpenBLAS) + if ($(CMAKE_VERSION) VERSION_GREATER_EQUAL 3.22) + set(BLA_SIZEOF_INTEGER 8) + endif() + set(BLA_VENDOR ${LLAMA_BLAS_VENDOR}) find_package(BLAS) if (BLAS_FOUND) - message(STATUS "OpenBLAS found") + message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}") + add_compile_options(${BLAS_LINKER_FLAGS}) add_compile_definitions(GGML_USE_OPENBLAS) - add_link_options(${BLAS_LIBRARIES}) - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} openblas) - - # find header file - set(OPENBLAS_INCLUDE_SEARCH_PATHS - /usr/include - /usr/include/openblas - /usr/include/openblas-base - /usr/local/include - /usr/local/include/openblas - /usr/local/include/openblas-base - /opt/OpenBLAS/include - $ENV{OpenBLAS_HOME} - $ENV{OpenBLAS_HOME}/include - ) - find_path(OPENBLAS_INC NAMES cblas.h PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS}) - add_compile_options(-I${OPENBLAS_INC}) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${BLAS_LIBRARIES}) + + message("${BLAS_LIBRARIES}") + include_directories(${BLAS_INCLUDE_DIRS}) else() - message(WARNING "OpenBLAS not found") + message(WARNING "BLAS not found, please refer to " + "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" + " to set correct LLAMA_BLAS_VENDOR") endif() endif() diff --git a/Makefile b/Makefile index 93c149edc9593..0a34a4f5f9eaf 100644 --- a/Makefile +++ b/Makefile @@ -122,6 +122,10 @@ ifdef LLAMA_OPENBLAS LDFLAGS += -lopenblas endif endif +ifdef LLAMA_BLIS + CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/blis -I/usr/include/blis + LDFLAGS += -lblis -L/usr/local/lib +endif ifdef LLAMA_CUBLAS CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include diff --git a/README.md b/README.md index 762f4aa0349e5..102cde43fb457 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ The main goal of `llama.cpp` is to run the LLaMA model using 4-bit integer quant - Mixed F16 / F32 precision - 4-bit, 5-bit and 8-bit integer quantization support - Runs on the CPU -- OpenBLAS support +- Supports OpenBLAS/Apple BLAS/ARM Performance Lib/ATLAS/BLIS/Intel MKL/NVHPC/ACML/SCSL/SGIMATH and [more](https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors) in BLAS - cuBLAS and CLBlast support The original implementation of `llama.cpp` was [hacked in an evening](https://github.com/ggerganov/llama.cpp/issues/33#issuecomment-1465108022). @@ -274,10 +274,25 @@ Building the program with BLAS support may lead to some performance improvements ```bash mkdir build cd build - cmake .. -DLLAMA_OPENBLAS=ON + cmake .. -DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS cmake --build . --config Release ``` +- BLIS + + Check [BLIS.md](BLIS.md) for more information. + +- Intel MKL + + By default, `LLAMA_BLAS_VENDOR` is set to `Generic`, so if you already sourced intel environment script and assign `-DLLAMA_BLAS=ON` in cmake, the mkl version of Blas will automatically been selected. You may also specify it by: + + ```bash + mkdir build + cd build + cmake .. -DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=Intel10_64lp -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx + cmake --build . -config Release + ``` + - cuBLAS This provides BLAS acceleration using the CUDA cores of your Nvidia GPU. Make sure to have the CUDA toolkit installed. You can download it from your Linux distro's package manager or from here: [CUDA Toolkit](https://developer.nvidia.com/cuda-downloads). From 977e74d70e15912f84730a53fe8110ea46d57fa2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 20 May 2023 12:03:48 +0300 Subject: [PATCH 25/34] Revert "feature : add blis and other BLAS implementation support (#1502)" This reverts commit 07e9ace0f9da424d82e75df969642522880feb92. --- BLIS.md | 67 -------------------------------------------------- CMakeLists.txt | 39 +++++++++++++++++------------ Makefile | 4 --- README.md | 19 ++------------ 4 files changed, 25 insertions(+), 104 deletions(-) delete mode 100644 BLIS.md diff --git a/BLIS.md b/BLIS.md deleted file mode 100644 index 9b3c3060515db..0000000000000 --- a/BLIS.md +++ /dev/null @@ -1,67 +0,0 @@ -BLIS Installation Manual ------------------------- - -BLIS is a portable software framework for high-performance BLAS-like dense linear algebra libraries. It has received awards and recognition, including the 2023 James H. Wilkinson Prize for Numerical Software and the 2020 SIAM Activity Group on Supercomputing Best Paper Prize. BLIS provides a new BLAS-like API and a compatibility layer for traditional BLAS routine calls. It offers features such as object-based API, typed API, BLAS and CBLAS compatibility layers. - -Project URL: https://github.com/flame/blis - -### Prepare: - -Compile BLIS: - -```bash -git clone https://github.com/flame/blis -cd blis -./configure --enable-cblas -t openmp,pthreads auto -# will install to /usr/local/ by default. -make -j -``` - -Install BLIS: - -```bash -sudo make install -``` - -We recommend using openmp since it's easier to modify the cores been used. - -### llama.cpp compilation - -Makefile: - -```bash -make LLAMA_BLIS=1 -j -# make LLAMA_BLIS=1 benchmark-matmult -``` - -CMake: - -```bash -mkdir build -cd build -cmake -DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=FLAME .. -make -j -``` - -### llama.cpp execution - -According to the BLIS documentation, we could set the following -environment variables to modify the behavior of openmp: - -``` -export GOMP_GPU_AFFINITY="0-19" -export BLIS_NUM_THREADS=14 -``` - -And then run the binaries as normal. - - -### Intel specific issue - -Some might get the error message saying that `libimf.so` cannot be found. -Please follow this [stackoverflow page](https://stackoverflow.com/questions/70687930/intel-oneapi-2022-libimf-so-no-such-file-or-directory-during-openmpi-compila). - -### Reference: - -1. https://github.com/flame/blis#getting-started -2. https://github.com/flame/blis/blob/master/docs/Multithreading.md diff --git a/CMakeLists.txt b/CMakeLists.txt index 0876ab90a2208..48e3238dfa52e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -65,8 +65,7 @@ endif() # 3rd party libs option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) -option(LLAMA_BLAS "llama: use BLAS" OFF) -option(LLAMA_BLAS_VENDOR "llama: BLA_VENDOR from https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" Generic) +option(LLAMA_OPENBLAS "llama: use OpenBLAS" OFF) option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) option(LLAMA_CLBLAST "llama: use CLBlast" OFF) @@ -146,28 +145,36 @@ if (APPLE AND LLAMA_ACCELERATE) endif() endif() -if (LLAMA_BLAS) +if (LLAMA_OPENBLAS) if (LLAMA_STATIC) set(BLA_STATIC ON) endif() - if ($(CMAKE_VERSION) VERSION_GREATER_EQUAL 3.22) - set(BLA_SIZEOF_INTEGER 8) - endif() - set(BLA_VENDOR ${LLAMA_BLAS_VENDOR}) + + set(BLA_VENDOR OpenBLAS) find_package(BLAS) if (BLAS_FOUND) - message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}") + message(STATUS "OpenBLAS found") - add_compile_options(${BLAS_LINKER_FLAGS}) add_compile_definitions(GGML_USE_OPENBLAS) - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${BLAS_LIBRARIES}) - - message("${BLAS_LIBRARIES}") - include_directories(${BLAS_INCLUDE_DIRS}) + add_link_options(${BLAS_LIBRARIES}) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} openblas) + + # find header file + set(OPENBLAS_INCLUDE_SEARCH_PATHS + /usr/include + /usr/include/openblas + /usr/include/openblas-base + /usr/local/include + /usr/local/include/openblas + /usr/local/include/openblas-base + /opt/OpenBLAS/include + $ENV{OpenBLAS_HOME} + $ENV{OpenBLAS_HOME}/include + ) + find_path(OPENBLAS_INC NAMES cblas.h PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS}) + add_compile_options(-I${OPENBLAS_INC}) else() - message(WARNING "BLAS not found, please refer to " - "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" - " to set correct LLAMA_BLAS_VENDOR") + message(WARNING "OpenBLAS not found") endif() endif() diff --git a/Makefile b/Makefile index 0a34a4f5f9eaf..93c149edc9593 100644 --- a/Makefile +++ b/Makefile @@ -122,10 +122,6 @@ ifdef LLAMA_OPENBLAS LDFLAGS += -lopenblas endif endif -ifdef LLAMA_BLIS - CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/blis -I/usr/include/blis - LDFLAGS += -lblis -L/usr/local/lib -endif ifdef LLAMA_CUBLAS CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include diff --git a/README.md b/README.md index 102cde43fb457..762f4aa0349e5 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ The main goal of `llama.cpp` is to run the LLaMA model using 4-bit integer quant - Mixed F16 / F32 precision - 4-bit, 5-bit and 8-bit integer quantization support - Runs on the CPU -- Supports OpenBLAS/Apple BLAS/ARM Performance Lib/ATLAS/BLIS/Intel MKL/NVHPC/ACML/SCSL/SGIMATH and [more](https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors) in BLAS +- OpenBLAS support - cuBLAS and CLBlast support The original implementation of `llama.cpp` was [hacked in an evening](https://github.com/ggerganov/llama.cpp/issues/33#issuecomment-1465108022). @@ -274,25 +274,10 @@ Building the program with BLAS support may lead to some performance improvements ```bash mkdir build cd build - cmake .. -DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS + cmake .. -DLLAMA_OPENBLAS=ON cmake --build . --config Release ``` -- BLIS - - Check [BLIS.md](BLIS.md) for more information. - -- Intel MKL - - By default, `LLAMA_BLAS_VENDOR` is set to `Generic`, so if you already sourced intel environment script and assign `-DLLAMA_BLAS=ON` in cmake, the mkl version of Blas will automatically been selected. You may also specify it by: - - ```bash - mkdir build - cd build - cmake .. -DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=Intel10_64lp -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx - cmake --build . -config Release - ``` - - cuBLAS This provides BLAS acceleration using the CUDA cores of your Nvidia GPU. Make sure to have the CUDA toolkit installed. You can download it from your Linux distro's package manager or from here: [CUDA Toolkit](https://developer.nvidia.com/cuda-downloads). From ffe9652bc1ed0172446c15355684e1cef28daca2 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Wed, 17 May 2023 16:35:50 +0200 Subject: [PATCH 26/34] GPU weights not in RAM, direct loading with cuFile --- ggml-cuda.cu | 4 ++-- ggml-cuda.h | 2 ++ llama.cpp | 1 + 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index a1fec0fd735a0..6dbbc7950da0f 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -383,7 +383,7 @@ struct cuda_buffer { static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS]; static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT; -static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { +void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { scoped_spin_lock lock(g_cuda_pool_lock); for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { @@ -402,7 +402,7 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { return ptr; } -static void ggml_cuda_pool_free(void * ptr, size_t size) { +void ggml_cuda_pool_free(void * ptr, size_t size) { scoped_spin_lock lock(g_cuda_pool_lock); for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { diff --git a/ggml-cuda.h b/ggml-cuda.h index ab2c690b082b1..b46a804c7adf0 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -14,6 +14,8 @@ void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens // TODO: export these with GGML_API void * ggml_cuda_host_malloc(size_t size); void ggml_cuda_host_free(void * ptr); +void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size); +void ggml_cuda_pool_free(void * ptr, size_t size); void ggml_cuda_transform_tensor(struct ggml_tensor * tensor); void ggml_cuda_load_data(const char * fname, struct ggml_tensor ** tensors, int num_tensors, const size_t * offsets); diff --git a/llama.cpp b/llama.cpp index 91653f374a6e3..38de0e39d8bfa 100644 --- a/llama.cpp +++ b/llama.cpp @@ -10,6 +10,7 @@ #include "ggml.h" #ifdef GGML_USE_CUBLAS +#include #include "ggml-cuda.h" #endif From f67bc3c363279ad18d430c533b6d133c998d06c0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 20 May 2023 12:29:08 +0300 Subject: [PATCH 27/34] llama : code style fixes + progress print fix --- .gitignore | 31 +++++++++++++++---------------- llama.cpp | 44 ++++++++++++++++++++++++++------------------ 2 files changed, 41 insertions(+), 34 deletions(-) diff --git a/.gitignore b/.gitignore index 07528b5c6cc08..1aabe82fc2967 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,15 @@ -*.o -*.a +/*.o +/*.a +/*.sh +/*.log +/*.org + +/ppl-*.txt +/qnt-*.txt +/perf-*.txt + +*.bin + .DS_Store .build/ .cache/ @@ -21,8 +31,9 @@ build-no-accel/ build-sanitize-addr/ build-sanitize-thread/ +prompts/ models/* -*.bin +wikitext-2-raw/ /main /quantize @@ -33,6 +44,7 @@ models/* /benchmark-matmult /vdot /Pipfile +/libllama.so build-info.h arm_neon.h @@ -43,17 +55,4 @@ __pycache__ zig-out/ zig-cache/ -ppl-*.txt -qnt-*.txt -perf-*.txt - examples/jeopardy/results.txt - -/prompts -*.sh -*.log -*.py -*.txt -/wikitext-2-raw/ -*.org -/libllama.so diff --git a/llama.cpp b/llama.cpp index 38de0e39d8bfa..cbc6f8b40588c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1010,21 +1010,28 @@ static void llama_model_load_internal( ml->ggml_ctx = ctx; model.tok_embeddings = ml->get_tensor("tok_embeddings.weight", {n_embd, n_vocab}, GGML_BACKEND_CPU); - model.norm = ml->get_tensor("norm.weight", {n_embd}, GGML_BACKEND_CPU); - ggml_backend backend_output; - if (n_gpu_layers > int(n_layer)) { - backend_output = GGML_BACKEND_CUDA; - } else { - backend_output = GGML_BACKEND_CPU; + model.norm = ml->get_tensor("norm.weight", {n_embd}, GGML_BACKEND_CPU); + + // "output" tensor + { + ggml_backend backend_output; + if (n_gpu_layers > int(n_layer)) { + backend_output = GGML_BACKEND_CUDA; + } else { + backend_output = GGML_BACKEND_CPU; + } + + model.output = ml->get_tensor("output.weight", {n_embd, n_vocab}, backend_output); } - model.output = ml->get_tensor("output.weight", {n_embd, n_vocab}, backend_output); - model.layers.resize(n_layer); const int i_gpu_start = n_layer - n_gpu_layers; + + model.layers.resize(n_layer); for (uint32_t i = 0; i < n_layer; ++i) { - auto & layer = model.layers[i]; const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : GGML_BACKEND_CUDA; + auto & layer = model.layers[i]; + std::string layers_i = "layers." + std::to_string(i); layer.attention_norm = ml->get_tensor(layers_i + ".attention_norm.weight", {n_embd}, backend); @@ -1036,13 +1043,15 @@ static void llama_model_load_internal( layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend); - layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}, backend); + layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}, backend); layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}, backend); - layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend); + layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend); + if (backend == GGML_BACKEND_CUDA) { - vram_total += ggml_nbytes(layer.attention_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) - + ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.attention_norm) - + ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3); + vram_total += + ggml_nbytes(layer.attention_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + + ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.attention_norm) + + ggml_nbytes(layer.w1) + ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3); } } } @@ -1077,7 +1086,7 @@ static void llama_model_load_internal( } fprintf(stderr, "%s: [cublas] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024); #else - (void) n_gpu_layers; + (void) n_gpu_layers; #endif } @@ -2192,7 +2201,7 @@ struct llama_context * llama_init_from_file( unsigned * cur_percentage_p = (unsigned *) ctx; unsigned percentage = (unsigned) (100 * progress); while (percentage > *cur_percentage_p) { - ++*cur_percentage_p; + *cur_percentage_p = percentage; fprintf(stderr, "."); fflush(stderr); if (percentage >= 100) { @@ -2442,8 +2451,7 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char * } size_t idx = model_loader->tensors_map.name_to_idx[base_name]; llama_load_tensor & lt = model_loader->tensors_map.tensors[idx]; - base_t = model_loader->get_tensor( - base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] }, GGML_BACKEND_CPU); + base_t = model_loader->get_tensor(base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] }, GGML_BACKEND_CPU); lt.data = (uint8_t *) lt.ggml_tensor->data; model_loader->load_data_for(lt); lt.ggml_tensor->data = lt.data; From 3ec7941bad08c39bc2fc113e9ec1b27351e33f12 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 20 May 2023 13:09:01 +0300 Subject: [PATCH 28/34] ggml : ggml_mul better broadcast support --- ggml.c | 84 +++++++++++++++++++++++++++++++++---------------------- llama.cpp | 32 +++++++-------------- 2 files changed, 60 insertions(+), 56 deletions(-) diff --git a/ggml.c b/ggml.c index bcef056f455ee..d86e5942ab46e 100644 --- a/ggml.c +++ b/ggml.c @@ -3776,6 +3776,12 @@ static inline bool ggml_can_repeat(const struct ggml_tensor * t0, const struct g (t1->ne[3]%t0->ne[3] == 0); } +static inline bool ggml_can_repeat_rows(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { + static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); + + return (t0->ne[0] == t1->ne[0]) && ggml_can_repeat(t0, t1); +} + static inline int ggml_up32(int n) { return (n + 31) & ~31; } @@ -4658,11 +4664,15 @@ struct ggml_tensor * ggml_mul_impl( struct ggml_tensor * a, struct ggml_tensor * b, bool inplace) { - GGML_ASSERT(a->ne[0] == b->ne[0] && ggml_can_repeat(b, a)); + // TODO: support less-strict constraint + // GGML_ASSERT(ggml_can_repeat(b, a)); + GGML_ASSERT(ggml_can_repeat_rows(b, a)); bool is_node = false; if (!inplace && (a->grad || b->grad)) { + // TODO: support backward pass for broadcasting + GGML_ASSERT(ggml_are_same_shape(a, b)); is_node = true; } @@ -7960,22 +7970,14 @@ static void ggml_compute_forward_mul_f32( const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - const int nr = ggml_nrows(src0); - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; - const int64_t ne12 = src1->ne[2]; - const int64_t ne13 = src0->ne[3]; - - GGML_ASSERT(ne00 == ne10 && ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } const int ith = params->ith; const int nth = params->nth; + #ifdef GGML_USE_CUBLAS if (src1->backend == GGML_BACKEND_CUDA) { if (ith == 0) { @@ -7985,6 +7987,17 @@ static void ggml_compute_forward_mul_f32( } #endif + const int64_t nr = ggml_nrows(src0); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + const size_t nb00 = src0->nb[0]; const size_t nb01 = src0->nb[1]; const size_t nb02 = src0->nb[2]; @@ -8002,47 +8015,50 @@ static void ggml_compute_forward_mul_f32( GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(ne00 == ne10); - if (nb10 == sizeof(float) && ggml_are_same_shape(src0, src1)) { - for (int ir = ith; ir < nr; ir += nth) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne02*ne01); - const int i2 = (ir - i3*ne02*ne01)/ne01; - const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + if (nb10 == sizeof(float)) { + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); #ifdef GGML_USE_ACCELERATE UNUSED(ggml_vec_mul_f32); - vDSP_vmul( - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, - (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, - ne00); + vDSP_vmul( src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00); #else - ggml_vec_mul_f32(ne00, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), - (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); + ggml_vec_mul_f32(ne00, dst_ptr, src0_ptr, src1_ptr); #endif // } // } } } else { // src1 is not contiguous - for (int ir = ith; ir < nr; ir += nth) { + for (int64_t ir = ith; ir < nr; ir += nth) { // src0 and dst are same shape => same indices // src1 is broadcastable across src0 and dst in i1, i2, i3 - const int i03 = ir/(ne02*ne01); - const int i02 = (ir - i03*ne02*ne01)/ne01; - const int i01 = (ir - i03*ne02*ne01 - i02*ne01); - const int i13 = i03 % ne13; - const int i12 = i02 % ne12; - const int i11 = i01 % ne11; + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - for (int i0 = 0; i0 < ne00; i0++) { + + for (int64_t i0 = 0; i0 < ne00; i0++) { float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10); dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr); diff --git a/llama.cpp b/llama.cpp index cbc6f8b40588c..431c7eaf6355c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1000,6 +1000,12 @@ static void llama_model_load_internal( } } +#ifdef GGML_USE_CUBLAS +#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_CUDA +#else +#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_CPU +#endif + // prepare memory for the weights size_t vram_total = 0; { @@ -1016,19 +1022,19 @@ static void llama_model_load_internal( { ggml_backend backend_output; if (n_gpu_layers > int(n_layer)) { - backend_output = GGML_BACKEND_CUDA; + backend_output = LLAMA_BACKEND_OFFLOAD; } else { backend_output = GGML_BACKEND_CPU; } - model.output = ml->get_tensor("output.weight", {n_embd, n_vocab}, backend_output); + model.output = ml->get_tensor("output.weight", {n_embd, n_vocab}, backend_output); } const int i_gpu_start = n_layer - n_gpu_layers; model.layers.resize(n_layer); for (uint32_t i = 0; i < n_layer; ++i) { - const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : GGML_BACKEND_CUDA; + const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; auto & layer = model.layers[i]; @@ -1047,7 +1053,7 @@ static void llama_model_load_internal( layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}, backend); layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend); - if (backend == GGML_BACKEND_CUDA) { + if (backend == LLAMA_BACKEND_OFFLOAD) { vram_total += ggml_nbytes(layer.attention_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.attention_norm) + @@ -1213,13 +1219,7 @@ static bool llama_eval_internal( cur = ggml_rms_norm(ctx0, inpL); // cur = cur*attention_norm(broadcasted) -#ifdef GGML_USE_CUBLAS cur = ggml_mul(ctx0, cur, model.layers[il].attention_norm); -#else - cur = ggml_mul(ctx0, - ggml_repeat(ctx0, model.layers[il].attention_norm, cur), - cur); -#endif } // self-attention @@ -1327,13 +1327,7 @@ static bool llama_eval_internal( cur = ggml_rms_norm(ctx0, inpFF); // cur = cur*ffn_norm(broadcasted) -#ifdef GGML_USE_CUBLAS cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm); -#else - cur = ggml_mul(ctx0, - ggml_repeat(ctx0, model.layers[il].ffn_norm, cur), - cur); -#endif } struct ggml_tensor * tmp = ggml_mul_mat(ctx0, @@ -1371,13 +1365,7 @@ static bool llama_eval_internal( inpL = ggml_rms_norm(ctx0, inpL); // inpL = inpL*norm(broadcasted) -#ifdef GGML_USE_CUBLAS inpL = ggml_mul(ctx0, inpL, model.norm); -#else - inpL = ggml_mul(ctx0, - ggml_repeat(ctx0, model.norm, inpL), - inpL); -#endif embeddings = inpL; } From a3586c526f5124c1e5f9a61476183826b3045ccb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 20 May 2023 13:22:47 +0300 Subject: [PATCH 29/34] cmake : workarounds for cufile when CMake version < 3.25 --- CMakeLists.txt | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 48e3238dfa52e..b46c2adf9dfa9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -182,6 +182,7 @@ if (LLAMA_CUBLAS) cmake_minimum_required(VERSION 3.17) find_package(CUDAToolkit) + if (CUDAToolkit_FOUND) message(STATUS "cuBLAS found") @@ -197,6 +198,23 @@ if (LLAMA_CUBLAS) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt) endif() + if (CMAKE_VERSION VERSION_LESS 3.25) + if (NOT MSVC) + if (LLAMA_STATIC) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} -lcufile_static) + else() + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} -lcufile) + endif() + else() + message(FATAL "TODO: cufile on Windows") + endif() + else() + if (LLAMA_STATIC) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cufile_static) + else() + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cufile) + endif() + endif() else() message(WARNING "cuBLAS not found") endif() From fee87f6558123c056bbc2ad52aea05433cb0a50b Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Sat, 20 May 2023 13:27:21 +0200 Subject: [PATCH 30/34] gg rebase fixup --- .gitignore | 22 +++++++--------------- CMakeLists.txt | 18 ------------------ Makefile | 2 +- ggml-cuda.cu | 44 ++------------------------------------------ ggml-cuda.h | 2 -- llama.cpp | 1 - 6 files changed, 10 insertions(+), 79 deletions(-) diff --git a/.gitignore b/.gitignore index 1aabe82fc2967..d231f3ff8ed36 100644 --- a/.gitignore +++ b/.gitignore @@ -1,15 +1,5 @@ -/*.o -/*.a -/*.sh -/*.log -/*.org - -/ppl-*.txt -/qnt-*.txt -/perf-*.txt - -*.bin - +*.o +*.a .DS_Store .build/ .cache/ @@ -31,9 +21,8 @@ build-no-accel/ build-sanitize-addr/ build-sanitize-thread/ -prompts/ models/* -wikitext-2-raw/ +*.bin /main /quantize @@ -44,7 +33,6 @@ wikitext-2-raw/ /benchmark-matmult /vdot /Pipfile -/libllama.so build-info.h arm_neon.h @@ -55,4 +43,8 @@ __pycache__ zig-out/ zig-cache/ +ppl-*.txt +qnt-*.txt +perf-*.txt + examples/jeopardy/results.txt diff --git a/CMakeLists.txt b/CMakeLists.txt index b46c2adf9dfa9..48e3238dfa52e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -182,7 +182,6 @@ if (LLAMA_CUBLAS) cmake_minimum_required(VERSION 3.17) find_package(CUDAToolkit) - if (CUDAToolkit_FOUND) message(STATUS "cuBLAS found") @@ -198,23 +197,6 @@ if (LLAMA_CUBLAS) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt) endif() - if (CMAKE_VERSION VERSION_LESS 3.25) - if (NOT MSVC) - if (LLAMA_STATIC) - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} -lcufile_static) - else() - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} -lcufile) - endif() - else() - message(FATAL "TODO: cufile on Windows") - endif() - else() - if (LLAMA_STATIC) - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cufile_static) - else() - set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cufile) - endif() - endif() else() message(WARNING "cuBLAS not found") endif() diff --git a/Makefile b/Makefile index 93c149edc9593..f9ec8797a40dc 100644 --- a/Makefile +++ b/Makefile @@ -125,7 +125,7 @@ endif ifdef LLAMA_CUBLAS CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include - LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lcufile -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib + LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib OBJS += ggml-cuda.o NVCC = nvcc NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 6dbbc7950da0f..2be6ece7e114e 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2,13 +2,11 @@ #include #include #include -#include #include #include #include #include -#include #include "ggml-cuda.h" #include "ggml.h" @@ -34,15 +32,6 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); } \ } while (0) -#define CUFILE_CHECK(status) \ - do { \ - CUfileError_t status_ = (status); \ - if (status_.err != CU_FILE_SUCCESS) { \ - fprintf(stderr, "cuFile error %d at %s:%d\n", status_.err, __FILE__, __LINE__); \ - exit(1); \ - } \ - } while (0) - typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1); typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream); typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream); @@ -383,7 +372,7 @@ struct cuda_buffer { static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS]; static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT; -void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { +static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { scoped_spin_lock lock(g_cuda_pool_lock); for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { @@ -402,7 +391,7 @@ void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) { return ptr; } -void ggml_cuda_pool_free(void * ptr, size_t size) { +static void ggml_cuda_pool_free(void * ptr, size_t size) { scoped_spin_lock lock(g_cuda_pool_lock); for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { @@ -442,9 +431,6 @@ void ggml_init_cublas() { // configure logging to stdout // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr)); - - // initialize cuFile for loading model parameters directly to VRAM - CUFILE_CHECK(cuFileDriverOpen()); } } @@ -909,32 +895,6 @@ void ggml_cuda_transform_tensor(ggml_tensor * tensor) { } void ggml_cuda_load_data(const char * fname, struct ggml_tensor ** tensors, const int num_tensors, const size_t * offsets) { - CUfileDescr_t cf_descr; - memset((void *)&cf_descr, 0, sizeof(CUfileDescr_t)); - const int fd_cf = open(fname, O_RDONLY|O_DIRECT, 0644); - cf_descr.handle.fd = fd_cf; - cf_descr.type = CU_FILE_HANDLE_TYPE_OPAQUE_FD; - - CUfileHandle_t cf_handle; - CUfileError_t status = cuFileHandleRegister(&cf_handle, &cf_descr); - - if (status.err == CU_FILE_INTERNAL_ERROR) { - fprintf(stderr, "WARNING: cuFile experienced an internal error while loading weights from \"%s\". Using a workaround (slower). " - "This happens with weight files on Btrfs partitions. ext4 and NTFS are confirmed to work.\n", fname); - } else { - for (int i = 0; i < num_tensors; ++i) { - ggml_tensor * tensor = tensors[i]; - const size_t size = ggml_nbytes(tensor); - const size_t offset = offsets[i]; - - size_t actual_size; - void * buf = ggml_cuda_pool_malloc(size, &actual_size); - cuFileRead(cf_handle, buf, size, offset, 0); - tensor->data = buf; - } - return; - } - FILE * fp = fopen(fname, "rb"); for (int i = 0; i < num_tensors; ++i) { diff --git a/ggml-cuda.h b/ggml-cuda.h index b46a804c7adf0..ab2c690b082b1 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -14,8 +14,6 @@ void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens // TODO: export these with GGML_API void * ggml_cuda_host_malloc(size_t size); void ggml_cuda_host_free(void * ptr); -void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size); -void ggml_cuda_pool_free(void * ptr, size_t size); void ggml_cuda_transform_tensor(struct ggml_tensor * tensor); void ggml_cuda_load_data(const char * fname, struct ggml_tensor ** tensors, int num_tensors, const size_t * offsets); diff --git a/llama.cpp b/llama.cpp index 431c7eaf6355c..6a822e6d30782 100644 --- a/llama.cpp +++ b/llama.cpp @@ -10,7 +10,6 @@ #include "ggml.h" #ifdef GGML_USE_CUBLAS -#include #include "ggml-cuda.h" #endif From b81f662e9dcbc4a6b1bbfea199ba315d4154c709 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Sat, 20 May 2023 13:42:19 +0200 Subject: [PATCH 31/34] Loop in llama.cpp, fixed progress callback --- ggml-cuda.cu | 38 +++++++++++++++++--------------------- ggml-cuda.h | 2 +- llama.cpp | 27 ++++++++++++++++++--------- 3 files changed, 36 insertions(+), 31 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 2be6ece7e114e..385aa2fd62ec7 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -894,35 +894,31 @@ void ggml_cuda_transform_tensor(ggml_tensor * tensor) { tensor->backend = GGML_BACKEND_CUDA; } -void ggml_cuda_load_data(const char * fname, struct ggml_tensor ** tensors, const int num_tensors, const size_t * offsets) { +void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset) { FILE * fp = fopen(fname, "rb"); - for (int i = 0; i < num_tensors; ++i) { - ggml_tensor * tensor = tensors[i]; - const size_t size = ggml_nbytes(tensor); - const size_t offset = offsets[i]; + const size_t size = ggml_nbytes(tensor); - void * buf; - CUDA_CHECK(cudaMalloc(&buf, size)); - void * buf_host = malloc(size); + void * buf; + CUDA_CHECK(cudaMalloc(&buf, size)); + void * buf_host = malloc(size); #ifdef _WIN32 - int ret = _fseeki64(fp, (__int64) offset, SEEK_SET); + int ret = _fseeki64(fp, (__int64) offset, SEEK_SET); #else - int ret = fseek(fp, (long) offset, SEEK_SET); + int ret = fseek(fp, (long) offset, SEEK_SET); #endif - GGML_ASSERT(ret == 0); // same + GGML_ASSERT(ret == 0); // same - size_t ret2 = fread(buf_host, size, 1, fp); - if (ret2 != 1) { - fprintf(stderr, "unexpectedly reached end of file"); - exit(1); - } + size_t ret2 = fread(buf_host, size, 1, fp); + if (ret2 != 1) { + fprintf(stderr, "unexpectedly reached end of file"); + exit(1); + } - cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice); - cudaDeviceSynchronize(); + cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice); + cudaDeviceSynchronize(); - tensor->data = buf; - free(buf_host); - } + tensor->data = buf; + free(buf_host); } diff --git a/ggml-cuda.h b/ggml-cuda.h index ab2c690b082b1..6a04dde6c37a9 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -16,7 +16,7 @@ void * ggml_cuda_host_malloc(size_t size); void ggml_cuda_host_free(void * ptr); void ggml_cuda_transform_tensor(struct ggml_tensor * tensor); -void ggml_cuda_load_data(const char * fname, struct ggml_tensor ** tensors, int num_tensors, const size_t * offsets); +void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensors, size_t offset); #ifdef __cplusplus } diff --git a/llama.cpp b/llama.cpp index 6a822e6d30782..258537a922613 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1,6 +1,7 @@ // Defines fileno on msys: #ifndef _GNU_SOURCE #define _GNU_SOURCE +#include #include #include #endif @@ -720,9 +721,6 @@ struct llama_model_loader { lmlock->grow_to(done_size); } } - if (progress_callback) { - progress_callback(1.0f, progress_callback_user_data); - } } void load_data_for(llama_load_tensor & lt) { @@ -1104,20 +1102,31 @@ static void llama_model_load_internal( #ifdef GGML_USE_CUBLAS { - std::vector tensors; - std::vector offsets; + size_t done_size = 0; + size_t data_size = 0; + for (llama_load_tensor & lt : ml->tensors_map.tensors) { + data_size += lt.size; + if (lt.ggml_tensor->backend == GGML_BACKEND_CPU) { + done_size += lt.size; + } + } for (llama_load_tensor & lt : ml->tensors_map.tensors) { if (lt.ggml_tensor->backend != GGML_BACKEND_CUDA) { continue; } - tensors.emplace_back(lt.ggml_tensor); - LLAMA_ASSERT(lt.shards.size() == 1); - offsets.emplace_back(lt.shards.at(0).file_off); + if (progress_callback) { + progress_callback((float) done_size / data_size, progress_callback_user_data); + } + ggml_cuda_load_data(fname.c_str(), lt.ggml_tensor, lt.shards.at(0).file_off); + done_size += lt.size; } - ggml_cuda_load_data(fname.c_str(), tensors.data(), tensors.size(), offsets.data()); } #endif // GGML_USE_CUBLAS + if (progress_callback) { + progress_callback(1.0f, progress_callback_user_data); + } + model.mapping = std::move(ml->mapping); // loading time will be recalculate after the first eval, so From fadcd583fcc081e0cb5822421b6cc15457478c2c Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Sat, 20 May 2023 14:07:46 +0200 Subject: [PATCH 32/34] Attempt clang-tidy fix --- llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 258537a922613..d677b0af2ed11 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1018,7 +1018,7 @@ static void llama_model_load_internal( // "output" tensor { ggml_backend backend_output; - if (n_gpu_layers > int(n_layer)) { + if (n_gpu_layers > int(n_layer)) { // NOLINT backend_output = LLAMA_BACKEND_OFFLOAD; } else { backend_output = GGML_BACKEND_CPU; From a4da072d394edc1a38a99e4a35841b0032e0221c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 20 May 2023 15:14:53 +0300 Subject: [PATCH 33/34] llama : fix vram size computation --- llama.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index d677b0af2ed11..b38d55dda0e27 100644 --- a/llama.cpp +++ b/llama.cpp @@ -975,7 +975,7 @@ static void llama_model_load_internal( size_t ctx_size; size_t mmapped_size; ml->calc_sizes(&ctx_size, &mmapped_size); - fprintf(stderr, "%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/1024.0/1024.0); + fprintf(stderr, "%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/1024.0/1024.0); // create the ggml context { @@ -1050,7 +1050,7 @@ static void llama_model_load_internal( layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}, backend); layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend); - if (backend == LLAMA_BACKEND_OFFLOAD) { + if (backend == GGML_BACKEND_CUDA) { vram_total += ggml_nbytes(layer.attention_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk) + ggml_nbytes(layer.wv) + ggml_nbytes(layer.wo) + ggml_nbytes(layer.attention_norm) + From 37f2c6c2519fe3b568978840be82bbdc93c4f089 Mon Sep 17 00:00:00 2001 From: JohannesGaessler Date: Sat, 20 May 2023 14:16:40 +0200 Subject: [PATCH 34/34] Add forgotten fclose() --- ggml-cuda.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 385aa2fd62ec7..35d2e457cbf84 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -921,4 +921,5 @@ void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const tensor->data = buf; free(buf_host); + fclose(fp); }