diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index c81e4fb557e889..1ea4bfceaf5a7c 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -139,9 +139,8 @@ def load_tf_weights_in_electra(model, config, tf_checkpoint_path, discriminator_ elif m_name == "kernel": array = np.transpose(array) try: - assert ( - pointer.shape == array.shape - ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched" + if not pointer.shape == array.shape: + raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") except AssertionError as e: e.args += (pointer.shape, array.shape) raise @@ -447,7 +446,8 @@ def __init__(self, config): self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: - assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added" + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") self.crossattention = ElectraAttention(config) self.intermediate = ElectraIntermediate(config) self.output = ElectraOutput(config) @@ -482,9 +482,8 @@ def forward( cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: - assert hasattr( - self, "crossattention" - ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" + if not hasattr(self, "crossattention"): + raise ValueError(f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`") # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None diff --git a/src/transformers/models/electra/modeling_tf_electra.py b/src/transformers/models/electra/modeling_tf_electra.py index aad2a787d49d4d..0f384002ddb3b7 100644 --- a/src/transformers/models/electra/modeling_tf_electra.py +++ b/src/transformers/models/electra/modeling_tf_electra.py @@ -404,7 +404,9 @@ def call( Returns: final_embeddings (:obj:`tf.Tensor`): output embedding tensor. """ - assert not (input_ids is None and inputs_embeds is None) + if input_ids is None and inputs_embeds is None: + raise ValueError() + if input_ids is not None: inputs_embeds = tf.gather(params=self.weight, indices=input_ids)