@@ -287,6 +287,10 @@ class type
287
287
raise ValueError (
288
288
f"Cannot look up { base_class_wanted } . Cannot tell what it is."
289
289
)
290
+ if not isinstance (name , str ):
291
+ raise ValueError (
292
+ f"Cannot look up a { type (name )} in the registry. Got { name } ."
293
+ )
290
294
result = self ._mapping [base_class ].get (name )
291
295
if result is None :
292
296
raise ValueError (f"{ name } has not been registered." )
@@ -446,6 +450,11 @@ def create_pluggable(self, type_name, args):
446
450
setattr (self , name , None )
447
451
return
448
452
453
+ if not isinstance (type_name , str ):
454
+ raise ValueError (
455
+ f"A { type (type_name )} was received as the type of { name } ."
456
+ + f" Perhaps this is from { name } { TYPE_SUFFIX } ?"
457
+ )
449
458
chosen_class = registry .get (type_ , type_name )
450
459
if self ._known_implementations .get (type_name , chosen_class ) is not chosen_class :
451
460
# If this warning is raised, it means that a new definition of
@@ -514,7 +523,10 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
514
523
# because in practice get_default_args_field is used for
515
524
# separate types than the outer type.
516
525
517
- out : DictConfig = OmegaConf .structured (C )
526
+ try :
527
+ out : DictConfig = OmegaConf .structured (C )
528
+ except Exception as e :
529
+ raise ValueError (f"OmegaConf.structured({ C } ) failed" ) from e
518
530
exclude = getattr (C , "_processed_members" , ())
519
531
with open_dict (out ):
520
532
for field in exclude :
@@ -534,7 +546,11 @@ def get_default_args(C, *, _do_not_process: Tuple[type, ...] = ()) -> DictConfig
534
546
f"Cannot get args for { C } . Was enable_get_default_args forgotten?"
535
547
)
536
548
537
- return OmegaConf .structured (dataclass )
549
+ try :
550
+ out : DictConfig = OmegaConf .structured (dataclass )
551
+ except Exception as e :
552
+ raise ValueError (f"OmegaConf.structured failed for { dataclass_name } " ) from e
553
+ return out
538
554
539
555
540
556
def _dataclass_name_for_function (C : Any ) -> str :
@@ -546,22 +562,21 @@ def _dataclass_name_for_function(C: Any) -> str:
546
562
return name
547
563
548
564
549
- def enable_get_default_args (C : Any , * , overwrite : bool = True ) -> None :
565
+ def _field_annotations_for_default_args (
566
+ C : Any ,
567
+ ) -> List [Tuple [str , Any , dataclasses .Field ]]:
550
568
"""
551
569
If C is a function or a plain class with an __init__ function,
552
- and you want get_default_args(C) to work, then add
553
- `enable_get_default_args(C)` straight after the definition of C.
554
- This makes a dataclass corresponding to the default arguments of C
555
- and stores it in the same module as C.
570
+ return the fields which `enable_get_default_args(C)` will need
571
+ to make a dataclass with.
556
572
557
573
Args:
558
574
C: a function, or a class with an __init__ function. Must
559
575
have types for all its defaulted args.
560
- overwrite: whether to allow calling this a second time on
561
- the same function.
576
+
577
+ Returns:
578
+ a list of fields for a dataclass.
562
579
"""
563
- if not inspect .isfunction (C ) and not inspect .isclass (C ):
564
- raise ValueError (f"Unexpected { C } " )
565
580
566
581
field_annotations = []
567
582
for pname , defval in _params_iter (C ):
@@ -572,8 +587,8 @@ def enable_get_default_args(C: Any, *, overwrite: bool = True) -> None:
572
587
573
588
if defval .annotation == inspect ._empty :
574
589
raise ValueError (
575
- "All arguments of the input callable have to be typed. "
576
- + f" Argument '{ pname } ' does not have a type annotation."
590
+ "All arguments of the input to enable_get_default_args have to"
591
+ f" be typed. Argument '{ pname } ' does not have a type annotation."
577
592
)
578
593
579
594
_ , annotation = _resolve_optional (defval .annotation )
@@ -591,6 +606,28 @@ def enable_get_default_args(C: Any, *, overwrite: bool = True) -> None:
591
606
field_ = dataclasses .field (default = default )
592
607
field_annotations .append ((pname , defval .annotation , field_ ))
593
608
609
+ return field_annotations
610
+
611
+
612
+ def enable_get_default_args (C : Any , * , overwrite : bool = True ) -> None :
613
+ """
614
+ If C is a function or a plain class with an __init__ function,
615
+ and you want get_default_args(C) to work, then add
616
+ `enable_get_default_args(C)` straight after the definition of C.
617
+ This makes a dataclass corresponding to the default arguments of C
618
+ and stores it in the same module as C.
619
+
620
+ Args:
621
+ C: a function, or a class with an __init__ function. Must
622
+ have types for all its defaulted args.
623
+ overwrite: whether to allow calling this a second time on
624
+ the same function.
625
+ """
626
+ if not inspect .isfunction (C ) and not inspect .isclass (C ):
627
+ raise ValueError (f"Unexpected { C } " )
628
+
629
+ field_annotations = _field_annotations_for_default_args (C )
630
+
594
631
name = _dataclass_name_for_function (C )
595
632
module = sys .modules [C .__module__ ]
596
633
if hasattr (module , name ):
@@ -767,7 +804,7 @@ def create_x_impl(self, enabled, args):
767
804
768
805
Also adds the following class members, unannotated so that dataclass
769
806
ignores them.
770
- - _creation_functions: Tuple[str] of all the create_ functions,
807
+ - _creation_functions: Tuple[str, ... ] of all the create_ functions,
771
808
including those from base classes (not the create_x_impl ones).
772
809
- _known_implementations: Dict[str, Type] containing the classes which
773
810
have been found from the registry.
@@ -945,7 +982,7 @@ def _get_type_to_process(type_) -> Optional[Tuple[Type, _ProcessType]]:
945
982
return underlying , _ProcessType .OPTIONAL_CONFIGURABLE
946
983
947
984
if not isinstance (type_ , type ):
948
- # e.g. any other Union or Tuple
985
+ # e.g. any other Union or Tuple. Or ClassVar.
949
986
return
950
987
951
988
if issubclass (type_ , ReplaceableBase ) and ReplaceableBase in type_ .__bases__ :
0 commit comments