Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
[params] Support options depending on options in parse_kwargs. (#3900)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenroller authored Aug 5, 2021
1 parent aab5370 commit 6ca525c
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 52 deletions.
121 changes: 69 additions & 52 deletions parlai/core/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -1203,62 +1203,79 @@ def _kwargs_to_str_args(self, **kwargs):
), f"No duplicate names! ({kwname}, {kwname_to_action[kwname]}, {action})"
kwname_to_action[kwname] = action

# since we can have options that depend on options, repeat until convergence
string_args = []
for kwname, value in kwargs.items():
if kwname not in kwname_to_action:
# best guess, we need to delay it. hopefully this gets added
# during add_kw_Args
continue
action = kwname_to_action[kwname]
last_option_string = action.option_strings[-1]
if isinstance(action, argparse._StoreTrueAction):
if bool(value):
unparsed_args = set(kwargs.keys())
while unparsed_args:
string_args = []
for kwname, value in kwargs.items():
if kwname not in kwname_to_action:
# best guess, we need to delay it. hopefully this gets added
# during add_kw_Args
continue
action = kwname_to_action[kwname]
last_option_string = action.option_strings[-1]
if isinstance(action, argparse._StoreTrueAction):
if bool(value):
string_args.append(last_option_string)
elif isinstance(action, argparse._StoreAction) and action.nargs is None:
string_args.append(last_option_string)
elif isinstance(action, argparse._StoreAction) and action.nargs is None:
string_args.append(last_option_string)
string_args.append(self._value2argstr(value))
elif isinstance(action, argparse._StoreAction) and action.nargs in '*+':
string_args.append(last_option_string)
string_args.extend([self._value2argstr(value) for v in value])
else:
raise TypeError(f"Don't know what to do with {action}")

# become aware of any extra args that might be specified if the user
# provides something like model="transformer/generator".
self.add_extra_args(string_args)

# do it again, this time knowing about ALL args.
kwname_to_action = {}
for action in self._actions:
if action.dest == 'help':
# no help allowed
continue
for option_string in action.option_strings:
kwname = option_string.lstrip('-').replace('-', '_')
assert (kwname not in kwname_to_action) or (
kwname_to_action[kwname] is action
), f"No duplicate names! ({kwname}, {kwname_to_action[kwname]}, {action})"
kwname_to_action[kwname] = action

string_args = []
for kwname, value in kwargs.items():
# note we don't have the if kwname not in kwname_to_action here.
# it MUST appear, or else we legitimately should be throwing a KeyError
# because user has provided an unspecified option
action = kwname_to_action[kwname]
last_option_string = action.option_strings[-1]
if isinstance(action, argparse._StoreTrueAction):
if bool(value):
string_args.append(self._value2argstr(value))
elif isinstance(action, argparse._StoreAction) and action.nargs in '*+':
string_args.append(last_option_string)
elif isinstance(action, argparse._StoreAction) and action.nargs is None:
string_args.append(last_option_string)
string_args.append(self._value2argstr(value))
elif isinstance(action, argparse._StoreAction) and action.nargs in '*+':
string_args.append(last_option_string)
# Special case: Labels
string_args.extend([str(v) for v in value])
string_args.extend([self._value2argstr(value) for v in value])
else:
raise TypeError(f"Don't know what to do with {action}")

# become aware of any extra args that might be specified if the user
# provides something like model="transformer/generator".
self.add_extra_args(string_args)

# do it again, this time knowing about ALL args.
kwname_to_action = {}
for action in self._actions:
if action.dest == 'help':
# no help allowed
continue
for option_string in action.option_strings:
kwname = option_string.lstrip('-').replace('-', '_')
assert (kwname not in kwname_to_action) or (
kwname_to_action[kwname] is action
), f"No duplicate names! ({kwname}, {kwname_to_action[kwname]}, {action})"
kwname_to_action[kwname] = action

new_unparsed_args = set()
string_args = []
for kwname, value in kwargs.items():
if kwname not in kwname_to_action:
new_unparsed_args.add(kwname)
continue

action = kwname_to_action[kwname]
last_option_string = action.option_strings[-1]
if isinstance(action, argparse._StoreTrueAction):
if bool(value):
string_args.append(last_option_string)
elif isinstance(action, argparse._StoreAction) and action.nargs is None:
string_args.append(last_option_string)
string_args.append(self._value2argstr(value))
elif isinstance(action, argparse._StoreAction) and action.nargs in '*+':
string_args.append(last_option_string)
# Special case: Labels
string_args.extend([str(v) for v in value])
else:
raise TypeError(f"Don't know what to do with {action}")

if new_unparsed_args == unparsed_args:
# if we have converged to a fixed point with no improvements, we
# truly found some unreachable args
raise KeyError(
f'Failed to parse one or more kwargs: {", ".join(new_unparsed_args)}'
)
else:
raise TypeError(f"Don't know what to do with {action}")
# We've seen some improvements on the number of unparsed args,
# iterate again
unparsed_args = new_unparsed_args

return string_args

Expand Down
20 changes: 20 additions & 0 deletions tests/test_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,26 @@ def test_parse_kwargs(self):
parser = ParlaiParser(True, True)
parser.parse_kwargs(model='transformer/generator', fake_arg='foo')

def test_parse_kwargs_multirounds(self):
"""Test parse_kwargs when we have options that depend on options."""
parser = ParlaiParser(True, False)
opt = parser.parse_kwargs(
task='integration_tests', mutators='episode_shuffle', preserve_context=True
)
assert opt['preserve_context'] is True
opt = parser.parse_kwargs(
task='integration_tests', mutators='episode_shuffle', preserve_context=False
)
assert opt['preserve_context'] is False

with self.assertRaises(KeyError):
parser.parse_kwargs(
task='integration_tests', mutators='episode_shuffle', fake_option=False
)

with self.assertRaises(KeyError):
parser.parse_kwargs(task='integration_tests', fake_option=False)

def test_bool(self):
"""
test add_argument(type=bool)
Expand Down

0 comments on commit 6ca525c

Please sign in to comment.