Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[BlenderBot2] A hybrid mode (skip search when needed) #4221

Merged
merged 9 commits into from
Dec 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ jobs:
- runtests:
cachename: teacher
marker: teacher
pytest_flags: -v -s

build_website:
executor: small_cpu37
Expand Down
21 changes: 20 additions & 1 deletion parlai/tasks/wizard_of_internet/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import parlai.utils.logging as logging
import parlai.tasks.wizard_of_internet.constants as CONST
from .build import build
import parlai.tasks.wizard_of_internet.mutators
import parlai.tasks.wizard_of_internet.mutators # noqa: F401


def get_dtype(opt):
Expand Down Expand Up @@ -350,6 +350,9 @@ class WizardDialogTeacher(WizardOfInternetBaseTeacher):
def __init__(self, opt, shared=None):
self.prepend_gold_knowledge = opt.get('prepend_gold_knowledge')
self.gold_knowledge_delimiter = opt.get('gold_knowledge_delimiter', '\n')
self.add_skip_search_if_gold_prepended = opt.get(
'add_skip_search_if_gold_prepended'
)
super().__init__(opt, shared=shared)
self.id = 'WizInternetWizardTeacher'

Expand All @@ -363,6 +366,12 @@ def add_cmdline_args(cls, parser: ParlaiParser, partial_opt=None) -> ParlaiParse
default=False,
help='If true, prepend text with checked sentences',
)
arg_group.add_argument(
'--add-skip-search-if-gold-prepended',
type='bool',
default=False,
help='If true, add skip search field when prepending text with checked sentences',
)
return parser

def custom_evaluation(
Expand Down Expand Up @@ -443,10 +452,16 @@ def teacher_setup_data(self, datafile) -> Message:
f' {self.gold_knowledge_delimiter} {text}'
),
)
if self.add_skip_search_if_gold_prepended:
message[CONST.SKIP_SEARCH] = True
yield message, episode_started


class WizardDialogGoldKnowledgeTeacher(WizardDialogTeacher):
def __init__(self, opt, shared=None):
super().__init__(opt, shared=shared)
self.id = 'WizardDialogGoldKnowledgeTeacher'
Copy link
Contributor Author

@jxmsML jxmsML Dec 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah thanks for adding this, otherwise the train_model always flags this red warning ;)


