Skip to content

Commit f1f873b

Browse files
committed
refactor(protobuf): support forward references in ProtoStruct
1 parent 2f4a835 commit f1f873b

File tree

1 file changed

+46
-9
lines changed

1 file changed

+46
-9
lines changed

lagrange/utils/binary/protobuf/models.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import inspect
2+
import importlib
23
from types import GenericAlias
3-
from typing import cast, Dict, List, Tuple, Type, TypeVar, Union, Generic, Any, Callable, Mapping, overload
4-
from typing_extensions import Optional, Self, TypeAlias, dataclass_transform, get_origin, get_args
4+
from typing import cast, Dict, List, Tuple, Type, TypeVar, Union, Generic, Any, Callable, Mapping, overload, ForwardRef
5+
from typing_extensions import Optional, Self, TypeAlias, dataclass_transform
56

67
from .coder import Proto, proto_decode, proto_encode
78

@@ -10,6 +11,9 @@
1011
T = TypeVar("T", str, list, dict, bytes, int, float, bool, "ProtoStruct")
1112
V = TypeVar("V")
1213
NT: TypeAlias = Dict[int, Union[_ProtoTypes, "NT"]]
14+
AMT: TypeAlias = Dict[str, Tuple[Type[_ProtoTypes], "ProtoField"]]
15+
DAMT: TypeAlias = Dict[str, "DelayAnnoType"]
16+
DelayAnnoType = Union[str, type(List[str])]
1317
NoneType = type(None)
1418

1519

@@ -82,11 +86,13 @@ def proto_field(
8286
@dataclass_transform(kw_only_default=True, field_specifiers=(proto_field,))
8387
class ProtoStruct:
8488
_anno_map: Dict[str, Tuple[Type[_ProtoTypes], ProtoField[Any]]]
89+
_delay_anno_map: Dict[str, DelayAnnoType]
8590
_proto_debug: bool
8691

8792
def __init__(self, *args, **kwargs):
8893
undefined_params: List[str] = []
8994
args = list(args)
95+
self._resolve_annotations(self)
9096
for name, (typ, field) in self._anno_map.items():
9197
if args:
9298
self._set_attr(name, typ, args.pop(0))
@@ -103,8 +109,8 @@ def __init__(self, *args, **kwargs):
103109
)
104110

105111
def __init_subclass__(cls, **kwargs):
106-
cls._anno_map = cls._get_annotations()
107112
cls._proto_debug = kwargs.pop("debug") if "debug" in kwargs else False
113+
cls._anno_map, cls._delay_anno_map = cls._get_annotations()
108114
super().__init_subclass__(**kwargs)
109115

110116
def __repr__(self) -> str:
@@ -127,8 +133,9 @@ def _set_attr(self, name: str, data_typ: Type[V], value: V) -> None:
127133
@classmethod
128134
def _get_annotations(
129135
cls,
130-
) -> Dict[str, Tuple[Type[_ProtoTypes], "ProtoField"]]: # Name: (ReturnType, ProtoField)
131-
annotations: Dict[str, Tuple[Type[_ProtoTypes], "ProtoField"]] = {}
136+
) -> Tuple[AMT, DAMT]: # Name: (ReturnType, ProtoField)
137+
annotations: AMT = {}
138+
delay_annotations: DAMT = {}
132139
for obj in reversed(inspect.getmro(cls)):
133140
if obj in (ProtoStruct, object): # base object, ignore
134141
continue
@@ -142,15 +149,35 @@ def _get_annotations(
142149
if not isinstance(field, ProtoField):
143150
raise TypeError("attribute '{name}' is not a ProtoField object")
144151

152+
_typ = typ
153+
annotations[name] = (_typ, field)
154+
if isinstance(typ, str):
155+
delay_annotations[name] = typ
145156
if hasattr(typ, "__origin__"):
146-
typ = typ.__origin__[typ.__args__[0]]
147-
annotations[name] = (typ, field)
148-
149-
return annotations
157+
typ = cast(GenericAlias, typ)
158+
_inner = typ.__args__[0]
159+
_typ = typ.__origin__[typ.__args__[0]]
160+
annotations[name] = (_typ, field)
161+
162+
if isinstance(_inner, type):
163+
continue
164+
if isinstance(_inner, GenericAlias) and isinstance(_inner.__args__[0], type):
165+
continue
166+
if isinstance(_inner, str):
167+
delay_annotations[name] = _typ.__origin__[_inner]
168+
if isinstance(_inner, ForwardRef):
169+
delay_annotations[name] = _inner.__forward_arg__
170+
if isinstance(_inner, GenericAlias):
171+
delay_annotations[name] = _typ
172+
173+
return annotations, delay_annotations
150174

151175
@classmethod
152176
def _get_field_mapping(cls) -> Dict[int, Tuple[str, Type[_ProtoTypes]]]: # Tag, (Name, Type)
153177
field_mapping: Dict[int, Tuple[str, Type[_ProtoTypes]]] = {}
178+
if cls._delay_anno_map:
179+
print(f"WARNING: '{cls.__name__}' has delay annotations: {cls._delay_anno_map}")
180+
cls._resolve_annotations(cls)
154181
for name, (typ, field) in cls._anno_map.items():
155182
field_mapping[field.tag] = (name, typ)
156183
return field_mapping
@@ -161,6 +188,16 @@ def _get_stored_mapping(self) -> Dict[str, NT]:
161188
stored_mapping[name] = getattr(self, name)
162189
return stored_mapping
163190

191+
@staticmethod
192+
def _resolve_annotations(arg: Union[Type["ProtoStruct"], "ProtoStruct"]) -> None:
193+
for k, v in arg._delay_anno_map.copy().items():
194+
module = importlib.import_module(arg.__module__)
195+
if hasattr(v, "__origin__"): # resolve GenericAlias, such as list[str]
196+
arg._anno_map[k] = (v.__origin__[module.__getattribute__(v.__args__[0])], arg._anno_map[k][1])
197+
else:
198+
arg._anno_map[k] = (module.__getattribute__(v), arg._anno_map[k][1])
199+
arg._delay_anno_map.pop(k)
200+
164201
def _encode(self, v: _ProtoTypes) -> NT:
165202
if isinstance(v, ProtoStruct):
166203
v = v.encode()

0 commit comments

Comments
 (0)