Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[BlenderBot2] clean up reply during interactive mode #4379

Merged
merged 7 commits into from
Mar 25, 2022
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions projects/blenderbot2/agents/blenderbot2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
The Memory Decoder examines the context and generates memories to write to
the long-term memory module.
"""
import re
import copy
import torch
import torch.nn
Expand Down Expand Up @@ -898,6 +899,38 @@ def compute_loss(
else:
return loss

def _clean_text(self, txt):
cleaned_txt = re.sub(r'_[\S]*unsafe_*', '', txt, flags=re.IGNORECASE)
return cleaned_txt.strip()

def self_observe(self, self_message: Message) -> None:
"""
Observe one's own utterance. Override TorchAgent.self_observe with cleaned text.

:param self_message:
The message corresponding to the output from batch_act.
"""
use_reply = self.opt.get('use_reply', 'label')

if (
self.observation['episode_done'] # last example in the episode
or use_reply == 'none' # not including our own responses anyway
jxmsML marked this conversation as resolved.
Show resolved Hide resolved
or (
use_reply == 'label'
and any([l in self.observation for l in ['labels', 'eval_labels']])
) # has true label
):
return super().self_observe(self_message)

# otherwise, we use the CLEANED last output the model generated
if self_message is not None:
last_reply = self_message['text']
clean_reply = self._clean_text(last_reply)
self.history.add_reply(clean_reply)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps we should just override the History class and do this logic there?

return

raise RuntimeError("Unexpected case in self_observe.")


class BlenderBot2FidAgent(FidAgent, BlenderBot2RagAgent):
model: BlenderBot2FidModel
Expand Down