Skip to content

Commit

Permalink
Replace add_class_choices with set_choices
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca committed Sep 16, 2021
1 parent a3a791f commit 8e87359
Showing 1 changed file with 21 additions and 20 deletions.
41 changes: 21 additions & 20 deletions pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __str__(self) -> str:


OPTIMIZER_REGISTRY = _Registry()
OPTIMIZER_REGISTRY.register_package(torch.optim, torch.optim.Optimizer)
OPTIMIZER_REGISTRY.register_package(torch.optim, Optimizer)

LR_SCHEDULER_REGISTRY = _Registry()
LR_SCHEDULER_REGISTRY.register_package(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler)
Expand All @@ -86,6 +86,9 @@ def __str__(self) -> str:
class LightningArgumentParser(ArgumentParser):
"""Extension of jsonargparse's ArgumentParser for pytorch-lightning."""

# use class attribute because `parse_args` is only called on the main parser
_choices: Dict[str, Tuple[Type, ...]] = {}

def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None:
"""Initialize argument parser that supports configuration file input.
Expand All @@ -105,8 +108,6 @@ def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> Non
# separate optimizers and lr schedulers to know which were added
self._optimizers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {}
self._lr_schedulers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {}
# we need a mutable global argv copy in order to support `add_class_choices`
sys.__argv = sys.argv.copy()

def add_lightning_class_args(
self,
Expand Down Expand Up @@ -167,7 +168,8 @@ def add_optimizer_args(
assert issubclass(optimizer_class, Optimizer)
kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"params"}}
if isinstance(optimizer_class, tuple):
self.add_class_choices(optimizer_class, nested_key, **kwargs)
self.add_subclass_arguments(optimizer_class, nested_key, **kwargs)
self.set_choices(nested_key, optimizer_class)
else:
self.add_class_arguments(optimizer_class, nested_key, **kwargs)
self._optimizers[nested_key] = (optimizer_class, link_to)
Expand All @@ -191,33 +193,32 @@ def add_lr_scheduler_args(
assert issubclass(lr_scheduler_class, LRSchedulerTypeTuple)
kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}}
if isinstance(lr_scheduler_class, tuple):
self.add_class_choices(lr_scheduler_class, nested_key, **kwargs)
self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs)
self.set_choices(nested_key, lr_scheduler_class)
else:
self.add_class_arguments(lr_scheduler_class, nested_key, **kwargs)
self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to)

def parse_args(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
with mock.patch("sys.argv", sys.__argv):
argv = sys.argv
for k, classes in self._choices.items():
if not any(arg.startswith(f"--{k}") for arg in argv):
# the key wasn't passed - maybe defined in a config, maybe it's optional
continue
argv = self._convert_argv_issue_84(classes, k, argv)
self._choices.clear()
with mock.patch("sys.argv", argv):
return super().parse_args(*args, **kwargs)

def add_class_choices(self, classes: Tuple[Type, ...], nested_key: str, *args: Any, **kwargs: Any) -> None:
def set_choices(self, nested_key: str, classes: Tuple[Type, ...]) -> None:
self._choices[nested_key] = classes

@staticmethod
def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]:
"""Placeholder for https://github.com/omni-us/jsonargparse/issues/84.
This should be removed once implemented.
"""
if not any(arg.startswith(f"--{nested_key}") for arg in sys.__argv): # the key was passed
if any(arg.startswith("--config") for arg in sys.__argv): # a config was passed
# parsing config files would be too difficult, fall back to what's available
self.add_subclass_arguments(classes, nested_key, *args, **kwargs)
elif kwargs.get("required", False):
raise MisconfigurationException(f"The {nested_key} key is required but wasn't passed")
else:
clean_argv = self._convert_argv_issue_84(classes, nested_key, sys.__argv)
self.add_subclass_arguments(classes, nested_key, *args, **kwargs)
sys.__argv = clean_argv

@staticmethod
def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]:
passed_args, clean_argv = {}, []
argv_key = f"--{nested_key}"
# get the argv args for this nested key
Expand Down

0 comments on commit 8e87359

Please sign in to comment.