diff --git a/parlai/core/worlds.py b/parlai/core/worlds.py index f8cb269d086..0dd7456904a 100644 --- a/parlai/core/worlds.py +++ b/parlai/core/worlds.py @@ -39,17 +39,18 @@ import copy import random -from typing import List, Dict, Union +from typing import Dict, List, Optional, Union +import parlai.utils.logging as logging from parlai.core.agents import create_agents_from_shared from parlai.core.loader import load_task_module, load_world_module from parlai.core.metrics import aggregate_named_reports from parlai.core.opt import Opt +from parlai.core.params import ParlaiParser from parlai.core.teachers import Teacher, create_task_agent_from_taskname from parlai.utils.data import DatatypeHelper from parlai.utils.misc import Timer, display_messages from parlai.tasks.tasks import ids_to_tasks -import parlai.utils.logging as logging def validate(observation): @@ -304,6 +305,17 @@ class DialogPartnerWorld(World): chance to speak per turn and passing that back to the other one. """ + @classmethod + def add_cmdline_args( + cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None + ) -> ParlaiParser: + """ + Return the parser as-is. + + Self-chat-specific world flags can be added here. + """ + return parser + def __init__(self, opt: Opt, agents=None, shared=None): if not ((agents is not None) ^ (shared is not None)): raise ValueError('You must supply either agents or shared, but not both.') diff --git a/parlai/tasks/md_gender/worlds.py b/parlai/tasks/md_gender/worlds.py index 883074b7eca..b8708e79338 100644 --- a/parlai/tasks/md_gender/worlds.py +++ b/parlai/tasks/md_gender/worlds.py @@ -116,14 +116,14 @@ def add_cmdline_args( cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None ) -> ParlaiParser: super().add_cmdline_args(parser, partial_opt) - parser = parser.add_argument_group('Gender Multiclass Interactive World') - parser.add_argument( + group = parser.add_argument_group('Gender Multiclass Interactive World') + group.add_argument( '--self-threshold', type=float, default=0.52, help='Threshold for choosing unknown for self', ) - parser.add_argument( + group.add_argument( '--partner-threshold', type=float, default=0.52, diff --git a/parlai/tasks/self_chat/worlds.py b/parlai/tasks/self_chat/worlds.py index d8c81bd4852..415c10ea321 100644 --- a/parlai/tasks/self_chat/worlds.py +++ b/parlai/tasks/self_chat/worlds.py @@ -10,8 +10,6 @@ from parlai.agents.fixed_response.fixed_response import FixedResponseAgent from parlai.core.agents import Agent -from parlai.core.opt import Opt -from parlai.core.params import ParlaiParser from parlai.core.worlds import create_task, DialogPartnerWorld, validate from parlai.core.message import Message @@ -53,17 +51,6 @@ def load_openers(opt) -> Optional[List[str]]: class SelfChatWorld(DialogPartnerWorld): - @classmethod - def add_cmdline_args( - cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None - ) -> ParlaiParser: - """ - Return the parser as-is. - - Self-chat-specific world flags can be added here. - """ - return parser - def __init__(self, opt, agents, shared=None): super().__init__(opt, agents, shared) self.init_contexts(shared=shared)