From 66fa1747e5aaeac48afeba0b50c41a90f4998633 Mon Sep 17 00:00:00 2001 From: RF-Tar-Railt Date: Mon, 23 Sep 2024 21:27:37 +0800 Subject: [PATCH 1/5] feat: new ProtoStruct --- lagrange/pb/highway/rsp.py | 4 +- lagrange/pb/message/rich_text/elems.py | 2 +- lagrange/pb/service/friend.py | 2 +- lagrange/pb/service/group.py | 6 +- lagrange/utils/binary/protobuf/models.py | 196 ++++++++++++++--------- 5 files changed, 123 insertions(+), 87 deletions(-) diff --git a/lagrange/pb/highway/rsp.py b/lagrange/pb/highway/rsp.py index 07e79c2..11dff78 100644 --- a/lagrange/pb/highway/rsp.py +++ b/lagrange/pb/highway/rsp.py @@ -41,8 +41,8 @@ class DownloadInfo(ProtoStruct): domain: str = proto_field(1) url_path: str = proto_field(2) https_port: Optional[int] = proto_field(3, default=None) - v4_addrs: list[IPv4] = proto_field(4, default=[]) - v6_addrs: list[IPv6] = proto_field(5, default=[]) + v4_addrs: list[IPv4] = proto_field(4, default_factory=list) + v6_addrs: list[IPv6] = proto_field(5, default_factory=list) pic_info: Optional[PicUrlExtInfo] = proto_field(6, default=None) video_info: Optional[VideoExtInfo] = proto_field(7, default=None) diff --git a/lagrange/pb/message/rich_text/elems.py b/lagrange/pb/message/rich_text/elems.py index 10be308..019b793 100644 --- a/lagrange/pb/message/rich_text/elems.py +++ b/lagrange/pb/message/rich_text/elems.py @@ -110,7 +110,7 @@ class SrcMsg(ProtoStruct): seq: int = proto_field(1) uin: int = proto_field(2, default=0) timestamp: int = proto_field(3) - elems: list[dict] = proto_field(5, default=[{}]) + elems: list[dict] = proto_field(5, default_factory=lambda: [{}]) pb_reserved: Optional[SrcMsgArgs] = proto_field(8, default=None) to_uin: int = proto_field(10, default=0) diff --git a/lagrange/pb/service/friend.py b/lagrange/pb/service/friend.py index 28e0642..8a198c5 100644 --- a/lagrange/pb/service/friend.py +++ b/lagrange/pb/service/friend.py @@ -45,7 +45,7 @@ class PBGetFriendListRequest(ProtoStruct): f7: int = proto_field(7, default=2147483647) # MaxValue body: list[GetFriendBody] = proto_field( 10001, - default=[ + default_factory=lambda: [ GetFriendBody(type=1, f2=GetFriendNumbers(f1=[103, 102, 20002, 27394])), GetFriendBody(type=4, f2=GetFriendNumbers(f1=[100, 101, 102])), ], diff --git a/lagrange/pb/service/group.py b/lagrange/pb/service/group.py index 4696834..a2fe230 100644 --- a/lagrange/pb/service/group.py +++ b/lagrange/pb/service/group.py @@ -396,7 +396,7 @@ class GrpInfo(ProtoStruct): class GetGrpListResponse(ProtoStruct): - grp_list: list[GrpInfo] = proto_field(2, default=[]) + grp_list: list[GrpInfo] = proto_field(2, default_factory=list) class PBGetInfoFromUidReq(ProtoStruct): @@ -425,8 +425,8 @@ def to_str(self) -> str: class GetInfoRspField(ProtoStruct, debug=True): - int_t: list[GetInfoRspF1] = proto_field(1, default=[]) - str_t: list[GetInfoRspF2] = proto_field(2, default=[]) + int_t: list[GetInfoRspF1] = proto_field(1, default_factory=list) + str_t: list[GetInfoRspF2] = proto_field(2, default_factory=list) class GetInfoRspBody(ProtoStruct): diff --git a/lagrange/utils/binary/protobuf/models.py b/lagrange/utils/binary/protobuf/models.py index 3075a0e..d33becc 100644 --- a/lagrange/utils/binary/protobuf/models.py +++ b/lagrange/utils/binary/protobuf/models.py @@ -1,10 +1,10 @@ -import inspect +import sys from dataclasses import MISSING from types import GenericAlias -from typing import cast, TypeVar, Union, Any, Callable, overload +from typing import cast, TypeVar, Union, Any, Callable, overload, get_origin, ForwardRef, get_args from collections.abc import Mapping from typing_extensions import Self, TypeAlias, dataclass_transform -from typing import Optional +from typing import Optional, ClassVar from .coder import Proto, proto_decode, proto_encode @@ -17,16 +17,19 @@ class ProtoField: + name: str + type: Any + def __init__(self, tag: int, default: Any, default_factory: Any): if tag <= 0: raise ValueError("Tag must be a positive integer") - self._tag = tag + self.tag = tag self._default = default self._default_factory = default_factory - @property - def tag(self) -> int: - return self._tag + def ensure_annotation(self, name: str, type_: Any) -> None: + self.name = name + self.type = type_ def get_default(self) -> Any: if self._default is not MISSING: @@ -35,6 +38,12 @@ def get_default(self) -> Any: return self._default_factory() return MISSING + @property + def type_without_optional(self) -> Any: + if get_origin(self.type) is Union: + return get_args(self.type)[0] + return self.type + @overload # `default` and `default_factory` are optional and mutually exclusive. def proto_field( @@ -89,85 +98,116 @@ def proto_field( @dataclass_transform(kw_only_default=True, field_specifiers=(proto_field,)) class ProtoStruct: - _anno_map: dict[str, tuple[type[_ProtoTypes], ProtoField]] - _proto_debug: bool - - def __init__(self, *args, **kwargs): + __fields__: ClassVar[dict[str, ProtoField]] + __proto_debug__: ClassVar[bool] + __proto_evaluated__: ClassVar[bool] = False + + def check_type(self, value: Any, typ: Any) -> bool: + if typ is Any: + return True + if typ is list: + return isinstance(value, list) + if typ is dict: + return isinstance(value, dict) + if isinstance(typ, GenericAlias): + if get_origin(typ) is list: + return all(self.check_type(v, get_args(typ)[0]) for v in value) + if get_origin(typ) is dict: + return all(self.check_type(k, get_args(typ)[0]) and self.check_type(v, get_args(typ)[1]) for k, v in value.items()) + return False + if get_origin(typ) is Union: # Should Only be Optional + return self.check_type(value, get_args(typ)[0]) if value is not None else True + if isinstance(value, typ): + return True + return False # or True if value is None else False + + def _evaluate(self): + for base in reversed(self.__class__.__mro__): + if base in (ProtoStruct, object): + continue + if getattr(base, '__proto_evaluated__', False): + continue + base_globals = getattr(sys.modules.get(base.__module__, None), '__dict__', {}) + base_locals = dict(vars(base)) + base_globals, base_locals = base_locals, base_globals + for field in base.__fields__.values(): + if isinstance(field.type, str): + field.type = ForwardRef(field.type, is_argument=False, is_class=True)._evaluate( + base_globals, base_locals, recursive_guard=frozenset() + ) + base.__proto_evaluated__ = True + + def __init__(self, **kwargs): undefined_params: list[str] = [] - args = list(args) - for name, (typ, field) in self._anno_map.items(): - if args: - self._set_attr(name, typ, args.pop(0)) - elif name in kwargs: - self._set_attr(name, typ, kwargs.pop(name)) + self._evaluate() + for name, field in self.__fields__.items(): + if name in kwargs: + value = kwargs.pop(name) + if not self.check_type(value, field.type): + raise TypeError( + f"'{value}' is not a instance of type '{field.type}'" + ) + setattr(self, name, value) else: - if field.get_default() is not MISSING: - self._set_attr(name, typ, field.get_default()) + if (de := field.get_default()) is not MISSING: + setattr(self, name, de) else: undefined_params.append(name) if undefined_params: raise AttributeError( f"Undefined parameters in '{self}': {undefined_params}" ) + super().__init__(**kwargs) + + @classmethod + def _process_field(cls): + fields = {} + cls_annotations = cls.__dict__.get('__annotations__', {}) + cls_fields: list[ProtoField] = [] + for name, typ in cls_annotations.items(): + field = getattr(cls, name, MISSING) + if field is MISSING: + raise TypeError(f'{name!r} should define its proto_field!') + field.ensure_annotation(name, typ) + cls_fields.append(field) + + for f in cls_fields: + fields[f.name] = f + if f._default is MISSING: + delattr(cls, f.name) + + for name, value in cls.__dict__.items(): + if isinstance(value, ProtoField) and not name in cls_annotations: + raise TypeError(f'{name!r} is a proto_field but has no type annotation') + + cls.__fields__ = fields def __init_subclass__(cls, **kwargs): - cls._anno_map = cls._get_annotations() - cls._proto_debug = kwargs.pop("debug") if "debug" in kwargs else False + cls.__proto_debug__ = kwargs.pop("debug") if "debug" in kwargs else False + cls._process_field() super().__init_subclass__(**kwargs) def __repr__(self) -> str: attrs = "" - for k, v in self._get_stored_mapping().items(): - attrs += f"{k}={v}, " - return f"{self.__class__.__name__}({attrs[:-2]})" - - def _set_attr(self, name: str, data_typ: type[V], value: V) -> None: - # if get_origin(data_typ) is Union: - # data_typ = (typ for typ in get_args(data_typ) if typ is not NoneType) # type: ignore - if isinstance(data_typ, GenericAlias): # force ignore - pass - elif not isinstance(value, data_typ) and value is not None: - raise TypeError( - f"'{value}' is not a instance of type '{data_typ}'" - ) - setattr(self, name, value) - - @classmethod - def _get_annotations( - cls, - ) -> dict[str, tuple[type[_ProtoTypes], "ProtoField"]]: # Name: (ReturnType, ProtoField) - annotations: dict[str, tuple[type[_ProtoTypes], "ProtoField"]] = {} - for obj in reversed(inspect.getmro(cls)): - if obj in (ProtoStruct, object): # base object, ignore + for k, v in vars(self).items(): + if k.startswith("_"): continue - for name, typ in obj.__annotations__.items(): - if name[0] == "_": # ignore internal var - continue - if not hasattr(obj, name): - raise AttributeError(f"attribute ‘{name}' not defined") - field = getattr(obj, name) # type: ProtoField + attrs += f"{k}={v!r}, " + return f"{self.__class__.__name__}({attrs})" - if not isinstance(field, ProtoField): - raise TypeError("attribute '{name}' is not a ProtoField object") - if hasattr(typ, "__origin__"): - typ = typ.__origin__[typ.__args__[0]] - annotations[name] = (typ, field) + # @classmethod + # def _get_field_mapping(cls) -> dict[int, tuple[str, type[_ProtoTypes]]]: # Tag, (Name, Type) + # field_mapping: dict[int, tuple[str, type[_ProtoTypes]]] = {} + # for name, (typ, field) in cls._anno_map.items(): + # field_mapping[field.tag] = (name, typ) + # return field_mapping - return annotations - - @classmethod - def _get_field_mapping(cls) -> dict[int, tuple[str, type[_ProtoTypes]]]: # Tag, (Name, Type) - field_mapping: dict[int, tuple[str, type[_ProtoTypes]]] = {} - for name, (typ, field) in cls._anno_map.items(): - field_mapping[field.tag] = (name, typ) - return field_mapping - - def _get_stored_mapping(self) -> dict[str, NT]: - stored_mapping: dict[str, NT] = {} - for name, (_, _) in self._anno_map.items(): - stored_mapping[name] = getattr(self, name) - return stored_mapping + # def _get_stored_mapping(self) -> dict[str, NT]: + # stored_mapping: dict[str, NT] = {} + # for name, (_, _) in self._anno_map.items(): + # stored_mapping[name] = getattr(self, name) + # return stored_mapping def _encode(self, v: _ProtoTypes) -> NT: if isinstance(v, ProtoStruct): @@ -176,7 +216,7 @@ def _encode(self, v: _ProtoTypes) -> NT: def encode(self) -> bytes: pb_dict: NT = {} - for name, (_, field) in self._anno_map.items(): + for name, field in self.__fields__.items(): tag = field.tag if tag in pb_dict: raise ValueError(f"duplicate tag: {tag}") @@ -221,21 +261,17 @@ def decode(cls, data: bytes) -> Self: if not data: return None # type: ignore pb_dict: Proto = proto_decode(data, 0).proto - mapping = cls._get_field_mapping() kwargs = {} - for tag, (name, typ) in mapping.items(): - if tag not in pb_dict: - _, field = cls._anno_map[name] - if field.get_default() is not MISSING: - kwargs[name] = field.get_default() + for _, field in cls.__fields__.items(): + if field.tag not in pb_dict: + if (de := field.get_default()) is not MISSING: + kwargs[field.name] = de continue - raise KeyError(f"tag {tag} not found in '{cls.__name__}'") - kwargs[name] = cls._decode(typ, pb_dict.pop(tag)) - if pb_dict and cls._proto_debug: # unhandled tags + raise KeyError(f"tag {field.tag} not found in '{cls.__name__}'") + kwargs[field.name] = cls._decode(field.type_without_optional, pb_dict.pop(field.tag)) + if pb_dict and cls.__proto_debug__: # unhandled tags pass return cls(**kwargs) - - From 79b93a1a2ead4a2ce37f8bbd78211a4199c4153d Mon Sep 17 00:00:00 2001 From: Tarrailt Date: Mon, 23 Sep 2024 21:34:50 +0800 Subject: [PATCH 2/5] fix: evaluate before decode --- lagrange/utils/binary/protobuf/models.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lagrange/utils/binary/protobuf/models.py b/lagrange/utils/binary/protobuf/models.py index d33becc..d3df3ad 100644 --- a/lagrange/utils/binary/protobuf/models.py +++ b/lagrange/utils/binary/protobuf/models.py @@ -121,8 +121,9 @@ def check_type(self, value: Any, typ: Any) -> bool: return True return False # or True if value is None else False - def _evaluate(self): - for base in reversed(self.__class__.__mro__): + @classmethod + def _evaluate(cls): + for base in reversed(cls.__mro__): if base in (ProtoStruct, object): continue if getattr(base, '__proto_evaluated__', False): @@ -263,6 +264,7 @@ def decode(cls, data: bytes) -> Self: pb_dict: Proto = proto_decode(data, 0).proto kwargs = {} + cls._evaluate() for _, field in cls.__fields__.items(): if field.tag not in pb_dict: if (de := field.get_default()) is not MISSING: From 0c412f625bf00b93360d072bb17c852dfca49b87 Mon Sep 17 00:00:00 2001 From: RF-Tar-Railt Date: Mon, 23 Sep 2024 23:42:07 +0800 Subject: [PATCH 3/5] fix: fields init --- lagrange/client/event.py | 2 +- lagrange/pb/highway/comm.py | 6 +- lagrange/pb/highway/req.py | 4 +- lagrange/pb/message/rich_text/elems.py | 4 +- lagrange/utils/binary/protobuf/models.py | 218 ++++++++++++----------- 5 files changed, 119 insertions(+), 115 deletions(-) diff --git a/lagrange/client/event.py b/lagrange/client/event.py index 64175f2..edf7ecc 100644 --- a/lagrange/client/event.py +++ b/lagrange/client/event.py @@ -32,7 +32,7 @@ async def _task_exec(self, client: "Client", event: "BaseEvent", handler: EVENT_ try: await handler(client, event) except Exception as e: - log.root.error( + log.root.exception( f"Unhandled exception on task {event}", exc_info=e ) diff --git a/lagrange/pb/highway/comm.py b/lagrange/pb/highway/comm.py index e928728..6ed5579 100644 --- a/lagrange/pb/highway/comm.py +++ b/lagrange/pb/highway/comm.py @@ -34,9 +34,9 @@ class AudioExtInfo(ProtoStruct): class ExtBizInfo(ProtoStruct): - pic: PicExtInfo = proto_field(1, default=PicExtInfo()) - video: VideoExtInfo = proto_field(2, default=VideoExtInfo()) - audio: AudioExtInfo = proto_field(3, default=AudioExtInfo()) + pic: PicExtInfo = proto_field(1, default_factory=PicExtInfo) + video: VideoExtInfo = proto_field(2, default_factory=VideoExtInfo) + audio: AudioExtInfo = proto_field(3, default_factory=AudioExtInfo) bus_type: Optional[int] = proto_field(4, default=None) diff --git a/lagrange/pb/highway/req.py b/lagrange/pb/highway/req.py index 5740913..7855d4f 100644 --- a/lagrange/pb/highway/req.py +++ b/lagrange/pb/highway/req.py @@ -63,13 +63,13 @@ class DownloadVideoExt(ProtoStruct): class DownloadExt(ProtoStruct): pic_ext: Optional[bytes] = proto_field(1, default=None) - video_ext: DownloadVideoExt = proto_field(2, default=DownloadVideoExt()) + video_ext: DownloadVideoExt = proto_field(2, default_factory=DownloadVideoExt) ptt_ext: Optional[bytes] = proto_field(3, default=None) class DownloadReq(ProtoStruct): node: IndexNode = proto_field(1) - ext: DownloadExt = proto_field(2, default=DownloadExt()) + ext: DownloadExt = proto_field(2, default_factory=DownloadExt) class NTV2RichMediaReq(ProtoStruct): diff --git a/lagrange/pb/message/rich_text/elems.py b/lagrange/pb/message/rich_text/elems.py index 019b793..43b9b3a 100644 --- a/lagrange/pb/message/rich_text/elems.py +++ b/lagrange/pb/message/rich_text/elems.py @@ -53,7 +53,7 @@ class NotOnlineImage(ProtoStruct): width: int = proto_field(9) res_id: str = proto_field(10) origin_path: Optional[str] = proto_field(15, default=None) - args: ImageReserveArgs = proto_field(34, default=ImageReserveArgs()) + args: ImageReserveArgs = proto_field(34, default_factory=ImageReserveArgs) class TransElem(ProtoStruct): @@ -90,7 +90,7 @@ class CustomFace(ProtoStruct): width: int = proto_field(22) height: int = proto_field(23) size: int = proto_field(25) - args: ImageReserveArgs = proto_field(34, default=ImageReserveArgs()) + args: ImageReserveArgs = proto_field(34, default_factory=ImageReserveArgs) class ExtraInfo(ProtoStruct): diff --git a/lagrange/utils/binary/protobuf/models.py b/lagrange/utils/binary/protobuf/models.py index d3df3ad..5451ceb 100644 --- a/lagrange/utils/binary/protobuf/models.py +++ b/lagrange/utils/binary/protobuf/models.py @@ -96,55 +96,73 @@ def proto_field( return ProtoField(tag, default, default_factory) +def _decode(typ: type[_ProtoTypes], raw): + if isinstance(typ, str): + raise ValueError("ForwardRef not resolved. Please call ProtoStruct.update_forwardref() before decoding") + if issubclass(typ, ProtoStruct): + return typ.decode(raw) + elif typ is str: + return raw.decode(errors="ignore") + elif typ is dict: + return proto_decode(raw).proto + elif typ is bool: + return raw == 1 + elif typ is list: + if not isinstance(raw, list): + return [raw] + return raw + elif isinstance(typ, GenericAlias) and get_origin(typ) is list: + real_typ = get_args(typ)[0] + ret = [] + if isinstance(raw, list): + for v in raw: + ret.append(_decode(real_typ, v)) + else: + ret.append(_decode(real_typ, raw)) + return ret + elif isinstance(raw, typ): + return raw + else: + raise NotImplementedError(f"unknown type '{typ}' and data {raw}") + + +def check_type(value: Any, typ: Any) -> bool: + if isinstance(typ, str): + raise ValueError("ForwardRef not resolved. Please call ProtoStruct.update_forwardref() before decoding") + if typ is Any: + return True + if typ is list: + return isinstance(value, list) + if typ is dict: + return isinstance(value, dict) + if isinstance(typ, GenericAlias): + if get_origin(typ) is list: + return all(check_type(v, get_args(typ)[0]) for v in value) + if get_origin(typ) is dict: + return all(check_type(k, get_args(typ)[0]) and check_type(v, get_args(typ)[1]) for k, v in value.items()) + return False + if get_origin(typ) is Union: # Should Only be Optional + return check_type(value, get_args(typ)[0]) if value is not None else True + if isinstance(value, typ): + return True + return False # or True if value is None else False + + @dataclass_transform(kw_only_default=True, field_specifiers=(proto_field,)) class ProtoStruct: - __fields__: ClassVar[dict[str, ProtoField]] + __proto_fields__: ClassVar[dict[str, ProtoField]] __proto_debug__: ClassVar[bool] __proto_evaluated__: ClassVar[bool] = False - def check_type(self, value: Any, typ: Any) -> bool: - if typ is Any: - return True - if typ is list: - return isinstance(value, list) - if typ is dict: - return isinstance(value, dict) - if isinstance(typ, GenericAlias): - if get_origin(typ) is list: - return all(self.check_type(v, get_args(typ)[0]) for v in value) - if get_origin(typ) is dict: - return all(self.check_type(k, get_args(typ)[0]) and self.check_type(v, get_args(typ)[1]) for k, v in value.items()) - return False - if get_origin(typ) is Union: # Should Only be Optional - return self.check_type(value, get_args(typ)[0]) if value is not None else True - if isinstance(value, typ): - return True - return False # or True if value is None else False - - @classmethod - def _evaluate(cls): - for base in reversed(cls.__mro__): - if base in (ProtoStruct, object): - continue - if getattr(base, '__proto_evaluated__', False): - continue - base_globals = getattr(sys.modules.get(base.__module__, None), '__dict__', {}) - base_locals = dict(vars(base)) - base_globals, base_locals = base_locals, base_globals - for field in base.__fields__.values(): - if isinstance(field.type, str): - field.type = ForwardRef(field.type, is_argument=False, is_class=True)._evaluate( - base_globals, base_locals, recursive_guard=frozenset() - ) - base.__proto_evaluated__ = True - - def __init__(self, **kwargs): - undefined_params: list[str] = [] + def __init__(self, __from_raw: bool = False, **kwargs): + undefined_params: list[ProtoField] = [] self._evaluate() - for name, field in self.__fields__.items(): + for name, field in self.__proto_fields__.items(): if name in kwargs: value = kwargs.pop(name) - if not self.check_type(value, field.type): + if __from_raw: + value = _decode(field.type_without_optional, value) + if not check_type(value, field.type): raise TypeError( f"'{value}' is not a instance of type '{field.type}'" ) @@ -153,16 +171,22 @@ def __init__(self, **kwargs): if (de := field.get_default()) is not MISSING: setattr(self, name, de) else: - undefined_params.append(name) + undefined_params.append(field) if undefined_params: raise AttributeError( - f"Undefined parameters in '{self}': {undefined_params}" + f"Missing required parameters: {', '.join(f'{f.name}({f.tag})' for f in undefined_params)}" ) - super().__init__(**kwargs) @classmethod def _process_field(cls): fields = {} + + for b in cls.__mro__[-1:0:-1]: + base_fields = getattr(b, "__proto_fields__", None) + if base_fields is not None: + for f in base_fields.values(): + fields[f.name] = f + cls_annotations = cls.__dict__.get('__annotations__', {}) cls_fields: list[ProtoField] = [] for name, typ in cls_annotations.items(): @@ -181,7 +205,34 @@ def _process_field(cls): if isinstance(value, ProtoField) and not name in cls_annotations: raise TypeError(f'{name!r} is a proto_field but has no type annotation') - cls.__fields__ = fields + cls.__proto_fields__ = fields + + @classmethod + def _evaluate(cls): + for base in reversed(cls.__mro__): + if base in (ProtoStruct, object): + continue + if getattr(base, '__proto_evaluated__', False): + continue + base_globals = getattr(sys.modules.get(base.__module__, None), '__dict__', {}) + base_locals = dict(vars(base)) + base_globals, base_locals = base_locals, base_globals + for field in base.__proto_fields__.values(): + if isinstance(field.type, str): + try: + field.type = ForwardRef(field.type, is_argument=False, is_class=True)._evaluate( + base_globals, base_locals, recursive_guard=frozenset() + ) + except NameError: + pass + base.__proto_evaluated__ = True + + @classmethod + def update_forwardref(cls, mapping: dict[str, "type[ProtoStruct]"]): + """更新 ForwardRef""" + for field in cls.__proto_fields__.values(): + if isinstance(field.type, str) and field.type in mapping: + field.type = mapping[field.type] def __init_subclass__(cls, **kwargs): cls.__proto_debug__ = kwargs.pop("debug") if "debug" in kwargs else False @@ -196,84 +247,37 @@ def __repr__(self) -> str: attrs += f"{k}={v!r}, " return f"{self.__class__.__name__}({attrs})" - - # @classmethod - # def _get_field_mapping(cls) -> dict[int, tuple[str, type[_ProtoTypes]]]: # Tag, (Name, Type) - # field_mapping: dict[int, tuple[str, type[_ProtoTypes]]] = {} - # for name, (typ, field) in cls._anno_map.items(): - # field_mapping[field.tag] = (name, typ) - # return field_mapping - - # def _get_stored_mapping(self) -> dict[str, NT]: - # stored_mapping: dict[str, NT] = {} - # for name, (_, _) in self._anno_map.items(): - # stored_mapping[name] = getattr(self, name) - # return stored_mapping - - def _encode(self, v: _ProtoTypes) -> NT: - if isinstance(v, ProtoStruct): - v = v.encode() - return v # type: ignore - def encode(self) -> bytes: pb_dict: NT = {} - for name, field in self.__fields__.items(): + + def _encode(v: _ProtoTypes) -> NT: + if isinstance(v, ProtoStruct): + v = v.encode() + return v # type: ignore + + for name, field in self.__proto_fields__.items(): tag = field.tag if tag in pb_dict: raise ValueError(f"duplicate tag: {tag}") value: _ProtoTypes = getattr(self, name) if isinstance(value, list): - pb_dict[tag] = [self._encode(v) for v in value] + pb_dict[tag] = [_encode(v) for v in value] else: - pb_dict[tag] = self._encode(value) + pb_dict[tag] = _encode(value) return proto_encode(cast(Proto, pb_dict)) - @classmethod - def _decode(cls, typ: type[_ProtoTypes], value): - if issubclass(typ, ProtoStruct): - return typ.decode(value) - elif typ is str: - return value.decode(errors="ignore") - elif typ is dict: - return proto_decode(value).proto - elif typ is bool: - return value == 1 - elif typ is list: - if not isinstance(value, list): - return [value] - return value - elif isinstance(typ, GenericAlias): - if typ.__name__.lower() == "list": - real_typ = typ.__args__[0] - ret = [] - if isinstance(value, list): - for v in value: - ret.append(cls._decode(real_typ, v)) - else: - ret.append(cls._decode(real_typ, value)) - return ret - elif isinstance(value, typ): - return value - else: - raise NotImplementedError(f"unknown type '{typ}' and data {value}") - @classmethod def decode(cls, data: bytes) -> Self: if not data: return None # type: ignore pb_dict: Proto = proto_decode(data, 0).proto - kwargs = {} - cls._evaluate() - for _, field in cls.__fields__.items(): - if field.tag not in pb_dict: - if (de := field.get_default()) is not MISSING: - kwargs[field.name] = de - continue + kwargs = { + field.name: pb_dict.pop(field.tag) + for field in cls.__proto_fields__.values() + if field.tag in pb_dict + } - raise KeyError(f"tag {field.tag} not found in '{cls.__name__}'") - kwargs[field.name] = cls._decode(field.type_without_optional, pb_dict.pop(field.tag)) if pb_dict and cls.__proto_debug__: # unhandled tags pass - - return cls(**kwargs) + return cls(True, **kwargs) From 47b886932294a4f21187a5014196f2cea8e5ee2d Mon Sep 17 00:00:00 2001 From: RF-Tar-Railt Date: Tue, 24 Sep 2024 10:25:45 +0800 Subject: [PATCH 4/5] chore: improve evaluate --- lagrange/pb/service/group.py | 12 +++--- lagrange/utils/binary/protobuf/models.py | 48 +++++++++++++++--------- 2 files changed, 37 insertions(+), 23 deletions(-) diff --git a/lagrange/pb/service/group.py b/lagrange/pb/service/group.py index a2fe230..373d510 100644 --- a/lagrange/pb/service/group.py +++ b/lagrange/pb/service/group.py @@ -319,6 +319,12 @@ class MemberInfoLevel(ProtoStruct): num: int = proto_field(2) +class GetGrpMemberInfoRsp(ProtoStruct): + grp_id: int = proto_field(1) + body: list["GetGrpMemberInfoRspBody"] = proto_field(2) + next_key: Optional[bytes] = proto_field(15, default=None) # base64(pb) + + class GetGrpMemberInfoRspBody(ProtoStruct): account: AccountInfo = proto_field(1) nickname: str = proto_field(10, default="") @@ -342,12 +348,6 @@ def is_owner(self) -> bool: return not self.is_admin and self.permission == 2 -class GetGrpMemberInfoRsp(ProtoStruct): - grp_id: int = proto_field(1) - body: list[GetGrpMemberInfoRspBody] = proto_field(2) - next_key: Optional[bytes] = proto_field(15, default=None) # base64(pb) - - class GetGrpListReqBody(ProtoStruct): cfg1: bytes = proto_field(1) cfg2: bytes = proto_field(2) diff --git a/lagrange/utils/binary/protobuf/models.py b/lagrange/utils/binary/protobuf/models.py index 5451ceb..2fd2d6d 100644 --- a/lagrange/utils/binary/protobuf/models.py +++ b/lagrange/utils/binary/protobuf/models.py @@ -1,10 +1,10 @@ -import sys from dataclasses import MISSING from types import GenericAlias -from typing import cast, TypeVar, Union, Any, Callable, overload, get_origin, ForwardRef, get_args +from typing import cast, TypeVar, Union, Any, Callable, overload, get_origin, get_args, get_type_hints from collections.abc import Mapping from typing_extensions import Self, TypeAlias, dataclass_transform from typing import Optional, ClassVar +import typing from .coder import Proto, proto_decode, proto_encode @@ -16,6 +16,13 @@ NoneType = type(None) +def _get_all_args(tp): + if args := get_args(tp): + for arg in args: + yield from _get_all_args(arg) + yield from args + + class ProtoField: name: str type: Any @@ -26,10 +33,15 @@ def __init__(self, tag: int, default: Any, default_factory: Any): self.tag = tag self._default = default self._default_factory = default_factory + self._unevaluated = False def ensure_annotation(self, name: str, type_: Any) -> None: self.name = name self.type = type_ + if isinstance(type_, str): + self._unevaluated = True + elif (args := [*_get_all_args(type_)]) and any(isinstance(a, str) for a in args): + self._unevaluated = True def get_default(self) -> Any: if self._default is not MISSING: @@ -98,7 +110,10 @@ def proto_field( def _decode(typ: type[_ProtoTypes], raw): if isinstance(typ, str): - raise ValueError("ForwardRef not resolved. Please call ProtoStruct.update_forwardref() before decoding") + raise ValueError( + f"ForwardRef '{typ}' not resolved. " + f"Please call ProtoStruct.update_forwardref({{'{typ}': {typ}}}) before decoding" + ) if issubclass(typ, ProtoStruct): return typ.decode(raw) elif typ is str: @@ -128,7 +143,10 @@ def _decode(typ: type[_ProtoTypes], raw): def check_type(value: Any, typ: Any) -> bool: if isinstance(typ, str): - raise ValueError("ForwardRef not resolved. Please call ProtoStruct.update_forwardref() before decoding") + raise ValueError( + f"ForwardRef '{typ}' not resolved. " + f"Please call ProtoStruct.update_forwardref({{'{typ}': {typ}}}) before decoding" + ) if typ is Any: return True if typ is list: @@ -154,7 +172,7 @@ class ProtoStruct: __proto_debug__: ClassVar[bool] __proto_evaluated__: ClassVar[bool] = False - def __init__(self, __from_raw: bool = False, **kwargs): + def __init__(self, __from_raw: bool = False, /, **kwargs): undefined_params: list[ProtoField] = [] self._evaluate() for name, field in self.__proto_fields__.items(): @@ -214,25 +232,21 @@ def _evaluate(cls): continue if getattr(base, '__proto_evaluated__', False): continue - base_globals = getattr(sys.modules.get(base.__module__, None), '__dict__', {}) - base_locals = dict(vars(base)) - base_globals, base_locals = base_locals, base_globals + try: + annotations = get_type_hints(base) + except NameError: + annotations = {} for field in base.__proto_fields__.values(): - if isinstance(field.type, str): - try: - field.type = ForwardRef(field.type, is_argument=False, is_class=True)._evaluate( - base_globals, base_locals, recursive_guard=frozenset() - ) - except NameError: - pass + if field._unevaluated and field.name in annotations: + field.type = annotations[field.name] base.__proto_evaluated__ = True @classmethod def update_forwardref(cls, mapping: dict[str, "type[ProtoStruct]"]): """更新 ForwardRef""" for field in cls.__proto_fields__.values(): - if isinstance(field.type, str) and field.type in mapping: - field.type = mapping[field.type] + if field._unevaluated: + field.type = typing._eval_type(field.type, mapping, mapping) # type: ignore def __init_subclass__(cls, **kwargs): cls.__proto_debug__ = kwargs.pop("debug") if "debug" in kwargs else False From f858e05f4954566c36004fcbfd8597fe131b6972 Mon Sep 17 00:00:00 2001 From: RF-Tar-Railt Date: Tue, 24 Sep 2024 17:04:07 +0800 Subject: [PATCH 5/5] chore: evaluate_all --- lagrange/__init__.py | 4 ++ lagrange/pb/service/group.py | 10 ++-- lagrange/pb/status/group.py | 14 ++--- lagrange/utils/binary/protobuf/models.py | 70 ++++++++++++++---------- 4 files changed, 58 insertions(+), 40 deletions(-) diff --git a/lagrange/__init__.py b/lagrange/__init__.py index febc07b..47bf9da 100644 --- a/lagrange/__init__.py +++ b/lagrange/__init__.py @@ -9,6 +9,7 @@ from .utils.sign import sign_provider from .info import InfoManager from .info.app import app_list +from .utils.binary.protobuf.models import evaluate_all class Lagrange: @@ -66,3 +67,6 @@ def launch(self): log.root.info("Program exited by user") else: log.root.info("Program exited normally") + + +evaluate_all() diff --git a/lagrange/pb/service/group.py b/lagrange/pb/service/group.py index 373d510..3ac60ad 100644 --- a/lagrange/pb/service/group.py +++ b/lagrange/pb/service/group.py @@ -321,7 +321,7 @@ class MemberInfoLevel(ProtoStruct): class GetGrpMemberInfoRsp(ProtoStruct): grp_id: int = proto_field(1) - body: list["GetGrpMemberInfoRspBody"] = proto_field(2) + body: "list[GetGrpMemberInfoRspBody]" = proto_field(2) next_key: Optional[bytes] = proto_field(15, default=None) # base64(pb) @@ -429,15 +429,15 @@ class GetInfoRspField(ProtoStruct, debug=True): str_t: list[GetInfoRspF2] = proto_field(2, default_factory=list) +class GetInfoFromUidRsp(ProtoStruct): + body: list["GetInfoRspBody"] = proto_field(1) + + class GetInfoRspBody(ProtoStruct): uid: str = proto_field(1) fields: GetInfoRspField = proto_field(2) -class GetInfoFromUidRsp(ProtoStruct): - body: list[GetInfoRspBody] = proto_field(1) - - class Oidb88D0Args(ProtoStruct): seq: Optional[int] = proto_field(22, default=None) diff --git a/lagrange/pb/status/group.py b/lagrange/pb/status/group.py index fb7d59c..0f17ad3 100644 --- a/lagrange/pb/status/group.py +++ b/lagrange/pb/status/group.py @@ -111,6 +111,13 @@ class GroupSub16Head(ProtoStruct): f44: Optional[PBGroupReaction] = proto_field(44, default=None) # set reaction only +class GroupSub20Head(ProtoStruct): + f1: int = proto_field(1) # 20 + grp_id: int = proto_field(4) + f13: int = proto_field(13) # 19 + body: "GroupSub20Body" = proto_field(26) + + class GroupSub20Body(ProtoStruct): type: int = proto_field(1) # 12: nudge, 14: group_sign # f2: int = proto_field(2) # 1061 @@ -121,13 +128,6 @@ class GroupSub20Body(ProtoStruct): f10: int = proto_field(10) # rand? -class GroupSub20Head(ProtoStruct): - f1: int = proto_field(1) # 20 - grp_id: int = proto_field(4) - f13: int = proto_field(13) # 19 - body: GroupSub20Body = proto_field(26) - - class PBGroupAlbumUpdateBody(ProtoStruct): # f1: 6 args: str = proto_field(2) diff --git a/lagrange/utils/binary/protobuf/models.py b/lagrange/utils/binary/protobuf/models.py index 2fd2d6d..6f7c0a4 100644 --- a/lagrange/utils/binary/protobuf/models.py +++ b/lagrange/utils/binary/protobuf/models.py @@ -1,6 +1,7 @@ +import sys from dataclasses import MISSING from types import GenericAlias -from typing import cast, TypeVar, Union, Any, Callable, overload, get_origin, get_args, get_type_hints +from typing import cast, TypeVar, Union, Any, Callable, overload, get_origin, get_args, ForwardRef from collections.abc import Mapping from typing_extensions import Self, TypeAlias, dataclass_transform from typing import Optional, ClassVar @@ -31,8 +32,8 @@ def __init__(self, tag: int, default: Any, default_factory: Any): if tag <= 0: raise ValueError("Tag must be a positive integer") self.tag = tag - self._default = default - self._default_factory = default_factory + self.default = default + self.default_factory = default_factory self._unevaluated = False def ensure_annotation(self, name: str, type_: Any) -> None: @@ -44,10 +45,10 @@ def ensure_annotation(self, name: str, type_: Any) -> None: self._unevaluated = True def get_default(self) -> Any: - if self._default is not MISSING: - return self._default - elif self._default_factory is not MISSING: - return self._default_factory() + if self.default is not MISSING: + return self.default + elif self.default_factory is not MISSING: + return self.default_factory() return MISSING @property @@ -166,15 +167,16 @@ def check_type(value: Any, typ: Any) -> bool: return False # or True if value is None else False +_unevaluated_classes: set[type["ProtoStruct"]] = set() + + @dataclass_transform(kw_only_default=True, field_specifiers=(proto_field,)) class ProtoStruct: __proto_fields__: ClassVar[dict[str, ProtoField]] __proto_debug__: ClassVar[bool] - __proto_evaluated__: ClassVar[bool] = False def __init__(self, __from_raw: bool = False, /, **kwargs): undefined_params: list[ProtoField] = [] - self._evaluate() for name, field in self.__proto_fields__.items(): if name in kwargs: value = kwargs.pop(name) @@ -212,11 +214,13 @@ def _process_field(cls): if field is MISSING: raise TypeError(f'{name!r} should define its proto_field!') field.ensure_annotation(name, typ) + if field._unevaluated: + _unevaluated_classes.add(cls) cls_fields.append(field) for f in cls_fields: fields[f.name] = f - if f._default is MISSING: + if f.default is MISSING: delattr(cls, f.name) for name, value in cls.__dict__.items(): @@ -225,28 +229,20 @@ def _process_field(cls): cls.__proto_fields__ = fields - @classmethod - def _evaluate(cls): - for base in reversed(cls.__mro__): - if base in (ProtoStruct, object): - continue - if getattr(base, '__proto_evaluated__', False): - continue - try: - annotations = get_type_hints(base) - except NameError: - annotations = {} - for field in base.__proto_fields__.values(): - if field._unevaluated and field.name in annotations: - field.type = annotations[field.name] - base.__proto_evaluated__ = True - @classmethod def update_forwardref(cls, mapping: dict[str, "type[ProtoStruct]"]): """更新 ForwardRef""" for field in cls.__proto_fields__.values(): - if field._unevaluated: - field.type = typing._eval_type(field.type, mapping, mapping) # type: ignore + if not field._unevaluated: + continue + try: + typ = field.type + if isinstance(typ, str): + typ = ForwardRef(typ, is_argument=False, is_class=True) + field.type = typing._eval_type(typ, mapping, mapping) # type: ignore + field._unevaluated = False + except NameError: + pass def __init_subclass__(cls, **kwargs): cls.__proto_debug__ = kwargs.pop("debug") if "debug" in kwargs else False @@ -295,3 +291,21 @@ def decode(cls, data: bytes) -> Self: if pb_dict and cls.__proto_debug__: # unhandled tags pass return cls(True, **kwargs) + + +def evaluate_all(): + modules = set() + for cls in _unevaluated_classes: + modules.add(cls.__module__) + for base in cls.__mro__[-1:0:-1][2:]: + modules.add(base.__module__) + globalns = {} + for module in modules: + globalns.update(getattr(sys.modules.get(module, None), "__dict__", {})) + for cls in _unevaluated_classes: + cls.update_forwardref(globalns) + _unevaluated_classes.clear() + modules.clear() + globalns.clear() + del modules + del globalns