Skip to content

Out of memory AOTI using llama 3.1 8b on RTX 4090 #1302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
byjlw opened this issue Oct 15, 2024 · 4 comments · Fixed by #1337
Closed

Out of memory AOTI using llama 3.1 8b on RTX 4090 #1302

byjlw opened this issue Oct 15, 2024 · 4 comments · Fixed by #1337
Assignees
Labels
Compile / AOTI Issues related to AOT Inductor and torch compile

Comments

@byjlw
Copy link
Contributor

byjlw commented Oct 15, 2024

🐛 Describe the bug

I'm getting an out of memory error when trying to use an exported .so file. The .so is 16GB my GPU has 24GB of memory

(.venv) warden@Vikander:~/source/torchchat$ python3 torchchat.py generate llama3.1 --dso-path exportedModels/llama3.1.so --prompt "Hello my name is" --device cuda
Warning: checkpoint path ignored because an exported DSO or PTE path specified
Warning: checkpoint path ignored because an exported DSO or PTE path specified
Using device=cuda NVIDIA GeForce RTX 4090
Loading model...
Time to load model: 1.85 seconds
Error: CUDA error: out of memory
Traceback (most recent call last):
  File "/home/warden/source/torchchat/torchchat/cli/builder.py", line 536, in _initialize_model
    model.forward = torch._export.aot_load(
                    ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/warden/source/torchchat/.venv/lib/python3.11/site-packages/torch/_export/__init__.py", line 320, in aot_load
    runner = torch._C._aoti.AOTIModelContainerRunnerCuda(so_path, 1, device)  # type: ignore[assignment, call-arg]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: create_func_( &container_handle_, num_models, device_str.c_str(), cubin_dir.empty() ? nullptr : cubin_dir.c_str()) API call failed at ../torch/csrc/inductor/aoti_runner/model_container_runner.cpp, line 85

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/warden/source/torchchat/torchchat.py", line 88, in <module>
    generate_main(args)
  File "/home/warden/source/torchchat/torchchat/generate.py", line 1215, in main
    gen = Generator(
          ^^^^^^^^^^
  File "/home/warden/source/torchchat/torchchat/generate.py", line 290, in __init__
    self.model = _initialize_model(self.builder_args, self.quantize, self.tokenizer)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/warden/source/torchchat/torchchat/cli/builder.py", line 540, in _initialize_model
    raise RuntimeError(f"Failed to load AOTI compiled {builder_args.dso_path}")
RuntimeError: Failed to load AOTI compiled exportedModels/llama3.1.so

Versions

Operating System Information
Linux Vikander 6.8.0-45-generic #45-Ubuntu SMP PREEMPT_DYNAMIC Fri Aug 30 12:02:04 UTC 2024 x86_64 x86_64 x86_64 GNU/Linux

PRETTY_NAME="Ubuntu 24.04.1 LTS"
NAME="Ubuntu"
VERSION_ID="24.04"
VERSION="24.04.1 LTS (Noble Numbat)"
VERSION_CODENAME=noble
ID=ubuntu
ID_LIKE=debian
HOME_URL="https://www.ubuntu.com/"
SUPPORT_URL="https://help.ubuntu.com/"
BUG_REPORT_URL="https://bugs.launchpad.net/ubuntu/"
PRIVACY_POLICY_URL="https://www.ubuntu.com/legal/terms-and-policies/privacy-policy"
UBUNTU_CODENAME=noble
LOGO=ubuntu-logo

Python Version
Python 3.11.10

PIP Version
pip 24.0 from /home/warden/source/torchchat/.venv/lib/python3.11/site-packages/pip (python 3.11)

Installed Packages
absl-py==2.1.0
accelerate==1.0.1
aiohappyeyeballs==2.4.3
aiohttp==3.10.10
aiosignal==1.3.1
altair==5.4.1
annotated-types==0.7.0
antlr4-python3-runtime==4.9.3
anyio==4.6.2.post1
attrs==24.2.0
blinker==1.8.2
blobfile==3.0.0
cachetools==5.5.0
certifi==2024.8.30
chardet==5.2.0
charset-normalizer==3.4.0
click==8.1.7
cmake==3.30.4
colorama==0.4.6
DataProperty==1.0.1
datasets==3.0.1
dill==0.3.8
distro==1.9.0
evaluate==0.4.3
filelock==3.16.1
Flask==3.0.3
frozenlist==1.4.1
fsspec==2024.6.1
gguf==0.10.0
gitdb==4.0.11
GitPython==3.1.43
h11==0.14.0
httpcore==1.0.6
httpx==0.27.2
huggingface-hub==0.25.2
idna==3.10
itsdangerous==2.2.0
Jinja2==3.1.4
jiter==0.6.1
joblib==1.4.2
jsonlines==4.0.0
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
lm_eval==0.4.2
lxml==5.3.0
markdown-it-py==3.0.0
MarkupSafe==3.0.1
mbstrdecoder==1.1.3
mdurl==0.1.2
more-itertools==10.5.0
mpmath==1.3.0
multidict==6.1.0
multiprocess==0.70.16
narwhals==1.9.3
networkx==3.4.1
ninja==1.11.1.1
nltk==3.9.1
numexpr==2.10.1
numpy==1.26.4
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.6.77
nvidia-nvtx-cu12==12.1.105
omegaconf==2.3.0
openai==1.51.2
packaging==24.1
pandas==2.2.3
pathvalidate==3.2.1
peft==0.13.2
pillow==10.4.0
portalocker==2.10.1
propcache==0.2.0
protobuf==5.28.2
psutil==6.0.0
pyarrow==17.0.0
pybind11==2.13.6
pycryptodomex==3.21.0
pydantic==2.9.2
pydantic_core==2.23.4
pydeck==0.9.1
Pygments==2.18.0
pytablewriter==1.2.0
python-dateutil==2.9.0.post0
pytorch-triton==3.1.0+cf34004b8a
pytz==2024.2
PyYAML==6.0.2
referencing==0.35.1
regex==2024.9.11
requests==2.32.3
rich==13.9.2
rouge-score==0.1.2
rpds-py==0.20.0
sacrebleu==2.4.3
safetensors==0.4.5
scikit-learn==1.5.2
scipy==1.14.1
sentencepiece==0.2.0
six==1.16.0
smmap==5.0.1
snakeviz==2.2.0
sniffio==1.3.1
sqlitedict==2.1.0
streamlit==1.39.0
sympy==1.13.1
tabledata==1.3.3
tabulate==0.9.0
tcolorpy==0.1.6
tenacity==9.0.0
threadpoolctl==3.5.0
tiktoken==0.8.0
tokenizers==0.20.1
toml==0.10.2
torch==2.6.0.dev20241002+cu121
torchao==0.5.0
torchtune==0.3.0.dev20240928+cu121
torchvision==0.20.0.dev20241002+cu121
tornado==6.4.1
tqdm==4.66.5
tqdm-multiprocess==0.0.11
transformers==4.45.2
typepy==1.3.2
typing_extensions==4.12.2
tzdata==2024.2
urllib3==2.2.3
watchdog==5.0.3
Werkzeug==3.0.4
word2number==1.1
xxhash==3.5.0
yarl==1.15.2
zstandard==0.23.0
zstd==1.5.5.1

PyTorch Version
2.6.0.dev20241002+cu121

@desertfire
Copy link
Contributor

Just to double check, does torch.compile work?

@byjlw
Copy link
Contributor Author

byjlw commented Oct 16, 2024

Just to double check, does torch.compile work?

Yes --compile works but it's 3x slower than eager

also if i quantize the model during export for AOTI i don't OOM

@byjlw byjlw added the Compile / AOTI Issues related to AOT Inductor and torch compile label Oct 23, 2024
@desertfire
Copy link
Contributor

also reported in #996

@desertfire desertfire self-assigned this Nov 1, 2024
desertfire added a commit that referenced this issue Nov 1, 2024
Summary: Fixes #1302. Because AOTI-compiled model contains a copy of model weights, we need to release the corresponding eager model weights in the Python deployment path.
@desertfire
Copy link
Contributor

#1337 should fix the issue. @byjlw , please give it a try and let me know the result.

Jack-Khuu added a commit that referenced this issue Nov 6, 2024
* [AOTI] Remove the original model weights in Python deployment

Summary: Fixes #1302. Because AOTI-compiled model contains a copy of model weights, we need to release the corresponding eager model weights in the Python deployment path.

* Revert "[AOTI] Remove the original model weights in Python deployment"

This reverts commit 962ec0d.

* Refactor the code

* Add setup_cache for aoti_package_path

---------

Co-authored-by: Jack-Khuu <jack.khuu.7@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Compile / AOTI Issues related to AOT Inductor and torch compile
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants