-
Notifications
You must be signed in to change notification settings - Fork 26.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Replace assert by ValueError of src/transformers/models/electra/model…
…ing_{electra,tf_electra}.py and all other models that had copies (#13955) * Replace all assert by ValueError in src/transformers/models/electra * Reformat with black to pass check_code_quality test * Change some assert to ValueError of modeling_bert & modeling_tf_albert * Change some assert in multiples models * Change multiples models assertion to ValueError in order to validate check_code_style test and models template test. * Black reformat * Change some more asserts in multiples models * Change assert to ValueError in modeling_layoutlm.py to fix copy error in code_style_check * Add proper message to ValueError in modeling_tf_albert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Simplify logic in models/bert/modeling_bert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Add ValueError message to models/convbert/modeling_tf_convbert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Add error message for ValueError to modeling_tf_electra.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Simplify logic in models/tapas/modeling_tapas.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Simplify logic in models/electra/modeling_electra.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Add ValueError message in src/transformers/models/bert/modeling_tf_bert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Simplify logic in src/transformers/models/rembert/modeling_rembert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Simplify logic in src/transformers/models/albert/modeling_albert.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
- Loading branch information
1 parent
64743d0
commit 3499728
Showing
13 changed files
with
64 additions
and
51 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -153,9 +153,8 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path): | |
elif m_name == "kernel": | ||
array = np.transpose(array) | ||
try: | ||
assert ( | ||
pointer.shape == array.shape | ||
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" | ||
if pointer.shape != array.shape: | ||
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") | ||
except AssertionError as e: | ||
This comment has been minimized.
Sorry, something went wrong.
milyiyo
Contributor
|
||
e.args += (pointer.shape, array.shape) | ||
raise | ||
|
@@ -450,7 +449,8 @@ def __init__(self, config): | |
self.is_decoder = config.is_decoder | ||
self.add_cross_attention = config.add_cross_attention | ||
if self.add_cross_attention: | ||
assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added" | ||
if not self.is_decoder: | ||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added") | ||
self.crossattention = BertAttention(config) | ||
self.intermediate = BertIntermediate(config) | ||
self.output = BertOutput(config) | ||
|
@@ -485,9 +485,10 @@ def forward( | |
|
||
cross_attn_present_key_value = None | ||
if self.is_decoder and encoder_hidden_states is not None: | ||
assert hasattr( | ||
self, "crossattention" | ||
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" | ||
if not hasattr(self, "crossattention"): | ||
raise ValueError( | ||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" | ||
) | ||
|
||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple | ||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -139,9 +139,8 @@ def load_tf_weights_in_electra(model, config, tf_checkpoint_path, discriminator_ | |
elif m_name == "kernel": | ||
array = np.transpose(array) | ||
try: | ||
assert ( | ||
pointer.shape == array.shape | ||
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" | ||
if pointer.shape != array.shape: | ||
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") | ||
except AssertionError as e: | ||
This comment has been minimized.
Sorry, something went wrong.
milyiyo
Contributor
|
||
e.args += (pointer.shape, array.shape) | ||
raise | ||
|
@@ -447,7 +446,8 @@ def __init__(self, config): | |
self.is_decoder = config.is_decoder | ||
self.add_cross_attention = config.add_cross_attention | ||
if self.add_cross_attention: | ||
assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added" | ||
if not self.is_decoder: | ||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added") | ||
self.crossattention = ElectraAttention(config) | ||
self.intermediate = ElectraIntermediate(config) | ||
self.output = ElectraOutput(config) | ||
|
@@ -482,9 +482,10 @@ def forward( | |
|
||
cross_attn_present_key_value = None | ||
if self.is_decoder and encoder_hidden_states is not None: | ||
assert hasattr( | ||
self, "crossattention" | ||
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" | ||
if not hasattr(self, "crossattention"): | ||
raise ValueError( | ||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" | ||
) | ||
|
||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple | ||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -167,9 +167,8 @@ def load_tf_weights_in_roformer(model, config, tf_checkpoint_path): | |
elif m_name == "kernel": | ||
array = np.transpose(array) | ||
try: | ||
assert ( | ||
pointer.shape == array.shape | ||
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" | ||
if not pointer.shape == array.shape: | ||
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") | ||
except AssertionError as e: | ||
This comment has been minimized.
Sorry, something went wrong.
milyiyo
Contributor
|
||
e.args += (pointer.shape, array.shape) | ||
raise | ||
|
@@ -463,7 +462,8 @@ def __init__(self, config): | |
self.is_decoder = config.is_decoder | ||
self.add_cross_attention = config.add_cross_attention | ||
if self.add_cross_attention: | ||
assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added" | ||
if not self.is_decoder: | ||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added") | ||
self.crossattention = RoFormerAttention(config) | ||
self.intermediate = RoFormerIntermediate(config) | ||
self.output = RoFormerOutput(config) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -252,9 +252,8 @@ def load_tf_weights_in_tapas(model, config, tf_checkpoint_path): | |
elif m_name == "kernel": | ||
array = np.transpose(array) | ||
try: | ||
assert ( | ||
pointer.shape == array.shape | ||
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" | ||
if pointer.shape != array.shape: | ||
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") | ||
except AssertionError as e: | ||
This comment has been minimized.
Sorry, something went wrong.
milyiyo
Contributor
|
||
e.args += (pointer.shape, array.shape) | ||
raise | ||
|
@@ -548,7 +547,8 @@ def __init__(self, config): | |
self.is_decoder = config.is_decoder | ||
self.add_cross_attention = config.add_cross_attention | ||
if self.add_cross_attention: | ||
assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added" | ||
if not self.is_decoder: | ||
raise ValueError(f"{self} should be used as a decoder model if cross attention is added") | ||
self.crossattention = TapasAttention(config) | ||
self.intermediate = TapasIntermediate(config) | ||
self.output = TapasOutput(config) | ||
|
@@ -583,9 +583,10 @@ def forward( | |
|
||
cross_attn_present_key_value = None | ||
if self.is_decoder and encoder_hidden_states is not None: | ||
assert hasattr( | ||
self, "crossattention" | ||
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" | ||
if not hasattr(self, "crossattention"): | ||
raise ValueError( | ||
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" | ||
) | ||
|
||
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple | ||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None | ||
|
As was changed the previous
assert
by the raise ofValueError
, maybe the except should change also toexcept ValueError as e: