Skip to content

Commit da7fe28

Browse files
bottlerfacebook-github-bot
authored andcommitted
small fixes to config
Summary: - indicate location of OmegaConf.structured failures - split the data gathering from enable_get_default_args to ease experimenting with it. - comment fixes. - nicer error when a_class_type has weird type. Reviewed By: kjchalup Differential Revision: D39434447 fbshipit-source-id: b80c7941547ca450e848038ef5be95b7ebbe8f3e
1 parent cb7bd33 commit da7fe28

File tree

2 files changed

+55
-17
lines changed

2 files changed

+55
-17
lines changed

Diff for: pytorch3d/implicitron/tools/config.py

+52-15
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,10 @@ class type
287287
raise ValueError(
288288
f"Cannot look up {base_class_wanted}. Cannot tell what it is."
289289
)
290+
if not isinstance(name, str):
291+
raise ValueError(
292+
f"Cannot look up a {type(name)} in the registry. Got {name}."
293+
)
290294
result = self._mapping[base_class].get(name)
291295
if result is None:
292296
raise ValueError(f"{name} has not been registered.")
@@ -446,6 +450,11 @@ def create_pluggable(self, type_name, args):
446450
setattr(self, name, None)
447451
return
448452

453+
if not isinstance(type_name, str):
454+
raise ValueError(
455+
f"A {type(type_name)} was received as the type of {name}."
456+
+ f" Perhaps this is from {name}{TYPE_SUFFIX}?"
457+
)
449458
chosen_class = registry.get(type_, type_name)
450459
if self._known_implementations.get(type_name, chosen_class) is not chosen_class:
451460
# If this warning is raised, it means that a new definition of
@@ -514,7 +523,10 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
514523
# because in practice get_default_args_field is used for
515524
# separate types than the outer type.
516525

517-
out: DictConfig = OmegaConf.structured(C)
526+
try:
527+
out: DictConfig = OmegaConf.structured(C)
528+
except Exception as e:
529+
raise ValueError(f"OmegaConf.structured({C}) failed") from e
518530
exclude = getattr(C, "_processed_members", ())
519531
with open_dict(out):
520532
for field in exclude:
@@ -534,7 +546,11 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
534546
f"Cannot get args for {C}. Was enable_get_default_args forgotten?"
535547
)
536548

537-
return OmegaConf.structured(dataclass)
549+
try:
550+
out: DictConfig = OmegaConf.structured(dataclass)
551+
except Exception as e:
552+
raise ValueError(f"OmegaConf.structured failed for {dataclass_name}") from e
553+
return out
538554

539555

540556
def _dataclass_name_for_function(C: Any) -> str:
@@ -546,22 +562,21 @@ def _dataclass_name_for_function(C: Any) -> str:
546562
return name
547563

548564

549-
def enable_get_default_args(C: Any, *, overwrite: bool = True) -> None:
565+
def _field_annotations_for_default_args(
566+
C: Any,
567+
) -> List[Tuple[str, Any, dataclasses.Field]]:
550568
"""
551569
If C is a function or a plain class with an __init__ function,
552-
and you want get_default_args(C) to work, then add
553-
`enable_get_default_args(C)` straight after the definition of C.
554-
This makes a dataclass corresponding to the default arguments of C
555-
and stores it in the same module as C.
570+
return the fields which `enable_get_default_args(C)` will need
571+
to make a dataclass with.
556572
557573
Args:
558574
C: a function, or a class with an __init__ function. Must
559575
have types for all its defaulted args.
560-
overwrite: whether to allow calling this a second time on
561-
the same function.
576+
577+
Returns:
578+
a list of fields for a dataclass.
562579
"""
563-
if not inspect.isfunction(C) and not inspect.isclass(C):
564-
raise ValueError(f"Unexpected {C}")
565580

566581
field_annotations = []
567582
for pname, defval in _params_iter(C):
@@ -572,8 +587,8 @@ def enable_get_default_args(C: Any, *, overwrite: bool = True) -> None:
572587