@classmethod
def add_cmdline_args(cls, parser: ParlaiParser, partial_opt=None) -> ParlaiParser:
super().add_cmdline_args(parser, partial_opt)
Expand All @@ -459,6 +474,10 @@ class WizardDialogGoldKnowledgeNoDocsTeacher(WizardDialogGoldKnowledgeTeacher):
Prepends gold (selected knowledge) to the context, and removes the retrieved docs.
"""

def __init__(self, opt, shared=None):
super().__init__(opt, shared=shared)
self.id = 'WizardDialogGoldKnowledgeNoDocsTeacher'

def additional_message_content(self, parlai_message: Message, action: Dict):
super().additional_message_content(parlai_message, action)
remove_retrieved_docs_from_message(parlai_message)
Expand Down
1 change: 1 addition & 0 deletions parlai/tasks/wizard_of_internet/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class HISTORY_TYPE:
IS_SEARCH_QUERY = 'is_search_query'
IS_LAST_SEARCH_QUERY = 'is_last_search_query'
LABELS = 'labels'
SKIP_SEARCH = 'skip_search'

# Message values
NO_SEARCH_QUERY_USED = '__no_search_used__'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1057,7 +1057,7 @@ acts:
episode_done: false
eval_labels:
- Have you seen any good movies lately?
id: WizInternetWizardTeacher
id: WizardDialogGoldKnowledgeTeacher
search_query: upcomming movies
text: "__knowledge__ __no_passages_used__ __endknowledge__ \n I live in New York,\
\ New York.\nI like going to the moives.\nI love to get the overpriced snacks."
Expand Down Expand Up @@ -4617,7 +4617,7 @@ acts:
- All four movies in the Toy Story series are great. Tom Hanks, Tim Allen, and
Don Rickles are all truly great actors and the creativity of the Pixar Animation
Studios team is really top notch.
id: WizInternetWizardTeacher
id: WizardDialogGoldKnowledgeTeacher
search_query: movie toy story
text: "__knowledge__ Tom Hanks, Tim Allen, Don Rickles | See full cast & crew\
\ »\nPixar Animation Studios - 1200 Park Avenue, Emeryville, California, USA\
Expand Down Expand Up @@ -8678,7 +8678,7 @@ acts:
work together on multiple projects. Further, Pixar, as a company, cares more
about creating high quality movies than they do making the most money possible.
This is their company's culture and the reason for their continuing success.
id: WizInternetWizardTeacher
id: WizardDialogGoldKnowledgeTeacher
search_query: pixar animation studios technology
text: "__knowledge__ Pixar makes money by unifying art and technology to produce\
\ original animated films that motivate audiences to buy movie tickets, DVDs,\
Expand Down Expand Up @@ -9831,7 +9831,7 @@ acts:
files on their server). But, the project was saved because a woman, who was
on maternit6y leave, had saved the movie onto her laptop. I am sure thy have
fixed their back up systems after this.
id: WizInternetWizardTeacher
id: WizardDialogGoldKnowledgeTeacher
search_query: laptop lost files pixar
text: "__knowledge__ Fun Facts › TV and Film Facts › Toy Story 2 Was Accidentally\
\ Deleted During Development and Almost Lost← NextRandomPrevious →\nDuring development\
Expand Down Expand Up @@ -11196,7 +11196,7 @@ acts:
- Pixar has produced 20 movies since its founding and all of the big names have
worked with them. Julia Louis-Dreyfus, Tom Holland, Jon Batiste, Phylicia Rashad,
Questlove, Tina Fey. Have you seen their newest movie Onward with Chris Pratt?
id: WizInternetWizardTeacher
id: WizardDialogGoldKnowledgeTeacher
search_query: pixar onward
text: "__knowledge__ ONWARD introduces young (and young-at-heart) viewers to a\
\ wonderful land where mythical creatures like elves and unicorns aren t merely\
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ acts:
of mystery ingredients into an extraordinary three-course meal. Course by course,
the chefs will be chopped from the competition until only one winner remains.
episode_done: false
id: WizInternetWizardTeacher
id: WizardDialogGoldKnowledgeTeacher
labels:
- ' Baskets? I don''t know anything about this show. Is it something about cooking?'
search_query: Chopped
Expand Down Expand Up @@ -1612,7 +1612,7 @@ acts:
- New Episodes Mondays 9|8c
- Watch a New Episode Now
episode_done: false
id: WizInternetWizardTeacher
id: WizardDialogGoldKnowledgeTeacher
labels:
- It sounds like quite a challenge! Do you watch every episode? Have you tried
to make the dishes yourself?
Expand Down Expand Up @@ -3692,7 +3692,7 @@ acts:
__selected-sentences__:
- Types of Cuisine From Around the World With Their Popular Foods
episode_done: false
id: WizInternetWizardTeacher
id: WizardDialogGoldKnowledgeTeacher
labels:
- 'I think it would be a big challenge to do what Chefs on the program do. Is
there a type of cuisine that they normally feature? Or it it cuisine from around
Expand Down Expand Up @@ -4688,7 +4688,7 @@ acts:
__selected-sentences__:
- Stop calling yourself a ‘foodie’
episode_done: false
id: WizInternetWizardTeacher
id: WizardDialogGoldKnowledgeTeacher
labels:
- That sounds interesting! I bet the show is pretty popular with foodies! Do you
consider yourself a foodie?
Expand All @@ -4713,7 +4713,7 @@ acts:
__selected-sentences__:
- __no_passages_used__
episode_done: true
id: WizInternetWizardTeacher
id: WizardDialogGoldKnowledgeTeacher
labels:
- All true foodies are picky. Do you like to cook in general, or do you prefer
to go to restaurants?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ acts:
episode_done: false
eval_labels:
- Same here! What kind of books do you read?
id: WizInternetWizardTeacher
id: WizardDialogGoldKnowledgeTeacher
search_query: __no_search_used__
text: "__knowledge__ __no_passages_used__ __endknowledge__ \n I work as a freelance\
\ accountant.\nI enjoy reading books. "
Expand All @@ -39,7 +39,7 @@ acts:
episode_done: false
eval_labels:
- I am a big Harry Potter nerd! What do you do for work?
id: WizInternetWizardTeacher
id: WizardDialogGoldKnowledgeTeacher
search_query: __no_search_used__
text: "__knowledge__ __no_passages_used__ __endknowledge__ \n All fiction! With\
\ COVID this past summer, I read 16 books in 3 months. You?"
Expand Down Expand Up @@ -1386,7 +1386,7 @@ acts:
episode_done: false
eval_labels:
- I do taxes actually!
id: WizInternetWizardTeacher
id: WizardDialogGoldKnowledgeTeacher
search_query: bank audit
text: "__knowledge__ __no_passages_used__ __endknowledge__ \n Audit banks. You?"
- - __retrieved-doc-sentences__:
Expand Down Expand Up @@ -3748,7 +3748,7 @@ acts:
episode_done: false
eval_labels:
- I will! It is on the 7th right ?
id: WizInternetWizardTeacher
id: WizardDialogGoldKnowledgeTeacher
search_query: Superbowl 2021
text: "__knowledge__ Super Bowl LV, the 55th Super Bowl and the 51st modern-era\
\ National Football League (NFL) championship game, will decide the league champion\
Expand Down Expand Up @@ -15953,7 +15953,7 @@ acts:
eval_labels:
- Some where in the middle. That guy is good, he has most games won by a quarterback/.
He seems huble
id: WizInternetWizardTeacher
id: WizardDialogGoldKnowledgeTeacher
search_query: Tom Brady records
text: "__knowledge__ Most games won by a quarterback: 237[2] __endknowledge__\
\ \n It is. Are you a Brady fan or foe?"
Expand Down
29 changes: 28 additions & 1 deletion projects/blenderbot2/agents/blenderbot2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
The Memory Decoder examines the context and generates memories to write to
the long-term memory module.
"""
import copy
import torch
import torch.nn
import torch.nn.functional as F
from typing import Union, Dict, List, Tuple, Optional, Any

