Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Normalized all convai2 teachers #3509

Merged
merged 3 commits into from
Mar 18, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 83 additions & 5 deletions parlai/tasks/convai2/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from .build import build
from parlai.utils.strings import normalize_reply
import parlai.utils.logging as logging
from parlai.core.params import ParlaiParser
from typing import Optional
from parlai.core.opt import Opt

import copy
import os
Expand Down Expand Up @@ -88,31 +91,106 @@ def __init__(self, opt, shared=None):
super().__init__(opt, shared)


class NormalizedTeacher(SelfOriginalTeacher):
class NormalizedTeacherTrait(object):
@classmethod
def add_cmdline_args(
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
) -> ParlaiParser:
super().add_cmdline_args(parser, partial_opt)
agent = parser.add_argument_group('NormalizedTeacher arguments')
agent.add_argument(
'--your-persona-first',
type='bool',
default=True,
help="whether to prepend your persona followed by partner's persona. True by default to be consistent with the BothTeach",
)
agent.add_argument(
'--max-num-turns',
type=int,
default=-1,
help="first X turns per episode to show. If -1 then the whole episode is shown",
)
return agent

def __init__(self, opt, shared=None):
self.max_num_turns = opt["max_num_turns"]
self.your_persona_first = opt["your_persona_first"]
super().__init__(opt, shared)

def normalize_replies(self, x):
xs = x.split('\n')
xs2 = []
your_personas = []
partner_personas = []
non_personas = []
for x in xs:
if 'your persona:' in x:
if x.startswith('your persona: '):
# Normalize the sentence appearing after 'your persona:'
x = x[len('your persona: ') :]
x = normalize_reply(x)
x = 'your persona: ' + x
your_personas.append(x)
elif x.startswith("partner's persona: "):
x = x[len("partner's persona: ") :]
x = normalize_reply(x)
x = "partner's persona: " + x
partner_personas.append(x)
else:
x = normalize_reply(x)

xs2.append(x)
non_personas.append(x)
xs2 = []
if self.your_persona_first:
xs2.extend(your_personas)
xs2.extend(partner_personas)
else:
xs2.extend(partner_personas)
xs2.extend(your_personas)
xs2.extend(non_personas)
return '\n'.join(xs2)

def setup_data(self, path):
logging.info(f"loading normalized fbdialog data: {path}")
exs_counter = 0
for (text, labels, reward, candidates), new_episode in super().setup_data(path):
if new_episode:
exs_counter = 0
if self.max_num_turns > 0 and exs_counter >= self.max_num_turns:
continue
text = self.normalize_replies(text)
labels = [self.normalize_replies(l) for l in labels]
candidates = [self.normalize_replies(c) for c in candidates]
exs_counter += 1
yield (text, labels, reward, candidates), new_episode


class NormalizedTeacher(NormalizedTeacherTrait, SelfOriginalTeacher):
pass


class NormalizedBothTeacher(NormalizedTeacherTrait, BothTeacher):
pass


class NormalizedTheirTeacher(NormalizedTeacherTrait, BothTeacher):
def normalize_replies(self, x):
xs = x.split('\n')
xs2 = []
for x in xs:
if x.startswith('your persona: '):
continue
elif x.startswith("partner's persona: "):
x = x[len("partner's persona: ") :]
x = normalize_reply(x)
x = "partner's persona: " + x
else:
x = normalize_reply(x)
xs2.append(x)
return '\n'.join(xs2)


class NormalizedNoneTeacher(NormalizedTeacherTrait, NoneTeacher):
pass


class DefaultTeacher(SelfOriginalTeacher):
pass

Expand Down