Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[tests] Speed up test_chunkteacher #3606

Merged
merged 3 commits into from
Apr 19, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions parlai/scripts/multiprocessing_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,15 @@ def setup_args():
class MultiProcessTrain(ParlaiScript):
@classmethod
def setup_args(cls):
return setup_args()
argparser = setup_args()
argparser.add_argument('--port', type=int, default=None)
return argparser

def run(self):
port = random.randint(32000, 48000)
if self.opt['port'] is None:
port = random.randint(32000, 48000)
else:
port = self.opt['port']
return launch_and_train(self.opt, port)


Expand Down
77 changes: 9 additions & 68 deletions tests/test_chunkteacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,21 @@
Test correctness of ChunkTeacher in a large number of settings.
"""

import random
from unittest import TestCase
import os
import parlai.utils.testing as testing_utils
from parlai.tasks.integration_tests.agents import NUM_TEST
import parlai.scripts.multiprocessing_train as mp_train


class TestNumExamples(TestCase):
class _Abstract(TestCase):
BASE_ARGS = {
'model': 'test_agents/counter',
'dict_file': 'zoo:unittest/transformer_generator2/model.dict',
'dict_tokenizer': 'space',
'truncate': 8,
'num_epochs': 2,
'max_train_steps': 10,
'datatype': 'train:stream',
}

Expand All @@ -47,14 +48,12 @@ def _run_mp(self, **kwargs):
assert test_report['times_seen'] == 1
return valid_report, test_report

# Regular chunk teacher

class TestNumExamples(_Abstract):
# Regular chunk teacher
def test_normal_bs1(self):
self._run(task='integration_tests:chunky')

def test_normal_bs2(self):
self._run(task='integration_tests:chunky', batchsize=2)

def test_normal_bs3(self):
self._run(task='integration_tests:chunky', batchsize=3)

Expand All @@ -66,14 +65,6 @@ def test_normal_batchsort(self):
task='integration_tests:chunky', batchsize=2, dynamic_batching='batchsort'
)

@testing_utils.skipUnlessGPU
def test_mp_normal_bs1(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe it's redundant with the others. We have a non-multprocessing batchsize 1, and we have multiprocessing with other batchsizes. The explicit combination seems unnecessary.

Same goes for some of the other things I filtered out. I decided to test some things in isolation, rather than the combinatorial explosion of options.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think the remaining test would catch sth like distributed_world_size > 1 while batchsize = 1 (is this sth that is necessary to test?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I don't think they will. I can bring that one back if you'd like.

self._run_mp(task='integration_tests:chunky')

@testing_utils.skipUnlessGPU
def test_mp_normal_bs2(self):
self._run_mp(task='integration_tests:chunky', batchsize=2)

@testing_utils.skipUnlessGPU
def test_mp_normal_bs3(self):
self._run_mp(task='integration_tests:chunky', batchsize=3)
Expand All @@ -84,20 +75,12 @@ def test_mp_normal_dynb(self):
task='integration_tests:chunky', batchsize=2, dynamic_batching='full'
)

@testing_utils.skipUnlessGPU
def test_mp_normal_batchsort(self):
self._run_mp(
task='integration_tests:chunky', batchsize=2, dynamic_batching='batchsort'
)

class TestSmallBuffer(_Abstract):
# Small buffer

def test_small_buffer_bs1(self):
self._run(task='integration_tests:chunky_small_buffer')

def test_small_buffer_bs2(self):
self._run(task='integration_tests:chunky_small_buffer', batchsize=2)

def test_small_buffer_bs3(self):
self._run(task='integration_tests:chunky_small_buffer', batchsize=3)

Expand All @@ -119,10 +102,6 @@ def test_small_buffer_batchsort(self):
def test_mp_small_buffer_bs1(self):
self._run_mp(task='integration_tests:chunky_small_buffer')

@testing_utils.skipUnlessGPU
def test_mp_small_buffer_bs2(self):
self._run_mp(task='integration_tests:chunky_small_buffer', batchsize=2)

@testing_utils.skipUnlessGPU
def test_mp_small_buffer_bs3(self):
self._run_mp(task='integration_tests:chunky_small_buffer', batchsize=3)
Expand All @@ -143,14 +122,9 @@ def test_mp_small_buffer_batchsort(self):
dynamic_batching='batchsort',
)

# Slow chunk

def test_slow_bs1(self):
self._run(task='integration_tests:chunky_slow')

def test_slow_bs2(self):
self._run(task='integration_tests:chunky_slow', batchsize=2)

class TestSlowChunk(_Abstract):
# Slow chunk
def test_slow_bs3(self):
self._run(task='integration_tests:chunky_slow', batchsize=3)

Expand All @@ -159,47 +133,14 @@ def test_slow_dynb(self):
task='integration_tests:chunky_slow', batchsize=2, dynamic_batching='full'
)

def test_slow_batchsort(self):
self._run(
task='integration_tests:chunky_slow',
batchsize=2,
dynamic_batching='batchsort',
)

@testing_utils.skipUnlessGPU
def test_mp_slow_bs1(self):
self._run_mp(task='integration_tests:chunky_slow')

@testing_utils.skipUnlessGPU
def test_mp_slow_bs2(self):
self._run_mp(task='integration_tests:chunky_slow', batchsize=2)

@testing_utils.skipUnlessGPU
def test_mp_slow_bs3(self):
self._run_mp(task='integration_tests:chunky_slow', batchsize=3)

@testing_utils.skipUnlessGPU
def test_mp_slow_dynb(self):
self._run_mp(
task='integration_tests:chunky_slow', batchsize=2, dynamic_batching='full'
)

@testing_utils.skipUnlessGPU
def test_mp_slow_batchsort(self):
self._run_mp(
task='integration_tests:chunky_slow',
batchsize=2,
dynamic_batching='batchsort',
)


class TestBackgroundPreprocessorNumExamples(TestNumExamples):
BASE_ARGS = {
'model': 'test_agents/counter',
'dict_file': 'zoo:unittest/transformer_generator2/model.dict',
'dict_tokenizer': 'space',
'truncate': 8,
'num_epochs': 2,
'max_train_steps': 10,
'datatype': 'train:stream',
'num_workers': 4,
}