diff --git a/test/models/test_models.py b/test/models/test_models.py index 488b9fc561..67876c75ff 100644 --- a/test/models/test_models.py +++ b/test/models/test_models.py @@ -1,5 +1,7 @@ import torchtext import torch +from torch.nn import functional as torch_F +import copy from ..common.torchtext_test_case import TorchtextTestCase from ..common.assets import get_asset_path @@ -126,3 +128,49 @@ def test_roberta_bundler_from_config(self): encoder_state_dict['encoder.' + k] = v model = torchtext.models.RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf, head=dummy_classifier_head, checkpoint=encoder_state_dict) self.assertEqual(model.state_dict(), dummy_classifier.state_dict()) + + def test_roberta_bundler_train(self): + from torchtext.models import RobertaEncoderConf, RobertaClassificationHead, RobertaModel, RobertaModelBundle + dummy_encoder_conf = RobertaEncoderConf(vocab_size=10, embedding_dim=16, ffn_dimension=64, num_attention_heads=2, num_encoder_layers=2) + from torch.optim import SGD + + def _train(model): + optim = SGD(model.parameters(), lr=1) + model_input = torch.tensor([[0, 1, 2, 3, 4, 5]]) + target = torch.tensor([0]) + logits = model(model_input) + loss = torch_F.cross_entropy(logits, target) + loss.backward() + optim.step() + + # does not freeze encoder + dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16) + dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head) + model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf, + head=dummy_classifier_head, + freeze_encoder=False, + checkpoint=dummy_classifier.state_dict()) + + encoder_current_state_dict = copy.deepcopy(model.encoder.state_dict()) + head_current_state_dict = copy.deepcopy(model.head.state_dict()) + + _train(model) + + self.assertNotEqual(model.encoder.state_dict(), encoder_current_state_dict) + self.assertNotEqual(model.head.state_dict(), head_current_state_dict) + + # freeze encoder + dummy_classifier_head = RobertaClassificationHead(num_classes=2, input_dim=16) + dummy_classifier = RobertaModel(dummy_encoder_conf, dummy_classifier_head) + model = RobertaModelBundle.from_config(encoder_conf=dummy_encoder_conf, + head=dummy_classifier_head, + freeze_encoder=True, + checkpoint=dummy_classifier.state_dict()) + + encoder_current_state_dict = copy.deepcopy(model.encoder.state_dict()) + head_current_state_dict = copy.deepcopy(model.head.state_dict()) + + _train(model) + + self.assertEqual(model.encoder.state_dict(), encoder_current_state_dict) + self.assertNotEqual(model.head.state_dict(), head_current_state_dict)