Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Self chat from messages in file #3580

Merged
merged 2 commits into from
Apr 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions parlai/scripts/self_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ def setup_args(parser=None):
default=False,
help='Automatically seed conversation with messages from task dataset.',
)
parser.add_argument(
'--seed-messages-from-file',
default=None,
help='If specified, loads newline-separated strings from the file as conversation starters.',
)
parser.add_argument(
'--outfile', type=str, default=None, help='File to save self chat logs'
)
Expand Down
9 changes: 9 additions & 0 deletions parlai/tasks/self_chat/worlds.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,13 @@ def load_openers(opt) -> Optional[List[str]]:
return list(openers)


def load_openers_from_file(filepath: str) -> List[str]:
openers = []
with open(filepath, 'r') as f:
openers = [l.strip() for l in f]
return openers


class SelfChatWorld(DialogPartnerWorld):
def __init__(self, opt, agents, shared=None):
super().__init__(opt, agents, shared)
Expand Down Expand Up @@ -81,6 +88,8 @@ def init_openers(self) -> None:
"""
if self.opt.get('seed_messages_from_task'):
self._openers = load_openers(self.opt)
elif self.opt.get('seed_messages_from_file'):
self._openers = load_openers_from_file(self.opt['seed_messages_from_file'])

def get_openers(self, episode_num: int) -> Optional[List[str]]:
"""
Expand Down
13 changes: 12 additions & 1 deletion tests/tasks/self_chat/test_worlds.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@
from parlai.agents.repeat_label.repeat_label import RepeatLabelAgent
from parlai.core.worlds import create_task
from parlai.scripts.display_data import setup_args
from parlai.tasks.self_chat.worlds import SelfChatWorld as SelfChatBaseWorld
from parlai.tasks.self_chat.worlds import (
load_openers_from_file,
SelfChatWorld as SelfChatBaseWorld,
)

from tempfile import NamedTemporaryFile
import unittest
from unittest.mock import MagicMock

Expand Down Expand Up @@ -61,6 +65,13 @@ def assert_contexts_match(contexts):
assert_contexts_match(['you are a seal', 'you are an ostrich'])
assert_contexts_match([])

def test_load_openers_from_file(self):
with NamedTemporaryFile() as tmpfile:
tmpfile.write(b'hey\nhowdy')
tmpfile.seek(0)
op = load_openers_from_file(tmpfile.name)
self.assertListEqual(op, ['hey', 'howdy'])


if __name__ == '__main__':
unittest.main()
19 changes: 19 additions & 0 deletions tests/test_self_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from tempfile import NamedTemporaryFile
import unittest

from parlai.scripts.self_chat import SelfChat
import parlai.utils.testing as testing_utils


class TestSelfChat(unittest.TestCase):
Expand All @@ -30,3 +33,19 @@ def test_no_plain_teacher(self):

with self.assertRaises(RuntimeError):
DisplayData.main(task='self_chat')

def test_seed_messages_from_file(self):
with testing_utils.capture_output() as output:
with NamedTemporaryFile() as tmpfile:
tmpfile.write(b'howdy\nunique message')
tmpfile.seek(0)
SelfChat.main(
model='fixed_response',
fixed_response='hi',
seed_messages_from_file=tmpfile.name,
num_self_chats=10,
selfchat_max_turns=2,
)
output = output.getvalue()
assert 'howdy' in output
assert 'unique message' in output