Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
add 2 flags
Browse files Browse the repository at this point in the history
  • Loading branch information
Jing Xu committed Mar 12, 2021
1 parent c7f1d67 commit 5787711
Showing 1 changed file with 97 additions and 4 deletions.
101 changes: 97 additions & 4 deletions parlai/tasks/convai2/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
from .build import build
from parlai.utils.strings import normalize_reply
import parlai.utils.logging as logging
from parlai.core.params import ParlaiParser
from typing import Optional
from parlai.core.opt import Opt

import copy
import os
Expand Down Expand Up @@ -89,40 +92,130 @@ def __init__(self, opt, shared=None):


class NormalizedTeacherTrait(object):
@classmethod
def add_cmdline_args(
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
) -> ParlaiParser:
super().add_cmdline_args(parser, partial_opt)
agent = parser.add_argument_group('NormalizedBothTeacher Caption arguments')
agent.add_argument(
'--your-persona-first',
type='bool',
default=True,
help="whether to prepend your persona followed by their persona, True by default to be consistent with the BothTeach",
)
agent.add_argument(
'--max-num-turns',
type=int,
default=-1,
help="first X turns per episode to show. If -1 then the whole episode is shown",
)
return agent

def __init__(self, opt, shared=None):
self.max_num_turns = opt["max_num_turns"]
self.your_persona_first = opt["your_persona_first"]
super().__init__(opt, shared)

def normalize_replies(self, x):
xs = x.split('\n')
xs2 = []
your_personas = []
partner_personas = []
non_personas = []
for x in xs:
if x.startswith('your persona: '):
# Normalize the sentence appearing after 'your persona:'
x = x[len('your persona: ') :]
x = normalize_reply(x)
x = 'your persona: ' + x
your_personas.append(x)
elif x.startswith("partner's persona: "):
x = x[len("partner's persona: ") :]
x = normalize_reply(x)
x = "partner's persona: " + x
partner_personas.append(x)
else:
x = normalize_reply(x)

xs2.append(x)
non_personas.append(x)
xs2 = []
if self.your_persona_first:
xs2.extend(your_personas)
xs2.extend(partner_personas)
else:
xs2.extend(partner_personas)
xs2.extend(your_personas)
xs2.extend(non_personas)
return '\n'.join(xs2)

def setup_data(self, path):
logging.info(f"loading normalized fbdialog data: {path}")
exs_counter = 0
for (text, labels, reward, candidates), new_episode in super().setup_data(path):
if new_episode:
exs_counter = 0
if self.max_num_turns > 0 and exs_counter >= self.max_num_turns:
continue
text = self.normalize_replies(text)
labels = [self.normalize_replies(l) for l in labels]
candidates = [self.normalize_replies(c) for c in candidates]
exs_counter += 1
yield (text, labels, reward, candidates), new_episode


class NormalizedTheirTeacher(NormalizedTeacherTrait, BothTeacher):
def normalize_replies(self, x):
xs = x.split('\n')
xs2 = []
for x in xs:
if x.startswith('your persona: '):
continue
elif x.startswith("partner's persona: "):
x = x[len("partner's persona: ") :]
x = normalize_reply(x)
x = "partner's persona: " + x
else:
x = normalize_reply(x)
xs2.append(x)
return '\n'.join(xs2)


class NormalizedTeacher(NormalizedTeacherTrait, SelfOriginalTeacher):
pass


class NormalizedBothTeacher(NormalizedTeacherTrait, BothTeacher):
pass
def __init__(self, opt, shared=None):
self.your_persona_first = opt.get("your_persona_first", True)
super().__init__(opt, shared)

def normalize_replies(self, x):
xs = x.split('\n')
xs2 = []
your_personas = []
partner_personas = []
others = []
for x in xs:
if x.startswith('your persona: '):
x = x[len("your persona: ") :]
x = normalize_reply(x)
x = "your persona: " + x
your_personas.append(x)
elif x.startswith("partner's persona: "):
x = x[len("partner's persona: ") :]
x = normalize_reply(x)
x = "partner's persona: " + x
partner_personas.append(x)
else:
x = normalize_reply(x)
others.append(x)
if self.your_persona_first:
xs2.extend(your_personas)
xs2.extend(partner_personas)
else:
xs2.extend(partner_personas)
xs2.extend(your_personas)
xs2.extend(others)
return '\n'.join(xs2)


class NormalizedNoneTeacher(NormalizedTeacherTrait, NoneTeacher):
Expand Down

0 comments on commit 5787711

Please sign in to comment.