Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deepspeed resume from ckpt fixes and adding support for deepspeed optimizer and HF scheduler #25863

Merged
merged 13 commits into from
Sep 5, 2023

Conversation

pacman100
Copy link
Contributor

@pacman100 pacman100 commented Aug 30, 2023

What does this PR do?

  1. Add support for deep speed optimizer and HF scheduler. Should be merged after Add support for deepspeed optimizer and custom scheduler accelerate#1909
  2. Fixing the lr scheduler not being saved by passing them to the DeepSpeed engine for all schedulers that are instances of LRScheduler. Should be merged after Add support for deepspeed optimizer and custom scheduler accelerate#1909

Below we will run the 4 combinations of optimizer and schedulers for the run_glue.py transformers example
Initial setup:

cd transformers
export CUDA_VISISBLE_DEVICES=0,1
export TASK_NAME=mrpc

a. HF Optimizer + HF Scheduler Case:

i. ds config ds_config_z3_hf_optim_hf_scheduler.json:

{
  "fp16": {
    "enabled": "auto",
    "loss_scale": 0,
    "loss_scale_window": 1000,
    "initial_scale_power": 16,
    "hysteresis": 2,
    "min_loss_scale": 1
  },

  "bf16": {
    "enabled": "auto"
  },
  "zero_optimization": {
    "stage": 3,
    "offload_optimizer": {
      "device": "cpu",
      "pin_memory": true
    },
    "offload_param": {
      "device": "cpu",
      "pin_memory": true
    },
    "overlap_comm": true,
    "contiguous_gradients": true,
    "sub_group_size": 1e9,
    "reduce_bucket_size": "auto",
    "stage3_prefetch_bucket_size": "auto",
    "stage3_param_persistence_threshold": "auto",
    "stage3_max_live_parameters": 1e9,
    "stage3_max_reuse_distance": 1e9,
    "stage3_gather_16bit_weights_on_model_save": true
  },

  "gradient_accumulation_steps": "auto",
  "gradient_clipping": "auto",
  "steps_per_print": 2000,
  "train_batch_size": "auto",
  "train_micro_batch_size_per_gpu": "auto",
  "wall_clock_breakdown": false
}

ii. command to run:

torchrun --nnodes 1 --nproc-per-node 2 ./examples/pytorch/text-classification/run_glue.py --model_name_or_path bert-base-cased  --task_name $TASK_NAME  --do_train  --do_eval  --max_seq_length 128  --per_device_train_batch_size 16  --learning_rate 5e-5  --num_train_epochs 3  --output_dir /tmp/$TASK_NAME/ --overwrite_output_dir --deepspeed ~/transformers/tests/deepspeed/ds_config_z3_hf_optim_hf_scheduler.json --lr_scheduler_type cosine --save_strategy "epoch" --evaluation_strategy "epoch" --logging_steps 1

Kill the process after epoch 1. run the above command with --resume_from_checkpoint as below:

torchrun --nnodes 1 --nproc-per-node 2 ./examples/pytorch/text-classification/run_glue.py --model_name_or_path bert-base-cased  --task_name $TASK_NAME  --do_train  --do_eval  --max_seq_length 128  --per_device_train_batch_size 16  --learning_rate 5e-5  --num_train_epochs 3  --output_dir /tmp/$TASK_NAME/ --overwrite_output_dir --deepspeed ~/transformers/tests/deepspeed/ds_config_z3_hf_optim_hf_scheduler.json --lr_scheduler_type cosine --save_strategy "epoch" --evaluation_strategy "epoch" --logging_steps 1 --resume_from_checkpoint /tmp/$TASK_NAME/checkpoint-115/

iii. Plots of loss and learning rate:
Screenshot 2023-09-02 at 2 23 02 AM

a. DS Optimizer + DS Scheduler Case:
i. ds config ds_config_z3_ds_optim_ds_scheduler.json:

{
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },

    "bf16": {
        "enabled": "auto"
    },

    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "betas": "auto",
            "eps": "auto",
            "weight_decay": "auto"
        }
    },

    "scheduler": {
        "type": "WarmupDecayLR",
        "params": {
	    "total_num_steps": "auto", 
            "warmup_min_lr": "auto",
            "warmup_max_lr": "auto",
            "warmup_num_steps": "auto"
        }
    },

    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "offload_param": {
            "device": "cpu",
            "pin_memory": true
        },
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e9,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_16bit_weights_on_model_save": true
    },

    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "steps_per_print": 2000,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": false
}

rest of the steps as above. Plots:
Screenshot 2023-09-02 at 2 24 47 AM

c. HF Optimizer + DS Scheduler Case:
i. ds config ds_config_z3_hf_optim_ds_scheduler.json:

