Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[BlenderBot2] A hybrid mode (skip search when needed) #4221

Merged
merged 9 commits into from
Dec 9, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion parlai/tasks/wizard_of_internet/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import parlai.utils.logging as logging
import parlai.tasks.wizard_of_internet.constants as CONST
from .build import build
import parlai.tasks.wizard_of_internet.mutators
import parlai.tasks.wizard_of_internet.mutators # noqa: F401


def get_dtype(opt):
Expand Down Expand Up @@ -350,6 +350,9 @@ class WizardDialogTeacher(WizardOfInternetBaseTeacher):
def __init__(self, opt, shared=None):
self.prepend_gold_knowledge = opt.get('prepend_gold_knowledge')
self.gold_knowledge_delimiter = opt.get('gold_knowledge_delimiter', '\n')
self.add_skip_search_if_gold_prepended = opt.get(
'add_skip_search_if_gold_prepended'
)
super().__init__(opt, shared=shared)
self.id = 'WizInternetWizardTeacher'

Expand All @@ -363,6 +366,12 @@ def add_cmdline_args(cls, parser: ParlaiParser, partial_opt=None) -> ParlaiParse
default=False,
help='If true, prepend text with checked sentences',
)
arg_group.add_argument(
'--add-skip-search-if-gold-prepended',
type='bool',
default=False,
help='If true, add skip search field when prepending text with checked sentences',
)
return parser

def custom_evaluation(
Expand Down Expand Up @@ -443,10 +452,16 @@ def teacher_setup_data(self, datafile) -> Message:
f' {self.gold_knowledge_delimiter} {text}'
),
)
if self.add_skip_search_if_gold_prepended:
message[CONST.SKIP_SEARCH] = True
yield message, episode_started


class WizardDialogGoldKnowledgeTeacher(WizardDialogTeacher):
def __init__(self, opt, shared=None):
super().__init__(opt, shared=shared)
self.id = 'WizardDialogGoldKnowledgeTeacher'
Copy link
Contributor Author

@jxmsML jxmsML Dec 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah thanks for adding this, otherwise the train_model always flags this red warning ;)


@classmethod
def add_cmdline_args(cls, parser: ParlaiParser, partial_opt=None) -> ParlaiParser:
super().add_cmdline_args(parser, partial_opt)
Expand All @@ -459,6 +474,10 @@ class WizardDialogGoldKnowledgeNoDocsTeacher(WizardDialogGoldKnowledgeTeacher):
Prepends gold (selected knowledge) to the context, and removes the retrieved docs.
"""

def __init__(self, opt, shared=None):
super().__init__(opt, shared=shared)
self.id = 'WizardDialogGoldKnowledgeNoDocsTeacher'

def additional_message_content(self, parlai_message: Message, action: Dict):
super().additional_message_content(parlai_message, action)
remove_retrieved_docs_from_message(parlai_message)
Expand Down
1 change: 1 addition & 0 deletions parlai/tasks/wizard_of_internet/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class HISTORY_TYPE:
IS_SEARCH_QUERY = 'is_search_query'
IS_LAST_SEARCH_QUERY = 'is_last_search_query'
LABELS = 'labels'
SKIP_SEARCH = 'skip_search'

# Message values
NO_SEARCH_QUERY_USED = '__no_search_used__'
Expand Down
29 changes: 28 additions & 1 deletion projects/blenderbot2/agents/blenderbot2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
The Memory Decoder examines the context and generates memories to write to
the long-term memory module.
"""
import copy
import torch
import torch.nn
import torch.nn.functional as F
from typing import Union, Dict, List, Tuple, Optional, Any

from parlai.agents.fid.fid import FidAgent, WizIntGoldDocRetrieverFiDAgent
from parlai.agents.rag.args import DPR_ZOO_MODEL, QUERY_MODEL_TYPES
from parlai.agents.rag.args import DPR_ZOO_MODEL, QUERY_MODEL_TYPES, RetrieverType
from parlai.agents.rag.rag import RagAgent
from parlai.agents.rag.model_types import (
RagTurn,
Expand All @@ -37,8 +38,10 @@
SELECTED_DOCS_TITLES,
SELECTED_SENTENCES,
NO_SELECTED_DOCS_TOKEN,
SKIP_SEARCH,
)
from parlai.utils.torch import padded_3d
from parlai.utils.typing import TShared

