diff --git a/README.md b/README.md index 1ef05855c..f47beac77 100644 --- a/README.md +++ b/README.md @@ -861,7 +861,7 @@ tokens: fsdp: fsdp_config: -# Deepspeed config path. e.g., deepspeed/zero3.json +# Deepspeed config path. e.g., deepspeed_configs/zero3.json deepspeed: # Advanced DDP Arguments @@ -982,11 +982,11 @@ for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usa We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3. ```yaml -deepspeed: deepspeed/zero1.json +deepspeed: deepspeed_configs/zero1.json ``` ```shell -accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json +accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed_configs/zero1.json ``` ##### FSDP diff --git a/deepspeed/zero1.json b/deepspeed_configs/zero1.json similarity index 100% rename from deepspeed/zero1.json rename to deepspeed_configs/zero1.json diff --git a/deepspeed/zero2.json b/deepspeed_configs/zero2.json similarity index 100% rename from deepspeed/zero2.json rename to deepspeed_configs/zero2.json diff --git a/deepspeed/zero3.json b/deepspeed_configs/zero3.json similarity index 100% rename from deepspeed/zero3.json rename to deepspeed_configs/zero3.json diff --git a/deepspeed/zero3_bf16.json b/deepspeed_configs/zero3_bf16.json similarity index 100% rename from deepspeed/zero3_bf16.json rename to deepspeed_configs/zero3_bf16.json diff --git a/examples/llama-2/fft_optimized.yml b/examples/llama-2/fft_optimized.yml index e2388ec67..c2639050e 100644 --- a/examples/llama-2/fft_optimized.yml +++ b/examples/llama-2/fft_optimized.yml @@ -62,7 +62,7 @@ evals_per_epoch: 4 eval_table_size: saves_per_epoch: 1 debug: -deepspeed: #deepspeed/zero2.json # multi-gpu only +deepspeed: #deepspeed_configs/zero2.json # multi-gpu only weight_decay: 0.1 fsdp: fsdp_config: diff --git a/examples/mistral/Mistral-7b-example/code.ipynb b/examples/mistral/Mistral-7b-example/code.ipynb index 756988006..7e84d8124 100644 --- a/examples/mistral/Mistral-7b-example/code.ipynb +++ b/examples/mistral/Mistral-7b-example/code.ipynb @@ -942,7 +942,7 @@ "not only optimizer states but also gradients and parameters across GPUs. The bf16 indicate mixed precision training using bfloat16.\n", "For more information read axolotl's readme\n", "\"\"\"\n", - "!accelerate launch -m axolotl.cli.train /folder/config.yml --deepspeed deepspeed/zero3_bf16.json" + "!accelerate launch -m axolotl.cli.train /folder/config.yml --deepspeed deepspeed_configs/zero3_bf16.json" ] } ], diff --git a/examples/mistral/Mistral-7b-example/config.yml b/examples/mistral/Mistral-7b-example/config.yml index 84be18d15..d28d8f6b7 100644 --- a/examples/mistral/Mistral-7b-example/config.yml +++ b/examples/mistral/Mistral-7b-example/config.yml @@ -65,7 +65,7 @@ eval_table_max_new_tokens: 128 saves_per_epoch: 1 debug: #default deepspeed, can use more aggresive if needed like zero2, zero3 -deepspeed: deepspeed/zero1.json +deepspeed: deepspeed_configs/zero1.json weight_decay: 0.0 fsdp: fsdp_config: diff --git a/examples/mistral/README.md b/examples/mistral/README.md index d1efb2cab..462c2d3e7 100644 --- a/examples/mistral/README.md +++ b/examples/mistral/README.md @@ -8,5 +8,5 @@ accelerate launch -m axolotl.cli.train examples/mistral/config.yml If you run into CUDA OOM, use deepspeed with config zero2.json: ```shell -accelerate launch -m axolotl.cli.train examples/mistral/config.yml --deepspeed deepspeed/zero2.json +accelerate launch -m axolotl.cli.train examples/mistral/config.yml --deepspeed deepspeed_configs/zero2.json ``` diff --git a/examples/mistral/mixtral.yml b/examples/mistral/mixtral.yml index cb14c6745..4489a272a 100644 --- a/examples/mistral/mixtral.yml +++ b/examples/mistral/mixtral.yml @@ -84,7 +84,7 @@ eval_table_size: eval_table_max_new_tokens: 128 saves_per_epoch: 1 debug: -deepspeed: deepspeed/zero2.json +deepspeed: deepspeed_configs/zero2.json weight_decay: 0.0 fsdp: fsdp_config: diff --git a/examples/phi/README.md b/examples/phi/README.md index 1109db0b5..1b9e8022e 100644 --- a/examples/phi/README.md +++ b/examples/phi/README.md @@ -3,7 +3,7 @@ Due to some nuances with the phi code, please use deepspeed when training phi for full finetune. ```shell -accelerate launch -m axolotl.cli.train examples/phi/phi-ft.yml --deepspeed deepspeed/zero1.json +accelerate launch -m axolotl.cli.train examples/phi/phi-ft.yml --deepspeed deepspeed_configs/zero1.json # OR diff --git a/src/axolotl/monkeypatch/mixtral/__init__.py b/src/axolotl/monkeypatch/mixtral/__init__.py index fb4017230..112b0dee9 100644 --- a/src/axolotl/monkeypatch/mixtral/__init__.py +++ b/src/axolotl/monkeypatch/mixtral/__init__.py @@ -1,12 +1,61 @@ """ Patches to support multipack for mixtral """ +import torch import transformers from axolotl.monkeypatch.utils import get_unpad_data -def replace_mixtral_attn_with_multipack_flash_attn(): +def patch_mixtral_moe_forward_zero3() -> None: + import torch.nn.functional as F + + def mlp_forward(self, hidden_states): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3( + hidden_states + ) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + # Ref. https://huggingface.co/deepseek-ai/deepseek-moe-16b-base/blob/main/modeling_deepseek.py + def moe_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + topk_weight, topk_idx = torch.topk( + routing_weights, self.top_k, dim=-1, sorted=False + ) + topk_weight /= topk_weight.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + topk_weight = topk_weight.to(hidden_states.dtype) + + hidden_states = hidden_states.repeat_interleave(self.top_k, dim=0) + y = torch.empty_like(hidden_states) # pylint: disable=invalid-name + flat_topk_idx = topk_idx.view(-1) + for i in range(self.num_experts): + expert = self.experts[i] + y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]) + y = ( # pylint: disable=invalid-name + y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1) + ).sum(dim=1) + final_hidden_states = y.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + from transformers.models.mixtral.modeling_mixtral import ( + MixtralBLockSparseTop2MLP, + MixtralSparseMoeBlock, + ) + + MixtralBLockSparseTop2MLP.forward = mlp_forward + MixtralSparseMoeBlock.forward = moe_forward + + +def replace_mixtral_attn_with_multipack_flash_attn(for_zero3=False): transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access get_unpad_data ) + if for_zero3: + patch_mixtral_moe_forward_zero3() diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 79b880234..9d7255fed 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -15,7 +15,7 @@ from peft import PeftModel from pkg_resources import get_distribution # type: ignore from transformers import PreTrainedModel, PreTrainedTokenizer -from transformers.deepspeed import is_deepspeed_zero3_enabled +from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from axolotl.common.cli import TrainerCliArgs from axolotl.logging_config import configure_logging diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 25b575686..bd0ce6c0d 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -21,7 +21,7 @@ PreTrainedModel, PreTrainedTokenizerBase, ) -from transformers.deepspeed import is_deepspeed_zero3_enabled +from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN @@ -333,7 +333,10 @@ def load_model( ) LOG.info("patching mixtral with flash attention") - replace_mixtral_attn_with_multipack_flash_attn() + mixtral_patch_kwargs = {} + if is_deepspeed_zero3_enabled(): + mixtral_patch_kwargs["for_zero3"] = True + replace_mixtral_attn_with_multipack_flash_attn(**mixtral_patch_kwargs) if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing: from axolotl.monkeypatch.falcon import ( @@ -646,6 +649,12 @@ def load_model( needs_fa2_dtype = cfg.adapter or cfg.fsdp skip_prepare_model_for_kbit_training = False + if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled(): + from deepspeed.utils import set_z3_leaf_modules + from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + + set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) + if cfg.model_config_type == "qwen" and cfg.adapter == "lora": # Qwen doesn't play nicely with LoRA if this is enabled skip_prepare_model_for_kbit_training = True