Skip to content

Memory consumption for inference with Llama2-7B is weird #28651

@c3ianwu

Description

@c3ianwu

System Info

  • transformers version: 4.36.2
  • Platform: Linux-5.15.107+-x86_64-with-glibc2.31
  • Python version: 3.10.13
  • Huggingface_hub version: 0.20.1
  • Safetensors version: 0.4.1
  • Accelerate version: 0.22.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.1.2+cu118 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed

Who can help?

@ArthurZucker @younesbelkada @gan

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I am trying to track GPU memory consumption when doing inference with Llama2-7B. This is my set-up:

import json
import tqdm
import warnings
warnings.filterwarnings('ignore')
import time

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import datasets
import matplotlib.pyplot as plt

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf", torch_dtype=torch.bfloat16)
model.to(device=0)

prompt_data = datasets.load_from_disk("/data/metamath_100k_2048/train") # this is just some supervised training text data
prompts = prompt_data["inputs"] # this is a list of strings


class LocalModel:

    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def generate(self, prompts, do_sample=False, temperature=0, top_k=0, top_p=0, repetition_penalty=1.0, max_new_tokens=128):
        self.tokenizer.pad_token = self.tokenizer.eos_token
        tokenized_inputs = self.tokenizer(prompts, return_tensors="pt", padding=True).to(self.model.device)
        inputs = tokenized_inputs["input_ids"]
        attention_mask = tokenized_inputs["attention_mask"]
        tic = time.time()
        logits = self.model.generate(input_ids=inputs, 
                                     attention_mask=attention_mask, 
                                     do_sample=do_sample, 
                                     temperature=temperature, 
                                     top_k=top_k, 
                                     top_p=top_p, 
                                     repetition_penalty=repetition_penalty,
                                     max_new_tokens=max_new_tokens)
        max_alloc = torch.cuda.max_memory_allocated(0) / 1e9
        print("Peak GPU Memory Consumption: {}".format(torch.cuda.max_memory_allocated(0) / 1e9))
        torch.cuda.reset_peak_memory_stats(0)
        toc = time.time()
        print("Time for generation: {}".format(toc - tic))
        return max_alloc

I ran

local_model = LocalModel(model, tokenizer)

alloc = []
x = [0, 2, 4, 6, 8, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160]
for i in x:
    alloc.append(local_model.generate(prompts[:64], max_new_tokens=i))


plt.scatter(x, alloc)
plt.xlabel("Max New Tokens")
plt.ylabel("Peak Mem Usage / GB")
plt.show()

This is the plot:

Screenshot 2024-01-22 at 20 00 36

Expected behavior

I tried to compute theoretical numbers. I estimated the number of input tokens:

def calculate_prompt_tokens(tokenizer, prompts, batch_size):
    tokenizer.pad_token = tokenizer.eos_token
    tokens = tokenizer(prompts[:batch_size], return_tensors="pt", padding=True)
    return tokens["input_ids"].shape[0] * tokens["input_ids"].shape[1]

calculate_prompt_tokens(tokenizer, prompts, batch_size=64)

which returns 12992. Taking the model to be 7B params ~ 14GB in bf16, and assuming that the kv cache consumes 4*num_layers*d_model = 4*32*4096 = 524,288 bytes/token, we get an estimated 14 + (12992*524288)*1e-9 = 20.8GB before anything is generated, which looks about right from the graph.

Using the same logic, we know that each additional generation step should cost (via the kv cache) 524,288*64 = 0.0034GB / step of memory. Looking at the gradient of the linear portion of the plot, we get ~0.0067GB / step instead, which is around double the amount.

  1. Why is the memory consumed for generation greater than expected?
  2. What's going on in the early portion of the plot? Why is there a big jump at the start?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions