Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OOM when finetuning Llama3.2-90B on 8xA100 80GB #2294

Open
2 of 4 tasks
maximilianmordig opened this issue Oct 29, 2024 · 1 comment
Open
2 of 4 tasks

OOM when finetuning Llama3.2-90B on 8xA100 80GB #2294

maximilianmordig opened this issue Oct 29, 2024 · 1 comment
Assignees
Labels
🐛 bug Something isn't working 🏋 SFT Related to SFT 👁️ VLM Related to Visual Language Models

Comments

@maximilianmordig
Copy link

maximilianmordig commented Oct 29, 2024

System Info

trl, transformers: most recent on github
python 3.10.11
ubuntu 22

package versions:

accelerate==1.0.1
addict==2.4.0
aiohappyeyeballs==2.4.3
aiohttp==3.10.10
aiosignal==1.3.1
anyio==4.6.2.post1
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
async-lru==2.0.4
async-timeout==4.0.3
attrs==24.2.0
babel==2.16.0
beautifulsoup4==4.12.3
bitsandbytes==0.44.1
black==24.10.0
bleach==6.1.0
certifi==2024.8.30
cffi==1.17.1
cfgv==3.4.0
charset-normalizer==3.4.0
click==8.1.7
cloudpickle==3.1.0
comm==0.2.2
contourpy==1.3.0
cycler==0.12.1
dask==2024.10.0
datasets==3.0.2
debugpy==1.8.7
decorator==5.1.1
defusedxml==0.7.1
dill==0.3.8
distlib==0.3.9
docker-pycreds==0.4.0
docstring_parser==0.16
et_xmlfile==2.0.0
evaluate==0.4.3
exceptiongroup==1.2.2
executing==2.1.0
fastjsonschema==2.20.0
filelock==3.16.1
fonttools==4.54.1
fqdn==1.5.1
frozenlist==1.5.0
fsspec==2024.9.0
gitdb==4.0.11
GitPython==3.1.43
h11==0.14.0
httpcore==1.0.6
httpx==0.27.2
huggingface-hub==0.26.2
identify==2.6.1
idna==3.10
importlib_metadata==8.5.0
iniconfig==2.0.0
ipykernel==6.29.5
ipython==8.29.0
ipywidgets==8.1.5
isoduration==20.11.0
isort==5.13.2
jedi==0.19.1
Jinja2==3.1.4
joblib==1.4.2
json5==0.9.25
jsonpointer==3.0.0
jsonschema==4.23.0
jsonschema-specifications==2024.10.1
jupyter==1.1.1
jupyter-console==6.6.3
jupyter-events==0.10.0
jupyter-lsp==2.2.5
jupyter_client==8.6.3
jupyter_core==5.7.2
jupyter_server==2.14.2
jupyter_server_terminals==0.5.3
jupyterlab==4.2.5
jupyterlab_pygments==0.3.0
jupyterlab_server==2.27.3
jupyterlab_widgets==3.0.13
kiwisolver==1.4.7
lightgbm==4.5.0
locket==1.0.0
markdown-it-py==3.0.0
MarkupSafe==3.0.2
matplotlib==3.9.2
matplotlib-inline==0.1.7
mdurl==0.1.2
mistune==3.0.2
more-itertools==10.5.0
mpmath==1.3.0
multidict==6.1.0
multiprocess==0.70.16
mypy-extensions==1.0.0
nbclient==0.10.0
nbconvert==7.16.4
nbformat==5.10.4
nest-asyncio==1.6.0
networkx==3.4.2
nodeenv==1.9.1
notebook==7.2.2
notebook_shim==0.2.4
numpy==2.1.2
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==9.1.0.70
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.21.5
nvidia-nvjitlink-cu12==12.4.127
nvidia-nvtx-cu12==12.1.105
openpyxl==3.1.5
overrides==7.7.0
packaging==24.1
pandas==2.2.3
pandocfilters==1.5.1
parso==0.8.4
partd==1.4.2
pathspec==0.12.1
peft==0.13.2
pexpect==4.9.0
pillow==11.0.0
platformdirs==4.3.6
pluggy==1.5.0
pre_commit==4.0.1
prometheus_client==0.21.0
prompt_toolkit==3.0.48
propcache==0.2.0
protobuf==5.28.3
psutil==6.1.0
ptyprocess==0.7.0
pure_eval==0.2.3
pyarrow==18.0.0
pycparser==2.22
Pygments==2.18.0
pyparsing==3.2.0
pytest==8.3.3
python-dateutil==2.9.0.post0
python-json-logger==2.0.7
pytz==2024.2
PyYAML==6.0.2
pyzmq==26.2.0
referencing==0.35.1
regex==2024.9.11
requests==2.32.3
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==13.9.3
rpds-py==0.20.0
safetensors==0.4.5
scikit-learn==1.5.2
scipy==1.14.1
seaborn==0.13.2
Send2Trash==1.8.3
sentry-sdk==2.17.0
setproctitle==1.3.3
shtab==1.7.1
six==1.16.0
smmap==5.0.1
sniffio==1.3.1
soupsieve==2.6
stack-data==0.6.3
sympy==1.13.1
terminado==0.18.1
threadpoolctl==3.5.0
timm==1.0.11
tinycss2==1.4.0
tokenizers==0.20.1
tomli==2.0.2
toolz==1.0.0
torch==2.5.1+cu121
torchaudio==2.5.1+cu121
torchvision==0.20.1+cu121
tornado==6.4.1
tqdm==4.66.6
traitlets==5.14.3
transformers==4.46.0
triton==3.1.0
trl @ git+https://github.com/huggingface/trl@b2696578ce6db1749a250661b507bf8b90e14dd5
types-python-dateutil==2.9.0.20241003
typing_extensions==4.12.2
tyro==0.8.14
tzdata==2024.2
uri-template==1.3.0
urllib3==2.2.3
virtualenv==20.27.1
wandb==0.18.5
wcwidth==0.2.13
webcolors==24.8.0
webencodings==0.5.1
websocket-client==1.8.0
wget==3.2
widgetsnbextension==4.0.13
xxhash==3.5.0
yarl==1.17.0
zipp==3.20.2

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

