Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add unit tests for testing model training #1449

Merged
merged 2 commits into from
Nov 23, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions test/models/test_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torchtext
import torch
from torch.nn import functional as torch_F
import copy
from ..common.torchtext_test_case import TorchtextTestCase
from ..common.assets import get_asset_path

Expand Down Expand Up @@ -126,3 +128,49 @@ def test_roberta_bundler_from_config(self):
encoder_state_dict['encoder.' + k] = v
model = torchtext.models.RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf, head=dummy_classifier_head, checkpoint=encoder_state_dict)
self.assertEqual(model.state_dict(), dummy_classifier.state_dict())

def test_roberta_bundler_train(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is a rather complicated for unit test, can you add doc string?
There are many style for writing docstring for tests, but one should be able to describe a test like "Given A, B should be met."

If the purpose of the test is just ensuring no error happens, then I call that kind of tests as "smoke test". (Could be different from what the "smoke test" means in industry though)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @mthrok, I agree on all the points mentioned. This would go in integration/smoke tests. Let me organize all the model tests in follow-up PR. There are other model test (not introduced in this PR) too that accordingly need to be separated.

from torchtext.models import RobertaEncoderConf, RobertaClassificationHead, RobertaModel, RobertaModelBundle
dummy_encoder_conf = RobertaEncoderConf(vocab_size=10, embedding_dim=16, ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2)
from torch.optim import SGD

def _train(model):
optim = SGD(model.parameters(), lr=1)
model_input = torch.tensor([[0, 1, 2, 3, 4, 5]])
target = torch.tensor([0])
logits = model(model_input)
loss = torch_F.cross_entropy(logits, target)
loss.backward()
optim.step()

# does not freeze encoder
dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head)
model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf,
head=dummy_classifier_head,
freeze_encoder=False,
checkpoint=dummy_classifier.state_dict())

encoder_current_state_dict = copy.deepcopy(model.encoder.state_dict())
head_current_state_dict = copy.deepcopy(model.head.state_dict())

_train(model)

self.assertNotEqual(model.encoder.state_dict(), encoder_current_state_dict)
self.assertNotEqual(model.head.state_dict(), head_current_state_dict)

# freeze encoder
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would split the test into two separate tests.

dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head)
model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf,
head=dummy_classifier_head,
freeze_encoder=True,
checkpoint=dummy_classifier.state_dict())

encoder_current_state_dict = copy.deepcopy(model.encoder.state_dict())
head_current_state_dict = copy.deepcopy(model.head.state_dict())

_train(model)

self.assertEqual(model.encoder.state_dict(), encoder_current_state_dict)
self.assertNotEqual(model.head.state_dict(), head_current_state_dict)