diff --git a/parlai/agents/hugging_face/t5.py b/parlai/agents/hugging_face/t5.py index 90ec7a1e336..331a97ec43c 100644 --- a/parlai/agents/hugging_face/t5.py +++ b/parlai/agents/hugging_face/t5.py @@ -54,10 +54,14 @@ def set_device(func): """ def wrap(*args, **kwargs): - if torch.cuda.is_available(): + self = args[0] + # self.paralleled implies whether the model has been paralleled. + # it is set to the opposite of `opt['t5_model_parallel]` + parallel = hasattr(self, 'paralleled') and not self.paralleled + if torch.cuda.is_available() and parallel: torch.cuda.set_device('cuda:0') ret = func(*args, **kwargs) - if torch.cuda.is_available(): + if torch.cuda.is_available() and parallel: torch.cuda.set_device('cuda:0') return ret @@ -293,6 +297,7 @@ def __init__(self, opt, dictionary): self.t5 = build_t5(opt) self.encoder = ParlaiT5Encoder(opt, self.t5.get_encoder(), self.pad_idx) self.decoder = ParlaiT5Decoder(opt, self.t5.get_decoder(), self.pad_idx) + self.paralleled = not opt['t5_model_parallel'] @set_device def _get_initial_forced_decoder_input(self, bsz: int, inputs: torch.LongTensor): diff --git a/parlai/agents/rag/modules.py b/parlai/agents/rag/modules.py index 573f4995161..4f4a0bc4f0b 100644 --- a/parlai/agents/rag/modules.py +++ b/parlai/agents/rag/modules.py @@ -535,6 +535,7 @@ def __init__(self, opt, dictionary, retriever_shared=None): super().__init__(opt, dictionary, retriever_shared) self.embedding_size = opt['t5'].model_dim self.t5 = opt.pop('t5', None) + self.paralleled = not opt['t5_model_parallel'] @classmethod def build_encoder( diff --git a/tests/nightly/gpu/test_t5.py b/tests/nightly/gpu/test_t5.py index 4f01443bed4..1f888a3a73c 100644 --- a/tests/nightly/gpu/test_t5.py +++ b/tests/nightly/gpu/test_t5.py @@ -34,6 +34,8 @@ from parlai.utils.torch import padded_tensor from parlai.utils.testing import tempdir +from tests.test_distributed import _AbstractTest + device = 'cpu' if not torch.cuda.is_available() else 'cuda:0' @@ -262,5 +264,31 @@ def test_t5_model_parallel(self): ) +@testing_utils.skipUnlessGPU +class TestT5Distributed(_AbstractTest): + base_config = dict( + task='integration_tests:overfit', + model='hugging_face/t5', + optimizer='adam', + batchsize=1, + num_epochs=50, + short_final_eval=True, + validation_max_exs=12, + t5_model_arch='t5-small', + validation_metric='ppl', + skip_generation=True, + learningrate=1e-2, + validation_every_n_epochs=25, + verbose=True, + save_after_valid=False, + ) + + def test_t5_distributed(self): + valid, test = self._distributed_train_model() + + self.assertLessEqual(valid['ppl'], 1.60) + self.assertLessEqual(test['ppl'], 1.60) + + if __name__ == '__main__': unittest.main()