Skip to content

Commit

Permalink
revert bad changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ArthurZucker committed Oct 14, 2024
1 parent 675892f commit d91bbbc
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 100 deletions.
85 changes: 0 additions & 85 deletions scripts/deberta_scrtipt.py

This file was deleted.

10 changes: 0 additions & 10 deletions scripts/from transformers import pipeline.py

This file was deleted.

4 changes: 2 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1643,11 +1643,11 @@ def _set_default_torch_dtype(cls, dtype: torch.dtype) -> torch.dtype:
return dtype_orig

@property
def base_model(self):
def base_model(self) -> nn.Module:
"""
`torch.nn.Module`: The main body of the model.
"""
return getattr(self, "base_model_prefix", self)
return getattr(self, self.base_model_prefix, self)

@classmethod
def can_generate(cls) -> bool:
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/models/deberta/modeling_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,12 @@ def forward(self, sequence_output, word_embeddings):

@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING)
class DebertaForMaskedLM(DebertaPreTrainedModel):
_tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
_tied_weights_keys = [
"cls.predictions.decoder.weight",
"cls.predictions.decoder.bias",
"deberta.embeddings.word_embeddings.weight",
"lm_predictions.lm_head.weight",
]

def __init__(self, config):
super().__init__(config)
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3512,8 +3512,8 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
`torch.Tensor`: The tensor with training loss on this batch.
"""
model.train()
# if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
# self.optimizer.train()
if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
self.optimizer.train()

inputs = self._prepare_inputs(inputs)
if is_sagemaker_mp_enabled():
Expand Down

0 comments on commit d91bbbc

Please sign in to comment.