-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OOM error running the demo script sft_llama2.py
on A100 GPU
#824
Comments
sft_llama2.py
on A100 GPU
Hi @Emerald01 ! |
Yes, that seems resolved the issue with this flag on + I set Now it comes to another running issue when I tried to run the same script but with deepspeed Zero3. It reports OOM again:
I believe for deepspeed ZERO3 it tries to distribute one single model to all available devices, so obviously this does not work as expected. I think trl claims that those scripts should support deepspeed without any extra code, so I guess it might be some configurations again... Actually I am a little confused here. Since in the original sft_llama2.py, it is using the 4bit quantization to load the entire model to each GPU (it seems pretty efficient that each GPU only used 10G memory). When if deepspeed takes over, what should be expected? Is it 4bit quant + ZERO3, or just ZERO3? The final question is that, does trl support torch.distributed.FSDP? To my feeling, torch provides a much cleaner ZERO3 solution. Anyway, if we can have more detailed documentations or blogs discussing these things, that would be very helpful. |
It looks like it will automatically set the I still have the issue with the following config: compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
offload_optimizer_device: none
offload_param_device: none
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false |
It seems something happend here making the GPU memory explode? if not is_gptq_quantized:
# cast all non INT8 parameters to fp32
for param in model.parameters():
if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
param.data = param.data.to(torch.float32) |
My model initialization is as follows: model = AutoModelForCausalLM.from_pretrained(script_args.model_name_or_path,
low_cpu_mem_usage=True,
quantization_config=bnb_config,
device_map={"": Accelerator().local_process_index},
torch_dtype=torch.bfloat16,
load_in_4bit=True) If I remove this |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
Hello:
I am testing run
trl/examples/research_projects/stack_llama_2/scripts/sft_llama2.py
I pull the latest main branch and pip install locally
The running env: GCP with eight A100 GPUs (memory 40G)
Just follow the README without any change
accelerate launch sft_llama2.py --output_dir=XXX
the accelerate configuration is from the example configure with 8 GPUs
But I got the OOM error:
I checked the memory usage. After finishing model load, i.e.,
base_model = AutoModelForCausalLM.from_pretrained(...)
, the GPU memory usage is only about4707MiB / 40960MiB
So it seems that, the script used
load_in_4bit
to load the model is pretty effective.Right before calling
trainer.train()
, the memory usage is still pretty reasonable, about8071MiB / 40960MiB
But when it starts to execute
trainer.train()
it quickly blows up the memory. I am wondering if there is any obvious problem right here? Any configure issue or some bug that eats up the memory? Since I just followed up the demo code without a single line of change, I hope someone could answer this question so I can gain confidence over this codebase and move on with this.Thank you so much for your help!
The text was updated successfully, but these errors were encountered: