Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Coverage for more scripts (#3110)
Browse files Browse the repository at this point in the history
* Test profile_ scripts

* Add test for distributed_eval

* More scripts.

* Lint.

* Also check -t self_chat

* Whoops, gotta de-init

* Update docstring
  • Loading branch information
stephenroller authored Sep 24, 2020
1 parent 62a3f45 commit 1c43b85
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 26 deletions.
29 changes: 5 additions & 24 deletions parlai/scripts/profile_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
few of them:
```shell
parlai profile_train -t babi:task1k:1 -m seq2seq -e 0.1 --dict-file /tmp/dict
parlai profile_train -t babi:task1k:1 -m seq2seq --dict-file /tmp/dict
```
"""

Expand All @@ -23,7 +23,6 @@
import parlai.utils.logging as logging
import cProfile
import io
import pdb
import pstats

try:
Expand Down Expand Up @@ -55,35 +54,19 @@ def setup_args(parser=None):
default=False,
help='If true, enter debugger at end of run.',
)
profile.set_defaults(num_epochs=1)
return parser


def profile(opt):
if opt['torch'] or opt['torch_cuda']:
with torch.autograd.profiler.profile(use_cuda=opt['torch_cuda']) as prof:
TrainLoop(opt).train()
print(prof.total_average())

sort_cpu = sorted(prof.key_averages(), key=lambda k: k.cpu_time)
sort_cuda = sorted(prof.key_averages(), key=lambda k: k.cuda_time)
key = 'cpu_time_total' if opt['torch'] else 'cuda_time_total'
print(prof.key_averages().table(sort_by=key, row_limit=25))

def cpu():
for e in sort_cpu:
print(e)

def cuda():
for e in sort_cuda:
print(e)

cpu()

if opt['debug']:
print(
'`cpu()` prints out cpu-sorted list, '
'`cuda()` prints cuda-sorted list'
)

pdb.set_trace()
return prof
else:
pr = cProfile.Profile()
pr.enable()
Expand All @@ -94,8 +77,6 @@ def cuda():
ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
ps.print_stats()
print(s.getvalue())
if opt['debug']:
pdb.set_trace()


@register_script('profile_train', hidden=True)
Expand Down
4 changes: 3 additions & 1 deletion parlai/scripts/token_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import numpy as np
from parlai.core.script import ParlaiScript
from parlai.core.script import ParlaiScript, register_script
from parlai.core.agents import create_agent
from parlai.core.torch_agent import TorchAgent
from parlai.core.worlds import create_task
Expand All @@ -14,6 +14,7 @@
import parlai.utils.logging as logging


@register_script("token_stats", hidden=True)
class TokenStats(ParlaiScript):
@classmethod
def setup_args(cls):
Expand Down Expand Up @@ -90,6 +91,7 @@ def run(self):

report = self._compute_stats(lengths)
print(nice_report(report))
return report


if __name__ == '__main__':
Expand Down
2 changes: 1 addition & 1 deletion parlai/scripts/vacuum.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import parlai.utils.logging as logging


@register_script("vacuum")
@register_script("vacuum", hidden=True)
class Vacuum(ParlaiScript):
@classmethod
def setup_args(cls):
Expand Down
24 changes: 24 additions & 0 deletions tests/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,5 +206,29 @@ def test_no_model_parallel(self):
dist.destroy_process_group()


@testing_utils.skipUnlessGPU
class TestDistributedEval(unittest.TestCase):
def test_mp_eval(self):
args = dict(
task='integration_tests:multiturn_nocandidate',
model='seq2seq',
model_file='zoo:unittest/seq2seq/model',
dict_file='zoo:unittest/seq2seq/model.dict',
skip_generation=False,
batchsize=8,
)
valid, _ = testing_utils.eval_model(args, skip_test=True)

from parlai.scripts.multiprocessing_eval import MultiProcessEval

valid_mp = MultiProcessEval.main(**args)

for key in ['exs', 'ppl', 'token_acc', 'f1', 'bleu-4', 'accuracy']:
self.assertAlmostEquals(
valid[key].value(), valid_mp[key].value(), delta=0.001
)
dist.destroy_process_group()


if __name__ == '__main__':
unittest.main()
9 changes: 9 additions & 0 deletions tests/test_interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,5 +84,14 @@ def _run_test_repeat(self, tmpdir: str, fake_input: FakeInput):
self.assertEqual(len(entry), 2 * fake_input.max_turns)


class TestProfileInteractive(unittest.TestCase):
def test_profile_interactive(self):
from parlai.scripts.profile_interactive import ProfileInteractive

fake_input = FakeInput(max_episodes=2)
with mock.patch('builtins.input', new=fake_input):
ProfileInteractive.main(model='repeat_query')


if __name__ == '__main__':
unittest.main()
66 changes: 66 additions & 0 deletions tests/test_other_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,69 @@ def test_party(self):
from parlai.scripts.party import Party

Party.main(seconds=0.01)


class TestProfileTrain(unittest.TestCase):
"""
Test profile_train doesn't crash.
"""

def test_cprofile(self):
from parlai.scripts.profile_train import ProfileTrain

with testing_utils.tempdir() as tmpdir:
ProfileTrain.main(
task='integration_tests:overfit',
model='test_agents/unigram',
model_file=os.path.join(tmpdir, 'model'),
skip_generation=True,
)

def test_torch(self):
from parlai.scripts.profile_train import ProfileTrain

with testing_utils.tempdir() as tmpdir:
ProfileTrain.main(
task='integration_tests:overfit',
model='test_agents/unigram',
torch=True,
model_file=os.path.join(tmpdir, 'model'),
skip_generation=True,
)

@testing_utils.skipUnlessGPU
def test_torch_cuda(self):
from parlai.scripts.profile_train import ProfileTrain

with testing_utils.tempdir() as tmpdir:
ProfileTrain.main(
task='integration_tests:overfit',
model='test_agents/unigram',
torch_cuda=True,
model_file=os.path.join(tmpdir, 'model'),
skip_generation=True,
)


class TestTokenStats(unittest.TestCase):
def test_token_stats(self):
from parlai.scripts.token_stats import TokenStats
from parlai.core.metrics import dict_report

results = dict_report(TokenStats.main(task='integration_tests:multiturn'))
assert results == {
'exs': 2000,
'max': 16,
'mean': 7.5,
'min': 1,
'p01': 1,
'p05': 1,
'p10': 1,
'p25': 4,
'p50': 7.5,
'p75': 11.5,
'p90': 16,
'p95': 16,
'p99': 16,
'p@128': 1,
}
6 changes: 6 additions & 0 deletions tests/test_self_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,9 @@ def test_convai2(self):
]
)
self_chat.self_chat(opt)

def test_no_plain_teacher(self):
from parlai.scripts.display_data import DisplayData

with self.assertRaises(RuntimeError):
DisplayData.main(task='self_chat')

0 comments on commit 1c43b85

Please sign in to comment.