Skip to content

Commit

Permalink
Refactor test classes (#492)
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt authored Feb 13, 2023
1 parent ee5f02b commit 1540ab0
Show file tree
Hide file tree
Showing 34 changed files with 251 additions and 254 deletions.
54 changes: 38 additions & 16 deletions .github/workflows/tests_torch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@ name: Tests

on:
push:
branches: [ 'master' ]
branches: ["master"]
paths:
- 'src/**'
- 'examples/**'
- 'templates/**'
- 'tests/**'
- 'tests_adapters/**'
- 'utils/**'
- "src/**"
- "examples/**"
- "templates/**"
- "tests/**"
- "tests_adapters/**"
- "utils/**"
pull_request:
branches: [ 'master', 'develop' ]
branches: ["master", "develop"]
paths:
- 'src/**'
- 'examples/**'
- 'templates/**'
- 'tests/**'
- 'tests_adapters/**'
- 'utils/**'
- "src/**"
- "examples/**"
- "templates/**"
- "tests/**"
- "tests_adapters/**"
- "utils/**"
workflow_dispatch:

jobs:
Expand All @@ -41,7 +41,7 @@ jobs:
run: |
make quality
make repo-consistency
run_reduced_tests_torch:
test_adapter_methods:
timeout-minutes: 60
runs-on: ubuntu-latest
steps:
Expand All @@ -62,4 +62,26 @@ jobs:
pip install datasets
- name: Test
run: |
make test-adapters
make test-adapter-methods
test_adapter_models:
timeout-minutes: 60
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
with:
python-version: 3.8
- uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install
run: |
pip install torch==1.12.1
pip install .[sklearn,testing,sentencepiece,vision]
pip install datasets
- name: Test
run: |
make test-adapter-models
6 changes: 6 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ test:
test-adapters:
python -m pytest -n auto --dist=loadfile -s -v ./tests_adapters/

test-adapter-methods:
python -m pytest --ignore ./tests_adapters/models -n auto --dist=loadfile -s -v ./tests_adapters/

test-adapter-models:
python -m pytest -n auto --dist=loadfile -s -v ./tests_adapters/models

# Run tests for examples

test-examples:
Expand Down
65 changes: 0 additions & 65 deletions tests_adapters/conftest.py

This file was deleted.

12 changes: 8 additions & 4 deletions tests_adapters/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from transformers import AutoTokenizer, TrainingArguments
from transformers import TrainingArguments
from transformers.adapters import ADAPTER_MODEL_MAPPING, AdapterSetup, AdapterTrainer, AutoAdapterModel
from transformers.adapters.utils import WEIGHTS_NAME
from transformers.testing_utils import require_torch, torch_device
Expand Down Expand Up @@ -188,7 +188,7 @@ def run_full_model_load_test(self, adapter_config):
self.assertEqual(len(output1), len(output2))
self.assertTrue(torch.equal(output1[0], output2[0]))

def trainings_run(self, model, lr=1.0, steps=20):
def trainings_run(self, model, lr=1.0, steps=8):
# setup dataset
train_dataset = self.dataset()
training_args = TrainingArguments(
Expand Down Expand Up @@ -242,11 +242,15 @@ def run_train_test(self, adapter_config, filter_keys):

self.trainings_run(model)

# check that the adapters have changed, but the base model has not
adapters_with_change, base_with_change = False, False
for ((k1, v1), (k2, v2)) in zip(state_dict_pre.items(), model.state_dict().items()):
if "mrpc" in k1:
self.assertFalse(torch.equal(v1, v2), k1)
adapters_with_change |= not torch.equal(v1, v2)
else:
self.assertTrue(torch.equal(v1, v2), k1)
base_with_change |= not torch.equal(v1, v2)
self.assertTrue(adapters_with_change)
self.assertFalse(base_with_change)

def run_merge_test(self, adapter_config):
model = self.get_model()
Expand Down
37 changes: 17 additions & 20 deletions tests_adapters/methods/test_adapter_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,7 @@ def test_forward_with_past(self):
self.skipTest("No causal lm class.")

static_model = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING[self.config_class](self.config())
flex_model = AutoAdapterModel.from_pretrained(
None, config=self.config(), state_dict=static_model.state_dict()
)
flex_model = AutoAdapterModel.from_pretrained(None, config=self.config(), state_dict=static_model.state_dict())
static_model.add_adapter("dummy")
static_model.set_active_adapters("dummy")
static_model.eval()
Expand Down Expand Up @@ -325,7 +323,10 @@ def patched_fusion_reg_loss():
model.base_model.get_fusion_regularization_loss = patched_fusion_reg_loss

self.trainings_run(model)
self.assertTrue(regularization_called)

# check that the adapters have changed, but the base model has not
adapters_with_change, base_with_change = False, False
for ((k1, v1), (k2, v2)) in zip(state_dict_pre.items(), model.state_dict().items()):
if (
"adapter_fusion_layer" in k1
Expand All @@ -334,10 +335,11 @@ def patched_fusion_reg_loss():
or "score" in k1
or "heads" in k1
):
self.assertFalse(torch.equal(v1, v2), k1)
adapters_with_change |= not torch.equal(v1, v2)
else:
self.assertTrue(torch.equal(v1, v2), k1)
self.assertTrue(regularization_called)
base_with_change |= not torch.equal(v1, v2)
self.assertTrue(adapters_with_change)
self.assertFalse(base_with_change)

def test_batch_split_training(self):
if self.config_class not in ADAPTER_MODEL_MAPPING:
Expand Down Expand Up @@ -366,17 +368,12 @@ def test_batch_split_training(self):

self.trainings_run(model)

self.assertFalse(
all(
torch.equal(v1, v2)
for ((k1, v1), (k2, v2)) in zip(state_dict_pre.items(), model.state_dict().items())
if "mrpc" in k1
)
)
self.assertTrue(
all(
torch.equal(v1, v2)
for ((k1, v1), (k2, v2)) in zip(state_dict_pre.items(), model.state_dict().items())
if "mrpc" not in k1
)
)
# check that the adapters have changed, but the base model has not
adapters_with_change, base_with_change = False, False
for ((k1, v1), (k2, v2)) in zip(state_dict_pre.items(), model.state_dict().items()):
if "mrpc" in k1:
adapters_with_change |= not torch.equal(v1, v2)
else:
base_with_change |= not torch.equal(v1, v2)
self.assertTrue(adapters_with_change)
self.assertFalse(base_with_change)
Empty file.
File renamed without changes.
11 changes: 11 additions & 0 deletions tests_adapters/models/test_bart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from tests.models.bart.test_modeling_bart import *
from transformers import BartAdapterModel
from transformers.testing_utils import require_torch

from .base import AdapterModelTesterMixin


@require_torch
class BartAdapterModelTest(AdapterModelTesterMixin, BartModelTest):
all_model_classes = (BartAdapterModel,)
fx_compatible = False
11 changes: 11 additions & 0 deletions tests_adapters/models/test_beit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from tests.models.beit.test_modeling_beit import *
from transformers import BeitAdapterModel
from transformers.testing_utils import require_torch

from .base import AdapterModelTesterMixin


@require_torch
class BeitAdapterModelTest(AdapterModelTesterMixin, BeitModelTest):
all_model_classes = (BeitAdapterModel,)
fx_compatible = False
11 changes: 11 additions & 0 deletions tests_adapters/models/test_bert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from tests.models.bert.test_modeling_bert import *
from transformers import BertAdapterModel
from transformers.testing_utils import require_torch

from .base import AdapterModelTesterMixin


@require_torch
class BertAdapterModelTest(AdapterModelTesterMixin, BertModelTest):
all_model_classes = (BertAdapterModel,)
fx_compatible = False
2 changes: 2 additions & 0 deletions tests_adapters/models/test_clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# flake8: noqa
from tests.models.clip.test_modeling_clip import * # Imported to execute model tests
11 changes: 11 additions & 0 deletions tests_adapters/models/test_deberta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from tests.models.deberta.test_modeling_deberta import *
from transformers import DebertaAdapterModel
from transformers.testing_utils import require_torch

from .base import AdapterModelTesterMixin


@require_torch
class DebertaAdapterModelTest(AdapterModelTesterMixin, DebertaModelTest):
all_model_classes = (DebertaAdapterModel,)
fx_compatible = False
11 changes: 11 additions & 0 deletions tests_adapters/models/test_debertaV2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from tests.models.deberta_v2.test_modeling_deberta_v2 import *
from transformers import DebertaV2AdapterModel
from transformers.testing_utils import require_torch

from .base import AdapterModelTesterMixin


@require_torch
class DebertaV2AdapterModelTest(AdapterModelTesterMixin, DebertaV2ModelTest):
all_model_classes = (DebertaV2AdapterModel,)
fx_compatible = False
11 changes: 11 additions & 0 deletions tests_adapters/models/test_distilbert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from tests.models.distilbert.test_modeling_distilbert import *
from transformers import DistilBertAdapterModel
from transformers.testing_utils import require_torch

from .base import AdapterModelTesterMixin


@require_torch
class DistilBertAdapterModelTest(AdapterModelTesterMixin, DistilBertModelTest):
all_model_classes = (DistilBertAdapterModel,)
fx_compatible = False
2 changes: 2 additions & 0 deletions tests_adapters/models/test_encoder_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# flake8: noqa
from tests.models.encoder_decoder.test_modeling_encoder_decoder import * # Imported to execute model tests
11 changes: 11 additions & 0 deletions tests_adapters/models/test_gpt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from tests.models.gpt2.test_modeling_gpt2 import *
from transformers import GPT2AdapterModel
from transformers.testing_utils import require_torch

from .base import AdapterModelTesterMixin


@require_torch
class GPT2AdapterModelTest(AdapterModelTesterMixin, GPT2ModelTest):
all_model_classes = (GPT2AdapterModel,)
fx_compatible = False
11 changes: 11 additions & 0 deletions tests_adapters/models/test_gptj.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from tests.models.gptj.test_modeling_gptj import *
from transformers import GPTJAdapterModel
from transformers.testing_utils import require_torch

from .base import AdapterModelTesterMixin


@require_torch
class GPTJAdapterModelTest(AdapterModelTesterMixin, GPTJModelTest):
all_model_classes = (GPTJAdapterModel,)
fx_compatible = False
11 changes: 11 additions & 0 deletions tests_adapters/models/test_mbart.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from tests.models.mbart.test_modeling_mbart import *
from transformers import MBartAdapterModel
from transformers.testing_utils import require_torch

from .base import AdapterModelTesterMixin


@require_torch
class MBartAdapterModelTest(AdapterModelTesterMixin, MBartModelTest):
all_model_classes = (MBartAdapterModel,)
fx_compatible = False
11 changes: 11 additions & 0 deletions tests_adapters/models/test_roberta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from tests.models.roberta.test_modeling_roberta import *
from transformers import RobertaAdapterModel
from transformers.testing_utils import require_torch

from .base import AdapterModelTesterMixin


@require_torch
class RobertaAdapterModelTest(AdapterModelTesterMixin, RobertaModelTest):
all_model_classes = (RobertaAdapterModel,)
fx_compatible = False
11 changes: 11 additions & 0 deletions tests_adapters/models/test_t5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from tests.models.t5.test_modeling_t5 import *
from transformers import T5AdapterModel
from transformers.testing_utils import require_torch

from .base import AdapterModelTesterMixin


@require_torch
class T5AdapterModelTest(AdapterModelTesterMixin, T5ModelTest):
all_model_classes = (T5AdapterModel,)
fx_compatible = False
Loading

0 comments on commit 1540ab0

Please sign in to comment.