Skip to content

Commit

Permalink
Move annotations into a dict under a single named attribute, to suppo…
Browse files Browse the repository at this point in the history
…rt additional annotations
  • Loading branch information
j4mie committed Apr 15, 2024
1 parent 41a1b3e commit 3d0628c
Showing 1 changed file with 25 additions and 15 deletions.
40 changes: 25 additions & 15 deletions django_readers/rest_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@
from rest_framework.utils import model_meta


def add_annotation(obj, key, value):
obj._readers_annotation = getattr(obj, "_readers_annotation", None) or {}
obj._readers_annotation[key] = value


def get_annotation(obj, key):
return getattr(obj, "_readers_annotation", {}).get(key)


class ProjectionSerializer:
def __init__(self, data=None, many=False, context=None):
self.many = many
Expand Down Expand Up @@ -88,10 +97,11 @@ def _prepare_field(self, field):
def _get_out_value(self, item):
# Either the item itself or (if this is a pair) just the
# producer/projector function may have been decorated
if hasattr(item, "out"):
return item.out
if isinstance(item, tuple) and hasattr(item[1], "out"):
return item[1].out
if out := get_annotation(item, "out"):
return out
if isinstance(item, tuple):
if out := get_annotation(item[1], "out"):
return out
return None

def visit_str(self, item):
Expand All @@ -100,8 +110,8 @@ def visit_str(self, item):
def visit_dict_item_str(self, key, value):
# This is a model field name. First, check if the
# field has been explicitly overridden
if hasattr(value, "out"):
field = self._prepare_field(value.out)
if out := get_annotation(value, "out"):
field = self._prepare_field(out)
self.fields[str(key)] = field
return key, field

Expand Down Expand Up @@ -231,12 +241,12 @@ def serializer_class_for_view(view):
return serializer_class_for_spec(name_prefix, model, view.spec)


class PairWithOutAttribute(tuple):
out = None
class PairWithAnnotation(tuple):
_readers_annotation = None


class StringWithOutAttribute(str):
out = None
class StringWithAnnotation(str):
_readers_annotation = None


def out(field_or_dict):
Expand All @@ -257,15 +267,15 @@ def wrapper(*args, **kwargs):
result = item(*args, **kwargs)
return self(result)

wrapper.out = field_or_dict
add_annotation(wrapper, "out", field_or_dict)
return wrapper
else:
if isinstance(item, str):
item = StringWithOutAttribute(item)
item.out = field_or_dict
item = StringWithAnnotation(item)
add_annotation(item, "out", field_or_dict)
if isinstance(item, tuple):
item = PairWithOutAttribute(item)
item.out = field_or_dict
item = PairWithAnnotation(item)
add_annotation(item, "out", field_or_dict)
return item

def __rrshift__(self, other):
Expand Down

0 comments on commit 3d0628c

Please sign in to comment.