Skip to content

Commit

Permalink
[T5] Bug correction & Refactor (huggingface#8518)
Browse files Browse the repository at this point in the history
* fix bug

* T5 refactor

* refactor tf

* apply sylvains suggestions
  • Loading branch information
patrickvonplaten authored and fabiocapsouza committed Nov 15, 2020
1 parent 1bcec4c commit 796e7c4
Show file tree
Hide file tree
Showing 7 changed files with 316 additions and 268 deletions.
6 changes: 0 additions & 6 deletions src/transformers/configuration_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ class T5Config(PretrainedConfig):
def __init__(
self,
vocab_size=32128,
n_positions=512,
d_model=512,
d_kv=64,
d_ff=2048,
Expand All @@ -98,7 +97,6 @@ def __init__(
**kwargs,
)
self.vocab_size = vocab_size
self.n_positions = n_positions
self.d_model = d_model
self.d_kv = d_kv
self.d_ff = d_ff
Expand All @@ -112,10 +110,6 @@ def __init__(
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_factor = initializer_factor

@property
def max_position_embeddings(self):
return self.n_positions

@property
def hidden_size(self):
return self.d_model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

import argparse

import torch

from transformers import T5Config, T5Model, load_tf_weights_in_t5
from transformers.utils import logging

Expand All @@ -37,7 +35,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du

# Save pytorch-model
print("Save PyTorch model to {}".format(pytorch_dump_path))
torch.save(model.state_dict(), pytorch_dump_path)
model.save_pretrained(pytorch_dump_path)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 796e7c4

Please sign in to comment.