diff --git a/parlai/agents/reranker/reranker.py b/parlai/agents/reranker/reranker.py index 86a71a1aaa2..6b7db5209ed 100644 --- a/parlai/agents/reranker/reranker.py +++ b/parlai/agents/reranker/reranker.py @@ -514,6 +514,7 @@ def batch_act(self, observations: List[Message]) -> List[Message]: inference_batch_reply = super().batch_act(observations) for i, resp in enumerate(inference_batch_reply): beam_texts = batch_reply[i].get('beam_texts', []) + batch_reply[i] = resp # add metrics, other response items new_beam_texts = [(*b, strategy) for b in resp.get('beam_texts', [])] batch_reply[i].force_set('beam_texts', beam_texts + new_beam_texts) # 2. Rerank