Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Self chat from messages in file (#3580)
Browse files Browse the repository at this point in the history
* generate self chat from messages in file

* tests
  • Loading branch information
spencerp authored Apr 14, 2021
1 parent 6216f34 commit cac3675
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 1 deletion.
5 changes: 5 additions & 0 deletions parlai/scripts/self_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ def setup_args(parser=None):
default=False,
help='Automatically seed conversation with messages from task dataset.',
)
parser.add_argument(
'--seed-messages-from-file',
default=None,
help='If specified, loads newline-separated strings from the file as conversation starters.',
)
parser.add_argument(
'--outfile', type=str, default=None, help='File to save self chat logs'
)
Expand Down
9 changes: 9 additions & 0 deletions parlai/tasks/self_chat/worlds.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ def load_openers(opt) -> Optional[List[str]]:
return list(openers)


def load_openers_from_file(filepath: str) -> List[str]:
openers = []
with open(filepath, 'r') as f:
openers = [l.strip() for l in f]
return openers


class SelfChatWorld(DialogPartnerWorld):
def __init__(self, opt, agents, shared=None):
super().__init__(opt, agents, shared)
Expand Down Expand Up @@ -81,6 +88,8 @@ def init_openers(self) -> None:
"""
if self.opt.get('seed_messages_from_task'):
self._openers = load_openers(self.opt)
elif self.opt.get('seed_messages_from_file'):
self._openers = load_openers_from_file(self.opt['seed_messages_from_file'])

def get_openers(self, episode_num: int) -> Optional[List[str]]:
"""
Expand Down
13 changes: 12 additions & 1 deletion tests/tasks/self_chat/test_worlds.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@
from parlai.agents.repeat_label.repeat_label import RepeatLabelAgent
from parlai.core.worlds import create_task
from parlai.scripts.display_data import setup_args
from parlai.tasks.self_chat.worlds import SelfChatWorld as SelfChatBaseWorld
from parlai.tasks.self_chat.worlds import (
load_openers_from_file,
SelfChatWorld as SelfChatBaseWorld,
)

from tempfile import NamedTemporaryFile
import unittest
from unittest.mock import MagicMock

Expand Down Expand Up @@ -61,6 +65,13 @@ def assert_contexts_match(contexts):
assert_contexts_match(['you are a seal', 'you are an ostrich'])
assert_contexts_match([])

def test_load_openers_from_file(self):
with NamedTemporaryFile() as tmpfile:
tmpfile.write(b'hey\nhowdy')
tmpfile.seek(0)
op = load_openers_from_file(tmpfile.name)
self.assertListEqual(op, ['hey', 'howdy'])


if __name__ == '__main__':
unittest.main()
19 changes: 19 additions & 0 deletions tests/test_self_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from tempfile import NamedTemporaryFile
import unittest

from parlai.scripts.self_chat import SelfChat
import parlai.utils.testing as testing_utils


class TestSelfChat(unittest.TestCase):
Expand All @@ -30,3 +33,19 @@ def test_no_plain_teacher(self):

with self.assertRaises(RuntimeError):
DisplayData.main(task='self_chat')

def test_seed_messages_from_file(self):
with testing_utils.capture_output() as output:
with NamedTemporaryFile() as tmpfile:
tmpfile.write(b'howdy\nunique message')
tmpfile.seek(0)
SelfChat.main(
model='fixed_response',
fixed_response='hi',
seed_messages_from_file=tmpfile.name,
num_self_chats=10,
selfchat_max_turns=2,
)
output = output.getvalue()
assert 'howdy' in output
assert 'unique message' in output

0 comments on commit cac3675

Please sign in to comment.