From fef88fe33c18dd9965e6cf7e29291d6b060b40b2 Mon Sep 17 00:00:00 2001 From: klshuster Date: Mon, 10 May 2021 12:14:54 -0400 Subject: [PATCH 1/2] mtator fix --- parlai/core/teachers.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/parlai/core/teachers.py b/parlai/core/teachers.py index 25a153b9686..b81b1a74249 100644 --- a/parlai/core/teachers.py +++ b/parlai/core/teachers.py @@ -465,8 +465,10 @@ def next_example(self): break buffer_entry_idx += 1 # apply mutators - for mutator in self.mutators: - episode_buffer = mutator(episode_buffer) + if self.mutators: + episode_buffer = [m.copy() for m in episode_buffer] + for mutator in self.mutators: + episode_buffer = mutator(episode_buffer) self.episode_buffer = list(episode_buffer) if not self.episode_buffer: @@ -766,8 +768,10 @@ def next_example(self): self._saw_epoch_done = epoch_done break # perform any mutations there are - for mutator in self.mutators: - episode_buffer = mutator(episode_buffer) + if self.mutators: + episode_buffer = [m.copy() for m in episode_buffer] + for mutator in self.mutators: + episode_buffer = mutator(episode_buffer) # make sure mutations are fully realized (not generators) self.episode_buffer = list(episode_buffer) # The recursive call has dual purpose: From ccdc79a783f5e6cc4a1002f7126a4a7c7e7eb985 Mon Sep 17 00:00:00 2001 From: klshuster Date: Mon, 10 May 2021 12:49:19 -0400 Subject: [PATCH 2/2] add test --- tests/test_mutators.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/test_mutators.py b/tests/test_mutators.py index 1f92b23dc40..c26a2712987 100644 --- a/tests/test_mutators.py +++ b/tests/test_mutators.py @@ -226,3 +226,27 @@ def test_word_shuffle(self): assert set(ex2['text'].split()) == set(EXAMPLE2['text'].split()) assert set(ex3['text'].split()) == set(EXAMPLE3['text'].split()) assert set(ex4['text'].split()) == set(EXAMPLE4['text'].split()) + + +class TestMutatorStickiness(unittest.TestCase): + """ + Test that mutations DO NOT stick with episode. + """ + + def test_not_sticky(self): + pp = ParlaiParser(True, False) + opt = pp.parse_kwargs( + task='integration_tests:multiturn', + mutators='flatten', + datatype='train:ordered', + ) + teacher = create_task_agent_from_taskname(opt)[0] + first_epoch = [] + second_epoch = [] + for _ in range(teacher.num_examples()): + first_epoch.append(teacher.act()) + teacher.reset() + for _ in range(teacher.num_examples()): + second_epoch.append(teacher.act()) + + assert all(f == s for f, s in zip(first_epoch, second_epoch))