Skip to content

Commit

Permalink
Replace all assert by ValueError in src/transformers/models/electra
Browse files Browse the repository at this point in the history
  • Loading branch information
AkechiShiro committed Oct 10, 2021
1 parent db3133b commit f183373
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
13 changes: 6 additions & 7 deletions src/transformers/models/electra/modeling_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,8 @@ def load_tf_weights_in_electra(model, config, tf_checkpoint_path, discriminator_
elif m_name == "kernel":
array = np.transpose(array)
try:
assert (
pointer.shape == array.shape
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
if not pointer.shape == array.shape:
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
except AssertionError as e:
e.args += (pointer.shape, array.shape)
raise
Expand Down Expand Up @@ -447,7 +446,8 @@ def __init__(self, config):
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = ElectraAttention(config)
self.intermediate = ElectraIntermediate(config)
self.output = ElectraOutput(config)
Expand Down Expand Up @@ -482,9 +482,8 @@ def forward(

cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
assert hasattr(
self, "crossattention"
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
if not hasattr(self, "crossattention"):
raise ValueError(f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`")

# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/models/electra/modeling_tf_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,9 @@ def call(
Returns:
final_embeddings (:obj:`tf.Tensor`): output embedding tensor.
"""
assert not (input_ids is None and inputs_embeds is None)
if input_ids is None and inputs_embeds is None:
raise ValueError()


if input_ids is not None:
inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
Expand Down

0 comments on commit f183373

Please sign in to comment.