Skip to content

Commit

Permalink
Promote t5 and variants (#2064)
Browse files Browse the repository at this point in the history
* Promote T5 from prototype to beta

* Add models to README

* Move T5 tests in integration tests

* Fix linting

* Fix formatting

* Actually add t5

* Add the rest of the files you absolute donkey

* Fix linting

* Modify paths for generation tests

* Fix linting
  • Loading branch information
joecummings committed Feb 17, 2023
1 parent 2cd5e12 commit 670e52a
Show file tree
Hide file tree
Showing 20 changed files with 21 additions and 39 deletions.
2 changes: 2 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ The library currently consist of following pre-trained models:
* RoBERTa: `Base and Large Architecture <https://github.com/pytorch/fairseq/tree/main/examples/roberta#pre-trained-models>`_
* `DistilRoBERTa <https://github.com/huggingface/transformers/blob/main/examples/research_projects/distillation/README.md>`_
* XLM-RoBERTa: `Base and Large Architure <https://github.com/pytorch/fairseq/tree/main/examples/xlmr#pre-trained-models>`_
* T5: `Small, Base, Large, 3B, and 11B Architecture <https://github.com/google-research/text-to-text-transfer-transformer>`_
* Flan-T5: `Small, Base, Large, XL, and XXL Architecture <https://github.com/google-research/t5x>`_

Tokenizers
==========
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import pytest # noqa: F401
import torch
from parameterized import parameterized_class
from torchtext.prototype.models import (
from torchtext.models import T5Bundle
from torchtext.models import (
T5_BASE,
T5_BASE_ENCODER,
T5_BASE_GENERATION,
Expand All @@ -14,7 +15,6 @@
T5_SMALL_ENCODER,
T5_SMALL_GENERATION,
)
from torchtext.prototype.models.t5.bundler import T5Bundle
from torchtext_unittest.common.assets import get_asset_path
from torchtext_unittest.common.parameterized_utils import nested_params
from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase
Expand Down
5 changes: 3 additions & 2 deletions test/torchtext_unittest/models/gpu_tests/models_gpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import pytest
import torch
from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase
from torchtext_unittest.models.models_test_impl import BaseTestModels
from torchtext_unittest.models.roberta_models_test_impl import RobertaBaseTestModels
from torchtext_unittest.models.t5_models_test_impl import T5BaseTestModels


@pytest.mark.gpu_test
@unittest.skipIf(not torch.cuda.is_available(), reason="CUDA is not available")
class TestModels32GPU(BaseTestModels, TorchtextTestCase):
class TestModels32GPU(RobertaBaseTestModels, T5BaseTestModels, TorchtextTestCase):
dtype = torch.float32
device = torch.device("cuda")
5 changes: 3 additions & 2 deletions test/torchtext_unittest/models/models_cpu_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch

from ..common.torchtext_test_case import TorchtextTestCase
from .models_test_impl import BaseTestModels
from .roberta_models_test_impl import RobertaBaseTestModels
from .t5_models_test_impl import T5BaseTestModels


class TestModels32CPU(BaseTestModels, TorchtextTestCase):
class TestModels32CPU(RobertaBaseTestModels, T5BaseTestModels, TorchtextTestCase):
dtype = torch.float32
device = torch.device("cpu")
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ..common.case_utils import TestBaseMixin


class BaseTestModels(TestBaseMixin):
class RobertaBaseTestModels(TestBaseMixin):
def get_model(self, encoder_conf, head=None, freeze_encoder=False, checkpoint=None, override_checkpoint_head=False):
from torchtext.models import RobertaBundle

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from torchtext_unittest.common.case_utils import TestBaseMixin


class BaseTestModels(TestBaseMixin):
class T5BaseTestModels(TestBaseMixin):
def test_t5_bundler_build_model(self) -> None:
from torchtext.prototype.models import T5Conf, T5Model, T5Bundle
from torchtext.models import T5Conf, T5Model, T5Bundle

# case: user provides encoder checkpoint state dict
dummy_encoder_conf = T5Conf(
Expand Down Expand Up @@ -57,7 +57,7 @@ def test_t5_bundler_build_model(self) -> None:

@patch("logging.Logger.warning")
def test_t5_bundler_get_model(self, mock):
from torchtext.prototype.models import T5Conf, T5Bundle
from torchtext.models import T5Conf, T5Bundle

# encoder-decoder with generation
dummy_t5_generation_conf = T5Conf(
Expand All @@ -77,7 +77,7 @@ def test_t5_bundler_get_model(self, mock):
)

def test_t5_bundler_raise_checkpoint(self) -> None:
from torchtext.prototype.models import T5Conf, T5Bundle
from torchtext.models import T5Conf, T5Bundle

# encoder-only
with self.assertRaises(TypeError):
Expand Down Expand Up @@ -132,7 +132,7 @@ def test_t5_bundler_raise_checkpoint(self) -> None:
)

def test_t5_bundler_conf_property(self) -> None:
from torchtext.prototype.models import T5Conf, T5Bundle
from torchtext.models import T5Conf, T5Bundle

dummy_t5_conf = T5Conf(
encoder_only=False,
Expand All @@ -148,7 +148,7 @@ def test_t5_bundler_conf_property(self) -> None:

def test_t5_bundler_train(self) -> None:
from torch.optim import SGD
from torchtext.prototype.models import T5Conf, T5Model, T5Bundle
from torchtext.models import T5Conf, T5Model, T5Bundle

torch.manual_seed(123)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from torchtext.prototype.models import T5Transform
from torchtext.models import T5Transform
from torchtext_unittest.common.assets import get_asset_path
from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase

Expand Down
Empty file.

This file was deleted.

This file was deleted.

2 changes: 1 addition & 1 deletion test/torchtext_unittest/prototype/test_generate.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from unittest.mock import patch

import torch
from torchtext.models import T5_BASE_GENERATION
from torchtext.prototype.generate import DEFAULT_MAX_SEQ_LEN, GenerationUtil
from torchtext.prototype.models import T5_BASE_GENERATION
from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase


Expand Down
1 change: 1 addition & 0 deletions torchtext/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .roberta import * # noqa: F401, F403
from .t5 import * # noqa: F401, F403
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion torchtext/prototype/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class GenerationUtil:
This means that popular HuggingFace implementation of T5, Bart, and GPT-2 can all be used with these generation utils!
>>> from transformers import T5Model
>>> model = T5Model.from_pretrained("t5-base")
>>> generative_model = GenerationUtil(model=model, is_huggingface_model=True)
>>> generative_model = GenerationUtils(model=model, is_huggingface_model=True)
>>> generative_model.generate(input_ids, num_beams=1, max_len=100)
More examples can be found in the `notebooks` directory of this repository.
Expand Down
1 change: 0 additions & 1 deletion torchtext/prototype/models/__init__.py

This file was deleted.

0 comments on commit 670e52a

Please sign in to comment.