Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

use should_shuffle instead of is_training to determine whether to ran… #4425

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions parlai/core/worlds.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from parlai.core.params import ParlaiParser
from parlai.core.teachers import Teacher, create_task_agent_from_taskname
from parlai.utils.data import DatatypeHelper
from parlai.utils.misc import Timer, display_messages
from parlai.utils.misc import Timer, display_messages, warn_once
from parlai.tasks.tasks import ids_to_tasks
from parlai.utils.misc import error_once

Expand Down Expand Up @@ -562,10 +562,17 @@ def __init__(self, opt: Opt, agents=None, shared=None, default_world=None):
self.parleys = -1
# Check to see if we are training
self.is_training = DatatypeHelper.is_training(opt.get('datatype'))
# Check to see if we should shuffle
self.should_shuffle = DatatypeHelper.should_shuffle(opt.get('datatype'))
# Make multi-task task probabilities.
self.cum_task_weights = [1] * len(self.worlds)
self.task_choices = range(len(self.worlds))
weights = self.opt.get('multitask_weights', [1])
# Warn about multi-task weights being ignored if we are in a datatype that doesn't involve shuffling
if weights != [1] and not self.should_shuffle:
warn_once(
f"WARNING: multitask weights are ignored for datatype {opt.get('datatype')} as we iterate through tasks in a round robin"
)
if weights == 'stochastic':
weights = [w.num_episodes() for w in self.worlds]
sum = 0
Expand Down Expand Up @@ -672,7 +679,7 @@ def parley_init(self):
if self.new_world:
self.new_world = False
self.parleys = 0
if self.is_training:
if self.should_shuffle:
# select random world
self.world_idx = random.choices(
self.task_choices, cum_weights=self.cum_task_weights
Expand Down
28 changes: 28 additions & 0 deletions tests/test_multiworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,31 @@ def test_with_stream(self):
exs = report[f'{task}/exs'].value()
assert exs > 0, err
world.reset_metrics()

def test_with_ordered(self):
"""
Test that multi-tasking works deterministically with datatype train:ordered.
"""

opt = ParlaiParser(True, True).parse_kwargs(
task='teacher1,teacher2',
multitask_weights='1,1',
model='fixed_response',
fixed_response='None',
datatype='train:ordered',
batchsize=1,
)
multiworld1 = create_task(opt, create_agent(opt))
multiworld2 = create_task(opt, create_agent(opt))

while not (multiworld1.epoch_done() or multiworld2.epoch_done()):
multiworld1.parley()
acts1 = multiworld1.get_acts()

multiworld2.parley()
acts2 = multiworld2.get_acts()

self.assertEqual(len(acts1), len(acts2))
assert all([act1 == act2 for act1, act2 in zip(acts1, acts2)])

assert multiworld1.epoch_done() and multiworld2.epoch_done()