Skip to content

Commit

Permalink
Reduce the memory usage of logits from O(context_length) to O(1) (#4688)
Browse files Browse the repository at this point in the history
Summary:
The logits size is big, with size [context_length x vocab_size]. But we always use the last (new) logits, because the model generates one new token in each Transformer inference. 

This PR changes the transformer to return the logits of the last token only. In the runner code, we don't have to fetch the logits for the last token specifically, but directly use the output .

Test command:
```
python -m examples.models.llama2.export_llama --checkpoint /Users/myuan/data/llama/story110m/checkpoint.pt --params /Users/myuan/data/llama/story110m/params.json -kv --use_sdpa_with_kv_cache -X -qmode 8da4w --group_size 128 -d fp32 --max_seq_length 1024 --profile_memory
```
Before: 284 MB activation, with 262 MB on logits
After: 162 MB activation, with 0.128 MB on logits

Verified with llamma_runner, before and after it generates the same text with temperature=0. 

Now the dominant memory usage would be KV cache. 

TODO: 
- Improve KV cache memory usage using pf16 or quantization.
- This PR only fixes logits. Further activation memory optimization with one token output.


Reviewed By: larryliu0820

Differential Revision: D61246566

Pulled By: iseeyuan
  • Loading branch information
Martin Yuan authored and facebook-github-bot committed Aug 22, 2024
1 parent d7c069f commit be438eb
Show file tree
Hide file tree
Showing 10 changed files with 59 additions and 39 deletions.
4 changes: 2 additions & 2 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def build_args_parser() -> argparse.ArgumentParser:
"--generate_full_logits",
action="store_true",
required=False,
default=True,
default=False,
help="Generate logits for all inputs.",
)
return parser
Expand Down Expand Up @@ -598,7 +598,7 @@ def _load_llama_model(
params_path: str,
use_kv_cache: bool = False,
use_sdpa_with_kv_cache: bool = False,
generate_full_logits: bool = True,
generate_full_logits: bool = False,
weight_type: WeightType = WeightType.LLAMA,
enable_dynamic_shape: bool = False,
verbose: bool = False,
Expand Down
2 changes: 1 addition & 1 deletion examples/models/llama2/llama_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class ModelArgs:
# Generate logits for all inputs. When it's True, it would take big memory usage
# at runtime. Enable it only necessary (e.g., use perplexity tools that requires
# logits for all input tokens.)
generate_full_logits: bool = True
generate_full_logits: bool = False
enable_dynamic_shape: bool = False # export model with dynamic shape support
use_hf_rope: bool = False # Use HuggingFace's RoPE implementation
rope_theta: Optional[float] = (
Expand Down
2 changes: 1 addition & 1 deletion examples/models/llama2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self, **kwargs):

self.use_kv_cache = kwargs.get("use_kv_cache", False)
self.use_sdpa_with_kv_cache_op = kwargs.get("use_sdpa_with_kv_cache", False)
self.generate_full_logits = kwargs.get("generate_full_logits", True)
self.generate_full_logits = kwargs.get("generate_full_logits", False)
self.enable_dynamic_shape = kwargs.get("enable_dynamic_shape", False)

self.max_seq_len = kwargs.get("max_seq_len", 128)
Expand Down
7 changes: 5 additions & 2 deletions examples/models/llava/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,16 +216,19 @@ def prefill_embedding(
result = torch.cat((embeds_before_img, image_embeds, embeds_after_img), dim=1)
return result

# prefill using the in house text_model of llama transformer
def prefill(
self,
prompt_before_image: torch.Tensor,
images: torch.Tensor,
prompt_after_image: torch.Tensor,
) -> torch.Tensor:
) -> (int, torch.Tensor):
"""Avoiding the torch.where() call to find <image> placeholder and insert image embedding. Taking 3 inputs instead."""
embeds = self.prefill_embedding(prompt_before_image, images, prompt_after_image)
return self.text_model.forward(None, torch.tensor([0]), embeds)
# returns the prefilled token length too, because the text model generates one logits in each forward call.
return embeds.shape[1], self.text_model.forward(None, torch.tensor([0]), embeds)

# reference prefill using the text model in HF
def prefill_ref(
self,
prompt_before_image: torch.Tensor,
Expand Down
6 changes: 5 additions & 1 deletion examples/models/llava/runner/llava_image_prefiller.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class LlavaImagePrefiller : public ImagePrefiller {
* @param start_pos The starting position in KV cache of the input in the LLM
* @return logits of the image prefill.
*/
inline Result<exec_aten::Tensor> prefill(Image& image, int64_t start_pos = 0)
inline Result<exec_aten::Tensor> prefill(Image& image, int64_t& start_pos)
override {
ManagedTensor managed_images(
image.data.data(), {3, image.height, image.width}, ScalarType::Byte);
Expand All @@ -43,6 +43,10 @@ class LlavaImagePrefiller : public ImagePrefiller {
outputs_res[0].isTensor(),
"Non Tensor Output returned from executing image prefill");

// Update the start_pos, which is only available inside this function.
// outputs_res can have only one logits.
start_pos += image_encoder_outputs[0].toTensor().size(1);

return outputs_res[0].toTensor();
}

Expand Down
4 changes: 2 additions & 2 deletions examples/models/llava/runner/llava_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ Error LlavaRunner::generate(

// prefill images
for (auto& image : images) {
auto logits = ET_UNWRAP(image_prefiller_->prefill(image, pos));
pos += logits.size(1);
// pos is updated inside image prefill.
ET_UNWRAP(image_prefiller_->prefill(image, pos));
}

// prefill user prompt. No BOS because preset prompt already has it.
Expand Down
34 changes: 21 additions & 13 deletions examples/models/llava/test/test_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ def setUp(self):
)

def test_prefill_logits(self):
prefill_logits = self.llava.prefill(
# For efficiency, the implemented prefill function only outputs the last logits.
_, prefill_logits = self.llava.prefill(
self.prompt_before_image, self.resized, self.prompt_after_image
)
# The reference implementation in HF genetates the full logits. Get the last one.
prefill_logits_ref = self.llava.prefill_ref(
self.prompt_before_image, self.resized, self.prompt_after_image
)[0]
)[0][:, -1, :]
self.assertTrue(torch.allclose(prefill_logits, prefill_logits_ref, atol=3e-2))

def test_generated_output(self):
Expand All @@ -62,11 +64,11 @@ def test_generated_output(self):
)[0].strip()

# being tested, using llama_transformer
prefill_logits = self.llava.prefill(
context_len, prefill_logits = self.llava.prefill(
self.prompt_before_image, self.resized, self.prompt_after_image
)
context_len = prefill_logits.shape[1]
new_tokens = [torch.argmax(prefill_logits[..., -1, :]).item()]
# Always generate one token at a time.
new_tokens = [torch.argmax(prefill_logits).item()]
for i in range(4):
logits = self.llava.step(
torch.tensor([new_tokens[i]]), torch.tensor([context_len + i])
Expand All @@ -93,24 +95,27 @@ def test_llava_export(self):
pte_embeds_before_img = llava_module.run_method(
"token_embedding", (prompt_before_image,)
)[0]
pte_prefill_before_img = llava_module.run_method(
llava_module.run_method(
"text_model",
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_before_img),
)[0]
)

start_pos += pte_prefill_before_img.shape[1]
# Update the start_pos. start_pos is used in kv cache. The source of truth
# of the delta length is from the embeddings, not from the logits.
start_pos += pte_embeds_before_img.shape[1]

# pte prefill image
pte_embeds_img = llava_module.run_method("image_encoder", (resized,))[0]
pte_prefill_img = llava_module.run_method(
llava_module.run_method(
"text_model",
(
torch.tensor([start_pos], dtype=torch.int64),
pte_embeds_img,
),
)[0]
)

start_pos += pte_prefill_img.shape[1]
# Update the logits for each prefill (kv cache) step.
start_pos += pte_embeds_img.shape[1]

# pte prefill prompt after img
pte_embeds_after_img = llava_module.run_method(
Expand All @@ -121,8 +126,11 @@ def test_llava_export(self):
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img),
)[0]

# Update the logits for each prefill (kv cache) step.
start_pos += pte_embeds_after_img.shape[1]

# being tested, using llama_transformer
new_tokens = [torch.argmax(pte_prefill_after_img[..., -1, :]).item()]
new_tokens = [torch.argmax(pte_prefill_after_img).item()]
# TODO: uncomment this line
# self.assertEquals(new_tokens[0], 1932) # When
for i in range(4):
Expand All @@ -134,7 +142,7 @@ def test_llava_export(self):
"text_model",
(torch.tensor([start_pos + i], dtype=torch.int64), token_embeds),
)[0]
new_tokens.append(torch.argmax(logits[..., -1, :]).item())
new_tokens.append(torch.argmax(logits).item())

outputs = llava_model.tokenizer.batch_decode(
torch.tensor([new_tokens]), skip_special_tokens=True
Expand Down
5 changes: 3 additions & 2 deletions extension/llm/runner/image_prefiller.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ class ImagePrefiller {
/**
* Prefill an LLM Module with the given image input.
* @param image The image input to the multimodal LLM.
* @param start_pos The starting position in KV cache of the input in the LLM
* @param start_pos The starting position in KV cache of the input in the LLM.
* It's passed as reference and will be updated inside this function.
* @return The next token of the LLM Module after prefill.
*/
virtual ::executorch::runtime::Result<exec_aten::Tensor> prefill(
Image& image,
int64_t start_pos = 0) = 0;
int64_t& start_pos) = 0;

virtual ::executorch::runtime::Error load() = 0;
virtual bool is_method_loaded() = 0;
Expand Down
29 changes: 19 additions & 10 deletions extension/llm/runner/text_decoder_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,23 +67,32 @@ class TextDecoderRunner {
* @return The next token.
*/
inline int32_t logits_to_token(const exec_aten::Tensor& logits_tensor) {
ET_CHECK_MSG(logits_tensor.dim() == 3, "Logits tensor must be 3D");
auto num_tokens = logits_tensor.size(1);
auto vocab_size = logits_tensor.size(2);

switch (logits_tensor.scalar_type()) {
// If the logit_tensor rank is 3, the shape is [batch, seq_length,
// vocab_size], get the last logits, sample and return. Else the model
// outputs the last logit, directly sample and return.
case exec_aten::ScalarType::Float: {
float* logits = logits_tensor.mutable_data_ptr<float>();
float* logits_last = logits;
logits_last += (num_tokens - 1) * vocab_size;
return sampler_->sample(logits_last);
if (logits_tensor.dim() == 3) {
auto num_tokens = logits_tensor.size(1);
auto vocab_size = logits_tensor.size(2);
float* logits_last = logits;
logits_last += (num_tokens - 1) * vocab_size;
return sampler_->sample(logits_last);
}
return sampler_->sample(logits);
}
case exec_aten::ScalarType::Half: {
exec_aten::Half* logits =
logits_tensor.mutable_data_ptr<exec_aten::Half>();
exec_aten::Half* logits_last = logits;
logits_last += (num_tokens - 1) * vocab_size;
return sampler_->sample(logits_last);
if (logits_tensor.dim() == 3) {
auto num_tokens = logits_tensor.size(1);
auto vocab_size = logits_tensor.size(2);
exec_aten::Half* logits_last = logits;
logits_last += (num_tokens - 1) * vocab_size;
return sampler_->sample(logits_last);
}
return sampler_->sample(logits);
}
default:
ET_CHECK_MSG(
Expand Down
5 changes: 0 additions & 5 deletions extension/llm/runner/text_prefiller.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,6 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
ET_CHECK_OK_OR_RETURN_ERROR(outputs_res.error());
ET_LOG(
Info, "Prefill token result numel(): %zu", outputs_res.get().numel());
ET_CHECK_MSG(
outputs_res.get().size(1) == num_prompt_tokens,
"Expected number of output tokens %d does not match returned value %zu.",
num_prompt_tokens,
outputs_res.get().size(1));
// insert new token into prompt_tokens
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
uint64_t prev = prompt_tokens[0];
Expand Down

0 comments on commit be438eb

Please sign in to comment.