Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama 2 FP8 quantization OOM #288

Closed
0xymoro opened this issue Nov 6, 2023 · 12 comments
Closed

Llama 2 FP8 quantization OOM #288

0xymoro opened this issue Nov 6, 2023 · 12 comments
Assignees
Labels
triaged Issue has been triaged by maintainers

Comments

@0xymoro
Copy link
Contributor

0xymoro commented Nov 6, 2023

Hi, playing around with quantizing a 70b model. Even with 4x A100s 80gb each it is OOM'ing, is this normal? It seems to be splitting the model on the GPUs correctly but not sure if it's splitting inference memory or just using the 1st GPU? Is there a general guide on size of model -> VRAM needed to quantize it to fp8 or some way for the quantization code to use all gpu memory available?

To reproduce:

  1. Build TRT-LLM into docker container
  2. Install ammo requirements from the documentation
  3. Run the default example quantize 70b command from examples/llama
@byshiue
Copy link
Collaborator

byshiue commented Nov 6, 2023

Can you share your scripts to build the engine?

@byshiue byshiue added the triaged Issue has been triaged by maintainers label Nov 6, 2023
@byshiue byshiue self-assigned this Nov 6, 2023
@0xymoro
Copy link
Contributor Author

0xymoro commented Nov 6, 2023

I used the default one from the installation guide: make -C docker release_build after git cloning into what I think is the release 0.5.0 default branch.

For the quantization install, I added this to running the built container in k8s as the runtime command (to later ssh into):

/bin/bash -c "cuda_version=$(nvcc --version | grep 'release' | awk '{print $6}' | awk -F'[V.]' '{print $2$3}');
python_version=$(python3 --version 2>&1 | awk '{print $2}' | awk -F. '{print $1$2}');
wget https://developer.nvidia.com/downloads/assets/cuda/files/nvidia-ammo/nvidia_ammo-0.3.0.tar.gz;
tar -xzf nvidia_ammo-0.3.0.tar.gz;
pip install nvidia_ammo-0.3.0/nvidia_ammo-0.3.0+cu$cuda_version-cp$python_version-cp$python_version-linux_x86_64.whl;
cd /app/tensorrt_llm/examples/quantization;
pip install -r requirements.txt;
ulimit -n 100000;
while true; do sleep 3600; done"

Which is basically copied & formatted almost exactly from the guide at https://github.com/NVIDIA/TensorRT-LLM/blob/d8ebeee2f6fcb219e6efc541ccc914765799fa3a/examples/quantization/README.md.

The only thing I didn't follow exactly was this:

--gpus all --ipc=host --ulimit memlock=-1 --shm-size=20g

But in a k8s environment I didn't have ipc host control always, the ulimit I tried to set during the command, and the gpus are being all used (as the model was split evenly in the 4 A100s)

The actual error is the OOM error during the first forward pass: OutOfMemoryError: CUDA out of memory. Tried to allocate 1.74 GiB. GPU 0 has a total capacty of 79.15 GiB of which 1.02 GiB is free. Process 1528546 has 78.11 GiB memory in use. Of the allocated memory 76.24 GiB is allocated by PyTorch, and 1.38 GiB is
reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

@byshiue
Copy link
Collaborator

byshiue commented Nov 6, 2023

@0xymoro I don't see the scripts to build the engine. Do I miss anything?

@0xymoro
Copy link
Contributor Author

0xymoro commented Nov 6, 2023

@byshiue I haven't gotten to building the engine yet. I'm on the first command of the fp8 quantization to build the npz that will be used to build the engine, from link below. The OOM is in the quantization script, not in building/running the engine. I plan on using the fp8 to then build the engine but this first quantization step isn't going through.

My actual command for quantization is the exact same as the documentation, I tried to keep it as standard as possible as I'm just getting started.

FP8 Post-Training Quantization

@Tracin
Copy link
Collaborator

Tracin commented Nov 6, 2023

Hi, I noticed this issue recently, It is due to LlamaForCausalLM.from_pretrained autocast fp16 model to fp32.
Still do not know why transformers do this.
You can try add .half() after this line. I tried, it works.

@0xymoro
Copy link
Contributor Author

0xymoro commented Nov 6, 2023

@Tracin Where exactly is the hf_llama_convert called for quantize? I'm looking at the quantize.py code https://github.com/NVIDIA/TensorRT-LLM/blob/d8ebeee2f6fcb219e6efc541ccc914765799fa3a/examples/llama/quantize.py

And it seems like it just loads from transformers? A bit confusing as the quantize seems to only interface with ammo and just uses transformers to load it but not do anything else

@Tracin
Copy link
Collaborator

Tracin commented Nov 6, 2023

@Tracin Where exactly is the hf_llama_convert called for quantize? I'm looking at the quantize.py code https://github.com/NVIDIA/TensorRT-LLM/blob/d8ebeee2f6fcb219e6efc541ccc914765799fa3a/examples/llama/quantize.py

And it seems like it just loads from transformers? A bit confusing as the quantize seems to only interface with ammo and just uses transformers to load it but not do anything else

Sorry I link the wrong file. However this OOM is caused by transformers only, you can check model.dtype after from_pretrained and cast it.

@0xymoro
Copy link
Contributor Author

0xymoro commented Nov 6, 2023

Actually looking at this again it seems to already have the load type. Investigating more into this

@0xymoro
Copy link
Contributor Author

0xymoro commented Nov 6, 2023

Ok, so after looking at it it already does the float16 load. So it may just be that ammo takes a lot lot of GPU memory that even 320GB combined of GPU memory is not enough to quantize a 70b... am trying 480 GB now with 10x A40s, and then will go up to 640 GB for 8x H100s if that doesn't work. This does seem a bit excessive for quantization though.

@0xymoro
Copy link
Contributor Author

0xymoro commented Nov 6, 2023

480 gb does not work either with 10x A40. @byshiue what config was used to quant 70b into fp8? It seems very prohibitive to run this quant script as of the current way, looking into it more on the code level now.

@0xymoro
Copy link
Contributor Author

0xymoro commented Nov 6, 2023

I figured it out. This code is wrong (because it's hugely memory inefficient), padding essentially isn't doing anything because it's incorrect Transformers syntax (https://huggingface.co/docs/transformers/v4.35.0/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.batch_encode_plus)

batch_encoded = tokenizer.batch_encode_plus(dataset,
                                            return_tensors="pt",
                                            padding=True,
                                            max_length=block_size)

This works and pads/truncates to the length specified:
batch_encoded = tokenizer.batch_encode_plus(dataset,
return_tensors="pt",
padding="max_length",
truncation=True,
max_length=block_size)

With this change I'm able to run it normally. On my 5x A100 setup it's using ~160gb gpu ram which is perfectly reasonable for 70b, with 2x A100 barely cutting it (might run into OOM) but 3x A100 definitely enough or 4x A6000. Going to make a PR for this for review.

@byshiue
Copy link
Collaborator

byshiue commented Dec 12, 2023

Thank you for the feedback. We have fixed it internally and will update to github in near future. Close this issue. Feel free to ask here if you still have question/issue, we will reopen the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

3 participants