Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
[ChunkTeacher] Fix a hang (#3549)
Browse files Browse the repository at this point in the history
* Fix a chunk teacher bug

* Mad tests for chunk teacher.

* Skip unless GPU

* Autoformat

* Improve comments.

* Drain shouldn't be there.
  • Loading branch information
stephenroller authored Mar 31, 2021
1 parent c639f06 commit 7da15f7
Show file tree
Hide file tree
Showing 6 changed files with 342 additions and 57 deletions.
125 changes: 125 additions & 0 deletions parlai/agents/test_agents/counter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Test agent which counts its number of unique items.
"""

from __future__ import annotations
from typing import Tuple
from collections import Counter

from parlai.core.torch_agent import TorchAgent
from parlai.core.metrics import Metric, SumMetric
from parlai.core.message import Message


class _CounterMetric(Metric):
__slots__ = ('_counter',)

def __init__(self, counter: Counter):
self._counter = counter

def __add__(self, other: Metric):
if other is None:
return self
counter = self._counter + other._counter
return type(self)(counter)


class TimesSeenMetric(_CounterMetric):
"""
Max number of times any example was seen.
"""

def value(self) -> int:
if not self._counter:
return 0
return max(self._counter.values())


class UniqueMetric(_CounterMetric):
"""
Number of unique utterances.
"""

def value(self) -> int:
if not self._counter:
return 0
return len(self._counter)


class CounterAgent(TorchAgent):
"""
Simple agent that counts the number of unique things it has seen.
Could be simpler, but we make it a TorchAgent so it's happy with dynamic batching.
"""

def __init__(self, opt, shared=None):
self.model = self.build_model()
self.criterion = None
super().__init__(opt, shared)
self._counter: Counter[Tuple[str, str]]
if shared is None:
self._counter = Counter()
self._padding_counter = Counter()
else:
self._counter = shared['counter']
self._padding_counter = shared['padding']

def share(self):
shared = super().share()
shared['counter'] = self._counter
shared['padding'] = self._padding_counter
return shared

def save(self, path=None):
pass

def load(self, path=None):
pass

def _val(self, val):
"""
Pull out a singleton value if provided a list.
"""
# necessary for labels
if isinstance(val, (tuple, list)):
return val[0]
else:
return val

def build_model(self):
return None

def train_step(self):
pass

def eval_step(self):
pass

def _to_tuple(self, msg: Message) -> Tuple:
# turned into an indexable object
keys = ['text', 'labels', 'eval_labels']
return tuple(self._val(msg.get(k)) for k in keys)

def batch_act(self, observations):
self._padding_counter.update(['val' for o in observations if o.is_padding()])
self._counter.update(
[self._to_tuple(o) for o in observations if not o.is_padding()]
)
return [Message() for o in observations]

def reset(self):
self._counter.clear()

def report(self):
report = {}
report['num_pad'] = SumMetric(self._padding_counter.get('val', 0))
report['unique'] = UniqueMetric(self._counter)
report['times_seen'] = TimesSeenMetric(self._counter)
return report
7 changes: 7 additions & 0 deletions parlai/core/teachers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2514,6 +2514,13 @@ def reset(self):
) # reset the count of samples loaded
self._enqueue_request()

def shutdown(self):
# Time to wrap up. We should rush out to the worker and tell them
# that they're "done" processing data.
# same signal as end of epoch
self.chunks.put((None, None))
self.chunks.put((None, None))


def _add_task_flags_to_agent_opt(agent, opt: Opt, flags):
"""
Expand Down
8 changes: 8 additions & 0 deletions parlai/core/worlds.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,14 @@ def __init__(self, opt: Opt, world: Union[DialogPartnerWorld, MultiWorld]):

self.reset()

def shutdown(self):
"""
Shutdown each world.
"""
for w in self.worlds:
w.shutdown()
self.world.shutdown()

def reset(self):
super().reset()
self._task_acts = [None for _ in range(self._BUFFER_SIZE)]
Expand Down
14 changes: 7 additions & 7 deletions parlai/tasks/integration_tests/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,11 @@ def create_message(self, sample_item, entry_idx=0):
return {'text': text, 'labels': [label], 'episode_done': True}


class ChunkySmallBufferTeacher(ChunkyTeacher):
def get_buffersize(self):
return NUM_TEST // 2


class InfiniteTrainTeacher(FixedDialogTeacher):
"""
Teacher with an effectively infinite number of training examples.
Expand All @@ -510,19 +515,14 @@ def get(self, episode_idx=0, entry_idx=0):
return Message({'text': '1 2 3 4', field: ['1 2 3 4'], 'episode_done': True})


class ChunkyUniqueSlowTeacher(ChunkyTeacher):
class ChunkySlowTeacher(ChunkyTeacher):
"""
Unique examples that load slowly.
"""

def load_from_chunk(self, chunk_idx: int):
output = []
for i in range(10):
text = str(i + chunk_idx * 10)
resp = str(i + chunk_idx * 10)
output.append((text, resp))
time.sleep(0.1)
return output
return super().load_from_chunk(chunk_idx)


class ShortFixedTeacher(FixedDialogCandidateTeacher):
Expand Down
195 changes: 195 additions & 0 deletions tests/test_chunkteacher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Test correctness of ChunkTeacher in a large number of settings.
"""

from unittest import TestCase
import os
import parlai.utils.testing as testing_utils
from parlai.tasks.integration_tests.agents import NUM_TEST
import torch.distributed as dist
import parlai.scripts.multiprocessing_train as mp_train

BASE_ARGS = {
'model': 'test_agents/counter',
'dict_file': 'zoo:unittest/transformer_generator2/model.dict',
'dict_tokenizer': 'space',
'truncate': 8,
'num_epochs': 2,
'datatype': 'train:stream',
}


class TestNumExamples(TestCase):
def _run(self, **kwargs):
opt = {**BASE_ARGS, **kwargs}
valid_report, test_report = testing_utils.train_model(opt)
assert valid_report['unique'] == NUM_TEST
assert valid_report['times_seen'] == 1
assert test_report['unique'] == NUM_TEST
assert test_report['times_seen'] == 1
return valid_report, test_report

def _run_mp(self, **kwargs):
opt = {**BASE_ARGS, **kwargs}
with testing_utils.tempdir() as tmpdir:
if 'model_file' not in opt:
opt['model_file'] = os.path.join(tmpdir, 'model')

valid_report, test_report = mp_train.MultiProcessTrain.main(**opt)
dist.destroy_process_group()
assert valid_report['unique'] == NUM_TEST
assert valid_report['times_seen'] == 1
assert test_report['unique'] == NUM_TEST
assert test_report['times_seen'] == 1
return valid_report, test_report

# Regular chunk teacher

def test_normal_bs1(self):
self._run(task='integration_tests:chunky')

def test_normal_bs2(self):
self._run(task='integration_tests:chunky', batchsize=2)

def test_normal_bs3(self):
self._run(task='integration_tests:chunky', batchsize=3)

def test_normal_dynb(self):
self._run(task='integration_tests:chunky', batchsize=2, dynamic_batching='full')

def test_normal_batchsort(self):
self._run(
task='integration_tests:chunky', batchsize=2, dynamic_batching='batchsort'
)

@testing_utils.skipUnlessGPU
def test_mp_normal_bs1(self):
self._run_mp(task='integration_tests:chunky')

@testing_utils.skipUnlessGPU
def test_mp_normal_bs2(self):
self._run_mp(task='integration_tests:chunky', batchsize=2)

@testing_utils.skipUnlessGPU
def test_mp_normal_bs3(self):
self._run_mp(task='integration_tests:chunky', batchsize=3)

@testing_utils.skipUnlessGPU
def test_mp_normal_dynb(self):
self._run_mp(
task='integration_tests:chunky', batchsize=2, dynamic_batching='full'
)

@testing_utils.skipUnlessGPU
def test_mp_normal_batchsort(self):
self._run_mp(
task='integration_tests:chunky', batchsize=2, dynamic_batching='batchsort'
)

# Small buffer

def test_small_buffer_bs1(self):
self._run(task='integration_tests:chunky_small_buffer')

def test_small_buffer_bs2(self):
self._run(task='integration_tests:chunky_small_buffer', batchsize=2)

def test_small_buffer_bs3(self):
self._run(task='integration_tests:chunky_small_buffer', batchsize=3)

def test_small_buffer_dynb(self):
self._run(
task='integration_tests:chunky_small_buffer',
batchsize=2,
dynamic_batching='full',
)

def test_small_buffer_batchsort(self):
self._run(
task='integration_tests:chunky_small_buffer',
batchsize=2,
dynamic_batching='batchsort',
)

@testing_utils.skipUnlessGPU
def test_mp_small_buffer_bs1(self):
self._run_mp(task='integration_tests:chunky_small_buffer')

@testing_utils.skipUnlessGPU
def test_mp_small_buffer_bs2(self):
self._run_mp(task='integration_tests:chunky_small_buffer', batchsize=2)

@testing_utils.skipUnlessGPU
def test_mp_small_buffer_bs3(self):
self._run_mp(task='integration_tests:chunky_small_buffer', batchsize=3)

@testing_utils.skipUnlessGPU
def test_mp_small_buffer_dynb(self):
self._run_mp(
task='integration_tests:chunky_small_buffer',
batchsize=2,
dynamic_batching='full',
)

@testing_utils.skipUnlessGPU
def test_mp_small_buffer_batchsort(self):
self._run_mp(
task='integration_tests:chunky_small_buffer',
batchsize=2,
dynamic_batching='batchsort',
)

# Slow chunk

def test_slow_bs1(self):
self._run(task='integration_tests:chunky_slow')

def test_slow_bs2(self):
self._run(task='integration_tests:chunky_slow', batchsize=2)

def test_slow_bs3(self):
self._run(task='integration_tests:chunky_slow', batchsize=3)

def test_slow_dynb(self):
self._run(
task='integration_tests:chunky_slow', batchsize=2, dynamic_batching='full'
)

def test_slow_batchsort(self):
self._run(
task='integration_tests:chunky_slow',
batchsize=2,
dynamic_batching='batchsort',
)

@testing_utils.skipUnlessGPU
def test_mp_slow_bs1(self):
self._run_mp(task='integration_tests:chunky_slow')

@testing_utils.skipUnlessGPU
def test_mp_slow_bs2(self):
self._run_mp(task='integration_tests:chunky_slow', batchsize=2)

@testing_utils.skipUnlessGPU
def test_mp_slow_bs3(self):
self._run_mp(task='integration_tests:chunky_slow', batchsize=3)

@testing_utils.skipUnlessGPU
def test_mp_slow_dynb(self):
self._run_mp(
task='integration_tests:chunky_slow', batchsize=2, dynamic_batching='full'
)

@testing_utils.skipUnlessGPU
def test_mp_slow_batchsort(self):
self._run_mp(
task='integration_tests:chunky_slow',
batchsize=2,
dynamic_batching='batchsort',
)
Loading

0 comments on commit 7da15f7

Please sign in to comment.