Skip to content

Commit

Permalink
feat: Implement | operator on Filtering and Related. (#338)
Browse files Browse the repository at this point in the history
Co-authored-by: Peter Schutt <peter@topsport.com.au>
Co-authored-by: Jimmy Jia <tesrin@gmail.com>
  • Loading branch information
3 people authored Jul 28, 2020
1 parent 0bf3f7e commit ac2e407
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 12 deletions.
17 changes: 15 additions & 2 deletions flask_resty/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,7 @@ def maybe_set_arg_name(self, arg_name):

if self._column_name and self._column_name != arg_name:
raise TypeError(
"cannot use ColumnFilter without explicit column name for "
+ "multiple arg names"
"cannot use ColumnFilter without explicit column name for multiple arg names"
)

self._column_name = arg_name
Expand Down Expand Up @@ -357,3 +356,17 @@ def filter_query(self, query, view):
raise e.update({"source": {"parameter": arg_name}})

return query

def __or__(self, other):
"""Combine two `Filtering` instances.
`Filtering` supports view inheritance by implementing the `|` operator.
For example, `Filtering(foo=..., bar=...) | Filtering(baz=...)` will
create a new `Filtering` instance with filters for each `foo`, `bar`
and `baz`. Filters on the right-hand side take precedence where each
`Filtering` instance has the same key.
"""
if not isinstance(other, Filtering):
return NotImplemented

return self.__class__(**{**self._arg_filters, **other._arg_filters})
17 changes: 17 additions & 0 deletions flask_resty/related.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,20 @@ def resolve_field(self, value, resolver):
return [resolve_item(item) for item in value]

return resolve_item(value)

def __or__(self, other):
"""Combine two `Related` instances.
`Related` supports view inheritance by implementing the `|` operator.
For example, `Related(foo=..., bar=...) | Related(baz=...)` will create
a new `Related` instance with resolvers for each `foo`, `bar` and
`baz`. Resolvers on the right-hand side take precedence where each
`Related` instance has the same key.
"""
if not isinstance(other, Related):
return NotImplemented

return self.__class__(
other._item_class or self._item_class,
**{**self._resolvers, **other._resolvers},
)
16 changes: 14 additions & 2 deletions tests/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ class WidgetViewBase(GenericModelView):
model = models["widget"]
schema = schemas["widget"]

class WidgetListView(WidgetViewBase):
filtering = Filtering(
color=operator.eq,
color_allow_empty=ColumnFilter(
Expand All @@ -81,11 +80,14 @@ class WidgetListView(WidgetViewBase):
),
)

class WidgetListView(WidgetViewBase):
def get(self):
return self.list()

class WidgetSizeRequiredListView(WidgetViewBase):
filtering = Filtering(size=ColumnFilter(operator.eq, required=True))
filtering = WidgetViewBase.filtering | Filtering(
size=ColumnFilter(operator.eq, required=True)
)

def get(self):
return self.list()
Expand Down Expand Up @@ -206,6 +208,11 @@ def test_column_filter_required_present(client):
assert_response(response, 200, [{"id": "1", "color": "red", "size": 1}])


def test_combine(client):
response = client.get("/widgets_size_required?size=1&color=green")
assert_response(response, 200, [])


def test_column_filter_unvalidated(client):
response = client.get("/widgets?size_min_unvalidated=-1")
assert_response(
Expand Down Expand Up @@ -333,3 +340,8 @@ def test_error_reuse_column_filter():

with pytest.raises(TypeError, match="without explicit column name"):
Filtering(foo=implicit_column_filter, bar=implicit_column_filter)


def test_error_combine_filtering_type_error():
with pytest.raises(TypeError):
Filtering() | {}
79 changes: 71 additions & 8 deletions tests/test_related.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,19 @@ class Parent(db.Model):
id = Column(Integer, primary_key=True)
name = Column(String)

children = relationship("Child", backref="parent", cascade="all")

class Child(db.Model):
__tablename__ = "children"

id = Column(Integer, primary_key=True)
name = Column(String)

parent_id = Column(ForeignKey(Parent.id))
parent = relationship(
Parent, foreign_keys=parent_id, backref="children"
)

other_parent_id = Column(ForeignKey(Parent.id))
other_parent = relationship(Parent, foreign_keys=other_parent_id)

db.create_all()

Expand All @@ -40,13 +44,15 @@ class ParentSchema(Schema):
id = fields.Integer(as_string=True)
name = fields.String(required=True)

children = RelatedItem("ChildSchema", many=True, exclude=("parent",))
children = RelatedItem(
"ChildSchema", many=True, exclude=("parent", "other_parent")
)
child_ids = fields.List(fields.Integer(as_string=True), load_only=True)

class ChildSchema(Schema):
@classmethod
def get_query_options(cls, load):
return (load.joinedload("parent"),)
return (load.joinedload("parent"), load.joinedload("other_parent"))

id = fields.Integer(as_string=True)
name = fields.String(required=True)
Expand All @@ -58,6 +64,10 @@ def get_query_options(cls, load):
as_string=True, allow_none=True, load_only=True
)

other_parent = fields.Nested(
ParentSchema, exclude=("children",), allow_none=True
)

return {"parent": ParentSchema(), "child": ChildSchema()}


Expand Down Expand Up @@ -110,12 +120,20 @@ class NestedChildView(GenericModelView):
def put(self, id):
return self.update(id)

class ChildWithOtherParentView(ChildView):
related = ChildView.related | Related(
other_parent=Related(models["parent"])
)

api = Api(app)
api.add_resource("/parents/<int:id>", ParentView)
api.add_resource("/nested_parents/<int:id>", NestedParentView)
api.add_resource("/parents_with_create/<int:id>", ParentWithCreateView)
api.add_resource("/children/<int:id>", ChildView)
api.add_resource("/nested_children/<int:id>", NestedChildView)
api.add_resource(
"/children_with_other_parent/<int:id>", ChildWithOtherParentView
)


@pytest.fixture(autouse=True)
Expand All @@ -141,12 +159,16 @@ def test_baseline(client):

child_1_response = client.get("/children/1")
assert_response(
child_1_response, 200, {"id": "1", "name": "Child 1", "parent": None}
child_1_response,
200,
{"id": "1", "name": "Child 1", "parent": None, "other_parent": None},
)

child_2_response = client.get("/children/2")
assert_response(
child_2_response, 200, {"id": "2", "name": "Child 2", "parent": None}
child_2_response,
200,
{"id": "2", "name": "Child 2", "parent": None, "other_parent": None},
)


Expand All @@ -163,6 +185,7 @@ def test_single(client):
"id": "1",
"name": "Updated Child",
"parent": {"id": "1", "name": "Parent"},
"other_parent": None,
},
)