{
  "fp16": {
    "enabled": "auto",
    "loss_scale": 0,
    "loss_scale_window": 1000,
    "initial_scale_power": 16,
    "hysteresis": 2,
    "min_loss_scale": 1
  },

  "bf16": {
    "enabled": "auto"
  },

  "scheduler": {
    "type": "WarmupDecayLR",
    "params": {
	  "total_num_steps": "auto", 
      "warmup_min_lr": "auto",
      "warmup_max_lr": "auto",
      "warmup_num_steps": "auto"
    }
  },

  "zero_optimization": {
    "stage": 3,
    "offload_optimizer": {
      "device": "cpu",
      "pin_memory": true
    },
    "offload_param": {
      "device": "cpu",
      "pin_memory": true
    },
    "overlap_comm": true,
    "contiguous_gradients": true,
    "sub_group_size": 1e9,
    "reduce_bucket_size": "auto",
    "stage3_prefetch_bucket_size": "auto",
    "stage3_param_persistence_threshold": "auto",
    "stage3_max_live_parameters": 1e9,
    "stage3_max_reuse_distance": 1e9,
    "stage3_gather_16bit_weights_on_model_save": true
  },

  "gradient_accumulation_steps": "auto",
  "gradient_clipping": "auto",
  "steps_per_print": 2000,
  "train_batch_size": "auto",
  "train_micro_batch_size_per_gpu": "auto",
  "wall_clock_breakdown": false
}

rest of the steps as above. Plots:
Screenshot 2023-09-02 at 2 27 02 AM

c. DS Optimizer + HF Scheduler Case:
i. ds config ds_config_z3_ds_optim_hf_scheduler.json:

{
    "fp16": {
        "enabled": "auto",
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },

    "bf16": {
        "enabled": "auto"
    },

    "optimizer": {
        "type": "AdamW",
        "params": {
            "lr": "auto",
            "betas": "auto",
            "eps": "auto",
            "weight_decay": "auto"
        }
    },
    "zero_optimization": {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": true
        },
        "offload_param": {
            "device": "cpu",
            "pin_memory": true
        },
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e9,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_16bit_weights_on_model_save": true
    },

    "gradient_accumulation_steps": "auto",
    "gradient_clipping": "auto",
    "steps_per_print": 2000,
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "wall_clock_breakdown": false
}

rest of the steps as above. Plots:
Screenshot 2023-09-02 at 2 30 10 AM

  1. Adding tests to check the resume from ckpt is working properly with DeepSpeed.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 30, 2023

The documentation is not available anymore as the PR was closed or merged.

@pacman100 pacman100 mentioned this pull request Aug 31, 2023
4 tasks
@pacman100 pacman100 changed the title Add support for deepspeed optimizer and HF scheduler deepspeed resume from ckpt fixes and adding support for deepspeed optimizer and HF scheduler Aug 31, 2023
@pacman100 pacman100 marked this pull request as ready for review September 4, 2023 14:42
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for tackling the issue you described and conducting (+ showing) the experiments you ran to prove that it works. Personally, I miss the experience with deepspeed required to understand the bigger picture, so I cannot provide a full on review, only some small comments.

tests/deepspeed/test_deepspeed.py Outdated Show resolved Hide resolved
src/transformers/integrations/deepspeed.py Show resolved Hide resolved
Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks! Let's definitely keep an eye out for pickle problems, and be prepared to move that to a util if needed

@pacman100 pacman100 merged commit 6bc517c into main Sep 5, 2023
@pacman100 pacman100 deleted the smangrul/ds-fixes-and-tests branch September 5, 2023 17:01
parambharat pushed a commit to parambharat/transformers that referenced this pull request Sep 26, 2023
…imizer and HF scheduler (huggingface#25863)

* Add support for deepspeed optimizer and HF scheduler

* fix bug

* fix the import

* fix issue with deepspeed scheduler saving for hf optim + hf scheduler scenario

* fix loading of hf scheduler when loading deepspeed checkpoint

* fix import of `DeepSpeedSchedulerWrapper`

* add tests

* add the comment and skip the failing tests

* address comment
LysandreJik pushed a commit that referenced this pull request Sep 26, 2023
…imizer and HF scheduler (#25863)

* Add support for deepspeed optimizer and HF scheduler

* fix bug

* fix the import

* fix issue with deepspeed scheduler saving for hf optim + hf scheduler scenario

* fix loading of hf scheduler when loading deepspeed checkpoint

* fix import of `DeepSpeedSchedulerWrapper`

* add tests

* add the comment and skip the failing tests

* address comment
blbadger pushed a commit to blbadger/transformers that referenced this pull request Nov 8, 2023
…imizer and HF scheduler (huggingface#25863)

* Add support for deepspeed optimizer and HF scheduler

* fix bug

* fix the import

* fix issue with deepspeed scheduler saving for hf optim + hf scheduler scenario

* fix loading of hf scheduler when loading deepspeed checkpoint

* fix import of `DeepSpeedSchedulerWrapper`

* add tests

* add the comment and skip the failing tests

* address comment
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 18, 2023
…imizer and HF scheduler (huggingface#25863)

* Add support for deepspeed optimizer and HF scheduler

* fix bug

* fix the import

* fix issue with deepspeed scheduler saving for hf optim + hf scheduler scenario

* fix loading of hf scheduler when loading deepspeed checkpoint

* fix import of `DeepSpeedSchedulerWrapper`

* add tests

* add the comment and skip the failing tests

* address comment
@sunyclj
Copy link

sunyclj commented Dec 21, 2023

DS Optimizer + HF Scheduler
Use the configuration in DS Optimizer + HF Scheduler ,
error : ValueError: You cannot create a DummyScheduler without specifying a scheduler in the config file.
It may have something to do with the version of accelerate. my version of accelerate is 0.22.0,what's your version, please.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants