Skip to content

Commit

Permalink
add unit tests to for testing model training (#1449)
Browse files Browse the repository at this point in the history
  • Loading branch information
parmeet committed Nov 23, 2021
1 parent c7110b5 commit aea6ad6
Showing 1 changed file with 48 additions and 0 deletions.
48 changes: 48 additions & 0 deletions test/models/test_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import torchtext
import torch
from torch.nn import functional as torch_F
import copy
from ..common.torchtext_test_case import TorchtextTestCase
from ..common.assets import get_asset_path

Expand Down Expand Up @@ -126,3 +128,49 @@ def test_roberta_bundler_from_config(self):
encoder_state_dict['encoder.' + k] = v
model = torchtext.models.RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf, head=dummy_classifier_head, checkpoint=encoder_state_dict)
self.assertEqual(model.state_dict(), dummy_classifier.state_dict())

def test_roberta_bundler_train(self):
from torchtext.models import RobertaEncoderConf, RobertaClassificationHead, RobertaModel, RobertaModelBundle
dummy_encoder_conf = RobertaEncoderConf(vocab_size=10, embedding_dim=16, ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2)
from torch.optim import SGD

def _train(model):
optim = SGD(model.parameters(), lr=1)
model_input = torch.tensor([[0, 1, 2, 3, 4, 5]])
target = torch.tensor([0])
logits = model(model_input)
loss = torch_F.cross_entropy(logits, target)
loss.backward()
optim.step()

# does not freeze encoder
dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head)
model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf,
head=dummy_classifier_head,
freeze_encoder=False,
checkpoint=dummy_classifier.state_dict())

encoder_current_state_dict = copy.deepcopy(model.encoder.state_dict())
head_current_state_dict = copy.deepcopy(model.head.state_dict())

_train(model)

self.assertNotEqual(model.encoder.state_dict(), encoder_current_state_dict)
self.assertNotEqual(model.head.state_dict(), head_current_state_dict)

# freeze encoder
dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16)
dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head)
model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf,
head=dummy_classifier_head,
freeze_encoder=True,
checkpoint=dummy_classifier.state_dict())

encoder_current_state_dict = copy.deepcopy(model.encoder.state_dict())
head_current_state_dict = copy.deepcopy(model.head.state_dict())

_train(model)

self.assertEqual(model.encoder.state_dict(), encoder_current_state_dict)
self.assertNotEqual(model.head.state_dict(), head_current_state_dict)

0 comments on commit aea6ad6

Please sign in to comment.