diff --git a/parlai/core/teachers.py b/parlai/core/teachers.py index 159ac0f729d..26fbfa5f5f8 100644 --- a/parlai/core/teachers.py +++ b/parlai/core/teachers.py @@ -33,7 +33,7 @@ structures for accessing textual dialog data and utilized by ``DialogTeacher`` """ import copy -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, TypeVar from parlai.core.agents import Agent, create_agent_from_shared from parlai.core.image_featurizers import ImageLoader @@ -63,6 +63,9 @@ import argparse +ChunkOutput = TypeVar('ChunkOutput') + + class DataLoader(Thread): """ A worker thread that provides a threadpool for data loading. @@ -288,7 +291,7 @@ def reset(self): self.metrics.clear() self.lastY = None self.last_act = None - self.episode_done = True + self._episode_done = True self.epochDone = False self.data_queue = queue.Queue() @@ -370,7 +373,7 @@ def next_example(self): episode. If that episode is over, gets a new episode index and returns the first example of that episode. """ - if self.episode_done: + if self._episode_done: self.episode_idx = self.next_episode_idx() self.entry_idx = 0 else: @@ -380,11 +383,11 @@ def next_example(self): return {'episode_done': True}, True ex = self.get(self.episode_idx, self.entry_idx) - self.episode_done = ex.get('episode_done', False) + self._episode_done = ex.get('episode_done', False) if ( not self.cycle - and self.episode_done + and self._episode_done and self.episode_idx + self.opt.get("batchsize", 1) >= self.num_episodes() ): epoch_done = True @@ -2129,7 +2132,7 @@ def __init__(self, opt, shared=None): if opt['numthreads'] > 1: raise ValueError('Chunk teacher is not compatible with Hogwild.') - self.set_datasettings(opt['datatype']) + self.set_datasettings(opt) self.dws = int(self.opt.get('distributed_world_size', 1)) self.rank = int(self.opt.get('rank', 0)) @@ -2160,7 +2163,8 @@ def __init__(self, opt, shared=None): # launch queue loader on the main thread self._enqueue_request() - self.episode_done = True + self._episode_done = True + self.last_queue_output = None def _get_data_folder(self): if not self.opt.get('datafile'): @@ -2172,7 +2176,7 @@ def _get_data_folder(self): return self.opt['datafile'] @abstractmethod - def get_num_samples(self, datatype: str) -> Tuple[int, int]: + def get_num_samples(self, opt: Opt) -> Tuple[int, int]: """ [Abstract] Return the number of samples. @@ -2181,7 +2185,7 @@ def get_num_samples(self, datatype: str) -> Tuple[int, int]: pass @abstractmethod - def get_fold_chunks(self, datatype: str) -> List[int]: # type: ignore + def get_fold_chunks(self, opt: Opt) -> List[int]: # type: ignore """ [Abstract] Return a list of chunk IDs (integer). @@ -2198,12 +2202,12 @@ def get_buffersize(self): """ return 100000 - def set_datasettings(self, datatype): + def set_datasettings(self, opt: Opt): self.folder = self._get_data_folder() - self.num_exs, self.num_eps = self.get_num_samples(datatype) - self.fold_chunks = self.get_fold_chunks(datatype) + self.num_exs, self.num_eps = self.get_num_samples(opt) + self.fold_chunks = self.get_fold_chunks(opt) - self.is_train = DatatypeHelper.is_training(datatype) + self.is_train = DatatypeHelper.is_training(opt['datatype']) def share(self): shared = super().share() @@ -2267,7 +2271,7 @@ def _enqueue_chunks(self): self.chunks.put(c) @abstractmethod - def load_from_chunk(self, chunk_idx: int): + def load_from_chunk(self, chunk_idx: int) -> List[ChunkOutput]: """ [Abstract] Given the chunk index, load examples from that chunk. @@ -2277,9 +2281,11 @@ def load_from_chunk(self, chunk_idx: int): pass @abstractmethod - def create_message(self, queue_output) -> 'Message': + def create_message(self, queue_output: ChunkOutput, entry_idx=0) -> 'Message': """ [Abstract] Given the tuple output of the queue, return an act. + + May depend on entry index if queue output is a multi-turn episode. """ pass @@ -2304,12 +2310,21 @@ def get_chunk(self): return output def get(self, episode_idx, entry_idx=0): - queue_output = self.samples.get() - if queue_output is None: - return None + if self._episode_done: + # Get the next episode or example + queue_output = self.samples.get() + if queue_output is None: + return None + + # Update the last queue output in the case + # of multi-turn episodes + self.last_queue_output = queue_output # create a Message object from the queue output - return self.create_message(queue_output) + msg = self.create_message(self.last_queue_output, entry_idx) + self._episode_done = msg['episode_done'] + + return msg def _drain(self, q): while not q.empty(): diff --git a/parlai/tasks/integration_tests/agents.py b/parlai/tasks/integration_tests/agents.py index 6167ceda9bb..e733d378352 100644 --- a/parlai/tasks/integration_tests/agents.py +++ b/parlai/tasks/integration_tests/agents.py @@ -551,7 +551,8 @@ class ChunkyTeacher(ChunkTeacher): def _get_data_folder(self): return None - def get_num_samples(self, datatype: str) -> Tuple[int, int]: + def get_num_samples(self, opt) -> Tuple[int, int]: + datatype = opt['datatype'] if 'train' in datatype: return NUM_TRAIN, NUM_TRAIN elif 'valid' in datatype: @@ -559,7 +560,8 @@ def get_num_samples(self, datatype: str) -> Tuple[int, int]: elif 'test' in datatype: return NUM_TEST, NUM_TEST - def get_fold_chunks(self, datatype: str) -> List[int]: + def get_fold_chunks(self, opt) -> List[int]: + datatype = opt['datatype'] if 'train' in datatype: return list(range(50)) elif 'valid' in datatype: @@ -575,7 +577,7 @@ def load_from_chunk(self, chunk_idx: int): output.append((text, resp)) return output - def create_message(self, sample_item): + def create_message(self, sample_item, entry_idx=0): text, label = sample_item return {'text': text, 'labels': [label], 'episode_done': True} @@ -585,7 +587,8 @@ class InfiniteTrainTeacher(ChunkyTeacher): Chunk teacher with an effectively infinite number of training examples. """ - def get_num_samples(self, datatype: str) -> Tuple[int, int]: + def get_num_samples(self, opt) -> Tuple[int, int]: + datatype = opt['datatype'] if 'train' in datatype: return INFINITE, INFINITE elif 'valid' in datatype: diff --git a/parlai/tasks/wikipedia/agents.py b/parlai/tasks/wikipedia/agents.py index 335a6a719d1..68708188c7f 100644 --- a/parlai/tasks/wikipedia/agents.py +++ b/parlai/tasks/wikipedia/agents.py @@ -13,7 +13,7 @@ To put the article in the labels and the title in the text, specify ':key-value' at the end (for a title/content key-value association) """ -from parlai.core.teachers import DialogTeacher, ChunkTeacher +from parlai.core.teachers import DialogTeacher, ChunkTeacher, ChunkOutput from parlai.core.message import Message from .build import build @@ -99,10 +99,11 @@ def __init__(self, opt, shared=None): def _get_data_folder(self): return os.path.join(self.opt['datapath'], 'wikipedia/full/wiki_full_extracted') - def get_num_samples(self, datatype) -> Tuple[int, int]: + def get_num_samples(self, opt) -> Tuple[int, int]: """ Return the number of samples given the datatype. """ + datatype = opt['datatype'] if 'train' in datatype: return self.TRAINSIZE, self.TRAINSIZE elif 'valid' in datatype: @@ -116,13 +117,14 @@ def _set_chunk_idx_to_file(self): all_subdirs = sorted([x for x in os.listdir(folder) if 'README' not in x]) self.chunk_idx_to_file = {i: x for i, x in enumerate(all_subdirs)} - def get_fold_chunks(self, datatype) -> List[int]: # type: ignore + def get_fold_chunks(self, opt) -> List[int]: # type: ignore """ Return a list of chunk IDs (integer). Given the datatype (train/test/valid), return the list of chunk IDs that correspond to that split. """ + datatype = opt['datatype'] all_chunk_idxs = list(self.chunk_idx_to_file.keys()) if 'train' in datatype: return all_chunk_idxs[:-2] @@ -131,7 +133,7 @@ def get_fold_chunks(self, datatype) -> List[int]: # type: ignore else: return [all_chunk_idxs[-1]] - def load_from_chunk(self, chunk_idx: int) -> List[Tuple[str, str]]: + def load_from_chunk(self, chunk_idx: int): """ Given the chunk index, load examples from that chunk. @@ -151,7 +153,7 @@ def load_from_chunk(self, chunk_idx: int) -> List[Tuple[str, str]]: return output - def create_message(self, queue_output: Tuple[str, ...]) -> 'Message': + def create_message(self, queue_output: ChunkOutput, entry_idx=0) -> 'Message': """ Given the tuple output of the queue, return an act. """