573588
if defval.annotation == inspect._empty:
574589
raise ValueError(
575-
"All arguments of the input callable have to be typed."
576-
+ f" Argument '{pname}' does not have a type annotation."
590+
"All arguments of the input to enable_get_default_args have to"
591+
f" be typed. Argument '{pname}' does not have a type annotation."
577592
)
578593

579594
_, annotation = _resolve_optional(defval.annotation)
@@ -591,6 +606,28 @@ def enable_get_default_args(C: Any, *, overwrite: bool = True) -> None:
591606
field_ = dataclasses.field(default=default)
592607
field_annotations.append((pname, defval.annotation, field_))
593608

609+
return field_annotations
610+
611+
612+
def enable_get_default_args(C: Any, *, overwrite: bool = True) -> None:
613+
"""
614+
If C is a function or a plain class with an __init__ function,
615+
and you want get_default_args(C) to work, then add
616+
`enable_get_default_args(C)` straight after the definition of C.
617+
This makes a dataclass corresponding to the default arguments of C
618+
and stores it in the same module as C.
619+
620+
Args:
621+
C: a function, or a class with an __init__ function. Must
622+
have types for all its defaulted args.
623+
overwrite: whether to allow calling this a second time on
624+
the same function.
625+
"""
626+
if not inspect.isfunction(C) and not inspect.isclass(C):
627+
raise ValueError(f"Unexpected {C}")
628+
629+
field_annotations = _field_annotations_for_default_args(C)
630+
594631
name = _dataclass_name_for_function(C)
595632
module = sys.modules[C.__module__]
596633
if hasattr(module, name):
@@ -767,7 +804,7 @@ def create_x_impl(self, enabled, args):
767804
768805
Also adds the following class members, unannotated so that dataclass
769806
ignores them.
770-
- _creation_functions: Tuple[str] of all the create_ functions,
807+
- _creation_functions: Tuple[str, ...] of all the create_ functions,
771808
including those from base classes (not the create_x_impl ones).
772809
- _known_implementations: Dict[str, Type] containing the classes which
773810
have been found from the registry.
@@ -945,7 +982,7 @@ def _get_type_to_process(type_) -> Optional[Tuple[Type, _ProcessType]]:
945982
return underlying, _ProcessType.OPTIONAL_CONFIGURABLE
946983

947984
if not isinstance(type_, type):
948-
# e.g. any other Union or Tuple
985+
# e.g. any other Union or Tuple. Or ClassVar.
949986
return
950987

951988
if issubclass(type_, ReplaceableBase) and ReplaceableBase in type_.__bases__:

Diff for: tests/implicitron/test_config.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def test_registry_entries(self):
168168
self.assertIn(Banana, all_fruit)
169169
self.assertIn(Pear, all_fruit)
170170
self.assertIn(LargePear, all_fruit)
171-
self.assertEqual(set(registry.get_all(Pear)), {LargePear})
171+
self.assertEqual(registry.get_all(Pear), [LargePear])
172172

173173
@registry.register
174174
class Apple(Fruit):
@@ -178,7 +178,7 @@ class Apple(Fruit):
178178
class CrabApple(Apple):
179179
pass
180180

181-
self.assertEqual(set(registry.get_all(Apple)), {CrabApple})
181+
self.assertEqual(registry.get_all(Apple), [CrabApple])
182182

183183
self.assertIs(registry.get(Fruit, "CrabApple"), CrabApple)
184184

@@ -601,6 +601,7 @@ def __init__(self, a: A = A.B1) -> None:
601601

602602
for C_ in [C, C_fn, C_cl]:
603603
base = get_default_args(C_)
604+
self.assertEqual(OmegaConf.to_yaml(base), "a: B1\n")
604605
self.assertEqual(base.a, A.B1)
605606
replaced = OmegaConf.merge(base, {"a": "B2"})
606607
self.assertEqual(replaced.a, A.B2)

0 commit comments

Comments
 (0)