diff --git a/projects/seeker/agents/seeker.py b/projects/seeker/agents/seeker.py index 0000465d7c3..08c98283261 100644 --- a/projects/seeker/agents/seeker.py +++ b/projects/seeker/agents/seeker.py @@ -81,6 +81,12 @@ def add_cmdline_args( default=False, help='Whether to make model act output fully serializable.', ) + combo_fid.add_argument( + '--force-skip-retrieval', + type='bool', + default=False, + help='If True, we force skip retrieval on any/all incoming examples', + ) def build_model(self) -> ComboFidModel: """ @@ -107,9 +113,12 @@ def batchify(self, obs_batch: List[Message], sort: bool = False) -> Batch: batch = super().batchify(obs_batch, sort) valid_exs = [ex for ex in obs_batch if self.is_valid(ex)] if valid_exs: - skip_retrieval = [ - ex.get(self.opt['skip_retrieval_key'], False) for ex in valid_exs - ] + if self.opt.get('force_skip_retrieval', False): + skip_retrieval = [True] * len(valid_exs) + else: + skip_retrieval = [ + ex.get(self.opt['skip_retrieval_key'], False) for ex in valid_exs + ] batch.skip_retrieval_vec = torch.BoolTensor(skip_retrieval) if any(ex.get('prior_knowledge_responses') for ex in valid_exs): vecs, _lens = self._pad_tensor( @@ -563,7 +572,9 @@ def observe(self, observation: Message) -> Dict[str, Message]: for key in ['label_candidates', 'knowledge']: # Delete unnecessarily large keys observation.pop(key, '') - observation['knowledge_response'] = observation.get('checked_sentence', '') + observation.force_set( + 'knowledge_response', observation.get('checked_sentence', '') + ) raw_observation = copy.deepcopy(observation) # This part is *specifically* for document chunking. diff --git a/tests/nightly/gpu/test_seeker.py b/tests/nightly/gpu/test_seeker.py index cf0caa0cf61..1d4084847db 100644 --- a/tests/nightly/gpu/test_seeker.py +++ b/tests/nightly/gpu/test_seeker.py @@ -6,6 +6,7 @@ import unittest import parlai.scripts.eval_model as ems +from parlai.scripts.self_chat import SelfChat import parlai.utils.testing as testing_utils R2C2_BASE_400M = 'zoo:seeker/r2c2_base_400M/model' @@ -111,3 +112,15 @@ def test_blenderbot(self): 'datatype': 'valid', } ems.EvalModel.main(**opt) + + +class TestSeekerSelfChat(unittest.TestCase): + def test_400m(self): + SelfChat.main( + model_file='zoo:seeker/seeker_dialogue_400M/model', + num_self_chats=1, + init_opt='gen/seeker_dialogue', + search_decision='never', + search_server='none', + krm_force_skip_retrieval=True, + )