Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

High Memory Usage in DataLoader Workers Leading to Out-of-Memory (OOM) #196

Closed
inigopm opened this issue Aug 30, 2024 · 11 comments
Closed

Comments

@inigopm
Copy link

inigopm commented Aug 30, 2024

I'm experiencing high memory usage in the DataLoader workers when using a custom dataset class for lazy loading large datasets. This leads to Out-of-Memory (OOM) errors during training. I've observed that the MaxRSS (maximum resident set size) steadily increases during training, indicating potential memory leaks or improper memory management in the DataLoader or dataset preprocessing.

Error Message Example:
RuntimeError: DataLoader worker (pid XXXX) is killed by signal: Killed

Setup: Distributed training with 3 nodes, 4 GPUs per node
Memory: 512 GB RAM

Training Configuration
Here are the relevant training configurations used:

#!/bin/bash
source use_env.sh

NNODES=$SLURM_NNODES
WORLD_SIZE=$((NNODES * NUM_GPUS))
NODE_RANK=$SLURM_NODEID
MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
MASTER_PORT=12802

ACCELERATE_CPU_AFFINITY=1 torchrun --nproc_per_node="${NUM_GPUS}" --nnodes="${NNODES}" --node_rank="${NODE_RANK}" --master_addr="${MASTER_ADDR}" --master_port="${MASTER_PORT}" \
    llava/train/train_mem.py \
    --deepspeed scripts/zero3.json \
    --model_name_or_path ${CKPT_PATH} \
    --version ${PROMPT_VERSION} \
    --data_path ./playground/data/llava_v1_5_mix665k_no-ocr.json \
    --image_folder ./playground/data \
    --pretrain_mm_mlp_adapter="./checkpoints/projectors/${BASE_RUN_NAME}/checkpoint-5/mm_projector.bin" \
    --mm_tunable_parts="mm_vision_tower,mm_mlp_adapter,mm_language_model" \
    --mm_vision_tower_lr=2e-6 \
    --vision_tower ${VISION_MODEL_VERSION} \
    --mm_projector_type mlp2x_gelu \
    --mm_vision_select_layer -2 \
    --mm_use_im_start_end False \
    --mm_use_im_patch_token False \
    --group_by_modality_length True \
    --image_aspect_ratio anyres \
    --image_grid_pinpoints "[(384, 768), (768, 384), (768, 768), (1152, 384), (384, 1152)]" \
    --mm_patch_merge_type spatial_unpad \
    --bf16 True \
    --run_name $MID_RUN_NAME \
    --output_dir "./checkpoints/${MID_RUN_NAME}" \
    --num_train_epochs 1 \
    --per_device_train_batch_size 1 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 1 \
    --dataloader_num_workers 4 \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_steps 5000 \
    --save_total_limit 1 \
    --learning_rate 1e-5 \
    --weight_decay 0. \
    --warmup_ratio 0.03 \
    --lr_scheduler_type "cosine" \
    --logging_steps 1 \
    --tf32 True \
    --model_max_length 32768 \
    --gradient_checkpointing True \
    --lazy_preprocess True \
    --report_to wandb \
    --torch_compile True \
    --torch_compile_backend "inductor" \
    --dataloader_drop_last True
  • Could the lazy preprocessing in the dataset be failing to release memory properly?
@mylesgoose
Copy link

--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \ why is yoru gradient accumulation so low? also waht is yoru deep speed config  size for the bin         "overlap_comm": true,
    "contiguous_gradients": true,
    "sub_group_size": 1e9,
    "reduce_bucket_size": "auto",
    "stage3_prefetch_bucket_size": "15099494", waht is this size? try reducing it. 
    "stage3_param_persistence_threshold": "auto",
    "stage3_max_live_parameters": 1e9,
    "stage3_max_reuse_distance": 1e9,
    "stage3_gather_16bit_weights_on_model_save": true
},

@leexinhao
Copy link

Meet same problems!

@Luodian
Copy link
Contributor

Luodian commented Sep 10, 2024

image

