From 267103750e973e6a65c5a6f9708da2ba64b74a47 Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Fri, 24 Jul 2020 12:21:03 +1000 Subject: [PATCH 1/4] feat: Implement | operator on Filtering and Related. --- flask_resty/filtering.py | 15 +++++++++++++++ flask_resty/related.py | 15 +++++++++++++++ tests/test_filtering.py | 13 +++++++++++++ tests/test_related.py | 12 ++++++++++++ 4 files changed, 55 insertions(+) diff --git a/flask_resty/filtering.py b/flask_resty/filtering.py index 294800a..73b899a 100644 --- a/flask_resty/filtering.py +++ b/flask_resty/filtering.py @@ -316,6 +316,12 @@ def wrapper(func): class Filtering: """Container for the arg filters on a :py:class:`ModelView`. + `Filtering` supports view inheritance by implementing the `|` operator. For + example, `Filtering(foo=..., bar=...) | Filtering(baz=...)` will create a + new `Filtering` instance with filters for each `foo`, `bar` and `baz`. + Filters on the right-hand side take precedence where each `Filtering` + instance has the same key. + :param dict kwargs: A mapping from filter field names to filters. """ @@ -325,6 +331,15 @@ def __init__(self, **kwargs): for arg_name, arg_filter in kwargs.items() } + def __or__(self, other): + if not isinstance(other, Filtering): + return NotImplemented + + new = Filtering() + new._arg_filters = dict(**self._arg_filters) + new._arg_filters.update(other._arg_filters) + return new + def make_arg_filter(self, arg_name, arg_filter): if callable(arg_filter): arg_filter = ColumnFilter(arg_name, arg_filter) diff --git a/flask_resty/related.py b/flask_resty/related.py index f64d41f..aaf672a 100644 --- a/flask_resty/related.py +++ b/flask_resty/related.py @@ -83,6 +83,12 @@ class Related: through the sequence and resolve each item in turn, using the rules as above. + `Related` supports view inheritance by implementing the `|` operator. For + example, `Related(foo=..., bar=...) | Related(baz=...)` will create a new + `Related` instance with resolvers for each `foo`, `bar` and `baz`. + Resolvers on the right-hand side take precedence where each `Related` + instance has the same key. + :param item_class: The SQLAlchemy mapper corresponding to the related item. :param dict kwargs: A mapping from related fields to a callable resolver. """ @@ -91,6 +97,15 @@ def __init__(self, item_class=None, **kwargs): self._item_class = item_class self._resolvers = kwargs + def __or__(self, other): + if not isinstance(other, Related): + return NotImplemented + + new = Related() + new._resolvers = dict(**self._resolvers) + new._resolvers.update(other._resolvers) + return new + def resolve_related(self, data): """Resolve the related values in the request data. diff --git a/tests/test_filtering.py b/tests/test_filtering.py index 744d642..cb3b897 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -333,3 +333,16 @@ def test_error_reuse_column_filter(): with pytest.raises(TypeError, match="without explicit column name"): Filtering(foo=implicit_column_filter, bar=implicit_column_filter) + + +def test_filter__or__(): + column_filter_foo = ColumnFilter("foo", operator.eq) + column_filter_bar = ColumnFilter("bar", operator.eq) + column_filter_baz = ColumnFilter("baz", operator.eq) + left = Filtering(foo=column_filter_foo, bar=column_filter_bar) + right = Filtering(bar=column_filter_baz) + union = left | right + assert isinstance(union, Filtering) + assert len(union._arg_filters) == 2 + assert union._arg_filters["foo"] is column_filter_foo + assert union._arg_filters["bar"] is column_filter_baz diff --git a/tests/test_related.py b/tests/test_related.py index 2f1b953..e58912f 100644 --- a/tests/test_related.py +++ b/tests/test_related.py @@ -365,3 +365,15 @@ def test_error_missing_id(client): } ], ) + + +def test_related__or__(models): + related_foo = RelatedId(GenericModelView, "view_id") + related_bar = GenericModelView + related_baz = Related(models["parent"]) + left = Related(foo=related_foo, bar=related_bar) + right = Related(bar=related_baz) + union = left | right + assert len(union._resolvers) == 2 + assert union._resolvers["foo"] is related_foo + assert union._resolvers["bar"] is related_baz From a46650f8d7a4af1c958c98eb187c0cc43c030dcd Mon Sep 17 00:00:00 2001 From: Peter Schutt Date: Sun, 26 Jul 2020 08:56:09 +1000 Subject: [PATCH 2/4] Tests for NotImplemented on different type. --- tests/test_filtering.py | 7 +++++++ tests/test_related.py | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/tests/test_filtering.py b/tests/test_filtering.py index cb3b897..89f14d8 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -346,3 +346,10 @@ def test_filter__or__(): assert len(union._arg_filters) == 2 assert union._arg_filters["foo"] is column_filter_foo assert union._arg_filters["bar"] is column_filter_baz + + +def test_filtering__or__typeerror(): + left = Filtering() + right = object() + with pytest.raises(TypeError): + left | right diff --git a/tests/test_related.py b/tests/test_related.py index e58912f..94320c2 100644 --- a/tests/test_related.py +++ b/tests/test_related.py @@ -377,3 +377,10 @@ def test_related__or__(models): assert len(union._resolvers) == 2 assert union._resolvers["foo"] is related_foo assert union._resolvers["bar"] is related_baz + + +def test_related__or__typeerror(): + left = Related() + right = object() + with pytest.raises(TypeError): + left | right From 2e858eddff90c4dd30fd5b4658fa1422050332f1 Mon Sep 17 00:00:00 2001 From: Jimmy Jia Date: Mon, 27 Jul 2020 23:32:19 -0400 Subject: [PATCH 3/4] fixes --- flask_resty/filtering.py | 32 +++++++-------- flask_resty/related.py | 32 ++++++++------- tests/test_filtering.py | 30 ++++++-------- tests/test_related.py | 84 +++++++++++++++++++++++++++++----------- 4 files changed, 105 insertions(+), 73 deletions(-) diff --git a/flask_resty/filtering.py b/flask_resty/filtering.py index 73b899a..55be4dc 100644 --- a/flask_resty/filtering.py +++ b/flask_resty/filtering.py @@ -211,8 +211,7 @@ def maybe_set_arg_name(self, arg_name): if self._column_name and self._column_name != arg_name: raise TypeError( - "cannot use ColumnFilter without explicit column name for " - + "multiple arg names" + "cannot use ColumnFilter without explicit column name for multiple arg names" ) self._column_name = arg_name @@ -316,12 +315,6 @@ def wrapper(func): class Filtering: """Container for the arg filters on a :py:class:`ModelView`. - `Filtering` supports view inheritance by implementing the `|` operator. For - example, `Filtering(foo=..., bar=...) | Filtering(baz=...)` will create a - new `Filtering` instance with filters for each `foo`, `bar` and `baz`. - Filters on the right-hand side take precedence where each `Filtering` - instance has the same key. - :param dict kwargs: A mapping from filter field names to filters. """ @@ -331,15 +324,6 @@ def __init__(self, **kwargs): for arg_name, arg_filter in kwargs.items() } - def __or__(self, other): - if not isinstance(other, Filtering): - return NotImplemented - - new = Filtering() - new._arg_filters = dict(**self._arg_filters) - new._arg_filters.update(other._arg_filters) - return new - def make_arg_filter(self, arg_name, arg_filter): if callable(arg_filter): arg_filter = ColumnFilter(arg_name, arg_filter) @@ -372,3 +356,17 @@ def filter_query(self, query, view): raise e.update({"source": {"parameter": arg_name}}) return query + + def __or__(self, other): + """Combine two `Filtering` instances. + + `Filtering` supports view inheritance by implementing the `|` operator. + For example, `Filtering(foo=..., bar=...) | Filtering(baz=...)` will + create a new `Filtering` instance with filters for each `foo`, `bar` + and `baz`. Filters on the right-hand side take precedence where each + `Filtering` instance has the same key. + """ + if not isinstance(other, Filtering): + return NotImplemented + + return self.__class__(**{**self._arg_filters, **other._arg_filters}) diff --git a/flask_resty/related.py b/flask_resty/related.py index aaf672a..69b34c5 100644 --- a/flask_resty/related.py +++ b/flask_resty/related.py @@ -83,12 +83,6 @@ class Related: through the sequence and resolve each item in turn, using the rules as above. - `Related` supports view inheritance by implementing the `|` operator. For - example, `Related(foo=..., bar=...) | Related(baz=...)` will create a new - `Related` instance with resolvers for each `foo`, `bar` and `baz`. - Resolvers on the right-hand side take precedence where each `Related` - instance has the same key. - :param item_class: The SQLAlchemy mapper corresponding to the related item. :param dict kwargs: A mapping from related fields to a callable resolver. """ @@ -97,15 +91,6 @@ def __init__(self, item_class=None, **kwargs): self._item_class = item_class self._resolvers = kwargs - def __or__(self, other): - if not isinstance(other, Related): - return NotImplemented - - new = Related() - new._resolvers = dict(**self._resolvers) - new._resolvers.update(other._resolvers) - return new - def resolve_related(self, data): """Resolve the related values in the request data. @@ -175,3 +160,20 @@ def resolve_field(self, value, resolver): return [resolve_item(item) for item in value] return resolve_item(value) + + def __or__(self, other): + """Combine two `Related` instances. + + `Related` supports view inheritance by implementing the `|` operator. + For example, `Related(foo=..., bar=...) | Related(baz=...)` will create + a new `Related` instance with resolvers for each `foo`, `bar` and + `baz`. Resolvers on the right-hand side take precedence where each + `Related` instance has the same key. + """ + if not isinstance(other, Related): + return NotImplemented + + return self.__class__( + other._item_class or self._item_class, + **{**self._resolvers, **other._resolvers}, + ) diff --git a/tests/test_filtering.py b/tests/test_filtering.py index 89f14d8..ac5cd47 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -58,7 +58,6 @@ class WidgetViewBase(GenericModelView): model = models["widget"] schema = schemas["widget"] - class WidgetListView(WidgetViewBase): filtering = Filtering( color=operator.eq, color_allow_empty=ColumnFilter( @@ -81,11 +80,14 @@ class WidgetListView(WidgetViewBase): ), ) + class WidgetListView(WidgetViewBase): def get(self): return self.list() class WidgetSizeRequiredListView(WidgetViewBase): - filtering = Filtering(size=ColumnFilter(operator.eq, required=True)) + filtering = WidgetViewBase.filtering | Filtering( + size=ColumnFilter(operator.eq, required=True) + ) def get(self): return self.list() @@ -206,6 +208,11 @@ def test_column_filter_required_present(client): assert_response(response, 200, [{"id": "1", "color": "red", "size": 1}]) +def test_combine(client): + response = client.get("/widgets_size_required?size=1&color=green") + assert_response(response, 200, []) + + def test_column_filter_unvalidated(client): response = client.get("/widgets?size_min_unvalidated=-1") assert_response( @@ -335,21 +342,6 @@ def test_error_reuse_column_filter(): Filtering(foo=implicit_column_filter, bar=implicit_column_filter) -def test_filter__or__(): - column_filter_foo = ColumnFilter("foo", operator.eq) - column_filter_bar = ColumnFilter("bar", operator.eq) - column_filter_baz = ColumnFilter("baz", operator.eq) - left = Filtering(foo=column_filter_foo, bar=column_filter_bar) - right = Filtering(bar=column_filter_baz) - union = left | right - assert isinstance(union, Filtering) - assert len(union._arg_filters) == 2 - assert union._arg_filters["foo"] is column_filter_foo - assert union._arg_filters["bar"] is column_filter_baz - - -def test_filtering__or__typeerror(): - left = Filtering() - right = object() +def test_error_combine_filtering_type_error(): with pytest.raises(TypeError): - left | right + Filtering() | {} diff --git a/tests/test_related.py b/tests/test_related.py index 94320c2..6d27e45 100644 --- a/tests/test_related.py +++ b/tests/test_related.py @@ -17,8 +17,6 @@ class Parent(db.Model): id = Column(Integer, primary_key=True) name = Column(String) - children = relationship("Child", backref="parent", cascade="all") - class Child(db.Model): __tablename__ = "children" @@ -26,6 +24,12 @@ class Child(db.Model): name = Column(String) parent_id = Column(ForeignKey(Parent.id)) + parent = relationship( + Parent, foreign_keys=parent_id, backref="children" + ) + + other_parent_id = Column(ForeignKey(Parent.id)) + other_parent = relationship(Parent, foreign_keys=other_parent_id) db.create_all() @@ -40,13 +44,15 @@ class ParentSchema(Schema): id = fields.Integer(as_string=True) name = fields.String(required=True) - children = RelatedItem("ChildSchema", many=True, exclude=("parent",)) + children = RelatedItem( + "ChildSchema", many=True, exclude=("parent", "other_parent") + ) child_ids = fields.List(fields.Integer(as_string=True), load_only=True) class ChildSchema(Schema): @classmethod def get_query_options(cls, load): - return (load.joinedload("parent"),) + return (load.joinedload("parent"), load.joinedload("other_parent")) id = fields.Integer(as_string=True) name = fields.String(required=True) @@ -58,6 +64,10 @@ def get_query_options(cls, load): as_string=True, allow_none=True, load_only=True ) + other_parent = fields.Nested( + ParentSchema, exclude=("children",), allow_none=True + ) + return {"parent": ParentSchema(), "child": ChildSchema()} @@ -110,12 +120,20 @@ class NestedChildView(GenericModelView): def put(self, id): return self.update(id) + class ChildWithOtherParentView(ChildView): + related = ChildView.related | Related( + other_parent=Related(models["parent"]) + ) + api = Api(app) api.add_resource("/parents/", ParentView) api.add_resource("/nested_parents/", NestedParentView) api.add_resource("/parents_with_create/", ParentWithCreateView) api.add_resource("/children/", ChildView) api.add_resource("/nested_children/", NestedChildView) + api.add_resource( + "/children_with_other_parent/", ChildWithOtherParentView + ) @pytest.fixture(autouse=True) @@ -163,6 +181,7 @@ def test_single(client): "id": "1", "name": "Updated Child", "parent": {"id": "1", "name": "Parent"}, + "other_parent": None, }, ) @@ -180,6 +199,7 @@ def test_single_nested(client): "id": "1", "name": "Updated Child", "parent": {"id": "1", "name": "Parent"}, + "other_parent": None, }, ) @@ -266,6 +286,7 @@ def test_missing(client): "id": "1", "name": "Twice Updated Child", "parent": {"id": "1", "name": "Parent"}, + "other_parent": None, }, ) @@ -280,7 +301,12 @@ def test_null(client): assert_response( response, 200, - {"id": "1", "name": "Twice Updated Child", "parent": None}, + { + "id": "1", + "name": "Twice Updated Child", + "parent": None, + "other_parent": None, + }, ) @@ -294,7 +320,12 @@ def test_null_nested(client): assert_response( response, 200, - {"id": "1", "name": "Twice Updated Child", "parent": None}, + { + "id": "1", + "name": "Twice Updated Child", + "parent": None, + "other_parent": None, + }, ) @@ -313,6 +344,29 @@ def test_many_falsy(client): ) +def test_combine(client): + response = client.put( + "/children_with_other_parent/1", + data={ + "id": "1", + "name": "Updated Child", + "parent_id": "1", + "other_parent": {"id": "2", "name": "Other Parent"}, + }, + ) + + assert_response( + response, + 200, + { + "id": "1", + "name": "Updated Child", + "parent": {"id": "1", "name": "Parent"}, + "other_parent": {"id": "2", "name": "Other Parent"}, + }, + ) + + # ----------------------------------------------------------------------------- @@ -367,20 +421,6 @@ def test_error_missing_id(client): ) -def test_related__or__(models): - related_foo = RelatedId(GenericModelView, "view_id") - related_bar = GenericModelView - related_baz = Related(models["parent"]) - left = Related(foo=related_foo, bar=related_bar) - right = Related(bar=related_baz) - union = left | right - assert len(union._resolvers) == 2 - assert union._resolvers["foo"] is related_foo - assert union._resolvers["bar"] is related_baz - - -def test_related__or__typeerror(): - left = Related() - right = object() +def test_error_combine_related_type_error(): with pytest.raises(TypeError): - left | right + Related() | {} From b881a8fe197802dd46e55ce63404c99ae803326f Mon Sep 17 00:00:00 2001 From: Jimmy Jia Date: Mon, 27 Jul 2020 23:34:02 -0400 Subject: [PATCH 4/4] oops --- tests/test_related.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_related.py b/tests/test_related.py index 6d27e45..92f2ce0 100644 --- a/tests/test_related.py +++ b/tests/test_related.py @@ -159,12 +159,16 @@ def test_baseline(client): child_1_response = client.get("/children/1") assert_response( - child_1_response, 200, {"id": "1", "name": "Child 1", "parent": None} + child_1_response, + 200, + {"id": "1", "name": "Child 1", "parent": None, "other_parent": None}, ) child_2_response = client.get("/children/2") assert_response( - child_2_response, 200, {"id": "2", "name": "Child 2", "parent": None} + child_2_response, + 200, + {"id": "2", "name": "Child 2", "parent": None, "other_parent": None}, )