from .modules import (
BlenderBot2RagModel,
Expand Down Expand Up @@ -99,6 +102,7 @@ def augment_batch_for_generation(
batch.num_gold_docs,
batch.memory_decoder_vec,
batch.num_memory_decoder_vecs,
batch.skip_search,
)
doc_log_probs = F.log_softmax(doc_scores, dim=1)
batch.src_text_vec = batch.text_vec
Expand Down Expand Up @@ -222,6 +226,12 @@ def add_cmdline_args(
default=SELECTED_DOCS_TITLES,
help='Field for selected docs titles.',
)
bb2_group.add_argument(
'--skip-search-key',
type=str,
default=SKIP_SEARCH,
help='Field for whether to skip search or not.',
)
bb2_group.add_argument(
'--insert-gold-docs',
type='bool',
Expand Down Expand Up @@ -700,6 +710,7 @@ def batchify(self, obs_batch: List[Message], sort: bool = False) -> Batch:
batch.num_gold_docs = None
batch.memory_decoder_vec = None
batch.num_memory_decoder_vecs = None
batch.skip_search = None
if any(ex.get('memory_vec') is not None for ex in valid_exs):
batch = self._set_batch_memory_vec(valid_exs, batch)
if any(ex.get('query_generator_vec') is not None for ex in valid_exs):
Expand All @@ -708,6 +719,8 @@ def batchify(self, obs_batch: List[Message], sort: bool = False) -> Batch:
batch = self._set_batch_gold_doc_vec(valid_exs, batch)
if any(ex.get('memory_decoder_vec') is not None for ex in valid_exs):
batch = self._set_batch_memory_decoder_vec(valid_exs, batch)
if any(ex.get(self.opt['skip_search_key']) is not None for ex in valid_exs):
batch = self._set_batch_skip_search(valid_exs, batch)
return batch

def _set_batch_memory_vec(self, valid_exs: List[Message], batch: Batch) -> Batch:
Expand Down Expand Up @@ -780,6 +793,11 @@ def _set_batch_memory_decoder_vec(
batch.num_memory_decoder_vecs = torch.LongTensor(num_memory_dec_toks)
return batch

def _set_batch_skip_search(self, valid_exs: List[Message], batch: Batch) -> Batch:
skip_search = [ex.get(self.opt['skip_search_key'], False) for ex in valid_exs]
batch.skip_search = torch.BoolTensor(skip_search)
return batch

def eval_step(self, batch):
output = super().eval_step(batch)
if output is None or not hasattr(self.model, 'retriever'):
Expand Down Expand Up @@ -807,6 +825,7 @@ def _model_input(
torch.LongTensor,
torch.LongTensor,
torch.LongTensor,
torch.BoolTensor,
]:
"""
Override RagAgent._model_input to include several more input vectors.
Expand All @@ -826,6 +845,7 @@ def _model_input(
batch.num_gold_docs,
batch.memory_decoder_vec,
batch.num_memory_decoder_vecs,
batch.skip_search,
)

def compute_loss(
Expand Down Expand Up @@ -894,6 +914,13 @@ def build_model(self) -> Union[BlenderBot2FidModel, T5BlenderBot2FidModel]:
return model


class BlenderBot2SearchQueryFiDAgent(BlenderBot2FidAgent):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

def __init__(self, opt: Opt, shared: TShared = None):
opt = copy.deepcopy(opt)
opt['rag_retriever_type'] = RetrieverType.SEARCH_ENGINE.value
super().__init__(opt, shared=shared)


class BlenderBot2WizIntGoldDocRetrieverFiDAgent(
WizIntGoldDocRetrieverFiDAgent, BlenderBot2FidAgent
):
Expand Down
9 changes: 8 additions & 1 deletion projects/blenderbot2/agents/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def encoder(
num_gold_docs: torch.LongTensor,
memory_decoder_vec: torch.LongTensor,
num_memory_decoder_vecs: torch.LongTensor,
skip_search: torch.BoolTensor,
positions: Optional[torch.LongTensor] = None,
segments: Optional[torch.LongTensor] = None,
) -> Tuple[
Expand Down Expand Up @@ -214,6 +215,8 @@ def encoder(
3D [bsz, num_lines, seqlen] text to convert to memories with memory decoder
:param num_memory_decoder_vecs:
1D [bsz] # of memory decoder vectors for each batch item
:param skip_search:
1D [bsz] whether to skip search
"""
# Retrieve, get expanded input
if all([tensor is not None for tensor in [input_lengths, query_vec]]):
Expand All @@ -230,6 +233,7 @@ def encoder(
num_gold_docs,
memory_decoder_vec,
num_memory_decoder_vecs,
skip_search,
)
else:
expanded_input = input
Expand Down Expand Up @@ -317,6 +321,7 @@ def retrieve_and_concat(
num_gold_docs: torch.LongTensor,
memory_decoder_vec: torch.LongTensor,
num_memory_decoder_vecs: torch.LongTensor,
skip_search: torch.BoolTensor,
) -> Tuple[torch.LongTensor, List[List[Document]], torch.Tensor]:
"""
Override RagModel.retrieve_and_concat to perform different retrieval, depending
Expand Down Expand Up @@ -360,7 +365,7 @@ def retrieve_and_concat(
if self.should_generate_query:
assert self.has_query_generator()
retrieval_type, search_queries = self.query_generator.classify_retrieval(
query_generator_vec, num_memories, generated_memories
query_generator_vec, num_memories, generated_memories, skip_search
)
logging.debug(f'Classify Retrieval: {time.time() - start:.2f}')
else:
Expand Down Expand Up @@ -844,6 +849,7 @@ def encoder(
num_gold_docs: torch.LongTensor,
memory_decoder_vec: torch.LongTensor,
num_memory_decoder_vecs: torch.LongTensor,
skip_search: torch.BoolTensor,
positions: Optional[torch.LongTensor] = None,
segments: Optional[torch.LongTensor] = None,
) -> Tuple[
Expand All @@ -866,6 +872,7 @@ def encoder(
num_gold_docs,
memory_decoder_vec,
num_memory_decoder_vecs,
skip_search,
positions,
segments,
) # type: ignore
Expand Down
5 changes: 4 additions & 1 deletion projects/blenderbot2/agents/sub_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def __init__(self, opt: Opt):
)
assert isinstance(base_agent, TorchAgent)
self.agents = [base_agent]
bsz = max(opt.get('batchsize', 1), opt.get('eval_batchsize', 1))
bsz = max(opt.get('batchsize') or 1, opt.get('eval_batchsize') or 1)
Copy link
Contributor

@mojtaba-komeili mojtaba-komeili Dec 1, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oddly, I was getting this error when trying this agent. This change was a solution to avoid it.

Screen Shot 2021-12-01 at 11 24 42 AM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm interesting, I thought anything return from opt.get(XXX, 1) would never be None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it can be None if the arg is in opt but no value is given

rag_turn_n_turns = opt.get('rag_turn_n_turns', 1)
if bsz > 1 or rag_turn_n_turns > 1:
self.agents += [
Expand All @@ -196,6 +196,7 @@ def classify_retrieval(
input: torch.LongTensor,
num_memories: torch.LongTensor,
generated_memories: Optional[List[List[str]]],
skip_search: Optional[torch.BoolTensor],
) -> Tuple[torch.LongTensor, List[str]]:
"""
Classify input and get retrieval type.
Expand Down Expand Up @@ -242,6 +243,8 @@ def classify_retrieval(
self.retrieval_type[i] = RetrievalType.MEMORY.value
elif strip_punc(s) in NONE_STRINGS + MEMORY_STRINGS:
self.retrieval_type[i] = RetrievalType.NONE.value
elif skip_search is not None and skip_search[i]:
self.retrieval_type[i] = RetrievalType.NONE.value
else:
self.retrieval_type[i] = RetrievalType.SEARCH.value
searches.append(s)
Expand Down