Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle MoE models with DeepSpeed #2662

Merged
merged 7 commits into from
Apr 16, 2024

Conversation

pacman100
Copy link
Contributor

@pacman100 pacman100 commented Apr 12, 2024

What does this PR do?

  1. Changes to support MoEs with DeepSpeed. For DS Z3 to work properly without hanging, one needs to follow this Delay reduce-scatter for ZeRO3 leaf modules microsoft/DeepSpeed#5008 so that DeepSpeed doesn't split the MOE layers further which can be called in different orders depending on inputs leading to invalid traces as well as hangs during reduce_scatter calls when an expert is unused on 1 rank but used on the other.
  2. This PR enables this via new config argument deepspeed_moe_layer_cls_names. An example of config is given below:
compute_environment: LOCAL_MACHINE                                                                                                                                           
debug: false                                                                                                                                                                 
deepspeed_config:
  deepspeed_multinode_launcher: standard
  deepspeed_moe_layer_cls_names: Qwen2MoeSparseMoeBlock
  offload_optimizer_device: none
  offload_param_device: none
  zero3_init_flag: true
  zero3_save_16bit_model: true
  zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
  1. An end-to-end example is as follows:
    a. Config as given above which can be found at https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/sft/training/configs/deepspeed_config_moe.yaml
    b. MOE training command (https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/sft/training/run_deepspeed_moe.sh):
accelerate launch --config_file "configs/deepspeed_config_moe.yaml" train.py \
--seed 100 \
--model_name_or_path "Qwen/Qwen1.5-MoE-A2.7B" \
--dataset_name "smangrul/ultrachat-10k-chatml" \
--chat_template_format "chatml" \
--add_special_tokens False \
--append_concat_token False \
--splits "train,test" \
--max_seq_len 2048 \
--num_train_epochs 1 \
--logging_steps 5 \
--log_level "info" \
--logging_strategy "steps" \
--evaluation_strategy "epoch" \
--save_strategy "epoch" \
--push_to_hub \
--hub_private_repo True \
--hub_strategy "every_save" \
--bf16 True \
--packing True \
--learning_rate 1e-4 \
--lr_scheduler_type "cosine" \
--weight_decay 1e-4 \
--warmup_ratio 0.0 \
--max_grad_norm 1.0 \
--output_dir "qwen-moe-sft-qlora-ds" \
--per_device_train_batch_size 2 \
--per_device_eval_batch_size 2 \
--gradient_accumulation_steps 2 \
--gradient_checkpointing True \
--use_reentrant True \
--dataset_text_field "content" \
--use_flash_attn True

c. output logs:

}
***** Running training *****
  Num examples = 5,742
  Num Epochs = 1
  Instantaneous batch size per device = 2
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 2
  Total optimization steps = 359
  Number of trainable parameters = 14,314,637,312
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
wandb: Currently logged in as: smangrul. Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.16.6
wandb: Run data is saved locally in /raid/sourab/DHS-LLM-Workshop/chat_assistant/sft/training/wandb/run-20240412_141844-4nz3s3bp
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run desert-violet-549
wandb: ⭐️ View project at https://wandb.ai/smangrul/huggingface
wandb: 🚀 View run at https://wandb.ai/smangrul/huggingface/runs/4nz3s3bp
  0%|▎                                                                                                                                     | 1/359 [00:38<3:47:30, 38.13s/it]
  1%|▋                                                                                                                                     | 2/359 [00:50<2:16:18, 22.91s/it]e
{'loss': 1.5661, 'grad_norm': 4.012563807969557, 'learning_rate': 9.995214563286675e-05, 'epoch': 0.01}                                                                      
  1%|█▊                                                                                                                                    | 5/359 [02:33<3:11:31, 32.46s/it]
  2%|██▏                                                                                                                                   | 6/359 [03:07<3:14:00, 32.98s/it]

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, looks good from my point of view. Only some nits.

src/accelerate/commands/config/cluster.py Outdated Show resolved Hide resolved
src/accelerate/utils/dataclasses.py Outdated Show resolved Hide resolved
src/accelerate/utils/dataclasses.py Show resolved Hide resolved
Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Overall LG2M bar the documentation nits of at least mentioning MoE in full somewhere :)

pacman100 and others added 3 commits April 16, 2024 15:32
Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
@pacman100 pacman100 merged commit 701e24c into huggingface:main Apr 16, 2024
23 checks passed
@fabianlim
Copy link
Contributor

@pacman100 just curious i noticed that FullyShardedDataParallelPlugin.get_module_class_from_name is removed in this PR, but this is used in other places such as in the peft package

@pacman100
Copy link
Contributor Author

@fabianlim,

Thank you for bringing this to our attention, nice catch! PR huggingface/peft#1694 fixes this by using the function instead of the static method which has been removed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants