diff --git a/deepspeed/module_inject/replace_policy.py b/deepspeed/module_inject/replace_policy.py index b0f5be283b371..0f8045dbd2735 100755 --- a/deepspeed/module_inject/replace_policy.py +++ b/deepspeed/module_inject/replace_policy.py @@ -444,13 +444,12 @@ def __init__(self, client_module, inference=True): try: import transformers HFOPTLayerPolicy._orig_layer_class = transformers.models.opt.modeling_opt.OPTDecoderLayer + if isinstance(DSPolicy.hf_model_config, + transformers.models.opt.configuration_opt.OPTConfig): + self.pre_attn_norm = self.hf_model_config.do_layer_norm_before except: HFOPTLayerPolicy._orig_layer_class = None - if isinstance(DSPolicy.hf_model_config, - transformers.models.opt.configuration_opt.OPTConfig): - self.pre_attn_norm = self.hf_model_config.do_layer_norm_before - def get_hidden_heads(self): return self.client_module.self_attn.embed_dim, \ self.client_module.self_attn.num_heads