diff --git a/parlai/crowdsourcing/tasks/model_chat/model_chat_blueprint.py b/parlai/crowdsourcing/tasks/model_chat/model_chat_blueprint.py index 156cc07140f..004a7ba3eba 100644 --- a/parlai/crowdsourcing/tasks/model_chat/model_chat_blueprint.py +++ b/parlai/crowdsourcing/tasks/model_chat/model_chat_blueprint.py @@ -425,15 +425,22 @@ def __init__( run_statistics = {r: 0 for (r, v) in self.conversations_needed.items()} shared_state.run_statistics = run_statistics - context_generator: Optional[ContextGenerator] = None if ( args.blueprint.include_persona - # 'hi' mode does not use a context generator and instead just displays "Hi!" at the start of the conversation + # 'hi' mode does not use a context generator and instead just displays "Hi!" + # at the start of the conversation or args.blueprint.conversation_start_mode != 'hi' ): + if args.blueprint.conversation_start_mode == 'hi': + # Default to using the context from BlendedSkillTalk + task = 'blended_skill_talk' + else: + task = args.blueprint.conversation_start_mode context_generator = get_context_generator( - args.blueprint.override_opt, args.blueprint.conversation_start_mode + override_opt=args.blueprint.override_opt, task=task ) + else: + context_generator: Optional[ContextGenerator] = None shared_state.context_generator = context_generator # Lock for editing run statistics between threads diff --git a/parlai/crowdsourcing/tasks/model_chat/utils.py b/parlai/crowdsourcing/tasks/model_chat/utils.py index 49f9c0985f0..baed9406fb3 100644 --- a/parlai/crowdsourcing/tasks/model_chat/utils.py +++ b/parlai/crowdsourcing/tasks/model_chat/utils.py @@ -413,7 +413,7 @@ def _check_final_chat_data( def get_context_generator( override_opt: Optional[Dict[str, Any]] = None, - conversation_start_mode: Optional[str] = 'blended_skill_talk', + task: Optional[str] = 'blended_skill_talk', **kwargs, ) -> ContextGenerator: """ @@ -424,7 +424,7 @@ def get_context_generator( if override_opt is not None: argparser.set_params(**override_opt) opt = argparser.parse_args([]) - task_module = load_task_module(conversation_start_mode) + task_module = load_task_module(task) context_generator_class = getattr(task_module, 'ContextGenerator', None) context_generator = context_generator_class(opt, datatype='test', seed=0, **kwargs) # We pull from the test set so that the model can't regurgitate