Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[BB3] Option to ignore in session memories #4753

Merged
merged 5 commits into from
Aug 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion parlai/opt_presets/gen/opt_bb3.opt
Original file line number Diff line number Diff line change
Expand Up @@ -192,5 +192,6 @@
"include_prompt": false,
"knowledge_chunk_size": 100,
"max_prompt_len": 1912,
"all_vanilla_prompt": false
"all_vanilla_prompt": false,
"ignore_in_session_memories_mkm": false
}
3 changes: 2 additions & 1 deletion parlai/opt_presets/gen/opt_pt.opt
Original file line number Diff line number Diff line change
Expand Up @@ -192,5 +192,6 @@
"include_prompt": true,
"knowledge_chunk_size": 100,
"max_prompt_len": 1912,
"all_vanilla_prompt": false
"all_vanilla_prompt": false,
"ignore_in_session_memories_mkm": false
}
36 changes: 27 additions & 9 deletions projects/bb3/agents/opt_bb3_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ def add_cmdline_args(
help='Number of times to retry on API request failures (< 0 for unlimited retry).',
)
parser.add_argument('--metaseq-server-timeout', default=20.0, type=float)
parser.add_argument(
'--ignore-in-session-memories-mkm',
type='bool',
default=False,
help='If true, we do not look at the in-session memories when '
'generating from the MKM',
)
return parser

