This repository has been archived by the owner on Nov 3, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Fix a chunk teacher bug * Mad tests for chunk teacher. * Skip unless GPU * Autoformat * Improve comments. * Drain shouldn't be there.
- Loading branch information
1 parent
c639f06
commit 7da15f7
Showing
6 changed files
with
342 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
""" | ||
Test agent which counts its number of unique items. | ||
""" | ||
|
||
from __future__ import annotations | ||
from typing import Tuple | ||
from collections import Counter | ||
|
||
from parlai.core.torch_agent import TorchAgent | ||
from parlai.core.metrics import Metric, SumMetric | ||
from parlai.core.message import Message | ||
|
||
|
||
class _CounterMetric(Metric): | ||
__slots__ = ('_counter',) | ||
|
||
def __init__(self, counter: Counter): | ||
self._counter = counter | ||
|
||
def __add__(self, other: Metric): | ||
if other is None: | ||
return self | ||
counter = self._counter + other._counter | ||
return type(self)(counter) | ||
|
||
|
||
class TimesSeenMetric(_CounterMetric): | ||
""" | ||
Max number of times any example was seen. | ||
""" | ||
|
||
def value(self) -> int: | ||
if not self._counter: | ||
return 0 | ||
return max(self._counter.values()) | ||
|
||
|
||
class UniqueMetric(_CounterMetric): | ||
""" | ||
Number of unique utterances. | ||
""" | ||
|
||
def value(self) -> int: | ||
if not self._counter: | ||
return 0 | ||
return len(self._counter) | ||
|
||
|
||
class CounterAgent(TorchAgent): | ||
""" | ||
Simple agent that counts the number of unique things it has seen. | ||
Could be simpler, but we make it a TorchAgent so it's happy with dynamic batching. | ||
""" | ||
|
||
def __init__(self, opt, shared=None): | ||
self.model = self.build_model() | ||
self.criterion = None | ||
super().__init__(opt, shared) | ||
self._counter: Counter[Tuple[str, str]] | ||
if shared is None: | ||
self._counter = Counter() | ||
self._padding_counter = Counter() | ||
else: | ||
self._counter = shared['counter'] | ||
self._padding_counter = shared['padding'] | ||
|
||
def share(self): | ||
shared = super().share() | ||
shared['counter'] = self._counter | ||
shared['padding'] = self._padding_counter | ||
return shared | ||
|
||
def save(self, path=None): | ||
pass | ||
|
||
def load(self, path=None): | ||
pass | ||
|
||
def _val(self, val): | ||
""" | ||
Pull out a singleton value if provided a list. | ||
""" | ||
# necessary for labels | ||
if isinstance(val, (tuple, list)): | ||
return val[0] | ||
else: | ||
return val | ||
|
||
def build_model(self): | ||
return None | ||
|
||
def train_step(self): | ||
pass | ||
|
||
def eval_step(self): | ||
pass | ||
|
||
def _to_tuple(self, msg: Message) -> Tuple: | ||
# turned into an indexable object | ||
keys = ['text', 'labels', 'eval_labels'] | ||
return tuple(self._val(msg.get(k)) for k in keys) | ||
|
||
def batch_act(self, observations): | ||
self._padding_counter.update(['val' for o in observations if o.is_padding()]) | ||
self._counter.update( | ||
[self._to_tuple(o) for o in observations if not o.is_padding()] | ||
) | ||
return [Message() for o in observations] | ||
|
||
def reset(self): | ||
self._counter.clear() | ||
|
||
def report(self): | ||
report = {} | ||
report['num_pad'] = SumMetric(self._padding_counter.get('val', 0)) | ||
report['unique'] = UniqueMetric(self._counter) | ||
report['times_seen'] = TimesSeenMetric(self._counter) | ||
return report |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
#!/usr/bin/env python3 | ||
|
||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
""" | ||
Test correctness of ChunkTeacher in a large number of settings. | ||
""" | ||
|
||
from unittest import TestCase | ||
import os | ||
import parlai.utils.testing as testing_utils | ||
from parlai.tasks.integration_tests.agents import NUM_TEST | ||
import torch.distributed as dist | ||
import parlai.scripts.multiprocessing_train as mp_train | ||
|
||
BASE_ARGS = { | ||
'model': 'test_agents/counter', | ||
'dict_file': 'zoo:unittest/transformer_generator2/model.dict', | ||
'dict_tokenizer': 'space', | ||
'truncate': 8, | ||
'num_epochs': 2, | ||
'datatype': 'train:stream', | ||
} | ||
|
||
|
||
class TestNumExamples(TestCase): | ||
def _run(self, **kwargs): | ||
opt = {**BASE_ARGS, **kwargs} | ||
valid_report, test_report = testing_utils.train_model(opt) | ||
assert valid_report['unique'] == NUM_TEST | ||
assert valid_report['times_seen'] == 1 | ||
assert test_report['unique'] == NUM_TEST | ||
assert test_report['times_seen'] == 1 | ||
return valid_report, test_report | ||
|
||
def _run_mp(self, **kwargs): | ||
opt = {**BASE_ARGS, **kwargs} | ||
with testing_utils.tempdir() as tmpdir: | ||
if 'model_file' not in opt: | ||
opt['model_file'] = os.path.join(tmpdir, 'model') | ||
|
||
valid_report, test_report = mp_train.MultiProcessTrain.main(**opt) | ||
dist.destroy_process_group() | ||
assert valid_report['unique'] == NUM_TEST | ||
assert valid_report['times_seen'] == 1 | ||
assert test_report['unique'] == NUM_TEST | ||
assert test_report['times_seen'] == 1 | ||
return valid_report, test_report | ||
|
||
# Regular chunk teacher | ||
|
||
def test_normal_bs1(self): | ||
self._run(task='integration_tests:chunky') | ||
|
||
def test_normal_bs2(self): | ||
self._run(task='integration_tests:chunky', batchsize=2) | ||
|
||
def test_normal_bs3(self): | ||
self._run(task='integration_tests:chunky', batchsize=3) | ||
|
||
def test_normal_dynb(self): | ||
self._run(task='integration_tests:chunky', batchsize=2, dynamic_batching='full') | ||
|
||
def test_normal_batchsort(self): | ||
self._run( | ||
task='integration_tests:chunky', batchsize=2, dynamic_batching='batchsort' | ||
) | ||
|
||
@testing_utils.skipUnlessGPU | ||
def test_mp_normal_bs1(self): | ||
self._run_mp(task='integration_tests:chunky') | ||
|
||
@testing_utils.skipUnlessGPU | ||
def test_mp_normal_bs2(self): | ||
self._run_mp(task='integration_tests:chunky', batchsize=2) | ||
|
||
@testing_utils.skipUnlessGPU | ||
def test_mp_normal_bs3(self): | ||
self._run_mp(task='integration_tests:chunky', batchsize=3) | ||
|
||
@testing_utils.skipUnlessGPU | ||
def test_mp_normal_dynb(self): | ||
self._run_mp( | ||
task='integration_tests:chunky', batchsize=2, dynamic_batching='full' | ||
) | ||
|
||
@testing_utils.skipUnlessGPU | ||
def test_mp_normal_batchsort(self): | ||
self._run_mp( | ||
task='integration_tests:chunky', batchsize=2, dynamic_batching='batchsort' | ||
) | ||
|
||
# Small buffer | ||
|
||
def test_small_buffer_bs1(self): | ||
self._run(task='integration_tests:chunky_small_buffer') | ||
|
||
def test_small_buffer_bs2(self): | ||
self._run(task='integration_tests:chunky_small_buffer', batchsize=2) | ||
|
||
def test_small_buffer_bs3(self): | ||
self._run(task='integration_tests:chunky_small_buffer', batchsize=3) | ||
|
||
def test_small_buffer_dynb(self): | ||
self._run( | ||
task='integration_tests:chunky_small_buffer', | ||
batchsize=2, | ||
dynamic_batching='full', | ||
) | ||
|
||
def test_small_buffer_batchsort(self): | ||
self._run( | ||
task='integration_tests:chunky_small_buffer', | ||
batchsize=2, | ||
dynamic_batching='batchsort', | ||
) | ||
|
||
@testing_utils.skipUnlessGPU | ||
def test_mp_small_buffer_bs1(self): | ||
self._run_mp(task='integration_tests:chunky_small_buffer') | ||
|
||
@testing_utils.skipUnlessGPU | ||
def test_mp_small_buffer_bs2(self): | ||
self._run_mp(task='integration_tests:chunky_small_buffer', batchsize=2) | ||
|
||
@testing_utils.skipUnlessGPU | ||
def test_mp_small_buffer_bs3(self): | ||
self._run_mp(task='integration_tests:chunky_small_buffer', batchsize=3) | ||
|
||
@testing_utils.skipUnlessGPU | ||
def test_mp_small_buffer_dynb(self): | ||
self._run_mp( | ||
task='integration_tests:chunky_small_buffer', | ||
batchsize=2, | ||
dynamic_batching='full', | ||
) | ||
|
||
@testing_utils.skipUnlessGPU | ||
def test_mp_small_buffer_batchsort(self): | ||
self._run_mp( | ||
task='integration_tests:chunky_small_buffer', | ||
batchsize=2, | ||
dynamic_batching='batchsort', | ||
) | ||
|
||
# Slow chunk | ||
|
||
def test_slow_bs1(self): | ||
self._run(task='integration_tests:chunky_slow') | ||
|
||
def test_slow_bs2(self): | ||
self._run(task='integration_tests:chunky_slow', batchsize=2) | ||
|
||
def test_slow_bs3(self): | ||
self._run(task='integration_tests:chunky_slow', batchsize=3) | ||
|
||
def test_slow_dynb(self): | ||
self._run( | ||
task='integration_tests:chunky_slow', batchsize=2, dynamic_batching='full' | ||
) | ||
|
||
def test_slow_batchsort(self): | ||
self._run( | ||
task='integration_tests:chunky_slow', | ||
batchsize=2, | ||
dynamic_batching='batchsort', | ||
) | ||
|
||
@testing_utils.skipUnlessGPU | ||
def test_mp_slow_bs1(self): | ||
self._run_mp(task='integration_tests:chunky_slow') | ||
|
||
@testing_utils.skipUnlessGPU | ||
def test_mp_slow_bs2(self): | ||
self._run_mp(task='integration_tests:chunky_slow', batchsize=2) | ||
|
||
@testing_utils.skipUnlessGPU | ||
def test_mp_slow_bs3(self): | ||
self._run_mp(task='integration_tests:chunky_slow', batchsize=3) | ||
|
||
@testing_utils.skipUnlessGPU | ||
def test_mp_slow_dynb(self): | ||
self._run_mp( | ||
task='integration_tests:chunky_slow', batchsize=2, dynamic_batching='full' | ||
) | ||
|
||
@testing_utils.skipUnlessGPU | ||
def test_mp_slow_batchsort(self): | ||
self._run_mp( | ||
task='integration_tests:chunky_slow', | ||
batchsize=2, | ||
dynamic_batching='batchsort', | ||
) |
Oops, something went wrong.