Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Implementing respond() and batch_respond() -- agent convenience funct… (
Browse files Browse the repository at this point in the history
#3775)

* Implementing respond() and batch_respond() -- agent convenience functions.

* Added self_observe() to Agent class and fixed lint issues.

* Remove self_observe() from Agent base class.

* Implement batch_act() for Repeat Query agent. (#3776)
  • Loading branch information
kauterry authored Jul 14, 2021
1 parent f0a37fe commit ef1511f
Show file tree
Hide file tree
Showing 6 changed files with 311 additions and 4 deletions.
16 changes: 14 additions & 2 deletions docs/source/tutorial_tipsntricks.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,18 @@ parlai display_data --task light_dialog:light_label_type=speech,light_dialog:lig
That is, by adding a colon ":" followed by the flag name, an equals
sign, and the value. You can add multiple flags, all separated by ":".

Agent Convenience Functions
----------
Tip: Having implemented `batch_act()` and `act()`, you can make use of the agent convenience functions `batch_respond()` and `respond()` which provide the agent's response to messages by internally calling `batch_act()` and `act()` respectively. The function signatures are as follows:

```python
def respond(self, text_or_message: Union[str, Message], **other_message_fields) -> str:
pass

def batch_respond(self, messages: List[Message]) -> List[str]:
pass
```

Self-Chats
----------

Expand All @@ -94,7 +106,7 @@ This will generate 10 self-chats between 2 poly-encoder models with persona cont
Flags to generate and store the self-chat:

- `--num-self-chats` specify the number of self-chats to generate (1 by default).
- `--selfchat-max-turns` specify the number of self-chat turns (6 by default), including context turn, seeded-utterance turns. Some self-chat world includes context information (such as persona; Wizard of Wikipedia(WoW) topics) in addition to the model utterances.
- `--selfchat-max-turns` specify the number of self-chat turns (6 by default), including context turn, seeded-utterance turns. Some self-chat world includes context information (such as persona; Wizard of Wikipedia(WoW) topics) in addition to the model utterances.
- `--selfchat-task` specify whether to create a self-chat version of the task. If True (by default), it allows for loading contexts and openers that seed the self-chat.
- `--outfile` specify file to save self-chat logs.
- `--save-format` specify the format to save self-chat logs in. Use `conversations` for jsonl format, or `parlai` for text format (`conversations` by default).
Expand Down Expand Up @@ -149,7 +161,7 @@ This handy script can prettify the display of json file of chats

```bash
# Display conversation in HTML format.
python parlai/scripts/convo_render.py -i projects/wizard_of_wikipedia/chat_example1.jsonl -o /tmp/chat.html
python parlai/scripts/convo_render.py -i projects/wizard_of_wikipedia/chat_example1.jsonl -o /tmp/chat.html
```

Some additional flags that can be used for convo-render:
Expand Down
15 changes: 13 additions & 2 deletions docs/source/tutorial_worlds.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ We continue with the implementation of parley:
# produce a model response
model_act = self.model_copies[i].act()
# compute any metrics of the response
self.teacher_copies[i].observe(model_act)
self.teacher_copies[i].observe(model_act)
```

<center>
Expand Down Expand Up @@ -263,6 +263,18 @@ Tip: if you implement `batch_act()`, your `act()` method can just call
list of length 1.
:::

:::{tip} Agent Convenience Functions
Tip: Having implemented `batch_act()` and `act()`, you can make use of the agent convenience functions `batch_respond()` and `respond()` which provide the agent's response to messages by internally calling `batch_act()` and `act()` respectively. The function signatures are as follows:

```python
def respond(self, text_or_message: Union[str, Message], **other_message_fields) -> str:
pass

def batch_respond(self, messages: List[Message]) -> List[str]:
pass
```
:::

## Dynamic Batching

:::{note}
Expand Down Expand Up @@ -372,4 +384,3 @@ from 4 to only 2! This is the trick of how dynamic batching can provide
:::{tip}
You can use this mode with `-dynb full` or `--dynamic-batching full`.
:::

9 changes: 9 additions & 0 deletions parlai/agents/repeat_query/repeat_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,12 @@ def act(self):
reply['text'] = 'Nothing to repeat yet.'
reply['episode_done'] = False
return Message(reply)

def batch_act(self, observations):
batch_reply = []
original_obs = self.observation
for obs in observations:
self.observation = obs
batch_reply.append(self.act())
self.observation = original_obs
return batch_reply
66 changes: 66 additions & 0 deletions parlai/core/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,12 @@
"""

import copy
from typing import List, Union

from parlai.core.build_data import modelzoo_path
from parlai.core.loader import load_agent_module
from parlai.core.loader import register_agent # noqa: F401
from parlai.core.message import Message
from parlai.core.opt import Opt
from parlai.utils.misc import warn_once
import parlai.utils.logging as logging
Expand Down Expand Up @@ -154,6 +156,70 @@ def shutdown(self):
"""
pass

def respond(
self, text_or_message: Union[str, Message], **other_message_fields
) -> str:
"""
An agent convenience function which calls the act() and provides a string
response to a text or message field.
:param Union[str, Message] text_or_message:
A string for the 'text' field or a message which MUST
comprise of the 'text' field apart from other fields.
:param kwargs other_message_fields:
Provide fields for the message in the form of keyowrd arguments.
:return:
Agent's response to the message.
:rtype:
str
"""
if isinstance(text_or_message, str):
observation = Message(text=text_or_message, **other_message_fields)
else:
observation = Message(**text_or_message, **other_message_fields)
if 'text' not in observation:
raise RuntimeError('The agent needs a \'text\' field in the message.')

if 'episode_done' not in observation:
observation['episode_done'] = True
agent = self.clone()
agent.observe(observation)
response = agent.act()
return response['text']

def batch_respond(self, messages: List[Message]) -> List[str]:
"""
An agent convenience function which calls the batch_act() and provides a batch
response to a list of messages.
:param List[Message] messages:
A list of messages each of which MUST comprise of the 'text' field
apart from other fields.
:return:
Agent's batch response to the messages.
:rtype:
List[str]
"""
observations = []
agents = []
for i, message in enumerate(messages):
if 'text' not in message:
raise RuntimeError(
'The agent needs a \'text\' field in the {}th message.'.format(i)
)
if 'episode_done' not in message:
message['episode_done'] = True
agent = self.clone()
agents.append(agent)
observations.append(agent.observe(message))
agent_acts = self.batch_act(observations)
response = []
for agent, resp in zip(agents, agent_acts):
if hasattr(agent, "self_observe"):
agent.self_observe(resp)
response.append(resp['text'])
return response

@classmethod
def upgrade_opt(cls, opt_from_disk: Opt):
"""
Expand Down
123 changes: 123 additions & 0 deletions tests/test_repeat_query_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Unit tests for RepeatQueryAgent.
"""

import unittest
from parlai.core.agents import create_agent
from parlai.core.message import Message


class TestRepeatQueryAgent(unittest.TestCase):
def test_respond(self):
"""
Tests respond() where the agent provides a string response to a single message.
"""
agent = create_agent(dict(model='repeat_query'))
message = Message(
{
'text': 'hi!',
'label': ['A'],
'episode_done': False,
'label_candidates': ['A', 'B', 'C'],
}
)
response = agent.respond(message)
self.assertEqual(response, 'hi!')
message = Message({'text': 'hello!', 'episode_done': False})
response = agent.respond(message, label=['A'])
self.assertEqual(response, 'hello!')
response = agent.respond(Message(text='no way!'), label=['A'])
self.assertEqual(response, 'no way!')
response = agent.respond('what\'s up?', episode_done=True)
self.assertEqual(response, 'what\'s up?')
response = agent.respond('hey there!')
self.assertEqual(response, 'hey there!')
response = agent.respond('')
self.assertEqual(response, 'Nothing to repeat yet.')
response = agent.respond(Message(episode_done=True), text='I feel infinite.')
self.assertEqual(response, 'I feel infinite.')

def test_respond_error(self):
"""
Tests respond() when it errors out.
"""
agent = create_agent(dict(model='repeat_query'))
error_message = 'The agent needs a \'text\' field in the message.'
with self.assertRaises(Exception) as context:
agent.respond(Message(episode_done=True))
self.assertEqual(str(context.exception), error_message)
with self.assertRaises(Exception) as context:
agent.respond({})
self.assertEqual(str(context.exception), error_message)
with self.assertRaises(Exception) as context:
agent.respond(Message())
self.assertEqual(str(context.exception), error_message)

def test_batch_respond(self):
"""
Tests batch_respond() of Repeat Query agent.
"""
agent = create_agent(dict(model='repeat_query'))
messages = [
Message({'text': 'hello!', 'episode_done': False}),
Message({'text': 'hi!', 'episode_done': False}),
Message({'text': 'what\'s up?', 'episode_done': False}),
Message({'text': '', 'episode_done': False}),
Message({'text': 'I feel infinite.', 'episode_done': False}),
]
expected_response = [
'hello!',
'hi!',
'what\'s up?',
'Nothing to repeat yet.',
'I feel infinite.',
]
batch_response = agent.batch_respond(messages)
self.assertEqual(batch_response, expected_response)

def test_batch_act(self):
"""
Tests batch_act() of Repeat Query agent.
"""
agent = create_agent(dict(model='repeat_query'))
observations = []
batch_reply = agent.batch_act(observations)
self.assertEqual(len(batch_reply), 0)
observations = [
Message({'text': 'hello!', 'episode_done': False}),
Message({'text': '', 'episode_done': False}),
Message({'episode_done': False}),
Message(),
None,
]
original_obs = "Hey there!"
agent.observe(original_obs)
self.assertEqual(agent.observation, original_obs)
batch_reply = agent.batch_act(observations)
# Make sure original observation doesn't change.
self.assertEqual(agent.observation, original_obs)
self.assertEqual(len(batch_reply[0]), 3)
self.assertEqual(batch_reply[0]['text'], 'hello!')
self.assertEqual(batch_reply[0]['episode_done'], False)
self.assertEqual(batch_reply[0]['id'], 'RepeatQueryAgent')
self.assertEqual(len(batch_reply[1]), 3)
self.assertEqual(batch_reply[1]['text'], 'Nothing to repeat yet.')
self.assertEqual(batch_reply[1]['episode_done'], False)
self.assertEqual(batch_reply[1]['id'], 'RepeatQueryAgent')
self.assertEqual(len(batch_reply[2]), 3)
self.assertEqual(batch_reply[2]['text'], "I don't know")
self.assertEqual(batch_reply[2]['episode_done'], False)
self.assertEqual(batch_reply[2]['id'], 'RepeatQueryAgent')
self.assertEqual(len(batch_reply[3]), 3)
self.assertEqual(batch_reply[3]['text'], "I don't know")
self.assertEqual(batch_reply[3]['episode_done'], False)
self.assertEqual(batch_reply[3]['id'], 'RepeatQueryAgent')
self.assertEqual(len(batch_reply[4]), 2)
self.assertEqual(batch_reply[4]['text'], 'Nothing to repeat yet.')
self.assertEqual(batch_reply[4]['episode_done'], False)
86 changes: 86 additions & 0 deletions tests/test_torch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,92 @@ def test_batch_act(self):
for i in range(len(obs_elabs_vecs)):
self.assertIn('Evaluating {}'.format(i), reply[i]['text'])

def test_respond(self):
"""
Tests respond() in the base Agent class, where the agent provides
a string response to a single message.
"""
agent = get_agent()
message = Message(
{
'text': "It's only a flesh wound.",
'labels': ['Yield!'],
'episode_done': True,
}
)
response = agent.respond(message)
self.assertEqual(response, 'Training 0!')
message = Message(
{
'text': "It's only a flesh wound.",
'eval_labels': ['Yield!'],
'episode_done': True,
}
)
response = agent.respond(message)
self.assertIn('Evaluating 0', response)

def test_batch_respond(self):
"""
Tests batch_respond() in the base Agent class, where the agent provides
a batch response to a batch of messages.
"""
agent = get_agent()

obs_labs = [
Message(
{
'text': "It's only a flesh wound.",
'labels': ['Yield!'],
'episode_done': True,
}
),
Message(
{
'text': 'The needs of the many outweigh...',
'labels': ['The needs of the few.'],
'episode_done': True,
}
),
Message(
{
'text': 'Hello there.',
'labels': ['General Kenobi.'],
'episode_done': True,
}
),
]
response = agent.batch_respond(obs_labs)
for i, resp in enumerate(response):
self.assertEqual(resp, 'Training {}!'.format(i))

obs_elabs = [
Message(
{
'text': "It's only a flesh wound.",
'eval_labels': ['Yield!'],
'episode_done': True,
}
),
Message(
{
'text': 'The needs of the many outweigh...',
'eval_labels': ['The needs of the few.'],
'episode_done': True,
}
),
Message(
{
'text': 'Hello there.',
'eval_labels': ['General Kenobi.'],
'episode_done': True,
}
),
]
response = agent.batch_respond(obs_elabs)
for i, resp in enumerate(response):
self.assertIn('Evaluating {}'.format(i), resp)

def test_interactive_mode(self):
"""
Test if conversation history is destroyed in MTurk mode.
Expand Down

0 comments on commit ef1511f

Please sign in to comment.