diff --git a/flask_resty/filtering.py b/flask_resty/filtering.py index 294800a..55be4dc 100644 --- a/flask_resty/filtering.py +++ b/flask_resty/filtering.py @@ -211,8 +211,7 @@ def maybe_set_arg_name(self, arg_name): if self._column_name and self._column_name != arg_name: raise TypeError( - "cannot use ColumnFilter without explicit column name for " - + "multiple arg names" + "cannot use ColumnFilter without explicit column name for multiple arg names" ) self._column_name = arg_name @@ -357,3 +356,17 @@ def filter_query(self, query, view): raise e.update({"source": {"parameter": arg_name}}) return query + + def __or__(self, other): + """Combine two `Filtering` instances. + + `Filtering` supports view inheritance by implementing the `|` operator. + For example, `Filtering(foo=..., bar=...) | Filtering(baz=...)` will + create a new `Filtering` instance with filters for each `foo`, `bar` + and `baz`. Filters on the right-hand side take precedence where each + `Filtering` instance has the same key. + """ + if not isinstance(other, Filtering): + return NotImplemented + + return self.__class__(**{**self._arg_filters, **other._arg_filters}) diff --git a/flask_resty/related.py b/flask_resty/related.py index f64d41f..69b34c5 100644 --- a/flask_resty/related.py +++ b/flask_resty/related.py @@ -160,3 +160,20 @@ def resolve_field(self, value, resolver): return [resolve_item(item) for item in value] return resolve_item(value) + + def __or__(self, other): + """Combine two `Related` instances. + + `Related` supports view inheritance by implementing the `|` operator. + For example, `Related(foo=..., bar=...) | Related(baz=...)` will create + a new `Related` instance with resolvers for each `foo`, `bar` and + `baz`. Resolvers on the right-hand side take precedence where each + `Related` instance has the same key. + """ + if not isinstance(other, Related): + return NotImplemented + + return self.__class__( + other._item_class or self._item_class, + **{**self._resolvers, **other._resolvers}, + ) diff --git a/tests/test_filtering.py b/tests/test_filtering.py index 744d642..ac5cd47 100644 --- a/tests/test_filtering.py +++ b/tests/test_filtering.py @@ -58,7 +58,6 @@ class WidgetViewBase(GenericModelView): model = models["widget"] schema = schemas["widget"] - class WidgetListView(WidgetViewBase): filtering = Filtering( color=operator.eq, color_allow_empty=ColumnFilter( @@ -81,11 +80,14 @@ class WidgetListView(WidgetViewBase): ), ) + class WidgetListView(WidgetViewBase): def get(self): return self.list() class WidgetSizeRequiredListView(WidgetViewBase): - filtering = Filtering(size=ColumnFilter(operator.eq, required=True)) + filtering = WidgetViewBase.filtering | Filtering( + size=ColumnFilter(operator.eq, required=True) + ) def get(self): return self.list() @@ -206,6 +208,11 @@ def test_column_filter_required_present(client): assert_response(response, 200, [{"id": "1", "color": "red", "size": 1}]) +def test_combine(client): + response = client.get("/widgets_size_required?size=1&color=green") + assert_response(response, 200, []) + + def test_column_filter_unvalidated(client): response = client.get("/widgets?size_min_unvalidated=-1") assert_response( @@ -333,3 +340,8 @@ def test_error_reuse_column_filter(): with pytest.raises(TypeError, match="without explicit column name"): Filtering(foo=implicit_column_filter, bar=implicit_column_filter) + + +def test_error_combine_filtering_type_error(): + with pytest.raises(TypeError): + Filtering() | {} diff --git a/tests/test_related.py b/tests/test_related.py index 2f1b953..92f2ce0 100644 --- a/tests/test_related.py +++ b/tests/test_related.py @@ -17,8 +17,6 @@ class Parent(db.Model): id = Column(Integer, primary_key=True) name = Column(String) - children = relationship("Child", backref="parent", cascade="all") - class Child(db.Model): __tablename__ = "children" @@ -26,6 +24,12 @@ class Child(db.Model): name = Column(String) parent_id = Column(ForeignKey(Parent.id)) + parent = relationship( + Parent, foreign_keys=parent_id, backref="children" + ) + + other_parent_id = Column(ForeignKey(Parent.id)) + other_parent = relationship(Parent, foreign_keys=other_parent_id) db.create_all() @@ -40,13 +44,15 @@ class ParentSchema(Schema): id = fields.Integer(as_string=True) name = fields.String(required=True) - children = RelatedItem("ChildSchema", many=True, exclude=("parent",)) + children = RelatedItem( + "ChildSchema", many=True, exclude=("parent", "other_parent") + ) child_ids = fields.List(fields.Integer(as_string=True), load_only=True) class ChildSchema(Schema): @classmethod def get_query_options(cls, load): - return (load.joinedload("parent"),) + return (load.joinedload("parent"), load.joinedload("other_parent")) id = fields.Integer(as_string=True) name = fields.String(required=True) @@ -58,6 +64,10 @@ def get_query_options(cls, load): as_string=True, allow_none=True, load_only=True ) + other_parent = fields.Nested( + ParentSchema, exclude=("children",), allow_none=True + ) + return {"parent": ParentSchema(), "child": ChildSchema()} @@ -110,12 +120,20 @@ class NestedChildView(GenericModelView): def put(self, id): return self.update(id) + class ChildWithOtherParentView(ChildView): + related = ChildView.related | Related( + other_parent=Related(models["parent"]) + ) + api = Api(app) api.add_resource("/parents/", ParentView) api.add_resource("/nested_parents/", NestedParentView) api.add_resource("/parents_with_create/", ParentWithCreateView) api.add_resource("/children/", ChildView) api.add_resource("/nested_children/", NestedChildView) + api.add_resource( + "/children_with_other_parent/", ChildWithOtherParentView + ) @pytest.fixture(autouse=True) @@ -141,12 +159,16 @@ def test_baseline(client): child_1_response = client.get("/children/1") assert_response( - child_1_response, 200, {"id": "1", "name": "Child 1", "parent": None} + child_1_response, + 200, + {"id": "1", "name": "Child 1", "parent": None, "other_parent": None}, ) child_2_response = client.get("/children/2") assert_response( - child_2_response, 200, {"id": "2", "name": "Child 2", "parent": None} + child_2_response, + 200, + {"id": "2", "name": "Child 2", "parent": None, "other_parent": None}, ) @@ -163,6 +185,7 @@ def test_single(client): "id": "1", "name": "Updated Child", "parent": {"id": "1", "name": "Parent"}, + "other_parent": None, }, ) @@ -180,6 +203,7 @@ def test_single_nested(client): "id": "1", "name": "Updated Child", "parent": {"id": "1", "name": "Parent"}, + "other_parent": None, }, ) @@ -266,6 +290,7 @@ def test_missing(client): "id": "1", "name": "Twice Updated Child", "parent": {"id": "1", "name": "Parent"}, + "other_parent": None, }, ) @@ -280,7 +305,12 @@ def test_null(client): assert_response( response, 200, - {"id": "1", "name": "Twice Updated Child", "parent": None}, + { + "id": "1", + "name": "Twice Updated Child", + "parent": None, + "other_parent": None, + }, ) @@ -294,7 +324,12 @@ def test_null_nested(client): assert_response( response, 200, - {"id": "1", "name": "Twice Updated Child", "parent": None}, + { + "id": "1", + "name": "Twice Updated Child", + "parent": None, + "other_parent": None, + }, ) @@ -313,6 +348,29 @@ def test_many_falsy(client): ) +def test_combine(client): + response = client.put( + "/children_with_other_parent/1", + data={ + "id": "1", + "name": "Updated Child", + "parent_id": "1", + "other_parent": {"id": "2", "name": "Other Parent"}, + }, + ) + + assert_response( + response, + 200, + { + "id": "1", + "name": "Updated Child", + "parent": {"id": "1", "name": "Parent"}, + "other_parent": {"id": "2", "name": "Other Parent"}, + }, + ) + + # ----------------------------------------------------------------------------- @@ -365,3 +423,8 @@ def test_error_missing_id(client): } ], ) + + +def test_error_combine_related_type_error(): + with pytest.raises(TypeError): + Related() | {}