From 06731b8e4075a075600079873a1285aa1db147d3 Mon Sep 17 00:00:00 2001 From: Jing Xu Date: Wed, 23 Feb 2022 11:41:47 -0800 Subject: [PATCH 1/6] reply clean --- projects/blenderbot2/agents/blenderbot2.py | 40 ++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/projects/blenderbot2/agents/blenderbot2.py b/projects/blenderbot2/agents/blenderbot2.py index a9d897f24af..6be6c2da572 100644 --- a/projects/blenderbot2/agents/blenderbot2.py +++ b/projects/blenderbot2/agents/blenderbot2.py @@ -13,6 +13,7 @@ The Memory Decoder examines the context and generates memories to write to the long-term memory module. """ +import re import copy import torch import torch.nn @@ -253,6 +254,12 @@ def add_cmdline_args( help='filter input to the global knowledge retriever such that any utterance containing ' 'the phrase will not be given as input.', ) + bb2_group.add_argument( + '--clean-reply', + type=bool, + default=False, + help='filter reply during self_observe', + ) q_gen_group = parser.add_argument_group('BlenderBot2 Query Generator Args') q_gen_group.add_argument( '--query-generator-ignore-phrase', @@ -898,6 +905,39 @@ def compute_loss( else: return loss + def _clean_text(self, txt): + cleaned_txt = re.sub(r'_[\S]*unsafe_*', '', txt, flags=re.IGNORECASE) + return cleaned_txt.strip() + + def self_observe(self, self_message: Message) -> None: + """ + Observe one's own utterance. + Override TorchAgent.self_observe with the optional cleaned text + + :param self_message: + The message corresponding to the output from batch_act. + """ + + if ( + self.opt('clean_reply', False) + or self.observation['episode_done'] # last example in the episode + or use_reply == 'none' # not including our own responses anyway + or ( + use_reply == 'label' + and any([x in self.observation for l in ['labels', 'eval_labels']]) + ) # has true label + ): + return super().self_observe(self_message) + + # otherwise, we use the CLEANED last output the model generated + if self_message is not None: + last_reply = self_message['text'] + clean_reply = self._clean_text(last_reply) + self.history.add_reply(clean_reply) + return + + raise RuntimeError("Unexpected case in self_observe.") + class BlenderBot2FidAgent(FidAgent, BlenderBot2RagAgent): model: BlenderBot2FidModel From 466490be31a7604ba6bc46e1850b8f67bf6c73d5 Mon Sep 17 00:00:00 2001 From: Jing Xu Date: Wed, 23 Feb 2022 11:50:29 -0800 Subject: [PATCH 2/6] comments and typos --- projects/blenderbot2/agents/blenderbot2.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/projects/blenderbot2/agents/blenderbot2.py b/projects/blenderbot2/agents/blenderbot2.py index 6be6c2da572..eb9043d0685 100644 --- a/projects/blenderbot2/agents/blenderbot2.py +++ b/projects/blenderbot2/agents/blenderbot2.py @@ -254,12 +254,6 @@ def add_cmdline_args( help='filter input to the global knowledge retriever such that any utterance containing ' 'the phrase will not be given as input.', ) - bb2_group.add_argument( - '--clean-reply', - type=bool, - default=False, - help='filter reply during self_observe', - ) q_gen_group = parser.add_argument_group('BlenderBot2 Query Generator Args') q_gen_group.add_argument( '--query-generator-ignore-phrase', @@ -917,14 +911,14 @@ def self_observe(self, self_message: Message) -> None: :param self_message: The message corresponding to the output from batch_act. """ + use_reply = self.opt.get('use_reply', 'label') if ( - self.opt('clean_reply', False) - or self.observation['episode_done'] # last example in the episode + self.observation['episode_done'] # last example in the episode or use_reply == 'none' # not including our own responses anyway or ( use_reply == 'label' - and any([x in self.observation for l in ['labels', 'eval_labels']]) + and any([l in self.observation for l in ['labels', 'eval_labels']]) ) # has true label ): return super().self_observe(self_message) From 70349c314cbeef1449d73c8f739f111ebfd5b7a1 Mon Sep 17 00:00:00 2001 From: Jing Xu Date: Wed, 23 Feb 2022 12:43:40 -0800 Subject: [PATCH 3/6] black --- projects/blenderbot2/agents/blenderbot2.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/projects/blenderbot2/agents/blenderbot2.py b/projects/blenderbot2/agents/blenderbot2.py index eb9043d0685..a72b331c353 100644 --- a/projects/blenderbot2/agents/blenderbot2.py +++ b/projects/blenderbot2/agents/blenderbot2.py @@ -905,8 +905,7 @@ def _clean_text(self, txt): def self_observe(self, self_message: Message) -> None: """ - Observe one's own utterance. - Override TorchAgent.self_observe with the optional cleaned text + Observe one's own utterance. Override TorchAgent.self_observe with cleaned text. :param self_message: The message corresponding to the output from batch_act. From 1d665cec03a24347e74bfd8d807e76b46e7618d5 Mon Sep 17 00:00:00 2001 From: Jing Xu Date: Tue, 22 Mar 2022 11:45:19 -0700 Subject: [PATCH 4/6] bb2 agent history --- projects/blenderbot2/agents/blenderbot2.py | 92 ++++++++++++++-------- tests/nightly/gpu/test_bb2.py | 34 ++++++++ 2 files changed, 93 insertions(+), 33 deletions(-) diff --git a/projects/blenderbot2/agents/blenderbot2.py b/projects/blenderbot2/agents/blenderbot2.py index a72b331c353..74b805ebf7c 100644 --- a/projects/blenderbot2/agents/blenderbot2.py +++ b/projects/blenderbot2/agents/blenderbot2.py @@ -13,6 +13,7 @@ The Memory Decoder examines the context and generates memories to write to the long-term memory module. """ +from abc import abstractmethod import re import copy import torch @@ -33,7 +34,7 @@ from parlai.core.metrics import AverageMetric from parlai.core.opt import Opt from parlai.core.params import ParlaiParser -from parlai.core.torch_agent import Batch +from parlai.core.torch_agent import Batch, History from parlai.tasks.wizard_of_internet.constants import ( SELECTED_DOCS, SELECTED_DOCS_TITLES, @@ -58,6 +59,53 @@ ZOO_MEMORY_DECODER = 'zoo:blenderbot2/memory_decoder/model' +class HistoryCleanReply(History): + def __init__( + self, + opt, + field='text', + maxlen=None, + size=-1, + p1_token='__p1__', + p2_token='__p2__', + dict_agent=None, + ): + super().__init__( + opt, + field=field, + maxlen=maxlen, + size=size, + p1_token=p1_token, + p2_token=p2_token, + dict_agent=dict_agent, + ) + self.add_cleaned_reply_to_history = opt.get( + 'add_cleaned_reply_to_history', True + ) + + @abstractmethod + def _clean_text(self, txt): + """ + Clean text to be override with custom logic + """ + + def add_reply(self, text): + clean_text = text + if self.add_cleaned_reply_to_history: + clean_text = self._clean_text(text) + super().add_reply(clean_text) + + +class HistoryCleanUnsafeToken(HistoryCleanReply): + """ + Override the history _clean_text to filter out special tokens like _potentially_unsafe + """ + + def _clean_text(self, txt): + cleaned_txt = re.sub(r'_[\S]*unsafe_*', '', txt, flags=re.IGNORECASE) + return cleaned_txt.strip() + + class BlenderBot2ModelTypeMixin(RagModelInterface): """ Override Normal RAG Model Types, in case we retrieve from both memory and search. @@ -340,6 +388,12 @@ def add_cmdline_args( hidden=True, help='model file for memory writer', ) + bb2_group.add_argument( + '--add-cleaned-reply-to-history', + type=bool, + default=False, + help='whether to add the cleaned bb2 generated text without any special tokens to its history', + ) memory_decoder = parser.add_argument_group('BlenderBot2 Memory Decoder Args') memory_decoder.add_argument( '--memory-decoder-key', @@ -392,6 +446,10 @@ def add_cmdline_args( ) return parser + @classmethod + def history_class(cls): + return HistoryCleanUnsafeToken + @property def rag_model_type(self) -> str: return self._rag_model_type @@ -899,38 +957,6 @@ def compute_loss( else: return loss - def _clean_text(self, txt): - cleaned_txt = re.sub(r'_[\S]*unsafe_*', '', txt, flags=re.IGNORECASE) - return cleaned_txt.strip() - - def self_observe(self, self_message: Message) -> None: - """ - Observe one's own utterance. Override TorchAgent.self_observe with cleaned text. - - :param self_message: - The message corresponding to the output from batch_act. - """ - use_reply = self.opt.get('use_reply', 'label') - - if ( - self.observation['episode_done'] # last example in the episode - or use_reply == 'none' # not including our own responses anyway - or ( - use_reply == 'label' - and any([l in self.observation for l in ['labels', 'eval_labels']]) - ) # has true label - ): - return super().self_observe(self_message) - - # otherwise, we use the CLEANED last output the model generated - if self_message is not None: - last_reply = self_message['text'] - clean_reply = self._clean_text(last_reply) - self.history.add_reply(clean_reply) - return - - raise RuntimeError("Unexpected case in self_observe.") - class BlenderBot2FidAgent(FidAgent, BlenderBot2RagAgent): model: BlenderBot2FidModel diff --git a/tests/nightly/gpu/test_bb2.py b/tests/nightly/gpu/test_bb2.py index de4b0119e20..25293448d4b 100644 --- a/tests/nightly/gpu/test_bb2.py +++ b/tests/nightly/gpu/test_bb2.py @@ -6,6 +6,7 @@ import copy import torch.cuda import unittest +from parlai.core.agents import create_agent import parlai.utils.testing as testing_utils @@ -203,6 +204,39 @@ def test_rag(self): ) +@testing_utils.skipUnlessGPU +@unittest.skipIf(LOCAL, "Skipping Test because its slow and mem intensive") +class TestBB2CleanText(unittest.TestCase): + SPEICIAL_TOKEN = '_POTENTIALLY_UNSAFE__' + + def test_bb2_history(self): + """ + Test out-of-the-box BB2 generation. + """ + opt = copy.deepcopy(common_opt) + opt.update( + { + 'model_file': ZOO_BB2, + 'override': { + 'search_server': SEARCH_SERVER, + 'add_cleaned_reply_to_history': True, + }, + } + ) + bb2 = create_agent(opt) + + text_with_safety_token = f"Don't have a cow, Man! {self.SPEICIAL_TOKEN}" + obs = {'text': text_with_safety_token} + bb2.observe(obs) + assert self.SPEICIAL_TOKEN in bb2.history.get_history_str() + + bb2.history.reset() + obs = {'text': "I am Groot"} + bb2.observe(obs) + bb2.history.add_reply(text_with_safety_token) + assert self.SPEICIAL_TOKEN not in bb2.history.get_history_str() + + @testing_utils.skipUnlessGPU @unittest.skipIf(LOCAL, "Skipping Test because its slow and mem intensive") class TestBB2AdditionalTruncation(unittest.TestCase): From 2a7a502879d93052d3779bcaba19f62bd02d20a9 Mon Sep 17 00:00:00 2001 From: Jing Xu Date: Tue, 22 Mar 2022 11:47:17 -0700 Subject: [PATCH 5/6] clean --- projects/blenderbot2/agents/blenderbot2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/blenderbot2/agents/blenderbot2.py b/projects/blenderbot2/agents/blenderbot2.py index 74b805ebf7c..fe8d2b02ec4 100644 --- a/projects/blenderbot2/agents/blenderbot2.py +++ b/projects/blenderbot2/agents/blenderbot2.py @@ -80,7 +80,7 @@ def __init__( dict_agent=dict_agent, ) self.add_cleaned_reply_to_history = opt.get( - 'add_cleaned_reply_to_history', True + 'add_cleaned_reply_to_history', False ) @abstractmethod From d9840133a6b8ae22140a99b4eac8a5a2250ad1ad Mon Sep 17 00:00:00 2001 From: Jing Xu Date: Tue, 22 Mar 2022 12:10:45 -0700 Subject: [PATCH 6/6] black --- projects/blenderbot2/agents/blenderbot2.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/projects/blenderbot2/agents/blenderbot2.py b/projects/blenderbot2/agents/blenderbot2.py index fe8d2b02ec4..68cbbf85279 100644 --- a/projects/blenderbot2/agents/blenderbot2.py +++ b/projects/blenderbot2/agents/blenderbot2.py @@ -86,7 +86,7 @@ def __init__( @abstractmethod def _clean_text(self, txt): """ - Clean text to be override with custom logic + Clean text to be override with custom logic. """ def add_reply(self, text): @@ -98,7 +98,8 @@ def add_reply(self, text): class HistoryCleanUnsafeToken(HistoryCleanReply): """ - Override the history _clean_text to filter out special tokens like _potentially_unsafe + Override the history _clean_text to filter out special tokens like + _potentially_unsafe. """ def _clean_text(self, txt):