Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[ChunkTeacher] Fix a hang #3549

Merged
merged 6 commits into from
Mar 31, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions parlai/agents/test_agents/counter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Test agent which counts its number of unique items.
"""

from __future__ import annotations
from typing import Tuple
from collections import Counter

from parlai.core.torch_agent import TorchAgent
from parlai.core.metrics import Metric, SumMetric
from parlai.core.message import Message


class _CounterMetric(Metric):
__slots__ = ('_counter',)

def __init__(self, counter: Counter):
self._counter = counter

def __add__(self, other: Metric):
if other is None:
return self
counter = self._counter + other._counter
return type(self)(counter)


class TimesSeenMetric(_CounterMetric):
"""
Max number of times any example was seen.
"""

def value(self) -> int:
if not self._counter:
return 0
return max(self._counter.values())


class UniqueMetric(_CounterMetric):
"""
Number of unique utterances.
"""

def value(self) -> int:
if not self._counter:
return 0
return len(self._counter)


class CounterAgent(TorchAgent):
"""
Simple agent that counts the number of unique things it has seen.

Could be simpler, but we make it a TorchAgent so it's happy with dynamic batching.
"""

def __init__(self, opt, shared=None):
self.model = self.build_model()
self.criterion = None
super().__init__(opt, shared)
self._counter: Counter[Tuple[str, str]]
if shared is None:
self._counter = Counter()
self._padding_counter = Counter()
else:
self._counter = shared['counter']
self._padding_counter = shared['padding']

def share(self):
shared = super().share()
shared['counter'] = self._counter
shared['padding'] = self._padding_counter
return shared

def save(self, path=None):
pass

def load(self, path=None):
pass

def _val(self, val):
if isinstance(val, list):
stephenroller marked this conversation as resolved.
Show resolved Hide resolved
return val[0]
else:
return val

def build_model(self):
return None

def train_step(self):
pass

def eval_step(self):
pass

def _to_tuple(self, msg: Message) -> Tuple:
keys = ['text', 'labels', 'eval_labels']
return tuple(self._val(msg.get(k)) for k in keys)

def batch_act(self, observations):
self._padding_counter.update(['val' for o in observations if o.is_padding()])
self._counter.update(
[self._to_tuple(o) for o in observations if not o.is_padding()]
)
return [Message() for o in observations]

def reset(self):
self._counter.clear()

def report(self):
report = {}
report['num_pad'] = SumMetric(self._padding_counter.get('val', 0))
report['unique'] = UniqueMetric(self._counter)
report['times_seen'] = TimesSeenMetric(self._counter)
return report
5 changes: 5 additions & 0 deletions parlai/core/teachers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2514,6 +2514,11 @@ def reset(self):
) # reset the count of samples loaded
self._enqueue_request()

def shutdown(self):
# self._drain(self.chunks)
stephenroller marked this conversation as resolved.
Show resolved Hide resolved
self.chunks.put((None, None))
self.chunks.put((None, None))


def _add_task_flags_to_agent_opt(agent, opt: Opt, flags):
"""
Expand Down
8 changes: 8 additions & 0 deletions parlai/core/worlds.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,14 @@ def __init__(self, opt: Opt, world: Union[DialogPartnerWorld, MultiWorld]):

self.reset()

def shutdown(self):
"""
Shutdown each world.
"""
for w in self.worlds:
w.shutdown()
self.world.shutdown()

def reset(self):
super().reset()
self._task_acts = [None for _ in range(self._BUFFER_SIZE)]
Expand Down
14 changes: 7 additions & 7 deletions parlai/tasks/integration_tests/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,11 @@ def create_message(self, sample_item, entry_idx=0):
return {'text': text, 'labels': [label], 'episode_done': True}


class ChunkySmallBufferTeacher(ChunkyTeacher):
def get_buffersize(self):
return NUM_TEST // 2


class InfiniteTrainTeacher(FixedDialogTeacher):
"""
Teacher with an effectively infinite number of training examples.
Expand All @@ -510,19 +515,14 @@ def get(self, episode_idx=0, entry_idx=0):
return Message({'text': '1 2 3 4', field: ['1 2 3 4'], 'episode_done': True})


class ChunkyUniqueSlowTeacher(ChunkyTeacher):
class ChunkySlowTeacher(ChunkyTeacher):
"""
Unique examples that load slowly.
"""

def load_from_chunk(self, chunk_idx: int):
output = []
for i in range(10):
text = str(i + chunk_idx * 10)
resp = str(i + chunk_idx * 10)
output.append((text, resp))
time.sleep(0.1)
stephenroller marked this conversation as resolved.
Show resolved Hide resolved
return output
return super().load_from_chunk(chunk_idx)


class ShortFixedTeacher(FixedDialogCandidateTeacher):
Expand Down
195 changes: 195 additions & 0 deletions tests/test_chunkteacher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Test correctness of ChunkTeacher in a large number of settings.
"""

from unittest import TestCase
import os
import parlai.utils.testing as testing_utils
from parlai.tasks.integration_tests.agents import NUM_TEST
import torch.distributed as dist
import parlai.scripts.multiprocessing_train as mp_train

BASE_ARGS = {
'model': 'test_agents/counter',
'dict_file': 'zoo:unittest/transformer_generator2/model.dict',
'dict_tokenizer': 'space',
'truncate': 8,
'num_epochs': 2,
'datatype': 'train:stream',
}


