Skip to content

Commit

Permalink
Add version checks for the import of DeepSpeed moe utils (#2705)
Browse files Browse the repository at this point in the history
* fix import for moe utils

* Apply suggestions from code review

Co-authored-by: Zach Mueller <muellerzr@gmail.com>

---------

Co-authored-by: Zach Mueller <muellerzr@gmail.com>
  • Loading branch information
pacman100 and muellerzr authored Apr 24, 2024
1 parent 3e944c5 commit 092c3af
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
6 changes: 4 additions & 2 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,11 +975,13 @@ def _deepspeed_config_checks(self):
)

def set_moe_leaf_modules(self, model):
from deepspeed.utils import set_z3_leaf_modules

if self.transformer_moe_cls_names is None:
self.transformer_moe_cls_names = os.environ.get("ACCELERATE_DEEPSPEED_MOE_LAYER_CLS_NAMES", None)
if self.transformer_moe_cls_names is not None:
if compare_versions("deepspeed", "<", "0.14.0"):
raise ImportError("DeepSpeed version must be >= 0.14.0 to use MOE support. Please update DeepSpeed.")
from deepspeed.utils import set_z3_leaf_modules

class_names = self.transformer_moe_cls_names.split(",")
transformer_moe_cls = []
for layer_class in class_names:
Expand Down
4 changes: 2 additions & 2 deletions tests/deepspeed/test_deepspeed.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,8 +813,8 @@ def test_ds_config(self, stage):
)
assert deepspeed_plugin.zero_stage == int(stage.replace("zero", ""))

def test_prepare_deepspeed_preapre_moe(self):
if compare_versions("transformers", "<", "4.40"):
def test_prepare_deepspeed_prepare_moe(self):
if compare_versions("transformers", "<", "4.40") and compare_versions("deepspeed", "<", "0.14"):
return
deepspeed_plugin = DeepSpeedPlugin(
zero3_init_flag=True,
Expand Down

0 comments on commit 092c3af

Please sign in to comment.