11import inspect
2+ import importlib
23from types import GenericAlias
3- from typing import cast , Dict , List , Tuple , Type , TypeVar , Union , Generic , Any , Callable , Mapping , overload
4- from typing_extensions import Optional , Self , TypeAlias , dataclass_transform , get_origin , get_args
4+ from typing import cast , Dict , List , Tuple , Type , TypeVar , Union , Generic , Any , Callable , Mapping , overload , ForwardRef
5+ from typing_extensions import Optional , Self , TypeAlias , dataclass_transform
56
67from .coder import Proto , proto_decode , proto_encode
78
1011T = TypeVar ("T" , str , list , dict , bytes , int , float , bool , "ProtoStruct" )
1112V = TypeVar ("V" )
1213NT : TypeAlias = Dict [int , Union [_ProtoTypes , "NT" ]]
14+ AMT : TypeAlias = Dict [str , Tuple [Type [_ProtoTypes ], "ProtoField" ]]
15+ DAMT : TypeAlias = Dict [str , "DelayAnnoType" ]
16+ DelayAnnoType = Union [str , type (List [str ])]
1317NoneType = type (None )
1418
1519
@@ -82,11 +86,13 @@ def proto_field(
8286@dataclass_transform (kw_only_default = True , field_specifiers = (proto_field ,))
8387class ProtoStruct :
8488 _anno_map : Dict [str , Tuple [Type [_ProtoTypes ], ProtoField [Any ]]]
89+ _delay_anno_map : Dict [str , DelayAnnoType ]
8590 _proto_debug : bool
8691
8792 def __init__ (self , * args , ** kwargs ):
8893 undefined_params : List [str ] = []
8994 args = list (args )
95+ self ._resolve_annotations (self )
9096 for name , (typ , field ) in self ._anno_map .items ():
9197 if args :
9298 self ._set_attr (name , typ , args .pop (0 ))
@@ -103,8 +109,8 @@ def __init__(self, *args, **kwargs):
103109 )
104110
105111 def __init_subclass__ (cls , ** kwargs ):
106- cls ._anno_map = cls ._get_annotations ()
107112 cls ._proto_debug = kwargs .pop ("debug" ) if "debug" in kwargs else False
113+ cls ._anno_map , cls ._delay_anno_map = cls ._get_annotations ()
108114 super ().__init_subclass__ (** kwargs )
109115
110116 def __repr__ (self ) -> str :
@@ -127,8 +133,9 @@ def _set_attr(self, name: str, data_typ: Type[V], value: V) -> None:
127133 @classmethod
128134 def _get_annotations (
129135 cls ,
130- ) -> Dict [str , Tuple [Type [_ProtoTypes ], "ProtoField" ]]: # Name: (ReturnType, ProtoField)
131- annotations : Dict [str , Tuple [Type [_ProtoTypes ], "ProtoField" ]] = {}
136+ ) -> Tuple [AMT , DAMT ]: # Name: (ReturnType, ProtoField)
137+ annotations : AMT = {}
138+ delay_annotations : DAMT = {}
132139 for obj in reversed (inspect .getmro (cls )):
133140 if obj in (ProtoStruct , object ): # base object, ignore
134141 continue
@@ -142,15 +149,35 @@ def _get_annotations(
142149 if not isinstance (field , ProtoField ):
143150 raise TypeError ("attribute '{name}' is not a ProtoField object" )
144151
152+ _typ = typ
153+ annotations [name ] = (_typ , field )
154+ if isinstance (typ , str ):
155+ delay_annotations [name ] = typ
145156 if hasattr (typ , "__origin__" ):
146- typ = typ .__origin__ [typ .__args__ [0 ]]
147- annotations [name ] = (typ , field )
148-
149- return annotations
157+ typ = cast (GenericAlias , typ )
158+ _inner = typ .__args__ [0 ]
159+ _typ = typ .__origin__ [typ .__args__ [0 ]]
160+ annotations [name ] = (_typ , field )
161+
162+ if isinstance (_inner , type ):
163+ continue
164+ if isinstance (_inner , GenericAlias ) and isinstance (_inner .__args__ [0 ], type ):
165+ continue
166+ if isinstance (_inner , str ):
167+ delay_annotations [name ] = _typ .__origin__ [_inner ]
168+ if isinstance (_inner , ForwardRef ):
169+ delay_annotations [name ] = _inner .__forward_arg__
170+ if isinstance (_inner , GenericAlias ):
171+ delay_annotations [name ] = _typ
172+
173+ return annotations , delay_annotations
150174
151175 @classmethod
152176 def _get_field_mapping (cls ) -> Dict [int , Tuple [str , Type [_ProtoTypes ]]]: # Tag, (Name, Type)
153177 field_mapping : Dict [int , Tuple [str , Type [_ProtoTypes ]]] = {}
178+ if cls ._delay_anno_map :
179+ print (f"WARNING: '{ cls .__name__ } ' has delay annotations: { cls ._delay_anno_map } " )
180+ cls ._resolve_annotations (cls )
154181 for name , (typ , field ) in cls ._anno_map .items ():
155182 field_mapping [field .tag ] = (name , typ )
156183 return field_mapping
@@ -161,6 +188,16 @@ def _get_stored_mapping(self) -> Dict[str, NT]:
161188 stored_mapping [name ] = getattr (self , name )
162189 return stored_mapping
163190
191+ @staticmethod
192+ def _resolve_annotations (arg : Union [Type ["ProtoStruct" ], "ProtoStruct" ]) -> None :
193+ for k , v in arg ._delay_anno_map .copy ().items ():
194+ module = importlib .import_module (arg .__module__ )
195+ if hasattr (v , "__origin__" ): # resolve GenericAlias, such as list[str]
196+ arg ._anno_map [k ] = (v .__origin__ [module .__getattribute__ (v .__args__ [0 ])], arg ._anno_map [k ][1 ])
197+ else :
198+ arg ._anno_map [k ] = (module .__getattribute__ (v ), arg ._anno_map [k ][1 ])
199+ arg ._delay_anno_map .pop (k )
200+
164201 def _encode (self , v : _ProtoTypes ) -> NT :
165202 if isinstance (v , ProtoStruct ):
166203 v = v .encode ()
0 commit comments