From 1635cfda4b081338d19c643fc3ddf9907e0e051a Mon Sep 17 00:00:00 2001 From: Jing Date: Fri, 25 Mar 2022 13:55:10 -0400 Subject: [PATCH] [BlenderBot2] clean up reply during interactive mode (#4379) * reply clean * comments and typos * black * bb2 agent history * clean * black --- projects/blenderbot2/agents/blenderbot2.py | 62 +++++++++++++++++++++- tests/nightly/gpu/test_bb2.py | 34 ++++++++++++ 2 files changed, 95 insertions(+), 1 deletion(-) diff --git a/projects/blenderbot2/agents/blenderbot2.py b/projects/blenderbot2/agents/blenderbot2.py index 0b691e5c714..4d54e03a91e 100644 --- a/projects/blenderbot2/agents/blenderbot2.py +++ b/projects/blenderbot2/agents/blenderbot2.py @@ -13,6 +13,8 @@ The Memory Decoder examines the context and generates memories to write to the long-term memory module. """ +from abc import abstractmethod +import re import copy import torch import torch.nn @@ -32,7 +34,7 @@ from parlai.core.metrics import AverageMetric from parlai.core.opt import Opt from parlai.core.params import ParlaiParser -from parlai.core.torch_agent import Batch +from parlai.core.torch_agent import Batch, History from parlai.tasks.wizard_of_internet.constants import ( SELECTED_DOCS, SELECTED_DOCS_TITLES, @@ -57,6 +59,54 @@ ZOO_MEMORY_DECODER = 'zoo:blenderbot2/memory_decoder/model' +class HistoryCleanReply(History): + def __init__( + self, + opt, + field='text', + maxlen=None, + size=-1, + p1_token='__p1__', + p2_token='__p2__', + dict_agent=None, + ): + super().__init__( + opt, + field=field, + maxlen=maxlen, + size=size, + p1_token=p1_token, + p2_token=p2_token, + dict_agent=dict_agent, + ) + self.add_cleaned_reply_to_history = opt.get( + 'add_cleaned_reply_to_history', False + ) + + @abstractmethod + def _clean_text(self, txt): + """ + Clean text to be override with custom logic. + """ + + def add_reply(self, text): + clean_text = text + if self.add_cleaned_reply_to_history: + clean_text = self._clean_text(text) + super().add_reply(clean_text) + + +class HistoryCleanUnsafeToken(HistoryCleanReply): + """ + Override the history _clean_text to filter out special tokens like + _potentially_unsafe. + """ + + def _clean_text(self, txt): + cleaned_txt = re.sub(r'_[\S]*unsafe_*', '', txt, flags=re.IGNORECASE) + return cleaned_txt.strip() + + class BlenderBot2ModelTypeMixin(RagModelInterface): """ Override Normal RAG Model Types, in case we retrieve from both memory and search. @@ -339,6 +389,12 @@ def add_cmdline_args( hidden=True, help='model file for memory writer', ) + bb2_group.add_argument( + '--add-cleaned-reply-to-history', + type=bool, + default=False, + help='whether to add the cleaned bb2 generated text without any special tokens to its history', + ) memory_decoder = parser.add_argument_group('BlenderBot2 Memory Decoder Args') memory_decoder.add_argument( '--memory-decoder-key', @@ -391,6 +447,10 @@ def add_cmdline_args( ) return parser + @classmethod + def history_class(cls): + return HistoryCleanUnsafeToken + @property def rag_model_type(self) -> str: return self._rag_model_type diff --git a/tests/nightly/gpu/test_bb2.py b/tests/nightly/gpu/test_bb2.py index de4b0119e20..25293448d4b 100644 --- a/tests/nightly/gpu/test_bb2.py +++ b/tests/nightly/gpu/test_bb2.py @@ -6,6 +6,7 @@ import copy import torch.cuda import unittest +from parlai.core.agents import create_agent import parlai.utils.testing as testing_utils @@ -203,6 +204,39 @@ def test_rag(self): ) +@testing_utils.skipUnlessGPU +@unittest.skipIf(LOCAL, "Skipping Test because its slow and mem intensive") +class TestBB2CleanText(unittest.TestCase): + SPEICIAL_TOKEN = '_POTENTIALLY_UNSAFE__' + + def test_bb2_history(self): + """ + Test out-of-the-box BB2 generation. + """ + opt = copy.deepcopy(common_opt) + opt.update( + { + 'model_file': ZOO_BB2, + 'override': { + 'search_server': SEARCH_SERVER, + 'add_cleaned_reply_to_history': True, + }, + } + ) + bb2 = create_agent(opt) + + text_with_safety_token = f"Don't have a cow, Man! {self.SPEICIAL_TOKEN}" + obs = {'text': text_with_safety_token} + bb2.observe(obs) + assert self.SPEICIAL_TOKEN in bb2.history.get_history_str() + + bb2.history.reset() + obs = {'text': "I am Groot"} + bb2.observe(obs) + bb2.history.add_reply(text_with_safety_token) + assert self.SPEICIAL_TOKEN not in bb2.history.get_history_str() + + @testing_utils.skipUnlessGPU @unittest.skipIf(LOCAL, "Skipping Test because its slow and mem intensive") class TestBB2AdditionalTruncation(unittest.TestCase):