diff --git a/parlai/scripts/multiprocessing_train.py b/parlai/scripts/multiprocessing_train.py index 302bb416908..e6dad8a29d9 100644 --- a/parlai/scripts/multiprocessing_train.py +++ b/parlai/scripts/multiprocessing_train.py @@ -82,10 +82,15 @@ def setup_args(): class MultiProcessTrain(ParlaiScript): @classmethod def setup_args(cls): - return setup_args() + argparser = setup_args() + argparser.add_argument('--port', type=int, default=None) + return argparser def run(self): - port = random.randint(32000, 48000) + if self.opt['port'] is None: + port = random.randint(32000, 48000) + else: + port = self.opt['port'] return launch_and_train(self.opt, port) diff --git a/tests/test_chunkteacher.py b/tests/test_chunkteacher.py index f98f5fa7009..8f85f156aa3 100644 --- a/tests/test_chunkteacher.py +++ b/tests/test_chunkteacher.py @@ -15,13 +15,13 @@ import parlai.scripts.multiprocessing_train as mp_train -class TestNumExamples(TestCase): +class _Abstract(TestCase): BASE_ARGS = { 'model': 'test_agents/counter', 'dict_file': 'zoo:unittest/transformer_generator2/model.dict', 'dict_tokenizer': 'space', 'truncate': 8, - 'num_epochs': 2, + 'max_train_steps': 10, 'datatype': 'train:stream', } @@ -47,14 +47,12 @@ def _run_mp(self, **kwargs): assert test_report['times_seen'] == 1 return valid_report, test_report - # Regular chunk teacher +class TestNumExamples(_Abstract): + # Regular chunk teacher def test_normal_bs1(self): self._run(task='integration_tests:chunky') - def test_normal_bs2(self): - self._run(task='integration_tests:chunky', batchsize=2) - def test_normal_bs3(self): self._run(task='integration_tests:chunky', batchsize=3) @@ -68,11 +66,7 @@ def test_normal_batchsort(self): @testing_utils.skipUnlessGPU def test_mp_normal_bs1(self): - self._run_mp(task='integration_tests:chunky') - - @testing_utils.skipUnlessGPU - def test_mp_normal_bs2(self): - self._run_mp(task='integration_tests:chunky', batchsize=2) + self._run_mp(task='integration_tests:chunky', batchsize=1) @testing_utils.skipUnlessGPU def test_mp_normal_bs3(self): @@ -84,20 +78,12 @@ def test_mp_normal_dynb(self): task='integration_tests:chunky', batchsize=2, dynamic_batching='full' ) - @testing_utils.skipUnlessGPU - def test_mp_normal_batchsort(self): - self._run_mp( - task='integration_tests:chunky', batchsize=2, dynamic_batching='batchsort' - ) +class TestSmallBuffer(_Abstract): # Small buffer - def test_small_buffer_bs1(self): self._run(task='integration_tests:chunky_small_buffer') - def test_small_buffer_bs2(self): - self._run(task='integration_tests:chunky_small_buffer', batchsize=2) - def test_small_buffer_bs3(self): self._run(task='integration_tests:chunky_small_buffer', batchsize=3) @@ -119,10 +105,6 @@ def test_small_buffer_batchsort(self): def test_mp_small_buffer_bs1(self): self._run_mp(task='integration_tests:chunky_small_buffer') - @testing_utils.skipUnlessGPU - def test_mp_small_buffer_bs2(self): - self._run_mp(task='integration_tests:chunky_small_buffer', batchsize=2) - @testing_utils.skipUnlessGPU def test_mp_small_buffer_bs3(self): self._run_mp(task='integration_tests:chunky_small_buffer', batchsize=3) @@ -143,14 +125,9 @@ def test_mp_small_buffer_batchsort(self): dynamic_batching='batchsort', ) - # Slow chunk - - def test_slow_bs1(self): - self._run(task='integration_tests:chunky_slow') - - def test_slow_bs2(self): - self._run(task='integration_tests:chunky_slow', batchsize=2) +class TestSlowChunk(_Abstract): + # Slow chunk def test_slow_bs3(self): self._run(task='integration_tests:chunky_slow', batchsize=3) @@ -159,39 +136,6 @@ def test_slow_dynb(self): task='integration_tests:chunky_slow', batchsize=2, dynamic_batching='full' ) - def test_slow_batchsort(self): - self._run( - task='integration_tests:chunky_slow', - batchsize=2, - dynamic_batching='batchsort', - ) - - @testing_utils.skipUnlessGPU - def test_mp_slow_bs1(self): - self._run_mp(task='integration_tests:chunky_slow') - - @testing_utils.skipUnlessGPU - def test_mp_slow_bs2(self): - self._run_mp(task='integration_tests:chunky_slow', batchsize=2) - - @testing_utils.skipUnlessGPU - def test_mp_slow_bs3(self): - self._run_mp(task='integration_tests:chunky_slow', batchsize=3) - - @testing_utils.skipUnlessGPU - def test_mp_slow_dynb(self): - self._run_mp( - task='integration_tests:chunky_slow', batchsize=2, dynamic_batching='full' - ) - - @testing_utils.skipUnlessGPU - def test_mp_slow_batchsort(self): - self._run_mp( - task='integration_tests:chunky_slow', - batchsize=2, - dynamic_batching='batchsort', - ) - class TestBackgroundPreprocessorNumExamples(TestNumExamples): BASE_ARGS = { @@ -199,7 +143,7 @@ class TestBackgroundPreprocessorNumExamples(TestNumExamples): 'dict_file': 'zoo:unittest/transformer_generator2/model.dict', 'dict_tokenizer': 'space', 'truncate': 8, - 'num_epochs': 2, + 'max_train_steps': 10, 'datatype': 'train:stream', 'num_workers': 4, }