Using FSDP with accelerate launch:

accelerate launch --num_processes 8 \
    --config_file "$FSDP_CONFIG_FILE" \
    examples/scripts/sft_vlm.py \
    --model_name_or_path "meta-llama/Llama-3.2-90B-Vision-Instruct" \
    --dataset_name HuggingFaceH4/llava-instruct-mix-vsft \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 8 \
    --output_dir test_trl_llama32 \
    --bf16 \
    --torch_dtype bfloat16 \
    --gradient_checkpointing

FSDP config file:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
# dynamo_config:
#   dynamo_backend: INDUCTOR
#enable_cpu_affinity: false
fsdp_config:
  #fsdp_activation_checkpointing: false
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: true
  fsdp_forward_prefetch: false
  #fsdp_offload_params: true
  fsdp_offload_params: false
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  #fsdp_use_orig_params: true
  fsdp_use_orig_params: false
machine_rank: 0
main_process_ip: 1.1.1.1
main_process_port: 29500
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 0
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Expected behavior

It should not run into an OOM, but it also runs into OOM with LoRA enabled.

@tanaybaswa
Copy link

I have a similar issue trying to fine tune a 12B model on 8xH100s

@qgallouedec qgallouedec added 🐛 bug Something isn't working 🏋 SFT Related to SFT 👁️ VLM Related to Visual Language Models labels Dec 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🐛 bug Something isn't working 🏋 SFT Related to SFT 👁️ VLM Related to Visual Language Models
Projects
None yet
Development

No branches or pull requests

4 participants