def __init__(self, opt, shared=None):
Expand All @@ -173,6 +180,9 @@ def __init__(self, opt, shared=None):
self.dictionary = top_agent.dictionary
# continue
self.max_prompt_len = opt.get('max_prompt_len', PROMPT.MAX_PROMPT_LEN)
self.ignore_in_session_memories_mkm = opt.get(
'ignore_in_session_memories_mkm', False
)
self.search_agent = SearchAgent(
{
'server': self.opt.get('search_server', 'default'),
Expand Down Expand Up @@ -415,10 +425,10 @@ def batch_act_knowledge(
for module in Module:
obs = all_obs[module]
if module is Module.MEMORY_KNOWLEDGE and i in memory_indices:
memories = MemoryUtils.get_available_memory(
all_obs['raw'], self.dictionary
memories = MemoryUtils.maybe_reduce_memories(
all_obs['raw']['text'], available_memory[i], self.dictionary
)
memories = '\n'.join(available_memory[i])
memories = '\n'.join(memories)
new_prompt = self._check_and_limit_len(
obs['prompt'].replace(module.opt_pre_context_tok(), memories)
)
Expand Down Expand Up @@ -772,7 +782,15 @@ def batch_act(
for _ in range(len(observations))
]
# Step 1: determine whether we're searching or accessing memory
available_memory = [o['raw']['memories'] for o in observations]
all_memory = [o['raw']['memories'] for o in observations]
available_memory = [
MemoryUtils.get_available_memories(
o['raw']['memories'],
o['raw']['in_session_memories'],
self.ignore_in_session_memories_mkm,
)
for o in observations
]

batch_reply_sdm, search_indices = self.batch_act_decision(
observations,
Expand Down Expand Up @@ -866,7 +884,7 @@ def batch_act(
batch_reply_mgm_partner,
batch_reply_knowledge,
batch_reply_dialogue,
available_memory,
all_memory,
)
for i, reply in enumerate(batch_reply_final):
reply.force_set('id', 'BlenderBot3')
Expand Down Expand Up @@ -900,8 +918,8 @@ def self_observe(self, self_message: Message):
memory_candidate,
MemoryUtils.get_memory_prefix(person, self.MODEL_TYPE),
):
self.memories.append(
MemoryUtils.add_memory_prefix(
memory_candidate, person, self.MODEL_TYPE
)
memory_to_add = MemoryUtils.add_memory_prefix(
memory_candidate, person, self.MODEL_TYPE
)
self.memories.append(memory_to_add)
self.in_session_memories.add(memory_to_add)
18 changes: 10 additions & 8 deletions projects/bb3/agents/r2c2_bb3_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ def __init__(self, opt, shared=None):
self.agents[Module.SEARCH_KNOWLEDGE] = agent

self.memories = []
self.in_session_memories = set()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to have this as a set? Why not relying on the memory deduplication algorithm we had before?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the set is just for quick O(1) checking in the get_available_memory function for memory utils

memory in list --> O(n)
memory in set --> O(1)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I see. I thought you had that for deduplication. That makes sense.

self.search_knowledge_responses = ['__SILENCE__']
self.memory_knowledge_responses = ['__SILENCE__']
self.contextual_knowledge_responses = ['__SILENCE__']
Expand Down Expand Up @@ -529,6 +530,7 @@ def reset(self, clones_only: bool = False):
self.contextual_knowledge_responses = ['__SILENCE__']
self.memory_knowledge_responses = ['__SILENCE__']
self.memories = []
self.in_session_memories = set()

def _construct_subagent_opts(self, opt: Opt):
"""
Expand Down Expand Up @@ -657,6 +659,7 @@ def observe(self, observation: Message) -> Dict[Module, Message]:

raw_observation = copy.deepcopy(observation)
raw_observation['memories'] = self.memories
raw_observation['in_session_memories'] = self.in_session_memories
observations['raw'] = raw_observation

if observation.get('episode_done'):
Expand Down Expand Up @@ -1402,15 +1405,14 @@ def self_observe(self, self_message: Message):
),
MemoryUtils.get_memory_prefix(person, self.MODEL_TYPE),
):
self.memories.append(
MemoryUtils.add_memory_prefix(
self_message[
f'{Module.MEMORY_GENERATOR.message_name()}_{person}'
],
person,
self.MODEL_TYPE,
)
memory_to_add = MemoryUtils.add_memory_prefix(
self_message[f'{Module.MEMORY_GENERATOR.message_name()}_{person}'],
person,
self.MODEL_TYPE,
)

self.memories.append(memory_to_add)
self.in_session_memories.add(memory_to_add)
observation = {
'text': clean_text(
self.agents[Module.SEARCH_KNOWLEDGE].history.get_history_str() or ''
Expand Down
39 changes: 31 additions & 8 deletions projects/bb3/agents/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import os
import string
import time
from typing import List, Tuple, Optional, Dict, Any
from typing import List, Tuple, Optional, Dict, Any, Set

from parlai.agents.ir_baseline.ir_baseline import score_match, MaxPriorityQueue
from parlai.core.dict import DictionaryAgent
Expand Down Expand Up @@ -332,8 +332,8 @@ def _build_query_representation(
return rep

@staticmethod
def get_available_memory(
observation: Message, dictionary: DictionaryAgent
def maybe_reduce_memories(
text: str, memories: List[str], dictionary: DictionaryAgent
) -> List[str]:
"""
TFIDF-Match memories with the textual input to reduce num memories.
Expand All @@ -347,17 +347,40 @@ def get_available_memory(
return - potentially shortened - list of memories
"""
new_memories = []
mems = observation['memories']
if not mems or len(mems) < 32: # 512 / 16, assuming 16 tokens max per memory
return mems
if (
not memories or len(memories) < 32
): # 512 / 16, assuming 16 tokens max per memory
return memories
mpq = MaxPriorityQueue(1000)
query = MemoryUtils._build_query_representation(observation['text'], dictionary)
for m in mems:
query = MemoryUtils._build_query_representation(text, dictionary)
for m in memories:
score = score_match(query, m, 0, dictionary)
mpq.add(m, score)
new_memories = list(reversed(mpq))[:32]
return new_memories

@staticmethod
def get_available_memories(
memories: List[str],
in_session_memories: Set[str],
ignore_in_session_memories: bool,
) -> List[str]:
"""
Return available memories.

:param memories:
list of all memories
:param in_session_memories:
set of memories generated within the current conversation session
:param ignore_in_session_memories:
whether to ignore memories generated within the session
"""
return [
m
for m in memories
if m not in in_session_memories or not ignore_in_session_memories
]


#################
# OPT API UTILS #
Expand Down
74 changes: 70 additions & 4 deletions tests/nightly/gpu/test_bb3.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,15 +411,81 @@ def test_memory_tfidf(self):
agent = create_agent(self.opt)
dictionary = agent.dictionary
memories = self.memories * 100
new_memories = MemoryUtils.get_available_memory(
{'text': 'I wish I could see my cats again!', 'memories': memories},
new_memories = MemoryUtils.maybe_reduce_memories(
'I wish I could see my cats again!',
memories,
dictionary,
)
assert "cats" in new_memories[0]
assert len(new_memories) <= 32
new_memories = MemoryUtils.get_available_memory(
{'text': 'I hope the horses are faster today!', 'memories': memories},
new_memories = MemoryUtils.maybe_reduce_memories(
'I hope the horses are faster today!',
memories,
dictionary,
)
assert "horses" in new_memories[0]
assert len(new_memories) <= 32


class TestIgnoreInSessionMemories(TestOptFtBase):
def test_in_session_memories(self):
opt = copy.deepcopy(self.opt)
opt['knowledge_conditioning'] = 'separate'
opt['override']['knowledge_conditioning'] = 'separate'
agent = create_agent(opt)
opt2 = copy.deepcopy(self.opt)
opt2['knowledge_conditioning'] = 'separate'
opt2['override']['knowledge_conditioning'] = 'separate'
opt2['ignore_in_session_memories_mkm'] = True
opt2['override']['ignore_in_session_memories_mkm'] = True
agent2 = create_agent(opt2)

# first, check with normal messages
agent1_acts = []
agent2_acts = []
for _ in range(5):
agent.observe(self.message)
agent2.observe(self.message)
agent1_acts.append(agent.act())
agent2_acts.append(agent2.act())

# ignore first message for agent1 since there aren't any memories
assert all(a[Module.MEMORY_DIALOGUE.message_name()] for a in agent1_acts[1:])
assert all(not a[Module.MEMORY_DIALOGUE.message_name()] for a in agent2_acts)

# Check that in session memories is strict subset of memories
# when using opening message
agent.reset()
original_memories = copy.deepcopy(self.memories)
agent.observe(self.opening_message)
agent.act()
assert all(m in agent.memories for m in agent.in_session_memories)
assert not any(m in agent.in_session_memories for m in agent.memories)

# set ignore in session memories to True; ensure that final returned memories
# still have all the memories, but that we don't use the memory module
agent.in_session_memories = set()
agent.ignore_in_session_memories_mkm = True
message = copy.deepcopy(self.message)
agent.observe(message)
act = agent.act()
assert all(
m in act['memories']
for m in original_memories + list(agent.in_session_memories)
)
assert len(agent.in_session_memories) == (
len(act['memories']) - len(original_memories)
)

def test_memory_utils(self):
new_memories = ['in session memory 1', 'in session memory 2']
memories = self.memories + new_memories
in_session_memories = set(new_memories)
available_memories = MemoryUtils.get_available_memories(
memories, in_session_memories, ignore_in_session_memories=False
)
assert available_memories == memories
available_memories = MemoryUtils.get_available_memories(
memories, in_session_memories, ignore_in_session_memories=True
)
assert available_memories == self.memories