From 4e1f7415b7f61d6764b458b48e16a2936ec53ba6 Mon Sep 17 00:00:00 2001 From: Martin Yuan Date: Tue, 13 Aug 2024 17:38:47 -0700 Subject: [PATCH] Reduce the memory usage of logits from O(context_length) to O(1) (#4688) Summary: The logits size is big, with size [context_length x vocab_size]. But we always use the last (new) logits, because the model generates one new token in each Transformer inference. This PR changes the transformer to return the logits of the last token only. In the runner code, we don't have to fetch the logits for the last token specifically, but directly use the output . Test command: ``` python -m examples.models.llama2.export_llama --checkpoint /Users/myuan/data/llama/story110m/checkpoint.pt --params /Users/myuan/data/llama/story110m/params.json -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32 --max_seq_length 1024 --profile_memory ``` Before: 284 MB activation, with 262 MB on logits After: 162 MB activation, with 0.128 MB on logits Verified with llamma_runner, before and after it generates the same text with temperature=0. Now the dominant memory usage would be KV cache. TODO: - Improve KV cache memory usage using pf16 or quantization. - This PR only fixes logits. Further activation memory optimization with one token output. Differential Revision: D61246566 --- examples/models/llama2/llama_transformer.py | 3 +++ extension/llm/runner/text_decoder_runner.h | 12 ++---------- extension/llm/runner/text_prefiller.cpp | 5 ----- 3 files changed, 5 insertions(+), 15 deletions(-) diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 99544426fd3..3c688133b01 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -524,6 +524,9 @@ def forward( input_pos, ) + # Only the last logit is used for the new generated token + h = h[:, -1, :] + h = self.norm(h) logits = self.output(h) diff --git a/extension/llm/runner/text_decoder_runner.h b/extension/llm/runner/text_decoder_runner.h index 31b8c1b983f..6019e7ce481 100644 --- a/extension/llm/runner/text_decoder_runner.h +++ b/extension/llm/runner/text_decoder_runner.h @@ -63,23 +63,15 @@ class TextDecoderRunner { * @return The next token. */ inline int32_t logits_to_token(const exec_aten::Tensor& logits_tensor) { - ET_CHECK_MSG(logits_tensor.dim() == 3, "Logits tensor must be 3D"); - auto num_tokens = logits_tensor.size(1); - auto vocab_size = logits_tensor.size(2); - switch (logits_tensor.scalar_type()) { case ScalarType::Float: { float* logits = logits_tensor.mutable_data_ptr(); - float* logits_last = logits; - logits_last += (num_tokens - 1) * vocab_size; - return sampler_->sample(logits_last); + return sampler_->sample(logits); } case ScalarType::Half: { exec_aten::Half* logits = logits_tensor.mutable_data_ptr(); - exec_aten::Half* logits_last = logits; - logits_last += (num_tokens - 1) * vocab_size; - return sampler_->sample(logits_last); + return sampler_->sample(logits); } default: ET_CHECK_MSG( diff --git a/extension/llm/runner/text_prefiller.cpp b/extension/llm/runner/text_prefiller.cpp index a5aa668e73a..fa084cbe016 100644 --- a/extension/llm/runner/text_prefiller.cpp +++ b/extension/llm/runner/text_prefiller.cpp @@ -50,11 +50,6 @@ Result TextPrefiller::prefill( ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error()); ET_LOG( Info, "Prefill token result numel(): %zu", outputs_res.get().numel()); - ET_CHECK_MSG( - outputs_res.get().size(1) == num_prompt_tokens, - "Expected number of output tokens %d does not match returned value %zu.", - num_prompt_tokens, - outputs_res.get().size(1)); // insert new token into prompt_tokens // NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds) uint64_t prev = prompt_tokens[0];