From 59610225061e75c18b42f5e7bf963df83a224988 Mon Sep 17 00:00:00 2001 From: Jonny Li Date: Fri, 24 May 2024 15:14:09 -0400 Subject: [PATCH] Fix DeepSpeed compatibility with weight_norm (#30881) --- src/transformers/models/hubert/modeling_hubert.py | 10 ++++++++-- .../models/seamless_m4t/modeling_seamless_m4t.py | 10 ++++++++-- src/transformers/models/sew/modeling_sew.py | 10 ++++++++-- src/transformers/models/sew_d/modeling_sew_d.py | 10 ++++++++-- src/transformers/models/speecht5/modeling_speecht5.py | 10 ++++++++-- .../models/unispeech/modeling_unispeech.py | 10 ++++++++-- .../models/unispeech_sat/modeling_unispeech_sat.py | 10 ++++++++-- src/transformers/models/wav2vec2/modeling_wav2vec2.py | 10 ++++++++-- .../wav2vec2_conformer/modeling_wav2vec2_conformer.py | 10 ++++++++-- src/transformers/models/wavlm/modeling_wavlm.py | 10 ++++++++-- 10 files changed, 80 insertions(+), 20 deletions(-) diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 388d16415cd3..c12ed7dd3829 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -295,8 +295,14 @@ def __init__(self, config): with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): self.conv = weight_norm(self.conv, name="weight", dim=2) - deepspeed.zero.register_external_parameter(self, self.conv.weight_v) - deepspeed.zero.register_external_parameter(self, self.conv.weight_g) + if hasattr(self.conv, "parametrizations"): + weight_g = self.conv.parametrizations.weight.original0 + weight_v = self.conv.parametrizations.weight.original1 + else: + weight_g = self.conv.weight_g + weight_v = self.conv.weight_v + deepspeed.zero.register_external_parameter(self, weight_v) + deepspeed.zero.register_external_parameter(self, weight_g) else: self.conv = weight_norm(self.conv, name="weight", dim=2) diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index d35629938bd6..8a15ba68d1cb 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -325,8 +325,14 @@ def __init__(self, config): with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): self.conv = weight_norm(self.conv, name="weight", dim=2) - deepspeed.zero.register_external_parameter(self, self.conv.weight_v) - deepspeed.zero.register_external_parameter(self, self.conv.weight_g) + if hasattr(self.conv, "parametrizations"): + weight_g = self.conv.parametrizations.weight.original0 + weight_v = self.conv.parametrizations.weight.original1 + else: + weight_g = self.conv.weight_g + weight_v = self.conv.weight_v + deepspeed.zero.register_external_parameter(self, weight_v) + deepspeed.zero.register_external_parameter(self, weight_g) else: self.conv = weight_norm(self.conv, name="weight", dim=2) diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index ea10b2581814..e758f23740dd 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -294,8 +294,14 @@ def __init__(self, config): with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) - deepspeed.zero.register_external_parameter(self, self.conv.weight_v) - deepspeed.zero.register_external_parameter(self, self.conv.weight_g) + if hasattr(self.conv, "parametrizations"): + weight_g = self.conv.parametrizations.weight.original0 + weight_v = self.conv.parametrizations.weight.original1 + else: + weight_g = self.conv.weight_g + weight_v = self.conv.weight_v + deepspeed.zero.register_external_parameter(self, weight_v) + deepspeed.zero.register_external_parameter(self, weight_g) else: self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index 2daa6739ef61..f704b8166a89 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -354,8 +354,14 @@ def __init__(self, config): with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) - deepspeed.zero.register_external_parameter(self, self.conv.weight_v) - deepspeed.zero.register_external_parameter(self, self.conv.weight_g) + if hasattr(self.conv, "parametrizations"): + weight_g = self.conv.parametrizations.weight.original0 + weight_v = self.conv.parametrizations.weight.original1 + else: + weight_g = self.conv.weight_g + weight_v = self.conv.weight_v + deepspeed.zero.register_external_parameter(self, weight_v) + deepspeed.zero.register_external_parameter(self, weight_g) else: self.conv = nn.utils.weight_norm(self.conv, name="weight", dim=2) diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 22ee648b97ff..a5e228b6a8de 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -368,8 +368,14 @@ def __init__(self, config): with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): self.conv = weight_norm(self.conv, name="weight", dim=2) - deepspeed.zero.register_external_parameter(self, self.conv.weight_v) - deepspeed.zero.register_external_parameter(self, self.conv.weight_g) + if hasattr(self.conv, "parametrizations"): + weight_g = self.conv.parametrizations.weight.original0 + weight_v = self.conv.parametrizations.weight.original1 + else: + weight_g = self.conv.weight_g + weight_v = self.conv.weight_v + deepspeed.zero.register_external_parameter(self, weight_v) + deepspeed.zero.register_external_parameter(self, weight_g) else: self.conv = weight_norm(self.conv, name="weight", dim=2) diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 0db34c1fc8c6..bc88096f2497 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -330,8 +330,14 @@ def __init__(self, config): with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): self.conv = weight_norm(self.conv, name="weight", dim=2) - deepspeed.zero.register_external_parameter(self, self.conv.weight_v) - deepspeed.zero.register_external_parameter(self, self.conv.weight_g) + if hasattr(self.conv, "parametrizations"): + weight_g = self.conv.parametrizations.weight.original0 + weight_v = self.conv.parametrizations.weight.original1 + else: + weight_g = self.conv.weight_g + weight_v = self.conv.weight_v + deepspeed.zero.register_external_parameter(self, weight_v) + deepspeed.zero.register_external_parameter(self, weight_g) else: self.conv = weight_norm(self.conv, name="weight", dim=2) diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index b78ffb4ac1a2..00e8673739c9 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -347,8 +347,14 @@ def __init__(self, config): with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): self.conv = weight_norm(self.conv, name="weight", dim=2) - deepspeed.zero.register_external_parameter(self, self.conv.weight_v) - deepspeed.zero.register_external_parameter(self, self.conv.weight_g) + if hasattr(self.conv, "parametrizations"): + weight_g = self.conv.parametrizations.weight.original0 + weight_v = self.conv.parametrizations.weight.original1 + else: + weight_g = self.conv.weight_g + weight_v = self.conv.weight_v + deepspeed.zero.register_external_parameter(self, weight_v) + deepspeed.zero.register_external_parameter(self, weight_g) else: self.conv = weight_norm(self.conv, name="weight", dim=2) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index bdf2d4f50a35..57d4fffb0835 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -398,8 +398,14 @@ def __init__(self, config): with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): self.conv = weight_norm(self.conv, name="weight", dim=2) - deepspeed.zero.register_external_parameter(self, self.conv.weight_v) - deepspeed.zero.register_external_parameter(self, self.conv.weight_g) + if hasattr(self.conv, "parametrizations"): + weight_g = self.conv.parametrizations.weight.original0 + weight_v = self.conv.parametrizations.weight.original1 + else: + weight_g = self.conv.weight_g + weight_v = self.conv.weight_v + deepspeed.zero.register_external_parameter(self, weight_v) + deepspeed.zero.register_external_parameter(self, weight_g) else: self.conv = weight_norm(self.conv, name="weight", dim=2) diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index 14e80b516ccc..dc9e634e62d8 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -361,8 +361,14 @@ def __init__(self, config): with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): self.conv = weight_norm(self.conv, name="weight", dim=2) - deepspeed.zero.register_external_parameter(self, self.conv.weight_v) - deepspeed.zero.register_external_parameter(self, self.conv.weight_g) + if hasattr(self.conv, "parametrizations"): + weight_g = self.conv.parametrizations.weight.original0 + weight_v = self.conv.parametrizations.weight.original1 + else: + weight_g = self.conv.weight_g + weight_v = self.conv.weight_v + deepspeed.zero.register_external_parameter(self, weight_v) + deepspeed.zero.register_external_parameter(self, weight_g) else: self.conv = weight_norm(self.conv, name="weight", dim=2) diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index 320959b439f3..753437e85639 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -287,8 +287,14 @@ def __init__(self, config): with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): self.conv = weight_norm(self.conv, name="weight", dim=2) - deepspeed.zero.register_external_parameter(self, self.conv.weight_v) - deepspeed.zero.register_external_parameter(self, self.conv.weight_g) + if hasattr(self.conv, "parametrizations"): + weight_g = self.conv.parametrizations.weight.original0 + weight_v = self.conv.parametrizations.weight.original1 + else: + weight_g = self.conv.weight_g + weight_v = self.conv.weight_v + deepspeed.zero.register_external_parameter(self, weight_v) + deepspeed.zero.register_external_parameter(self, weight_g) else: self.conv = weight_norm(self.conv, name="weight", dim=2)