class TestNumExamples(TestCase):
def _run(self, **kwargs):
opt = {**BASE_ARGS, **kwargs}
valid_report, test_report = testing_utils.train_model(opt)
assert valid_report['unique'] == NUM_TEST
assert valid_report['times_seen'] == 1
assert test_report['unique'] == NUM_TEST
assert test_report['times_seen'] == 1
return valid_report, test_report

def _run_mp(self, **kwargs):
opt = {**BASE_ARGS, **kwargs}
with testing_utils.tempdir() as tmpdir:
if 'model_file' not in opt:
opt['model_file'] = os.path.join(tmpdir, 'model')

valid_report, test_report = mp_train.MultiProcessTrain.main(**opt)
dist.destroy_process_group()
assert valid_report['unique'] == NUM_TEST
assert valid_report['times_seen'] == 1
assert test_report['unique'] == NUM_TEST
assert test_report['times_seen'] == 1
return valid_report, test_report

# Regular chunk teacher

def test_normal_bs1(self):
self._run(task='integration_tests:chunky')

def test_normal_bs2(self):
self._run(task='integration_tests:chunky', batchsize=2)

def test_normal_bs3(self):
self._run(task='integration_tests:chunky', batchsize=3)

def test_normal_dynb(self):
self._run(task='integration_tests:chunky', batchsize=2, dynamic_batching='full')

def test_normal_batchsort(self):
self._run(
task='integration_tests:chunky', batchsize=2, dynamic_batching='batchsort'
)

@testing_utils.skipUnlessGPU
def test_mp_normal_bs1(self):
self._run_mp(task='integration_tests:chunky')

@testing_utils.skipUnlessGPU
def test_mp_normal_bs2(self):
self._run_mp(task='integration_tests:chunky', batchsize=2)

@testing_utils.skipUnlessGPU
def test_mp_normal_bs3(self):
self._run_mp(task='integration_tests:chunky', batchsize=3)

@testing_utils.skipUnlessGPU
def test_mp_normal_dynb(self):
self._run_mp(
task='integration_tests:chunky', batchsize=2, dynamic_batching='full'
)

@testing_utils.skipUnlessGPU
def test_mp_normal_batchsort(self):
self._run_mp(
task='integration_tests:chunky', batchsize=2, dynamic_batching='batchsort'
)

# Small buffer

def test_small_buffer_bs1(self):
self._run(task='integration_tests:chunky_small_buffer')

def test_small_buffer_bs2(self):
self._run(task='integration_tests:chunky_small_buffer', batchsize=2)

def test_small_buffer_bs3(self):
self._run(task='integration_tests:chunky_small_buffer', batchsize=3)

def test_small_buffer_dynb(self):
self._run(
task='integration_tests:chunky_small_buffer',
batchsize=2,
dynamic_batching='full',
)

def test_small_buffer_batchsort(self):
self._run(
task='integration_tests:chunky_small_buffer',
batchsize=2,
dynamic_batching='batchsort',
)

@testing_utils.skipUnlessGPU
def test_mp_small_buffer_bs1(self):
self._run_mp(task='integration_tests:chunky_small_buffer')

@testing_utils.skipUnlessGPU
def test_mp_small_buffer_bs2(self):
self._run_mp(task='integration_tests:chunky_small_buffer', batchsize=2)

@testing_utils.skipUnlessGPU
def test_mp_small_buffer_bs3(self):
self._run_mp(task='integration_tests:chunky_small_buffer', batchsize=3)

@testing_utils.skipUnlessGPU
def test_mp_small_buffer_dynb(self):
self._run_mp(
task='integration_tests:chunky_small_buffer',
batchsize=2,
dynamic_batching='full',
)

@testing_utils.skipUnlessGPU
def test_mp_small_buffer_batchsort(self):
self._run_mp(
task='integration_tests:chunky_small_buffer',
batchsize=2,
dynamic_batching='batchsort',
)

# Slow chunk

def test_slow_bs1(self):
self._run(task='integration_tests:chunky_slow')

def test_slow_bs2(self):
self._run(task='integration_tests:chunky_slow', batchsize=2)

def test_slow_bs3(self):
self._run(task='integration_tests:chunky_slow', batchsize=3)

def test_slow_dynb(self):
self._run(
task='integration_tests:chunky_slow', batchsize=2, dynamic_batching='full'
)

def test_slow_batchsort(self):
self._run(
task='integration_tests:chunky_slow',
batchsize=2,
dynamic_batching='batchsort',
)

@testing_utils.skipUnlessGPU
def test_mp_slow_bs1(self):
self._run_mp(task='integration_tests:chunky_slow')

@testing_utils.skipUnlessGPU
def test_mp_slow_bs2(self):
self._run_mp(task='integration_tests:chunky_slow', batchsize=2)

@testing_utils.skipUnlessGPU
def test_mp_slow_bs3(self):
self._run_mp(task='integration_tests:chunky_slow', batchsize=3)

@testing_utils.skipUnlessGPU
def test_mp_slow_dynb(self):
self._run_mp(
task='integration_tests:chunky_slow', batchsize=2, dynamic_batching='full'
)

@testing_utils.skipUnlessGPU
def test_mp_slow_batchsort(self):
self._run_mp(
task='integration_tests:chunky_slow',
batchsize=2,
dynamic_batching='batchsort',
)
Loading