From 8e873591d5cac65d973784b1291cd14636d14cb6 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 16 Sep 2021 22:50:03 +0200 Subject: [PATCH] Replace add_class_choices with set_choices --- pytorch_lightning/utilities/cli.py | 41 +++++++++++++++--------------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 576bd4ae67554..7b73b97baf1cd 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -77,7 +77,7 @@ def __str__(self) -> str: OPTIMIZER_REGISTRY = _Registry() -OPTIMIZER_REGISTRY.register_package(torch.optim, torch.optim.Optimizer) +OPTIMIZER_REGISTRY.register_package(torch.optim, Optimizer) LR_SCHEDULER_REGISTRY = _Registry() LR_SCHEDULER_REGISTRY.register_package(torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler) @@ -86,6 +86,9 @@ def __str__(self) -> str: class LightningArgumentParser(ArgumentParser): """Extension of jsonargparse's ArgumentParser for pytorch-lightning.""" + # use class attribute because `parse_args` is only called on the main parser + _choices: Dict[str, Tuple[Type, ...]] = {} + def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> None: """Initialize argument parser that supports configuration file input. @@ -105,8 +108,6 @@ def __init__(self, *args: Any, parse_as_dict: bool = True, **kwargs: Any) -> Non # separate optimizers and lr schedulers to know which were added self._optimizers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} self._lr_schedulers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} - # we need a mutable global argv copy in order to support `add_class_choices` - sys.__argv = sys.argv.copy() def add_lightning_class_args( self, @@ -167,7 +168,8 @@ def add_optimizer_args( assert issubclass(optimizer_class, Optimizer) kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"params"}} if isinstance(optimizer_class, tuple): - self.add_class_choices(optimizer_class, nested_key, **kwargs) + self.add_subclass_arguments(optimizer_class, nested_key, **kwargs) + self.set_choices(nested_key, optimizer_class) else: self.add_class_arguments(optimizer_class, nested_key, **kwargs) self._optimizers[nested_key] = (optimizer_class, link_to) @@ -191,33 +193,32 @@ def add_lr_scheduler_args( assert issubclass(lr_scheduler_class, LRSchedulerTypeTuple) kwargs = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}} if isinstance(lr_scheduler_class, tuple): - self.add_class_choices(lr_scheduler_class, nested_key, **kwargs) + self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs) + self.set_choices(nested_key, lr_scheduler_class) else: self.add_class_arguments(lr_scheduler_class, nested_key, **kwargs) self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to) def parse_args(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: - with mock.patch("sys.argv", sys.__argv): + argv = sys.argv + for k, classes in self._choices.items(): + if not any(arg.startswith(f"--{k}") for arg in argv): + # the key wasn't passed - maybe defined in a config, maybe it's optional + continue + argv = self._convert_argv_issue_84(classes, k, argv) + self._choices.clear() + with mock.patch("sys.argv", argv): return super().parse_args(*args, **kwargs) - def add_class_choices(self, classes: Tuple[Type, ...], nested_key: str, *args: Any, **kwargs: Any) -> None: + def set_choices(self, nested_key: str, classes: Tuple[Type, ...]) -> None: + self._choices[nested_key] = classes + + @staticmethod + def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]: """Placeholder for https://github.com/omni-us/jsonargparse/issues/84. This should be removed once implemented. """ - if not any(arg.startswith(f"--{nested_key}") for arg in sys.__argv): # the key was passed - if any(arg.startswith("--config") for arg in sys.__argv): # a config was passed - # parsing config files would be too difficult, fall back to what's available - self.add_subclass_arguments(classes, nested_key, *args, **kwargs) - elif kwargs.get("required", False): - raise MisconfigurationException(f"The {nested_key} key is required but wasn't passed") - else: - clean_argv = self._convert_argv_issue_84(classes, nested_key, sys.__argv) - self.add_subclass_arguments(classes, nested_key, *args, **kwargs) - sys.__argv = clean_argv - - @staticmethod - def _convert_argv_issue_84(classes: Tuple[Type, ...], nested_key: str, argv: List[str]) -> List[str]: passed_args, clean_argv = {}, [] argv_key = f"--{nested_key}" # get the argv args for this nested key