Skip to content

Commit

Permalink
Get rid of __raw_get
Browse files Browse the repository at this point in the history
  • Loading branch information
Gobot1234 committed Apr 11, 2022
1 parent fea05f2 commit c942fe1
Showing 1 changed file with 37 additions and 38 deletions.
75 changes: 37 additions & 38 deletions src/betterproto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def decode_varint(buffer: bytes, pos: int) -> Tuple[int, int]:
"""
result = 0
shift = 0
while 1:
while True:
b = buffer[pos]
result |= (b & 0x7F) << shift
pos += 1
Expand Down Expand Up @@ -629,6 +629,8 @@ class Message:
Calls :meth:`__bool__`.
"""

__slots__ = ("_serialized_on_wire", "_unknown_fields", "_group_current")

_serialized_on_wire: bool
_unknown_fields: bytes
_group_current: Dict[str, str]
Expand All @@ -644,8 +646,8 @@ def __post_init__(self) -> None:
if meta.group:
group_current.setdefault(meta.group)

value = self.__raw_get(field_name)
if value != PLACEHOLDER and not (meta.optional and value is None):
value = object.__getattribute__(self, field_name)
if value is not PLACEHOLDER and not (meta.optional and value is None):
# Found a non-sentinel value
all_sentinel = False

Expand All @@ -654,47 +656,42 @@ def __post_init__(self) -> None:
group_current[meta.group] = field_name

# Now that all the defaults are set, reset it!
self.__dict__["_serialized_on_wire"] = not all_sentinel
self.__dict__["_unknown_fields"] = b""
self.__dict__["_group_current"] = group_current

__raw_get = object.__getattribute__
super().__setattr__("_serialized_on_wire", not all_sentinel)
super().__setattr__("_unknown_fields", b"")
super().__setattr__("_group_current", group_current)

def __eq__(self, other) -> bool:
if type(self) is not type(other):
if not isinstance(other, type(self)):
return False

for field_name in self._betterproto.meta_by_field_name:
self_val = self.__raw_get(field_name)
other_val = other.__raw_get(field_name)
self_val = object.__getattribute__(self, field_name)
other_val = object.__getattribute__(other, field_name)
if self_val is PLACEHOLDER:
if other_val is PLACEHOLDER:
continue
self_val = self._get_field_default(field_name)
elif other_val is PLACEHOLDER:
other_val = other._get_field_default(field_name)

if self_val != other_val:
if self_val != other_val and (
not isinstance(self_val, float)
or not isinstance(other_val, float)
or not math.isnan(self_val)
or not math.isnan(other_val)
):
# We consider two nan values to be the same for the
# purposes of comparing messages (otherwise a message
# is not equal to itself)
if (
isinstance(self_val, float)
and isinstance(other_val, float)
and math.isnan(self_val)
and math.isnan(other_val)
):
continue
else:
return False
return False

return True

def __repr__(self) -> str:
parts = [
f"{field_name}={value!r}"
for field_name in self._betterproto.sorted_field_names
for value in (self.__raw_get(field_name),)
for value in (object.__getattribute__(self, field_name),)
if value is not PLACEHOLDER
]
return f"{self.__class__.__name__}({', '.join(parts)})"
Expand All @@ -715,31 +712,33 @@ def __getattribute__(self, name: str) -> Any:
def __setattr__(self, attr: str, value: Any) -> None:
if attr != "_serialized_on_wire":
# Track when a field has been set.
self.__dict__["_serialized_on_wire"] = True

if hasattr(self, "_group_current"): # __post_init__ had already run
if attr in self._betterproto.oneof_group_by_field:
group = self._betterproto.oneof_group_by_field[attr]
for field in self._betterproto.oneof_field_by_group[group]:
if field.name == attr:
self._group_current[group] = field.name
else:
super().__setattr__(field.name, PLACEHOLDER)
super().__setattr__("_serialized_on_wire", True)

if (
hasattr(self, "_group_current")
and attr in self._betterproto.oneof_group_by_field
): # __post_init__ had already run
group = self._betterproto.oneof_group_by_field[attr]
for field in self._betterproto.oneof_field_by_group[group]:
if field.name == attr:
self._group_current[group] = field.name
else:
super().__setattr__(field.name, PLACEHOLDER)

super().__setattr__(attr, value)

def __bool__(self) -> bool:
"""True if the Message has any fields with non-default values."""
return any(
self.__raw_get(field_name)
object.__getattribute__(self, field_name)
not in (PLACEHOLDER, self._get_field_default(field_name))
for field_name in self._betterproto.meta_by_field_name
)

def __deepcopy__(self: T, _: Any = {}) -> T:
kwargs = {}
for name in self._betterproto.sorted_field_names:
value = self.__raw_get(name)
value = object.__getattribute__(self, name)
if value is not PLACEHOLDER:
kwargs[name] = deepcopy(value)
return self.__class__(**kwargs) # type: ignore
Expand Down Expand Up @@ -879,9 +878,9 @@ def _type_hints(cls) -> Dict[str, Type]:
def _cls_for(cls, field: dataclasses.Field, index: int = 0) -> Type:
"""Get the message class for a field from the type hints."""
field_cls = cls._type_hint(field.name)
if hasattr(field_cls, "__args__") and index >= 0:
if field_cls.__args__ is not None:
field_cls = field_cls.__args__[index]
args = getattr(field_cls, "__args__", None)
if args and index >= 0 and args is not None:
field_cls = field_cls.__args__[index]
return field_cls

def _get_field_default(self, field_name: str) -> Any:
Expand Down Expand Up @@ -1325,7 +1324,7 @@ def is_set(self, name: str) -> bool:
:class:`bool`
`True` if field has been set, otherwise `False`.
"""
return self.__raw_get(name) is not PLACEHOLDER
return object.__getattribute__(self, name) is not PLACEHOLDER


def serialized_on_wire(message: Message) -> bool:
Expand Down

0 comments on commit c942fe1

Please sign in to comment.