Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
[tests] Speed up test_chunkteacher (#3606)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenroller authored Apr 19, 2021
1 parent 1fa2029 commit d83cd22
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 67 deletions.
9 changes: 7 additions & 2 deletions parlai/scripts/multiprocessing_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,15 @@ def setup_args():
class MultiProcessTrain(ParlaiScript):
@classmethod
def setup_args(cls):
return setup_args()
argparser = setup_args()
argparser.add_argument('--port', type=int, default=None)
return argparser

def run(self):
port = random.randint(32000, 48000)
if self.opt['port'] is None:
port = random.randint(32000, 48000)
else:
port = self.opt['port']
return launch_and_train(self.opt, port)


Expand Down
74 changes: 9 additions & 65 deletions tests/test_chunkteacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
import parlai.scripts.multiprocessing_train as mp_train


class TestNumExamples(TestCase):
class _Abstract(TestCase):
BASE_ARGS = {
'model': 'test_agents/counter',
'dict_file': 'zoo:unittest/transformer_generator2/model.dict',
'dict_tokenizer': 'space',
'truncate': 8,
'num_epochs': 2,
'max_train_steps': 10,
'datatype': 'train:stream',
}

Expand All @@ -47,14 +47,12 @@ def _run_mp(self, **kwargs):
assert test_report['times_seen'] == 1
return valid_report, test_report

# Regular chunk teacher

class TestNumExamples(_Abstract):
# Regular chunk teacher
def test_normal_bs1(self):
self._run(task='integration_tests:chunky')

def test_normal_bs2(self):
self._run(task='integration_tests:chunky', batchsize=2)

def test_normal_bs3(self):
self._run(task='integration_tests:chunky', batchsize=3)

Expand All @@ -68,11 +66,7 @@ def test_normal_batchsort(self):

@testing_utils.skipUnlessGPU
def test_mp_normal_bs1(self):
self._run_mp(task='integration_tests:chunky')

@testing_utils.skipUnlessGPU
def test_mp_normal_bs2(self):
self._run_mp(task='integration_tests:chunky', batchsize=2)
self._run_mp(task='integration_tests:chunky', batchsize=1)

@testing_utils.skipUnlessGPU
def test_mp_normal_bs3(self):
Expand All @@ -84,20 +78,12 @@ def test_mp_normal_dynb(self):
task='integration_tests:chunky', batchsize=2, dynamic_batching='full'
)

@testing_utils.skipUnlessGPU
def test_mp_normal_batchsort(self):
self._run_mp(
task='integration_tests:chunky', batchsize=2, dynamic_batching='batchsort'
)

class TestSmallBuffer(_Abstract):
# Small buffer

def test_small_buffer_bs1(self):
self._run(task='integration_tests:chunky_small_buffer')

def test_small_buffer_bs2(self):
self._run(task='integration_tests:chunky_small_buffer', batchsize=2)

def test_small_buffer_bs3(self):
self._run(task='integration_tests:chunky_small_buffer', batchsize=3)

Expand All @@ -119,10 +105,6 @@ def test_small_buffer_batchsort(self):
def test_mp_small_buffer_bs1(self):
self._run_mp(task='integration_tests:chunky_small_buffer')

@testing_utils.skipUnlessGPU
def test_mp_small_buffer_bs2(self):
self._run_mp(task='integration_tests:chunky_small_buffer', batchsize=2)

@testing_utils.skipUnlessGPU
def test_mp_small_buffer_bs3(self):
self._run_mp(task='integration_tests:chunky_small_buffer', batchsize=3)
Expand All @@ -143,14 +125,9 @@ def test_mp_small_buffer_batchsort(self):
dynamic_batching='batchsort',
)

# Slow chunk

def test_slow_bs1(self):
self._run(task='integration_tests:chunky_slow')

def test_slow_bs2(self):
self._run(task='integration_tests:chunky_slow', batchsize=2)

class TestSlowChunk(_Abstract):
# Slow chunk
def test_slow_bs3(self):
self._run(task='integration_tests:chunky_slow', batchsize=3)

Expand All @@ -159,47 +136,14 @@ def test_slow_dynb(self):
task='integration_tests:chunky_slow', batchsize=2, dynamic_batching='full'
)

def test_slow_batchsort(self):
self._run(
task='integration_tests:chunky_slow',
batchsize=2,
dynamic_batching='batchsort',
)

@testing_utils.skipUnlessGPU
def test_mp_slow_bs1(self):
self._run_mp(task='integration_tests:chunky_slow')

@testing_utils.skipUnlessGPU
def test_mp_slow_bs2(self):
self._run_mp(task='integration_tests:chunky_slow', batchsize=2)

@testing_utils.skipUnlessGPU
def test_mp_slow_bs3(self):
self._run_mp(task='integration_tests:chunky_slow', batchsize=3)

@testing_utils.skipUnlessGPU
def test_mp_slow_dynb(self):
self._run_mp(
task='integration_tests:chunky_slow', batchsize=2, dynamic_batching='full'
)

@testing_utils.skipUnlessGPU
def test_mp_slow_batchsort(self):
self._run_mp(
task='integration_tests:chunky_slow',
batchsize=2,
dynamic_batching='batchsort',
)


class TestBackgroundPreprocessorNumExamples(TestNumExamples):
BASE_ARGS = {
'model': 'test_agents/counter',
'dict_file': 'zoo:unittest/transformer_generator2/model.dict',
'dict_tokenizer': 'space',
'truncate': 8,
'num_epochs': 2,
'max_train_steps': 10,
'datatype': 'train:stream',
'num_workers': 4,
}

0 comments on commit d83cd22

Please sign in to comment.