From f1f873bcdfdd5573b62fd02c8cf06e3785904d86 Mon Sep 17 00:00:00 2001 From: pk5ls20 Date: Wed, 14 Aug 2024 09:20:55 +0800 Subject: [PATCH 1/4] refactor(protobuf): support forward references in `ProtoStruct` --- lagrange/utils/binary/protobuf/models.py | 55 ++++++++++++++++++++---- 1 file changed, 46 insertions(+), 9 deletions(-) diff --git a/lagrange/utils/binary/protobuf/models.py b/lagrange/utils/binary/protobuf/models.py index 8bce841..48920e3 100644 --- a/lagrange/utils/binary/protobuf/models.py +++ b/lagrange/utils/binary/protobuf/models.py @@ -1,7 +1,8 @@ import inspect +import importlib from types import GenericAlias -from typing import cast, Dict, List, Tuple, Type, TypeVar, Union, Generic, Any, Callable, Mapping, overload -from typing_extensions import Optional, Self, TypeAlias, dataclass_transform, get_origin, get_args +from typing import cast, Dict, List, Tuple, Type, TypeVar, Union, Generic, Any, Callable, Mapping, overload, ForwardRef +from typing_extensions import Optional, Self, TypeAlias, dataclass_transform from .coder import Proto, proto_decode, proto_encode @@ -10,6 +11,9 @@ T = TypeVar("T", str, list, dict, bytes, int, float, bool, "ProtoStruct") V = TypeVar("V") NT: TypeAlias = Dict[int, Union[_ProtoTypes, "NT"]] +AMT: TypeAlias = Dict[str, Tuple[Type[_ProtoTypes], "ProtoField"]] +DAMT: TypeAlias = Dict[str, "DelayAnnoType"] +DelayAnnoType = Union[str, type(List[str])] NoneType = type(None) @@ -82,11 +86,13 @@ def proto_field( @dataclass_transform(kw_only_default=True, field_specifiers=(proto_field,)) class ProtoStruct: _anno_map: Dict[str, Tuple[Type[_ProtoTypes], ProtoField[Any]]] + _delay_anno_map: Dict[str, DelayAnnoType] _proto_debug: bool def __init__(self, *args, **kwargs): undefined_params: List[str] = [] args = list(args) + self._resolve_annotations(self) for name, (typ, field) in self._anno_map.items(): if args: self._set_attr(name, typ, args.pop(0)) @@ -103,8 +109,8 @@ def __init__(self, *args, **kwargs): ) def __init_subclass__(cls, **kwargs): - cls._anno_map = cls._get_annotations() cls._proto_debug = kwargs.pop("debug") if "debug" in kwargs else False + cls._anno_map, cls._delay_anno_map = cls._get_annotations() super().__init_subclass__(**kwargs) def __repr__(self) -> str: @@ -127,8 +133,9 @@ def _set_attr(self, name: str, data_typ: Type[V], value: V) -> None: @classmethod def _get_annotations( cls, - ) -> Dict[str, Tuple[Type[_ProtoTypes], "ProtoField"]]: # Name: (ReturnType, ProtoField) - annotations: Dict[str, Tuple[Type[_ProtoTypes], "ProtoField"]] = {} + ) -> Tuple[AMT, DAMT]: # Name: (ReturnType, ProtoField) + annotations: AMT = {} + delay_annotations: DAMT = {} for obj in reversed(inspect.getmro(cls)): if obj in (ProtoStruct, object): # base object, ignore continue @@ -142,15 +149,35 @@ def _get_annotations( if not isinstance(field, ProtoField): raise TypeError("attribute '{name}' is not a ProtoField object") + _typ = typ + annotations[name] = (_typ, field) + if isinstance(typ, str): + delay_annotations[name] = typ if hasattr(typ, "__origin__"): - typ = typ.__origin__[typ.__args__[0]] - annotations[name] = (typ, field) - - return annotations + typ = cast(GenericAlias, typ) + _inner = typ.__args__[0] + _typ = typ.__origin__[typ.__args__[0]] + annotations[name] = (_typ, field) + + if isinstance(_inner, type): + continue + if isinstance(_inner, GenericAlias) and isinstance(_inner.__args__[0], type): + continue + if isinstance(_inner, str): + delay_annotations[name] = _typ.__origin__[_inner] + if isinstance(_inner, ForwardRef): + delay_annotations[name] = _inner.__forward_arg__ + if isinstance(_inner, GenericAlias): + delay_annotations[name] = _typ + + return annotations, delay_annotations @classmethod def _get_field_mapping(cls) -> Dict[int, Tuple[str, Type[_ProtoTypes]]]: # Tag, (Name, Type) field_mapping: Dict[int, Tuple[str, Type[_ProtoTypes]]] = {} + if cls._delay_anno_map: + print(f"WARNING: '{cls.__name__}' has delay annotations: {cls._delay_anno_map}") + cls._resolve_annotations(cls) for name, (typ, field) in cls._anno_map.items(): field_mapping[field.tag] = (name, typ) return field_mapping @@ -161,6 +188,16 @@ def _get_stored_mapping(self) -> Dict[str, NT]: stored_mapping[name] = getattr(self, name) return stored_mapping + @staticmethod + def _resolve_annotations(arg: Union[Type["ProtoStruct"], "ProtoStruct"]) -> None: + for k, v in arg._delay_anno_map.copy().items(): + module = importlib.import_module(arg.__module__) + if hasattr(v, "__origin__"): # resolve GenericAlias, such as list[str] + arg._anno_map[k] = (v.__origin__[module.__getattribute__(v.__args__[0])], arg._anno_map[k][1]) + else: + arg._anno_map[k] = (module.__getattribute__(v), arg._anno_map[k][1]) + arg._delay_anno_map.pop(k) + def _encode(self, v: _ProtoTypes) -> NT: if isinstance(v, ProtoStruct): v = v.encode() From 74d3ccf007cf91b7ddc7a16bc057e0e4b8593754 Mon Sep 17 00:00:00 2001 From: pk5ls20 Date: Wed, 14 Aug 2024 09:31:45 +0800 Subject: [PATCH 2/4] chore(protobuf): remove unused debug code --- lagrange/utils/binary/protobuf/models.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lagrange/utils/binary/protobuf/models.py b/lagrange/utils/binary/protobuf/models.py index 48920e3..1ad9fb1 100644 --- a/lagrange/utils/binary/protobuf/models.py +++ b/lagrange/utils/binary/protobuf/models.py @@ -176,7 +176,6 @@ def _get_annotations( def _get_field_mapping(cls) -> Dict[int, Tuple[str, Type[_ProtoTypes]]]: # Tag, (Name, Type) field_mapping: Dict[int, Tuple[str, Type[_ProtoTypes]]] = {} if cls._delay_anno_map: - print(f"WARNING: '{cls.__name__}' has delay annotations: {cls._delay_anno_map}") cls._resolve_annotations(cls) for name, (typ, field) in cls._anno_map.items(): field_mapping[field.tag] = (name, typ) From e1e24ff819c17c4c0cff62fec0baf5f9b1a97c8e Mon Sep 17 00:00:00 2001 From: pk5ls20 Date: Wed, 14 Aug 2024 19:44:10 +0800 Subject: [PATCH 3/4] feat(protobuf): enhanced type annotations (mypy) --- lagrange/utils/binary/protobuf/models.py | 32 +++++++++++------------- 1 file changed, 14 insertions(+), 18 deletions(-) diff --git a/lagrange/utils/binary/protobuf/models.py b/lagrange/utils/binary/protobuf/models.py index 1ad9fb1..ac47b93 100644 --- a/lagrange/utils/binary/protobuf/models.py +++ b/lagrange/utils/binary/protobuf/models.py @@ -6,14 +6,15 @@ from .coder import Proto, proto_decode, proto_encode -_ProtoTypes = Union[str, list, dict, bytes, int, float, bool, "ProtoStruct"] +_ProtoBasicTypes = Union[str, list, dict, bytes, int, float, bool] +_ProtoTypes = Union[_ProtoBasicTypes, "ProtoStruct"] T = TypeVar("T", str, list, dict, bytes, int, float, bool, "ProtoStruct") V = TypeVar("V") NT: TypeAlias = Dict[int, Union[_ProtoTypes, "NT"]] AMT: TypeAlias = Dict[str, Tuple[Type[_ProtoTypes], "ProtoField"]] DAMT: TypeAlias = Dict[str, "DelayAnnoType"] -DelayAnnoType = Union[str, type(List[str])] +DelayAnnoType = Union[str, List[str]] NoneType = type(None) @@ -21,8 +22,8 @@ class ProtoField(Generic[T]): def __init__(self, tag: int, default: T): if tag <= 0: raise ValueError("Tag must be a positive integer") - self._tag = tag - self._default = default + self._tag: int = tag + self._default: T = default @property def tag(self) -> int: @@ -91,11 +92,11 @@ class ProtoStruct: def __init__(self, *args, **kwargs): undefined_params: List[str] = [] - args = list(args) + args_list = list(args) self._resolve_annotations(self) for name, (typ, field) in self._anno_map.items(): if args: - self._set_attr(name, typ, args.pop(0)) + self._set_attr(name, typ, args_list.pop(0)) elif name in kwargs: self._set_attr(name, typ, kwargs.pop(name)) else: @@ -104,9 +105,7 @@ def __init__(self, *args, **kwargs): else: undefined_params.append(name) if undefined_params: - raise AttributeError( - "Undefined parameters in '{}': {}".format(self, undefined_params) - ) + raise AttributeError(f"Undefined parameters in {self}: {undefined_params}") def __init_subclass__(cls, **kwargs): cls._proto_debug = kwargs.pop("debug") if "debug" in kwargs else False @@ -125,9 +124,7 @@ def _set_attr(self, name: str, data_typ: Type[V], value: V) -> None: if isinstance(data_typ, GenericAlias): # force ignore pass elif not isinstance(value, data_typ) and value is not None: - raise TypeError( - "'{}' is not a instance of type '{}'".format(value, data_typ) - ) + raise TypeError("{value} is not a instance of type {data_typ}") setattr(self, name, value) @classmethod @@ -191,13 +188,13 @@ def _get_stored_mapping(self) -> Dict[str, NT]: def _resolve_annotations(arg: Union[Type["ProtoStruct"], "ProtoStruct"]) -> None: for k, v in arg._delay_anno_map.copy().items(): module = importlib.import_module(arg.__module__) - if hasattr(v, "__origin__"): # resolve GenericAlias, such as list[str] - arg._anno_map[k] = (v.__origin__[module.__getattribute__(v.__args__[0])], arg._anno_map[k][1]) - else: - arg._anno_map[k] = (module.__getattribute__(v), arg._anno_map[k][1]) + if isinstance(v, GenericAlias): # resolve GenericAlias, such as list[str] + arg._anno_map[k] = (v.__origin__[getattr(module, v.__args__[0])], arg._anno_map[k][1]) + if isinstance(v, str): + arg._anno_map[k] = (getattr(module, v), arg._anno_map[k][1]) arg._delay_anno_map.pop(k) - def _encode(self, v: _ProtoTypes) -> NT: + def _encode(self, v: _ProtoTypes) -> _ProtoBasicTypes: if isinstance(v, ProtoStruct): v = v.encode() return v @@ -266,4 +263,3 @@ def decode(cls, data: bytes) -> Self: return cls(**kwargs) - From 086b60f55db57a7573f9131ff781795384192c64 Mon Sep 17 00:00:00 2001 From: pk5ls20 Date: Thu, 15 Aug 2024 14:49:36 +0800 Subject: [PATCH 4/4] chore(protobuf): use `ForwardRef._evaluate` to resolve forward references --- lagrange/utils/binary/protobuf/models.py | 49 +++++++++++++++--------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/lagrange/utils/binary/protobuf/models.py b/lagrange/utils/binary/protobuf/models.py index ac47b93..4e02e29 100644 --- a/lagrange/utils/binary/protobuf/models.py +++ b/lagrange/utils/binary/protobuf/models.py @@ -11,10 +11,6 @@ T = TypeVar("T", str, list, dict, bytes, int, float, bool, "ProtoStruct") V = TypeVar("V") -NT: TypeAlias = Dict[int, Union[_ProtoTypes, "NT"]] -AMT: TypeAlias = Dict[str, Tuple[Type[_ProtoTypes], "ProtoField"]] -DAMT: TypeAlias = Dict[str, "DelayAnnoType"] -DelayAnnoType = Union[str, List[str]] NoneType = type(None) @@ -84,10 +80,18 @@ def proto_field( return ProtoField(tag, default) +NT: TypeAlias = Dict[int, Union[_ProtoTypes, "NT"]] +AMT: TypeAlias = Dict[str, Tuple[Type[_ProtoTypes], "ProtoField"]] +PS = TypeVar("PS", bound=ProtoField) +DAMT: Union[Type[list[ForwardRef]], ForwardRef] +DAMDT: TypeAlias = Dict[str, Union[Type[list[ForwardRef]], ForwardRef]] + + +# noinspection PyProtectedMember @dataclass_transform(kw_only_default=True, field_specifiers=(proto_field,)) class ProtoStruct: _anno_map: Dict[str, Tuple[Type[_ProtoTypes], ProtoField[Any]]] - _delay_anno_map: Dict[str, DelayAnnoType] + _delay_anno_map: DAMDT _proto_debug: bool def __init__(self, *args, **kwargs): @@ -128,11 +132,15 @@ def _set_attr(self, name: str, data_typ: Type[V], value: V) -> None: setattr(self, name, value) @classmethod - def _get_annotations( - cls, - ) -> Tuple[AMT, DAMT]: # Name: (ReturnType, ProtoField) + def _handle_inner_generic(cls, inner: GenericAlias) -> GenericAlias: + if inner.__origin__ is list: + return GenericAlias(list, ForwardRef(inner.__args__[0])) + raise NotImplementedError(f"unknown inner generic type '{inner}'") + + @classmethod + def _get_annotations(cls) -> Tuple[AMT, DAMDT]: # Name: (ReturnType, ProtoField) annotations: AMT = {} - delay_annotations: DAMT = {} + delay_annotations: DAMDT = {} for obj in reversed(inspect.getmro(cls)): if obj in (ProtoStruct, object): # base object, ignore continue @@ -149,7 +157,7 @@ def _get_annotations( _typ = typ annotations[name] = (_typ, field) if isinstance(typ, str): - delay_annotations[name] = typ + delay_annotations[name] = ForwardRef(typ) if hasattr(typ, "__origin__"): typ = cast(GenericAlias, typ) _inner = typ.__args__[0] @@ -161,11 +169,11 @@ def _get_annotations( if isinstance(_inner, GenericAlias) and isinstance(_inner.__args__[0], type): continue if isinstance(_inner, str): - delay_annotations[name] = _typ.__origin__[_inner] + delay_annotations[name] = _typ.__origin__[ForwardRef(_inner)] if isinstance(_inner, ForwardRef): - delay_annotations[name] = _inner.__forward_arg__ + delay_annotations[name] = _inner if isinstance(_inner, GenericAlias): - delay_annotations[name] = _typ + delay_annotations[name] = cast(Type[list[ForwardRef]], cls._handle_inner_generic(_inner)) return annotations, delay_annotations @@ -186,12 +194,17 @@ def _get_stored_mapping(self) -> Dict[str, NT]: @staticmethod def _resolve_annotations(arg: Union[Type["ProtoStruct"], "ProtoStruct"]) -> None: + if not arg._delay_anno_map: + return + local = importlib.import_module(arg.__module__).__dict__ for k, v in arg._delay_anno_map.copy().items(): - module = importlib.import_module(arg.__module__) - if isinstance(v, GenericAlias): # resolve GenericAlias, such as list[str] - arg._anno_map[k] = (v.__origin__[getattr(module, v.__args__[0])], arg._anno_map[k][1]) - if isinstance(v, str): - arg._anno_map[k] = (getattr(module, v), arg._anno_map[k][1]) + casted_forward: Type["ProtoStruct"] + if isinstance(v, GenericAlias): + casted_forward = v.__origin__[v.__args__[0]._evaluate(globals(), local, recursive_guard=frozenset())] + arg._anno_map[k] = (casted_forward, arg._anno_map[k][1]) + if isinstance(v, ForwardRef): + casted_forward = v._evaluate(globals(), local, recursive_guard=frozenset()) # type: ignore + arg._anno_map[k] = (casted_forward, arg._anno_map[k][1]) arg._delay_anno_map.pop(k) def _encode(self, v: _ProtoTypes) -> _ProtoBasicTypes: