Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
t5 ddp gen (#4505)
Browse files Browse the repository at this point in the history
  • Loading branch information
klshuster authored May 2, 2022
1 parent 5dc40b4 commit 97cf663
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
6 changes: 5 additions & 1 deletion parlai/agents/hugging_face/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from parlai.core.params import ParlaiParser
from parlai.core.torch_agent import Batch, TorchAgent
from parlai.core.torch_generator_agent import TorchGeneratorAgent, TorchGeneratorModel
from parlai.utils.fsdp import is_fsdp


def check_hf_version(v: Tuple[int, int]) -> bool:
Expand Down Expand Up @@ -198,7 +199,10 @@ def _generate(
if overrides:
generation_params.update(overrides)

outputs = self.model.t5.generate(**generation_params)
model = self.model
if hasattr(self.model, 'module') and not is_fsdp(self.model):
model = self.model.module
outputs = model.t5.generate(**generation_params)
outputs = [(outputs[i], 0, None) for i in range(outputs.size(0))]
return outputs, []

Expand Down
23 changes: 23 additions & 0 deletions tests/nightly/gpu/test_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,5 +290,28 @@ def test_t5_distributed(self):
self.assertLessEqual(test['ppl'], 1.60)


@testing_utils.skipUnlessGPU
class TestT5DistributedWithGen(_AbstractTest):
base_config = dict(
task='integration_tests:overfit',
model='hugging_face/t5',
optimizer='adam',
batchsize=1,
num_epochs=1,
short_final_eval=True,
validation_max_exs=12,
t5_model_arch='t5-small',
validation_metric='ppl',
skip_generation=False,
learningrate=1e-3,
verbose=True,
save_after_valid=False,
)

def test_t5_distributed(self):
# just testing this runs appropriately
valid, test = self._distributed_train_model()


if __name__ == '__main__':
unittest.main()

0 comments on commit 97cf663

Please sign in to comment.