Expand All @@ -180,6 +203,7 @@ def test_single_nested(client):
"id": "1",
"name": "Updated Child",
"parent": {"id": "1", "name": "Parent"},
"other_parent": None,
},
)

Expand Down Expand Up @@ -266,6 +290,7 @@ def test_missing(client):
"id": "1",
"name": "Twice Updated Child",
"parent": {"id": "1", "name": "Parent"},
"other_parent": None,
},
)

Expand All @@ -280,7 +305,12 @@ def test_null(client):
assert_response(
response,
200,
{"id": "1", "name": "Twice Updated Child", "parent": None},
{
"id": "1",
"name": "Twice Updated Child",
"parent": None,
"other_parent": None,
},
)


Expand All @@ -294,7 +324,12 @@ def test_null_nested(client):
assert_response(
response,
200,
{"id": "1", "name": "Twice Updated Child", "parent": None},
{
"id": "1",
"name": "Twice Updated Child",
"parent": None,
"other_parent": None,
},
)


Expand All @@ -313,6 +348,29 @@ def test_many_falsy(client):
)


def test_combine(client):
response = client.put(
"/children_with_other_parent/1",
data={
"id": "1",
"name": "Updated Child",
"parent_id": "1",
"other_parent": {"id": "2", "name": "Other Parent"},
},
)

assert_response(
response,
200,
{
"id": "1",
"name": "Updated Child",
"parent": {"id": "1", "name": "Parent"},
"other_parent": {"id": "2", "name": "Other Parent"},
},
)


# -----------------------------------------------------------------------------


Expand Down Expand Up @@ -365,3 +423,8 @@ def test_error_missing_id(client):
}
],
)


def test_error_combine_related_type_error():
with pytest.raises(TypeError):
Related() | {}

0 comments on commit ac2e407

Please sign in to comment.