from parlai.agents.fid.fid import FidAgent, WizIntGoldDocRetrieverFiDAgent
from parlai.agents.rag.args import DPR_ZOO_MODEL, QUERY_MODEL_TYPES
from parlai.agents.rag.args import DPR_ZOO_MODEL, QUERY_MODEL_TYPES, RetrieverType
from parlai.agents.rag.rag import RagAgent
from parlai.agents.rag.model_types import (
RagTurn,
Expand All @@ -37,8 +38,10 @@
SELECTED_DOCS_TITLES,
SELECTED_SENTENCES,
NO_SELECTED_DOCS_TOKEN,
SKIP_SEARCH,
)
from parlai.utils.torch import padded_3d
from parlai.utils.typing import TShared

from .modules import (
BlenderBot2RagModel,
Expand Down Expand Up @@ -99,6 +102,7 @@ def augment_batch_for_generation(
batch.num_gold_docs,
batch.memory_decoder_vec,
batch.num_memory_decoder_vecs,
batch.skip_search,
)
doc_log_probs = F.log_softmax(doc_scores, dim=1)
batch.src_text_vec = batch.text_vec
Expand Down Expand Up @@ -222,6 +226,12 @@ def add_cmdline_args(
default=SELECTED_DOCS_TITLES,
help='Field for selected docs titles.',
)
bb2_group.add_argument(
'--skip-search-key',
type=str,
default=SKIP_SEARCH,
help='Field for whether to skip search or not.',
)
bb2_group.add_argument(
'--insert-gold-docs',
type='bool',
Expand Down Expand Up @@ -700,6 +710,7 @@ def batchify(self, obs_batch: List[Message], sort: bool = False) -> Batch:
batch.num_gold_docs = None
batch.memory_decoder_vec = None
batch.num_memory_decoder_vecs = None
batch.skip_search = None
if any(ex.get('memory_vec') is not None for ex in valid_exs):
batch = self._set_batch_memory_vec(valid_exs, batch)
if any(ex.get('query_generator_vec') is not None for ex in valid_exs):
Expand All @@ -708,6 +719,8 @@ def batchify(self, obs_batch: List[Message], sort: bool = False) -> Batch:
batch = self._set_batch_gold_doc_vec(valid_exs, batch)
if any(ex.get('memory_decoder_vec') is not None for ex in valid_exs):
batch = self._set_batch_memory_decoder_vec(valid_exs, batch)
if any(ex.get(self.opt['skip_search_key']) is not None for ex in valid_exs):
batch = self._set_batch_skip_search(valid_exs, batch)
return batch

def _set_batch_memory_vec(self, valid_exs: List[Message], batch: Batch) -> Batch:
Expand Down Expand Up @@ -780,6 +793,11 @@ def _set_batch_memory_decoder_vec(
batch.num_memory_decoder_vecs = torch.LongTensor(num_memory_dec_toks)
return batch

def _set_batch_skip_search(self, valid_exs: List[Message], batch: Batch) -> Batch:
skip_search = [ex.get(self.opt['skip_search_key'], False) for ex in valid_exs]
batch.skip_search = torch.BoolTensor(skip_search)
return batch

def eval_step(self, batch):
output = super().eval_step(batch)
if output is None or not hasattr(self.model, 'retriever'):
Expand Down Expand Up @@ -807,6 +825,7 @@ def _model_input(
torch.LongTensor,
torch.LongTensor,
torch.LongTensor,
torch.BoolTensor,
]:
"""
Override RagAgent._model_input to include several more input vectors.
Expand All @@ -826,6 +845,7 @@ def _model_input(
batch.num_gold_docs,
batch.memory_decoder_vec,
batch.num_memory_decoder_vecs,
batch.skip_search,
)

def compute_loss(
Expand Down Expand Up @@ -894,6 +914,13 @@ def build_model(self) -> Union[BlenderBot2FidModel, T5BlenderBot2FidModel]:
return model


class BlenderBot2SearchQueryFiDAgent(BlenderBot2FidAgent):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

def __init__(self, opt: Opt, shared: TShared = None):
opt = copy.deepcopy(opt)
opt['rag_retriever_type'] = RetrieverType.SEARCH_ENGINE.value
super().__init__(opt, shared=shared)


class BlenderBot2WizIntGoldDocRetrieverFiDAgent(
WizIntGoldDocRetrieverFiDAgent, BlenderBot2FidAgent
):
Expand Down
9 changes: 8 additions & 1 deletion projects/blenderbot2/agents/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def encoder(
num_gold_docs: torch.LongTensor,
memory_decoder_vec: torch.LongTensor,
num_memory_decoder_vecs: torch.LongTensor,
skip_search: torch.BoolTensor,
positions: Optional[torch.LongTensor] = None,
segments: Optional[torch.LongTensor] = None,
) -> Tuple[
Expand Down Expand Up @@ -214,6 +215,8 @@ def encoder(
3D [bsz, num_lines, seqlen] text to convert to memories with memory decoder
:param num_memory_decoder_vecs:
1D [bsz] # of memory decoder vectors for each batch item
:param skip_search:
1D [bsz] whether to skip search
"""
# Retrieve, get expanded input
if all([tensor is not None for tensor in [input_lengths, query_vec]]):
Expand All @@ -230,6 +233,7 @@ def encoder(
num_gold_docs,
memory_decoder_vec,
num_memory_decoder_vecs,
skip_search,
)
else:
expanded_input = input
Expand Down Expand Up @@ -317,6 +321,7 @@ def retrieve_and_concat(
num_gold_docs: torch.LongTensor,
memory_decoder_vec: torch.LongTensor,
num_memory_decoder_vecs: torch.LongTensor,
skip_search: torch.BoolTensor,
) -> Tuple[torch.LongTensor, List[List[Document]], torch.Tensor]:
"""
Override RagModel.retrieve_and_concat to perform different retrieval, depending
Expand Down Expand Up @@ -360,7 +365,7 @@ def retrieve_and_concat(
if self.should_generate_query:
assert self.has_query_generator()
retrieval_type, search_queries = self.query_generator.classify_retrieval(
query_generator_vec, num_memories, generated_memories
query_generator_vec, num_memories, generated_memories, skip_search
)
logging.debug(f'Classify Retrieval: {time.time() - start:.2f}')
else:
Expand Down Expand Up @@ -844,6 +849,7 @@ def encoder(
num_gold_docs: torch.LongTensor,
memory_decoder_vec: torch.LongTensor,
num_memory_decoder_vecs: torch.LongTensor,
skip_search: torch.BoolTensor,
positions: Optional[torch.LongTensor] = None,
segments: Optional[torch.LongTensor] = None,
) -> Tuple[
Expand All @@ -866,6 +872,7 @@ def encoder(
num_gold_docs,
memory_decoder_vec,
num_memory_decoder_vecs,
skip_search,
positions,
segments,
) # type: ignore
Expand Down
Loading