From d39f1c7108662361be451232f337fbc7256ff6a0 Mon Sep 17 00:00:00 2001 From: Jasha <8935917+Jasha10@users.noreply.github.com> Date: Sun, 16 Jul 2023 07:01:31 +0000 Subject: [PATCH] handle SCMode.INSTANTIATE with throw_on_missing=False --- omegaconf/__init__.py | 21 ++++++++++----------- omegaconf/_utils.py | 3 ++- omegaconf/base.py | 2 ++ omegaconf/basecontainer.py | 2 +- omegaconf/dictconfig.py | 21 ++++++++++++--------- omegaconf/omegaconf.py | 2 -- tests/test_to_container.py | 12 ++++++++++++ 7 files changed, 39 insertions(+), 24 deletions(-) diff --git a/omegaconf/__init__.py b/omegaconf/__init__.py index e8b9e369e..a432bf416 100644 --- a/omegaconf/__init__.py +++ b/omegaconf/__init__.py @@ -1,4 +1,12 @@ -from .base import Container, DictKeyType, ListMergeMode, Node, SCMode, UnionNode +from .base import ( + MISSING, + Container, + DictKeyType, + ListMergeMode, + Node, + SCMode, + UnionNode, +) from .dictconfig import DictConfig from .errors import ( KeyValidationError, @@ -19,16 +27,7 @@ StringNode, ValueNode, ) -from .omegaconf import ( - II, - MISSING, - SI, - OmegaConf, - Resolver, - flag_override, - open_dict, - read_write, -) +from .omegaconf import II, SI, OmegaConf, Resolver, flag_override, open_dict, read_write from .version import __version__ __all__ = [ diff --git a/omegaconf/_utils.py b/omegaconf/_utils.py index 3452f48ca..202c2e70e 100644 --- a/omegaconf/_utils.py +++ b/omegaconf/_utils.py @@ -378,7 +378,8 @@ def get_dataclass_fields(obj: Any) -> List["dataclasses.Field[Any]"]: def get_dataclass_data( obj: Any, allow_objects: Optional[bool] = None ) -> Dict[str, Any]: - from omegaconf.omegaconf import MISSING, OmegaConf, _maybe_wrap + from omegaconf import MISSING, OmegaConf + from omegaconf.omegaconf import _maybe_wrap flags = {"allow_objects": allow_objects} if allow_objects is not None else {} d = {} diff --git a/omegaconf/base.py b/omegaconf/base.py index 77e951058..83d441bf2 100644 --- a/omegaconf/base.py +++ b/omegaconf/base.py @@ -38,6 +38,8 @@ from .grammar_parser import parse from .grammar_visitor import GrammarVisitor +MISSING: Any = "???" + DictKeyType = Union[str, bytes, int, Enum, float, bool] diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 156b1ca30..65b28f2bb 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -286,7 +286,7 @@ def get_node_value(key: Union[DictKeyType, int]) -> Any: if structured_config_mode == SCMode.INSTANTIATE and is_structured_config( conf._metadata.object_type ): - return conf._to_object() + return conf._to_object(throw_on_missing=throw_on_missing) retdict: Dict[DictKeyType, Any] = {} for key in conf.keys(): diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index 12c1ebde9..79662359a 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -38,7 +38,7 @@ is_structured_config_frozen, type_str, ) -from .base import Box, Container, ContainerMetadata, DictKeyType, Node +from .base import MISSING, Box, Container, ContainerMetadata, DictKeyType, Node from .basecontainer import BaseContainer from .errors import ( ConfigAttributeError, @@ -716,7 +716,7 @@ def _dict_conf_eq(d1: "DictConfig", d2: "DictConfig") -> bool: return True - def _to_object(self) -> Any: + def _to_object(self, throw_on_missing: bool) -> Any: """ Instantiate an instance of `self._metadata.object_type`. This requires `self` to be a structured config. @@ -741,13 +741,16 @@ def _to_object(self) -> Any: if node._is_missing(): if k not in init_field_names: continue # MISSING is ignored for init=False fields - self._format_and_raise( - key=k, - value=None, - cause=MissingMandatoryValue( - "Structured config of type `$OBJECT_TYPE` has missing mandatory value: $KEY" - ), - ) + if throw_on_missing: + self._format_and_raise( + key=k, + value=None, + cause=MissingMandatoryValue( + "Structured config of type `$OBJECT_TYPE` has missing mandatory value: $KEY" + ), + ) + else: + v = MISSING if isinstance(node, Container): v = OmegaConf.to_object(node) else: diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index 041602879..5315a8949 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -73,8 +73,6 @@ ValueNode, ) -MISSING: Any = "???" - Resolver = Callable[..., Any] diff --git a/tests/test_to_container.py b/tests/test_to_container.py index f3c5f910a..cb875ee14 100644 --- a/tests/test_to_container.py +++ b/tests/test_to_container.py @@ -6,6 +6,7 @@ from pytest import fixture, mark, param, raises from omegaconf import ( + MISSING, DictConfig, ListConfig, MissingMandatoryValue, @@ -407,6 +408,17 @@ def test_to_container_INSTANTIATE_enum_to_str_True(self, module: Any) -> None: assert container["color"] == "BLUE" assert container["obj"].not_optional is Color.BLUE + def test_to_container_INSTANTIATE_throw_on_missing_False(self, module: Any) -> None: + """Test the lower level `to_container` API with SCMode.INSTANTIATE and throw_on_missing=False""" + src = module.User("Bond") # age: MISSING + cfg = OmegaConf.create(src) + container = OmegaConf.to_container( + cfg, throw_on_missing=False, structured_config_mode=SCMode.INSTANTIATE + ) + assert isinstance(container, module.User) + assert container.name == "Bond" + assert container.age is MISSING + def test_to_object_InterpolationResolutionError(self, module: Any) -> None: with raises(InterpolationResolutionError): cfg = OmegaConf.structured(module.NestedWithAny)