Skip to content

Skip computation of much of last layer & unused logits during prompt eval / large N #2700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging"
option(LLAMA_MPI "llama: use MPI" OFF)
option(LLAMA_K_QUANTS "llama: use k-quants" ON)
option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF)
option(LLAMA_SKIP_UNUSED_LOGITS "llama: skip computation of unused logits" ON)

option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
Expand Down Expand Up @@ -410,6 +411,10 @@ if (LLAMA_HIPBLAS)
endif()
endif()

if (LLAMA_SKIP_UNUSED_LOGITS)
add_compile_definitions(LLAMA_SKIP_UNUSED_LOGITS)
endif()

if (LLAMA_ALL_WARNINGS)
if (NOT MSVC)
set(c_flags
Expand Down
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ ifdef LLAMA_DISABLE_LOGS
MK_CPPFLAGS += -DLOG_DISABLE_LOGS
endif # LLAMA_DISABLE_LOGS

ifndef LLAMA_NO_SKIP_UNUSED_LOGITS
MK_CPPFLAGS += -DLLAMA_SKIP_UNUSED_LOGITS
endif # LLAMA_NO_SKIP_UNUSED_LOGITS

# warnings
MK_CFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith \
-Wmissing-prototypes -Werror=implicit-int -Wno-unused-function
Expand Down
4 changes: 4 additions & 0 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6943,6 +6943,10 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
return;
}

if (tensor->backend != GGML_BACKEND_CPU) {
return;
}

// recursively assign CUDA buffers until a compute tensor is found
if (tensor->src[0] != nullptr && tensor->src[0]->backend == GGML_BACKEND_CPU) {
const ggml_op src0_op = tensor->src[0]->op;
Expand Down
49 changes: 40 additions & 9 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2433,7 +2433,8 @@ static struct ggml_cgraph * llm_build_llama(

GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT

const int N = n_tokens;
// Non-const to allow short-circuiting to the last token in the last layer in prompt eval mode.
int N = n_tokens;

const auto & model = lctx.model;
const auto & hparams = model.hparams;
Expand Down Expand Up @@ -2561,18 +2562,10 @@ static struct ggml_cgraph * llm_build_llama(
offload_func_kq(tmpk);
ggml_set_name(tmpk, "tmpk");

struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
offload_func_kq(tmpq);
ggml_set_name(tmpq, "tmpq");

struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
offload_func_kq(Kcur);
ggml_set_name(Kcur, "Kcur");

struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
offload_func_kq(Qcur);
ggml_set_name(Qcur, "Qcur");

// store key and value to memory
{
// compute the transposed [N, n_embd] V matrix
Expand Down Expand Up @@ -2600,6 +2593,37 @@ static struct ggml_cgraph * llm_build_llama(
ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
}

#ifdef LLAMA_SKIP_UNUSED_LOGITS
if (il == n_layer - 1 && !lctx.logits_all) {
// From here on, we only care about the last token and its logits.
// We do as if N = 1 (from the end), which means we only keep
// the last column of cur and inpSA ((n_embd, N) -> (n_embd, 1)).
//
// Note that we do this even when N==1 so that we don't change the # nodes in the graph,
// otherwise for Metal we'd have to rebuild the concurrency list.

cur = ggml_view_2d(ctx0, cur, n_embd, 1, cur->nb[1], (N - 1)*ggml_element_size(cur)*n_embd);
offload_func(cur);
ggml_set_name(cur, "cur-lastpos");

offload_func(inpSA);
inpSA = ggml_view_2d(ctx0, inpSA, n_embd, 1, inpSA->nb[1], (N - 1)*ggml_element_size(inpSA)*n_embd);
offload_func(inpSA);
ggml_set_name(inpSA, "inpSA-lastpos");

n_past += N - 1;
N = 1;
}
#endif // LLAMA_SKIP_UNUSED_LOGITS

struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
offload_func_kq(tmpq);
ggml_set_name(tmpq, "tmpq");

struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, N), n_past, n_embd_head, 0, 0, freq_base, freq_scale);
offload_func_kq(Qcur);
ggml_set_name(Qcur, "Qcur");

struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
offload_func_kq(Q);
ggml_set_name(Q, "Q");
Expand Down Expand Up @@ -3815,11 +3839,18 @@ static bool llama_eval_internal(

if (lctx.logits_all) {
logits_out.resize(n_vocab * N);
GGML_ASSERT(ggml_nelements(res) == n_vocab * N);
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab*N);
} else {
// return result for just the last token
logits_out.resize(n_vocab);
#ifdef LLAMA_SKIP_UNUSED_LOGITS
GGML_ASSERT(ggml_nelements(res) == n_vocab);
memcpy(logits_out.data(), (float *) ggml_get_data(res), sizeof(float)*n_vocab);
#else
GGML_ASSERT(ggml_nelements(res) == n_vocab * N);
memcpy(logits_out.data(), (float *) ggml_get_data(res) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
#endif
}
}

Expand Down