Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[ReRanker] Add reranker-delimiter, adapt it slightly to work with SeeKeR #4469

Merged
merged 1 commit into from
Apr 1, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions parlai/agents/reranker/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ def add_cmdline_args(cls, parser: ParlaiParser, partial_opt: Optional[Opt] = Non
help='Which strategy to use when re-ranking response candidates. '
f"Choices: {','.join(RERANKER_STRATEGIES)}",
)
reranker.add_argument(
'--reranker-delimiter',
type=str,
default=None,
help='delimiter for the retriever',
)
return parser

def __init__(self, opt: Opt, shared=None):
Expand All @@ -69,7 +75,9 @@ def __init__(self, opt: Opt, shared=None):
)
self.reranker_strategy = opt['reranker_strategy']
self.normalize_candidates = opt['normalize_candidates']
self.delimiter = opt.get('delimiter', '\n')
self.delimiter = opt.get('reranker_delimiter', None)
if not self.delimiter:
self.delimiter = opt.get('delimiter', '\n')
self.include_context = True
self.include_label_cand_only = False
self.init_predictor(opt, shared)
Expand Down Expand Up @@ -459,6 +467,12 @@ def add_cmdline_args(
default=False,
help='specify to enable certain debugging procedures.',
)
gen_agent.add_argument(
'--inference-opt-key',
type=str,
default='inference',
help='specify inference opt key for dialogue response model',
)

return parser

Expand All @@ -468,8 +482,9 @@ def __init__(self, opt: Opt, shared=None):
"""
super().__init__(opt, shared)
reranker_class = self.get_reranker_class()
self.inference_opt_key = opt.get('inference_opt_key', 'inference')
self.inference_strategies = (
opt['inference_strategies'] or opt['inference']
opt['inference_strategies'] or opt[self.inference_opt_key]
).split(',')
self.debug_mode = opt.get('debug_mode', False)
if not shared:
Expand All @@ -490,6 +505,11 @@ def set_rerank_strategy(self, strategy: str):
assert strategy in RERANKER_STRATEGIES
self.reranker.reranker_strategy = strategy

def get_observations_for_reranker(
self, observations: List[Message]
) -> List[Message]:
return observations

def share(self):
"""
Share model parameters.
Expand All @@ -510,15 +530,18 @@ def batch_act(self, observations: List[Message]) -> List[Message]:
batch_reply = [Message() for _ in range(len(observations))]
# 1. get all beam texts to consider
for strategy in self.inference_strategies:
self.opt['inference'] = strategy
self.opt[self.inference_opt_key] = strategy
inference_batch_reply = super().batch_act(observations)
for i, resp in enumerate(inference_batch_reply):
beam_texts = batch_reply[i].get('beam_texts', [])
batch_reply[i] = resp # add metrics, other response items
new_beam_texts = [(*b, strategy) for b in resp.get('beam_texts', [])]
batch_reply[i].force_set('beam_texts', beam_texts + new_beam_texts)
# 2. Rerank
for observation, generator_response in zip(observations, batch_reply):
observations_for_reranker = self.get_observations_for_reranker(observations)
for observation, generator_response in zip(
observations_for_reranker, batch_reply
):
if (
'beam_texts' not in generator_response
or not generator_response['beam_texts']
Expand Down