Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[Reranker] Set decoding method #4473

Merged
merged 4 commits into from
Apr 5, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions parlai/agents/reranker/reranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def add_cmdline_args(cls, parser: ParlaiParser, partial_opt: Optional[Opt] = Non
'--reranker-delimiter',
type=str,
default=None,
help='delimiter for the retriever',
help='delimiter for the reranker',
)
return parser

Expand Down Expand Up @@ -505,11 +505,6 @@ def set_rerank_strategy(self, strategy: str):
assert strategy in RERANKER_STRATEGIES
self.reranker.reranker_strategy = strategy

def get_observations_for_reranker(
self, observations: List[Message]
) -> List[Message]:
return observations

def share(self):
"""
Share model parameters.
Expand All @@ -520,6 +515,14 @@ def share(self):
shared['reranker'] = self.reranker.share()
return shared

def set_decoding_method(self, strategy):
self.opt[self.inference_opt_key] = strategy

def get_observations_for_reranker(
self, observations: List[Message], batch_reply: List[Message]
) -> List[Message]:
return observations

def batch_act(self, observations: List[Message]) -> List[Message]:
"""
Batch process a list of observations.
Expand All @@ -530,15 +533,17 @@ def batch_act(self, observations: List[Message]) -> List[Message]:
batch_reply = [Message() for _ in range(len(observations))]
# 1. get all beam texts to consider
for strategy in self.inference_strategies:
self.opt[self.inference_opt_key] = strategy
self.set_decoding_method(strategy)
inference_batch_reply = super().batch_act(observations)
for i, resp in enumerate(inference_batch_reply):
beam_texts = batch_reply[i].get('beam_texts', [])
batch_reply[i] = resp # add metrics, other response items
new_beam_texts = [(*b, strategy) for b in resp.get('beam_texts', [])]
batch_reply[i].force_set('beam_texts', beam_texts + new_beam_texts)
# 2. Rerank
observations_for_reranker = self.get_observations_for_reranker(observations)
observations_for_reranker = self.get_observations_for_reranker(
observations, batch_reply
)
for observation, generator_response in zip(
observations_for_reranker, batch_reply
):
Expand Down