I checked our wandb logs (randomly selecting the runs larger than 300 minutes), and there's indeed the sign of leakage (mainly due to decord).

But we didnt encounter this, one reason is that our instance is with 1.8T memory so it's hard to meet the bar. Another reason may be a coincidence that our infra is unstable so the runs are repeatedly stopped due to nccl timeout. So we need to set shorter checkpointing steps, and each time the run re-started, the mem gets back to normal threshold.

@inigopm
Copy link
Author

inigopm commented Sep 10, 2024

Thank you for your response and for checking the logs. I've been investigating further, and I was able to mitigate the issue somewhat by reducing the number of workers in the DataLoader. Although the MaxRSS still increases over time, it's now stable enough that the process doesn't hit OOM before 5000 steps, which is manageable for my experiment.

@inigopm inigopm closed this as completed Sep 10, 2024
@ftgreat
Copy link

ftgreat commented Oct 13, 2024

image I checked our wandb logs (randomly selecting the runs larger than 300 minutes), and there's indeed the sign of leakage (mainly due to `decord`).

But we didnt encounter this, one reason is that our instance is with 1.8T memory so it's hard to meet the bar. Another reason may be a coincidence that our infra is unstable so the runs are repeatedly stopped due to nccl timeout. So we need to set shorter checkpointing steps, and each time the run re-started, the mem gets back to normal threshold.

@Luodian “I checked our wandb logs (randomly selecting the runs larger than 300 minutes), and there's indeed the sign of leakage (mainly due to decord).”

what the decord refers to? Thanks

@guanyanchu
Copy link

image I checked our wandb logs (randomly selecting the runs larger than 300 minutes), and there's indeed the sign of leakage (mainly due to `decord`).

But we didnt encounter this, one reason is that our instance is with 1.8T memory so it's hard to meet the bar. Another reason may be a coincidence that our infra is unstable so the runs are repeatedly stopped due to nccl timeout. So we need to set shorter checkpointing steps, and each time the run re-started, the mem gets back to normal threshold.

I found this problem with mid_stage training as well, and did not use decord.

@slyforce
Copy link

slyforce commented Nov 7, 2024

Don't have time to provide a MR to fix this, but I found the issue on this line

https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/llava/train/train.py#L566

Don't deepcopy the tokeniser and instead just add the token on the fly / pass an appropriate tokeniser object that you want to modify.

Good luck!

@CuriousCat-7
Copy link

Don't have time to provide a MR to fix this, but I found the issue on this line

https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/llava/train/train.py#L566

Don't deepcopy the tokeniser and instead just add the token on the fly / pass an appropriate tokeniser object that you want to modify.

Good luck!

Are you sure? It seems that the deepcopy is ok there. The tokenizer will be removed after the code run out of the scope of the preprocess_qwen

@slyforce
Copy link

Are you sure? It seems that the deepcopy is ok there. The tokenizer will be removed after the code run out of the scope of the preprocess_qwen

It was definitely a source of memory leakage for me. Give it a try, sadly I'm still unable to prepare the MR :(

@CuriousCat-7
Copy link

Are you sure? It seems that the deepcopy is ok there. The tokenizer will be removed after the code run out of the scope of the preprocess_qwen

It was definitely a source of memory leakage for me. Give it a try, sadly I'm still unable to prepare the MR :(

We add:

    del tokenizer

and it is solved. Hardly know why, but it works. Ahhh, I guess I have to say "amazing".

@goodstudent9
Copy link

Are you sure? It seems that the deepcopy is ok there. The tokenizer will be removed after the code run out of the scope of the preprocess_qwen

It was definitely a source of memory leakage for me. Give it a try, sadly I'm still unable to prepare the MR :(

We add:

    del tokenizer

and it is solved. Hardly know why, but it works. Ahhh, I guess I have to say "amazing".

That seems like del tokenizer doesn't work for me. The memory usage still becomes higher and higher......
DO you know why?
I added the del tokeizer in the end of the preprocess_qwen(). But at first, the memory usage is 40G, but after 7 hours, it becomes 60G, so every new epoch will add 10GB in memory......
Why, that is so wired!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

9 participants