From 77847a2bcce797b4a1165ab0f034d05f19620dce Mon Sep 17 00:00:00 2001 From: Lukas Plank Date: Mon, 16 Dec 2024 11:47:00 +0100 Subject: [PATCH 1/3] test: test adaptions for QueryConstructor rewrite Adapt tests for query construction rewrite and move to httpx for running SPARQL queries. The change also restructures parts of the test suite towards closer proximity of testing parameters and test functions. Closes #173. --- tests/data/models/dummy_model.py | 14 --- .../test_adapter_grouped_pagination.py | 95 +++++++++++----- .../params}/count_query_parameters.py | 13 ++- .../test_query_constructor_items_query.py | 105 ++++++++++++++++++ ...l_bindings_mapper_model_bool_parameters.py | 16 +++ .../model_bindings_mapper_parameters.py | 14 ++- .../models/author_array_collection_model.py | 0 .../params}/models/author_work_title_model.py | 0 .../params}/models/basic_model.py | 0 .../params}/models/grouping_model.py | 0 .../params}/models/nested_grouping_model.py | 0 .../test_model_bindings_mapper.py | 6 +- .../test_model_bindings_mapper_model_bool.py | 2 +- tests/unit/test_construct_count_query.py | 36 ------ ...sad_path_get_bindings_from_query_result.py | 17 --- .../test_add_solution_modifier.py | 72 ++++++++++++ .../test_get_query_projection.py | 63 +++++++++++ .../test_inject_subquery.py | 6 +- .../test_remove_sparql_prefixes.py | 61 ++++++++++ .../test_replace_query_select_clause.py | 0 ...st_sad_path_replace_query_select_clause.py | 0 .../tests_utils/test_field_bindings_map.py | 25 +++++ tests/utils/utils.py | 11 ++ 23 files changed, 451 insertions(+), 105 deletions(-) delete mode 100644 tests/data/models/dummy_model.py rename tests/{data/parameters => tests_constructor/params}/count_query_parameters.py (90%) create mode 100644 tests/tests_constructor/test_query_constructor_items_query.py rename tests/{data/parameters => tests_mapper/params}/model_bindings_mapper_model_bool_parameters.py (88%) rename tests/{data/parameters => tests_mapper/params}/model_bindings_mapper_parameters.py (94%) rename tests/{data => tests_mapper/params}/models/author_array_collection_model.py (100%) rename tests/{data => tests_mapper/params}/models/author_work_title_model.py (100%) rename tests/{data => tests_mapper/params}/models/basic_model.py (100%) rename tests/{data => tests_mapper/params}/models/grouping_model.py (100%) rename tests/{data => tests_mapper/params}/models/nested_grouping_model.py (100%) delete mode 100644 tests/unit/test_construct_count_query.py delete mode 100644 tests/unit/test_sad_path_get_bindings_from_query_result.py create mode 100644 tests/unit/tests_sparql_utils/test_add_solution_modifier.py create mode 100644 tests/unit/tests_sparql_utils/test_get_query_projection.py rename tests/unit/{ => tests_sparql_utils}/test_inject_subquery.py (95%) create mode 100644 tests/unit/tests_sparql_utils/test_remove_sparql_prefixes.py rename tests/unit/{ => tests_sparql_utils}/test_replace_query_select_clause.py (100%) rename tests/unit/{ => tests_sparql_utils}/test_sad_path_replace_query_select_clause.py (100%) create mode 100644 tests/unit/tests_utils/test_field_bindings_map.py create mode 100644 tests/utils/utils.py diff --git a/tests/data/models/dummy_model.py b/tests/data/models/dummy_model.py deleted file mode 100644 index 041bf4d..0000000 --- a/tests/data/models/dummy_model.py +++ /dev/null @@ -1,14 +0,0 @@ -"""Simple dummy models e.g. for count query constructor testing.""" - -from pydantic import BaseModel -from rdfproxy import ConfigDict - - -class Dummy(BaseModel): - pass - - -class GroupedDummy(BaseModel): - model_config = ConfigDict(group_by="x") - - x: int diff --git a/tests/tests_adapter/test_adapter_grouped_pagination.py b/tests/tests_adapter/test_adapter_grouped_pagination.py index 7fd68a0..8ac6706 100644 --- a/tests/tests_adapter/test_adapter_grouped_pagination.py +++ b/tests/tests_adapter/test_adapter_grouped_pagination.py @@ -2,15 +2,16 @@ from typing import Annotated, Any, NamedTuple -import pytest - from pydantic import BaseModel +import pytest from rdfproxy import ( ConfigDict, + HttpxStrategy, Page, QueryParameters, SPARQLBinding, SPARQLModelAdapter, + SPARQLWrapperStrategy, ) @@ -57,28 +58,43 @@ class Parent(BaseModel): children: list[Child] -binding_adapter = SPARQLModelAdapter( - target="https://graphdb.r11.eu/repositories/RELEVEN", - query=binding_query, - model=BindingParent, -) +@pytest.fixture(params=[HttpxStrategy, SPARQLWrapperStrategy]) +def adapter(request): + return SPARQLModelAdapter( + target="https://graphdb.r11.eu/repositories/RELEVEN", + query=query, + model=Parent, + sparql_strategy=request.param, + ) + + +@pytest.fixture(params=[HttpxStrategy, SPARQLWrapperStrategy]) +def binding_adapter(request): + return SPARQLModelAdapter( + target="https://graphdb.r11.eu/repositories/RELEVEN", + query=binding_query, + model=BindingParent, + sparql_strategy=request.param, + ) -adapter = SPARQLModelAdapter( - target="https://graphdb.r11.eu/repositories/RELEVEN", - query=query, - model=Parent, -) + +@pytest.fixture(params=[HttpxStrategy, SPARQLWrapperStrategy]) +def ungrouped_adapter(request): + return SPARQLModelAdapter( + target="https://graphdb.r11.eu/repositories/RELEVEN", + query=query, + model=Child, + sparql_strategy=request.param, + ) class AdapterParameter(NamedTuple): - adapter: SPARQLModelAdapter query_parameters: dict[str, Any] expected: Page -adapter_parameters = [ +binding_adapter_parameters = [ AdapterParameter( - adapter=binding_adapter, query_parameters={"page": 1, "size": 2}, expected=Page[BindingParent]( items=[ @@ -92,7 +108,6 @@ class AdapterParameter(NamedTuple): ), ), AdapterParameter( - adapter=binding_adapter, query_parameters={"page": 2, "size": 2}, expected=Page[BindingParent]( items=[{"parent": "z", "children": []}], @@ -103,7 +118,6 @@ class AdapterParameter(NamedTuple): ), ), AdapterParameter( - adapter=binding_adapter, query_parameters={"page": 1, "size": 1}, expected=Page[BindingParent]( items=[{"parent": "x", "children": [{"name": "foo"}]}], @@ -114,22 +128,21 @@ class AdapterParameter(NamedTuple): ), ), AdapterParameter( - adapter=binding_adapter, query_parameters={"page": 2, "size": 1}, expected=Page[BindingParent]( items=[{"parent": "y", "children": []}], page=2, size=1, total=3, pages=3 ), ), AdapterParameter( - adapter=binding_adapter, query_parameters={"page": 3, "size": 1}, expected=Page[BindingParent]( items=[{"parent": "z", "children": []}], page=3, size=1, total=3, pages=3 ), ), - # +] +# +adapter_parameters = [ AdapterParameter( - adapter=adapter, query_parameters={"page": 1, "size": 2}, expected=Page[Parent]( items=[ @@ -143,7 +156,6 @@ class AdapterParameter(NamedTuple): ), ), AdapterParameter( - adapter=adapter, query_parameters={"page": 2, "size": 2}, expected=Page[Parent]( items=[{"parent": "z", "children": []}], @@ -154,7 +166,6 @@ class AdapterParameter(NamedTuple): ), ), AdapterParameter( - adapter=adapter, query_parameters={"page": 1, "size": 1}, expected=Page[Parent]( items=[{"parent": "x", "children": [{"name": "foo"}]}], @@ -165,14 +176,12 @@ class AdapterParameter(NamedTuple): ), ), AdapterParameter( - adapter=adapter, query_parameters={"page": 2, "size": 1}, expected=Page[Parent]( items=[{"parent": "y", "children": []}], page=2, size=1, total=3, pages=3 ), ), AdapterParameter( - adapter=adapter, query_parameters={"page": 3, "size": 1}, expected=Page[Parent]( items=[{"parent": "z", "children": []}], page=3, size=1, total=3, pages=3 @@ -180,11 +189,45 @@ class AdapterParameter(NamedTuple): ), ] +ungrouped_adapter_parameters = [ + AdapterParameter( + query_parameters={"page": 1, "size": 100}, + expected=Page[Child]( + items=[{"name": "foo"}], page=1, size=100, total=1, pages=1 + ), + ), +] + @pytest.mark.remote @pytest.mark.parametrize( - ["adapter", "query_parameters", "expected"], adapter_parameters + ["query_parameters", "expected"], + adapter_parameters, ) def test_basic_adapter_grouped_pagination(adapter, query_parameters, expected): parameters = QueryParameters(**query_parameters) assert adapter.query(parameters) == expected + + +@pytest.mark.remote +@pytest.mark.parametrize( + ["query_parameters", "expected"], + binding_adapter_parameters, +) +def test_basic_binding_adapter_grouped_pagination( + binding_adapter, query_parameters, expected +): + parameters = QueryParameters(**query_parameters) + assert binding_adapter.query(parameters) == expected + + +@pytest.mark.xfail +@pytest.mark.remote +@pytest.mark.parametrize( + ["query_parameters", "expected"], + ungrouped_adapter_parameters, +) +def test_basic_ungrouped_pagination(ungrouped_adapter, query_parameters, expected): + """This shows a possible pagination count bug that needs investigating.""" + parameters = QueryParameters(**query_parameters) + assert ungrouped_adapter.query(parameters) == expected diff --git a/tests/data/parameters/count_query_parameters.py b/tests/tests_constructor/params/count_query_parameters.py similarity index 90% rename from tests/data/parameters/count_query_parameters.py rename to tests/tests_constructor/params/count_query_parameters.py index e7aa9ad..7d93b32 100644 --- a/tests/data/parameters/count_query_parameters.py +++ b/tests/tests_constructor/params/count_query_parameters.py @@ -1,7 +1,18 @@ -from tests.data.models.dummy_model import Dummy, GroupedDummy +from pydantic import BaseModel +from rdfproxy import ConfigDict from tests.utils._types import CountQueryParameter +class Dummy(BaseModel): + pass + + +class GroupedDummy(BaseModel): + model_config = ConfigDict(group_by="x") + + x: int + + construct_count_query_parameters = [ CountQueryParameter( query=""" diff --git a/tests/tests_constructor/test_query_constructor_items_query.py b/tests/tests_constructor/test_query_constructor_items_query.py new file mode 100644 index 0000000..79ad389 --- /dev/null +++ b/tests/tests_constructor/test_query_constructor_items_query.py @@ -0,0 +1,105 @@ +"""Basic tests for the QueryConstructor class.""" + +from typing import NamedTuple + +import pytest + +from pydantic import BaseModel +from rdfproxy.constructor import QueryConstructor +from rdfproxy.utils._types import ConfigDict +from rdfproxy.utils.models import QueryParameters + + +class UngroupedModel(BaseModel): + x: int + y: int + + +class GroupedModel(BaseModel): + model_config = ConfigDict(group_by="x") + + x: int + y: list[int] + + +class Expected(NamedTuple): + count_query: str + items_query: str + + +class QueryConstructorParameters(NamedTuple): + query: str + query_parameters: QueryParameters + model: type[BaseModel] + + expected: Expected + + +parameters = [ + # ungrouped + QueryConstructorParameters( + query="select * where {?s ?p ?o}", + query_parameters=QueryParameters(), + model=UngroupedModel, + expected=Expected( + count_query="select (count(*) as ?cnt) where {?s ?p ?o}", + items_query="select * where {?s ?p ?o} order by ?s limit 100 offset 0", + ), + ), + QueryConstructorParameters( + query="select ?p ?o where {?s ?p ?o}", + query_parameters=QueryParameters(), + model=UngroupedModel, + expected=Expected( + count_query="select (count(*) as ?cnt) where {?s ?p ?o}", + items_query="select ?p ?o where {?s ?p ?o} order by ?p limit 100 offset 0", + ), + ), + QueryConstructorParameters( + query="select * where {?s ?p ?o}", + query_parameters=QueryParameters(page=2, size=2), + model=UngroupedModel, + expected=Expected( + count_query="select (count(*) as ?cnt) where {?s ?p ?o}", + items_query="select * where {?s ?p ?o} order by ?s limit 2 offset 2", + ), + ), + # grouped + QueryConstructorParameters( + query="select * where {?x a ?y}", + query_parameters=QueryParameters(), + model=GroupedModel, + expected=Expected( + count_query="select (count(distinct ?x) as ?cnt) where {?x a ?y}", + items_query="select * where {?x a ?y {select distinct ?x where {?x a ?y} order by ?x limit 100 offset 0} }", + ), + ), + QueryConstructorParameters( + query="select ?x ?y where {?x a ?y}", + query_parameters=QueryParameters(), + model=GroupedModel, + expected=Expected( + count_query="select (count(distinct ?x) as ?cnt) where {?x a ?y}", + items_query="select ?x ?y where {?x a ?y {select distinct ?x where {?x a ?y} order by ?x limit 100 offset 0} }", + ), + ), + QueryConstructorParameters( + query="select ?x ?y where {?x a ?y}", + query_parameters=QueryParameters(page=2, size=2), + model=GroupedModel, + expected=Expected( + count_query="select (count(distinct ?x) as ?cnt) where {?x a ?y}", + items_query="select ?x ?y where {?x a ?y {select distinct ?x where {?x a ?y} order by ?x limit 2 offset 2} }", + ), + ), +] + + +@pytest.mark.parametrize(["query", "query_parameters", "model", "expected"], parameters) +def test_query_constructor_items_query(query, query_parameters, model, expected): + constructor = QueryConstructor( + query=query, query_parameters=query_parameters, model=model + ) + + assert constructor.get_count_query() == expected.count_query + assert constructor.get_items_query() == expected.items_query diff --git a/tests/data/parameters/model_bindings_mapper_model_bool_parameters.py b/tests/tests_mapper/params/model_bindings_mapper_model_bool_parameters.py similarity index 88% rename from tests/data/parameters/model_bindings_mapper_model_bool_parameters.py rename to tests/tests_mapper/params/model_bindings_mapper_model_bool_parameters.py index 5721292..a4c23b5 100644 --- a/tests/data/parameters/model_bindings_mapper_model_bool_parameters.py +++ b/tests/tests_mapper/params/model_bindings_mapper_model_bool_parameters.py @@ -43,6 +43,13 @@ class Child5(BaseModel): child: str | None = Field(default=None, exclude=True) +class Child6(BaseModel): + model_config = ConfigDict(model_bool=["name", "child"]) + + name: str | None = None + child: str | None = None + + def _create_parent_with_child(child: type[BaseModel]) -> type[BaseModel]: model = create_model( "Parent", @@ -112,4 +119,13 @@ def _create_parent_with_child(child: type[BaseModel]) -> type[BaseModel]: {"parent": "z", "children": []}, ], ), + ModelBindingsMapperParameter( + model=_create_parent_with_child(Child6), + bindings=bindings, + expected=[ + {"parent": "x", "children": [{"name": "foo", "child": "c"}]}, + {"parent": "y", "children": []}, + {"parent": "z", "children": []}, + ], + ), ] diff --git a/tests/data/parameters/model_bindings_mapper_parameters.py b/tests/tests_mapper/params/model_bindings_mapper_parameters.py similarity index 94% rename from tests/data/parameters/model_bindings_mapper_parameters.py rename to tests/tests_mapper/params/model_bindings_mapper_parameters.py index 5d04fd4..8f9b70b 100644 --- a/tests/data/parameters/model_bindings_mapper_parameters.py +++ b/tests/tests_mapper/params/model_bindings_mapper_parameters.py @@ -1,12 +1,16 @@ -from tests.data.models.author_array_collection_model import Author as ArrayAuthor -from tests.data.models.author_work_title_model import Author -from tests.data.models.basic_model import ( +from tests.tests_mapper.params.models.author_array_collection_model import ( + Author as ArrayAuthor, +) +from tests.tests_mapper.params.models.author_work_title_model import Author +from tests.tests_mapper.params.models.basic_model import ( BasicComplexModel, BasicNestedModel, BasicSimpleModel, ) -from tests.data.models.grouping_model import GroupingComplexModel -from tests.data.models.nested_grouping_model import NestedGroupingComplexModel +from tests.tests_mapper.params.models.grouping_model import GroupingComplexModel +from tests.tests_mapper.params.models.nested_grouping_model import ( + NestedGroupingComplexModel, +) from tests.utils._types import ModelBindingsMapperParameter diff --git a/tests/data/models/author_array_collection_model.py b/tests/tests_mapper/params/models/author_array_collection_model.py similarity index 100% rename from tests/data/models/author_array_collection_model.py rename to tests/tests_mapper/params/models/author_array_collection_model.py diff --git a/tests/data/models/author_work_title_model.py b/tests/tests_mapper/params/models/author_work_title_model.py similarity index 100% rename from tests/data/models/author_work_title_model.py rename to tests/tests_mapper/params/models/author_work_title_model.py diff --git a/tests/data/models/basic_model.py b/tests/tests_mapper/params/models/basic_model.py similarity index 100% rename from tests/data/models/basic_model.py rename to tests/tests_mapper/params/models/basic_model.py diff --git a/tests/data/models/grouping_model.py b/tests/tests_mapper/params/models/grouping_model.py similarity index 100% rename from tests/data/models/grouping_model.py rename to tests/tests_mapper/params/models/grouping_model.py diff --git a/tests/data/models/nested_grouping_model.py b/tests/tests_mapper/params/models/nested_grouping_model.py similarity index 100% rename from tests/data/models/nested_grouping_model.py rename to tests/tests_mapper/params/models/nested_grouping_model.py diff --git a/tests/tests_mapper/test_model_bindings_mapper.py b/tests/tests_mapper/test_model_bindings_mapper.py index 4824674..52c47ed 100644 --- a/tests/tests_mapper/test_model_bindings_mapper.py +++ b/tests/tests_mapper/test_model_bindings_mapper.py @@ -1,10 +1,10 @@ """Pytest entry point for basic rdfproxy.mapper.ModelBindingsMapper.""" -from pydantic import BaseModel import pytest -from rdfproxy.mapper import ModelBindingsMapper -from tests.data.parameters.model_bindings_mapper_parameters import ( +from pydantic import BaseModel +from rdfproxy.mapper import ModelBindingsMapper +from tests.tests_mapper.params.model_bindings_mapper_parameters import ( author_array_collection_parameters, author_work_title_parameters, basic_parameters, diff --git a/tests/tests_mapper/test_model_bindings_mapper_model_bool.py b/tests/tests_mapper/test_model_bindings_mapper_model_bool.py index f9bde75..a14077b 100644 --- a/tests/tests_mapper/test_model_bindings_mapper_model_bool.py +++ b/tests/tests_mapper/test_model_bindings_mapper_model_bool.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from rdfproxy.mapper import ModelBindingsMapper -from tests.data.parameters.model_bindings_mapper_model_bool_parameters import ( +from tests.tests_mapper.params.model_bindings_mapper_model_bool_parameters import ( parent_child_parameters, ) diff --git a/tests/unit/test_construct_count_query.py b/tests/unit/test_construct_count_query.py deleted file mode 100644 index 94f7d10..0000000 --- a/tests/unit/test_construct_count_query.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Unit tests for rdfproxy.utils.sparql_utils.construct_count_query.""" - -import pytest - -from rdflib import Graph -from rdflib.plugins.sparql.processor import SPARQLResult -from rdfproxy.utils.sparql_utils import construct_count_query -from tests.data.parameters.count_query_parameters import ( - construct_count_query_parameters, -) - - -def _get_cnt_value_from_sparql_result( - result: SPARQLResult, count_var: str = "cnt" -) -> int: - """Get the 'cnt' binding of a count query from a SPARQLResult object.""" - return int(result.bindings[0][count_var]) - - -@pytest.mark.parametrize( - ["query", "model", "expected"], construct_count_query_parameters -) -def test_basic_construct_count_query(query, model, expected): - """Check the count of a grouped model. - - The count query constructed based on a grouped value must only count - distinct values according to the grouping specified in the model. - """ - - graph: Graph = Graph() - count_query: str = construct_count_query(query, model) - query_result: SPARQLResult = graph.query(count_query) - - cnt: int = _get_cnt_value_from_sparql_result(query_result) - - assert cnt == expected diff --git a/tests/unit/test_sad_path_get_bindings_from_query_result.py b/tests/unit/test_sad_path_get_bindings_from_query_result.py deleted file mode 100644 index 2bc655c..0000000 --- a/tests/unit/test_sad_path_get_bindings_from_query_result.py +++ /dev/null @@ -1,17 +0,0 @@ -"""Sad path tests for rdfprox.utils.sparql_utils.get_bindings_from_query_result.""" - -from unittest import mock - -import pytest - -from rdfproxy.utils.sparql_utils import get_bindings_from_query_result - - -def test_basic_sad_path_get_bindings_from_query_result(): - with mock.patch("SPARQLWrapper.QueryResult") as mock_query_result: - mock_query_result.return_value.requestedFormat = "xml" - exception_message = ( - "Only QueryResult objects with JSON format are currently supported." - ) - with pytest.raises(Exception, match=exception_message): - get_bindings_from_query_result(mock_query_result) diff --git a/tests/unit/tests_sparql_utils/test_add_solution_modifier.py b/tests/unit/tests_sparql_utils/test_add_solution_modifier.py new file mode 100644 index 0000000..a276286 --- /dev/null +++ b/tests/unit/tests_sparql_utils/test_add_solution_modifier.py @@ -0,0 +1,72 @@ +from typing import NamedTuple + +import pytest +from rdfproxy.utils.sparql_utils import add_solution_modifier + + +class AddSolutionModifierParameter(NamedTuple): + query: str + parameters: dict + expected: str + + +parameters = [ + # basics + AddSolutionModifierParameter( + query="prefix ns: select * where {?s ?p ?o }", + parameters={"order_by": None, "limit": None, "offset": None}, + expected="prefix ns: select * where {?s ?p ?o }", + ), + AddSolutionModifierParameter( + query="prefix ns: select * where {?s ?p ?o }", + parameters={"order_by": None, "limit": None, "offset": 1}, + expected="prefix ns: select * where {?s ?p ?o } offset 1", + ), + AddSolutionModifierParameter( + query="prefix ns: select * where {?s ?p ?o }", + parameters={"order_by": None, "limit": 1, "offset": None}, + expected="prefix ns: select * where {?s ?p ?o } limit 1", + ), + AddSolutionModifierParameter( + query="prefix ns: select * where {?s ?p ?o }", + parameters={"order_by": "x", "limit": None, "offset": None}, + expected="prefix ns: select * where {?s ?p ?o } order by ?x", + ), + AddSolutionModifierParameter( + query="prefix ns: select * where {?s ?p ?o }", + parameters={"order_by": "x", "limit": 1, "offset": 1}, + expected="prefix ns: select * where {?s ?p ?o } order by ?x limit 1 offset 1", + ), + # order + AddSolutionModifierParameter( + query="prefix ns: select * where {?s ?p ?o }", + parameters={"offset": 1, "limit": 1, "order_by": None}, + expected="prefix ns: select * where {?s ?p ?o } limit 1 offset 1", + ), + AddSolutionModifierParameter( + query="prefix ns: select * where {?s ?p ?o }", + parameters={"offset": 1, "limit": 1, "order_by": None}, + expected="prefix ns: select * where {?s ?p ?o } limit 1 offset 1", + ), + AddSolutionModifierParameter( + query="prefix ns: select * where {?s ?p ?o }", + parameters={"offset": 1, "limit": None, "order_by": "x"}, + expected="prefix ns: select * where {?s ?p ?o } order by ?x offset 1", + ), + AddSolutionModifierParameter( + query="prefix ns: select * where {?s ?p ?o }", + parameters={"offset": None, "limit": 1, "order_by": "x"}, + expected="prefix ns: select * where {?s ?p ?o } order by ?x limit 1", + ), + AddSolutionModifierParameter( + query="prefix ns: select * where {?s ?p ?o }", + parameters={"offset": 1, "limit": 1, "order_by": "x"}, + expected="prefix ns: select * where {?s ?p ?o } order by ?x limit 1 offset 1", + ), +] + + +@pytest.mark.parametrize(["query", "parameters", "expected"], parameters) +def test_add_solution_modifier(query, parameters, expected): + modified_query = add_solution_modifier(query, **parameters) + assert modified_query == expected diff --git a/tests/unit/tests_sparql_utils/test_get_query_projection.py b/tests/unit/tests_sparql_utils/test_get_query_projection.py new file mode 100644 index 0000000..2459950 --- /dev/null +++ b/tests/unit/tests_sparql_utils/test_get_query_projection.py @@ -0,0 +1,63 @@ +"""Unit tests for sparql_utils.get_query_projection.""" + +from typing import NamedTuple + +import pytest + +from rdfproxy.utils.sparql_utils import get_query_projection + + +class QueryProjectionParameter(NamedTuple): + query: str + expected: list[str] + + +parameters = [ + # explicit projection + QueryProjectionParameter( + query="select ?s ?p ?o where {?s ?p ?o}", expected=["s", "p", "o"] + ), + QueryProjectionParameter( + query=""" + PREFIX crm: + select ?s ?o where {?s ?p ?o}""", + expected=["s", "o"], + ), + # implicit projection + QueryProjectionParameter( + query="select * where {?s ?p ?o}", + expected=["s", "p", "o"], + ), + QueryProjectionParameter( + query=""" + PREFIX crm: + select * where {?s ?p ?o}""", + expected=["s", "p", "o"], + ), + # implicit projection with values clause + QueryProjectionParameter( + query=""" + select * where { + values (?s ?p ?o) + { (1 2 3) } + } + """, + expected=["s", "p", "o"], + ), + QueryProjectionParameter( + query=""" + PREFIX crm: + select * where { + values (?s ?p ?o) + { (1 2 3) } + } + """, + expected=["s", "p", "o"], + ), +] + + +@pytest.mark.parametrize(["query", "expected"], parameters) +def test_get_query_projection(query, expected): + projection = [str(binding) for binding in get_query_projection(query)] + assert projection == expected diff --git a/tests/unit/test_inject_subquery.py b/tests/unit/tests_sparql_utils/test_inject_subquery.py similarity index 95% rename from tests/unit/test_inject_subquery.py rename to tests/unit/tests_sparql_utils/test_inject_subquery.py index e95494d..a80765c 100644 --- a/tests/unit/test_inject_subquery.py +++ b/tests/unit/tests_sparql_utils/test_inject_subquery.py @@ -3,7 +3,8 @@ from typing import NamedTuple import pytest -from rdfproxy.utils.sparql_utils import inject_subquery + +from rdfproxy.utils.sparql_utils import inject_into_query, remove_sparql_prefixes class InjectSubqueryParameter(NamedTuple): @@ -118,5 +119,6 @@ class InjectSubqueryParameter(NamedTuple): @pytest.mark.parametrize(["query", "subquery", "expected"], inject_subquery_parameters) def test_inject_subquery(query, subquery, expected): - injected = inject_subquery(query=query, subquery=subquery) + injectant = remove_sparql_prefixes(subquery) + injected = inject_into_query(query=query, injectant=injectant) assert injected == expected diff --git a/tests/unit/tests_sparql_utils/test_remove_sparql_prefixes.py b/tests/unit/tests_sparql_utils/test_remove_sparql_prefixes.py new file mode 100644 index 0000000..e26e7b0 --- /dev/null +++ b/tests/unit/tests_sparql_utils/test_remove_sparql_prefixes.py @@ -0,0 +1,61 @@ +from typing import NamedTuple + +import pytest + +from rdfproxy.utils.sparql_utils import remove_sparql_prefixes +from tests.utils.utils import normalize_query + + +class SPARQLRemovePrefixParameter(NamedTuple): + query: str + expected: str + + +parameters = [ + SPARQLRemovePrefixParameter( + query=""" + select * where { ?s ?p ?o .} + """, + expected="select * where { ?s ?p ?o . }", + ), + SPARQLRemovePrefixParameter( + query=""" + prefix ns: + prefix other_ns: + select * where { ?s ?p ?o .} + """, + expected="select * where { ?s ?p ?o . }", + ), + SPARQLRemovePrefixParameter( + query=""" + prefix ns: prefix other_ns: + select * where { ?s ?p ?o .} + """, + expected="select * where { ?s ?p ?o . }", + ), + SPARQLRemovePrefixParameter( + query=""" + prefix ns: + + prefix other_ns: + select * where { ?s ?p ?o .} + """, + expected="select * where { ?s ?p ?o . }", + ), + SPARQLRemovePrefixParameter( + query=""" + prefix ns: + + prefix other_ns: + + select * where { ?s ?p ?o .} + """, + expected="select * where { ?s ?p ?o . }", + ), +] + + +@pytest.mark.parametrize(["query", "expected"], parameters) +def test_remove_sparql_prefixes(query, expected): + modified_query = remove_sparql_prefixes(query) + assert normalize_query(modified_query) == expected diff --git a/tests/unit/test_replace_query_select_clause.py b/tests/unit/tests_sparql_utils/test_replace_query_select_clause.py similarity index 100% rename from tests/unit/test_replace_query_select_clause.py rename to tests/unit/tests_sparql_utils/test_replace_query_select_clause.py diff --git a/tests/unit/test_sad_path_replace_query_select_clause.py b/tests/unit/tests_sparql_utils/test_sad_path_replace_query_select_clause.py similarity index 100% rename from tests/unit/test_sad_path_replace_query_select_clause.py rename to tests/unit/tests_sparql_utils/test_sad_path_replace_query_select_clause.py diff --git a/tests/unit/tests_utils/test_field_bindings_map.py b/tests/unit/tests_utils/test_field_bindings_map.py new file mode 100644 index 0000000..0364e7c --- /dev/null +++ b/tests/unit/tests_utils/test_field_bindings_map.py @@ -0,0 +1,25 @@ +"""Basic unit tests for FieldBindingsMap""" + +from typing import Annotated + +from pydantic import BaseModel +from rdfproxy.utils._types import SPARQLBinding +from rdfproxy.utils.utils import FieldsBindingsMap + + +class Point(BaseModel): + x: int + y: Annotated[int, SPARQLBinding("Y_ALIAS")] + z: Annotated[list[int], SPARQLBinding("Z_ALIAS")] + + +def test_basic_fields_bindings_map(): + mapping = FieldsBindingsMap(model=Point) + + assert mapping["x"] == "x" + assert mapping["y"] == "Y_ALIAS" + assert mapping["z"] == "Z_ALIAS" + + assert mapping.reverse["x"] == "x" + assert mapping.reverse["Y_ALIAS"] == "y" + assert mapping.reverse["Z_ALIAS"] == "z" diff --git a/tests/utils/utils.py b/tests/utils/utils.py new file mode 100644 index 0000000..deb354d --- /dev/null +++ b/tests/utils/utils.py @@ -0,0 +1,11 @@ +"""Testing utils.""" + +import re + + +def normalize_query(select_query: str) -> str: + """Normalize whitespace chars in a SPARQL query.""" + normalized_select_query = re.sub( + r"(? Date: Mon, 16 Dec 2024 11:47:40 +0100 Subject: [PATCH 2/3] feat: rewrite query construction functionality The change introduces a QueryConstructor class that encapsulates all SPARQL query construction functionality. This also leads to a significant cleanup of rdfproxy.utils.sparql_utils module (utils in general) and the SPARQLModelAdapter class. Query result ordering for ungrouped models is now implemented to default to the first binding of the projection as ORDER BY value. This might still be discussed in the future, but the decision seems reasonable at this point. Closes #128. Closes #134. Closes #168. --- rdfproxy/adapter.py | 33 ++---- rdfproxy/constructor.py | 124 +++++++++++++++++++++ rdfproxy/mapper.py | 3 +- rdfproxy/sparql_strategies.py | 7 +- rdfproxy/utils/mapper_utils.py | 127 +++++++++++++++++++++ rdfproxy/utils/sparql_utils.py | 196 +++++++++++---------------------- rdfproxy/utils/utils.py | 155 +++++++++----------------- 7 files changed, 384 insertions(+), 261 deletions(-) create mode 100644 rdfproxy/constructor.py create mode 100644 rdfproxy/utils/mapper_utils.py diff --git a/rdfproxy/adapter.py b/rdfproxy/adapter.py index ad98717..ebe7da3 100644 --- a/rdfproxy/adapter.py +++ b/rdfproxy/adapter.py @@ -4,15 +4,11 @@ import math from typing import Generic +from rdfproxy.constructor import QueryConstructor from rdfproxy.mapper import ModelBindingsMapper from rdfproxy.sparql_strategies import HttpxStrategy, SPARQLStrategy from rdfproxy.utils._types import _TModelInstance from rdfproxy.utils.models import Page, QueryParameters -from rdfproxy.utils.sparql_utils import ( - calculate_offset, - construct_count_query, - construct_items_query, -) class SPARQLModelAdapter(Generic[_TModelInstance]): @@ -24,10 +20,12 @@ class SPARQLModelAdapter(Generic[_TModelInstance]): SPARQLModelAdapter.query returns a Page model object with a default pagination size of 100 results. SPARQL bindings are implicitly assigned to model fields of the same name, - explicit SPARQL binding to model field allocation is available with typing.Annotated and rdfproxy.SPARQLBinding. + explicit SPARQL binding to model field allocation is available with rdfproxy.SPARQLBinding. Result grouping is controlled through the model, i.e. grouping is triggered when a field of list[pydantic.BaseModel] is encountered. + + See https://github.com/acdh-oeaw/rdfproxy/tree/main/examples for examples. """ def __init__( @@ -44,20 +42,21 @@ def __init__( def query(self, query_parameters: QueryParameters) -> Page[_TModelInstance]: """Run a query against an endpoint and return a Page model object.""" - count_query: str = construct_count_query(query=self._query, model=self._model) - items_query: str = construct_items_query( + query_constructor = QueryConstructor( query=self._query, + query_parameters=query_parameters, model=self._model, - limit=query_parameters.size, - offset=calculate_offset(query_parameters.page, query_parameters.size), ) - items_query_bindings: Iterator[dict] = self.sparql_strategy.query(items_query) + count_query = query_constructor.get_count_query() + items_query = query_constructor.get_items_query() + items_query_bindings: Iterator[dict] = self.sparql_strategy.query(items_query) mapper = ModelBindingsMapper(self._model, *items_query_bindings) - items: list[_TModelInstance] = mapper.get_models() - total: int = self._get_count(count_query) + + count_query_bindings: Iterator[dict] = self.sparql_strategy.query(count_query) + total: int = int(next(count_query_bindings)["cnt"]) pages: int = math.ceil(total / query_parameters.size) return Page( @@ -67,11 +66,3 @@ def query(self, query_parameters: QueryParameters) -> Page[_TModelInstance]: total=total, pages=pages, ) - - def _get_count(self, query: str) -> int: - """Run a count query and return the count result. - - Helper for SPARQLModelAdapter.query. - """ - result: Iterator[dict] = self.sparql_strategy.query(query) - return int(next(result)["cnt"]) diff --git a/rdfproxy/constructor.py b/rdfproxy/constructor.py new file mode 100644 index 0000000..8054bd3 --- /dev/null +++ b/rdfproxy/constructor.py @@ -0,0 +1,124 @@ +from rdfproxy.utils._types import _TModelInstance +from rdfproxy.utils.models import QueryParameters +from rdfproxy.utils.sparql_utils import ( + add_solution_modifier, + get_query_projection, + inject_into_query, + remove_sparql_prefixes, + replace_query_select_clause, +) +from rdfproxy.utils.utils import ( + FieldsBindingsMap, + QueryConstructorComponent as component, + compose_left, +) + + +class QueryConstructor: + """The class encapsulates dynamic SPARQL query modification logic + for implementing purely SPARQL-based, deterministic pagination. + + Public methods get_items_query and get_count_query are used in rdfproxy.SPARQLModelAdapter + to construct queries for retrieving arguments for Page object instantiation. + """ + + def __init__( + self, + query: str, + query_parameters: QueryParameters, + model: type[_TModelInstance], + ) -> None: + self.query = query + self.query_parameters = query_parameters + self.model = model + + self.bindings_map = FieldsBindingsMap(model) + self.group_by: str | None = self.bindings_map.get( + model.model_config.get("group_by") + ) + + def get_items_query(self) -> str: + """Construct a SPARQL items query for use in rdfproxy.SPARQLModelAdapter.""" + if self.group_by is None: + return self._get_ungrouped_items_query() + return self._get_grouped_items_query() + + def get_count_query(self) -> str: + """Construct a SPARQL count query for use in rdfproxy.SPARQLModelAdapter""" + if self.group_by is None: + select_clause = "select (count(*) as ?cnt)" + else: + select_clause = f"select (count(distinct ?{self.group_by}) as ?cnt)" + + return replace_query_select_clause(self.query, select_clause) + + @staticmethod + def _calculate_offset(page: int, size: int) -> int: + """Calculate the offset value for paginated SPARQL templates.""" + match page: + case 1: + return 0 + case 2: + return size + case _: + return size * (page - 1) + + def _get_grouped_items_query(self) -> str: + """Construct a SPARQL items query for grouped models.""" + filter_clause: str | None = self._compute_filter_clause() + select_clause: str = self._compute_select_clause() + order_by_value: str = self._compute_order_by_value() + limit, offset = self._compute_limit_offset() + + subquery = compose_left( + remove_sparql_prefixes, + component(replace_query_select_clause, repl=select_clause), + component(inject_into_query, injectant=filter_clause), + component( + add_solution_modifier, + order_by=order_by_value, + limit=limit, + offset=offset, + ), + )(self.query) + + return inject_into_query(self.query, subquery) + + def _get_ungrouped_items_query(self) -> str: + """Construct a SPARQL items query for ungrouped models.""" + filter_clause: str | None = self._compute_filter_clause() + order_by_value: str = self._compute_order_by_value() + limit, offset = self._compute_limit_offset() + + return compose_left( + component(inject_into_query, injectant=filter_clause), + component( + add_solution_modifier, + order_by=order_by_value, + limit=limit, + offset=offset, + ), + )(self.query) + + def _compute_limit_offset(self) -> tuple[int, int]: + """Calculate limit and offset values for SPARQL-based pagination.""" + limit = self.query_parameters.size + offset = self._calculate_offset( + self.query_parameters.page, self.query_parameters.size + ) + + return limit, offset + + def _compute_filter_clause(self) -> str | None: + """Stub: Always None for now.""" + return None + + def _compute_select_clause(self): + """Stub: Static SELECT clause for now.""" + return f"select distinct ?{self.group_by}" + + def _compute_order_by_value(self): + """Stub: Only basic logic for now.""" + if self.group_by is None: + return get_query_projection(self.query)[0] + return f"{self.group_by}" diff --git a/rdfproxy/mapper.py b/rdfproxy/mapper.py index da10054..7c3d56d 100644 --- a/rdfproxy/mapper.py +++ b/rdfproxy/mapper.py @@ -5,7 +5,7 @@ from pydantic import BaseModel from rdfproxy.utils._types import ModelBoolPredicate, _TModelInstance -from rdfproxy.utils.utils import ( +from rdfproxy.utils.mapper_utils import ( _collect_values_from_bindings, _get_group_by, _get_key_from_metadata, @@ -65,7 +65,6 @@ def _generate_binding_pairs( and (x[self._contexts[0]] == kwargs[self._contexts[0]]), self.bindings, ) - value = self._get_unique_models(group_model, applicable_bindings) elif _is_list_type(v.annotation): diff --git a/rdfproxy/sparql_strategies.py b/rdfproxy/sparql_strategies.py index 6b61860..da26d41 100644 --- a/rdfproxy/sparql_strategies.py +++ b/rdfproxy/sparql_strategies.py @@ -2,6 +2,7 @@ import abc from collections.abc import Iterator +from typing import cast from SPARQLWrapper import JSON, QueryResult, SPARQLWrapper import httpx @@ -13,7 +14,7 @@ def __init__(self, endpoint: str): @abc.abstractmethod def query(self, sparql_query: str) -> Iterator[dict[str, str]]: - raise NotImplementedError + raise NotImplementedError # pragma: no cover @staticmethod def _get_bindings_from_bindings_dict(bindings_dict: dict) -> Iterator[dict]: @@ -35,7 +36,9 @@ def query(self, sparql_query: str) -> Iterator[dict[str, str]]: self._sparql_wrapper.setQuery(sparql_query) result: QueryResult = self._sparql_wrapper.query() - return self._get_bindings_from_bindings_dict(result.convert()) + # SPARQLWrapper.Wrapper.convert is not overloaded properly and needs casting + # https://github.com/RDFLib/sparqlwrapper/blob/master/SPARQLWrapper/Wrapper.py#L1135 + return self._get_bindings_from_bindings_dict(cast(dict, result.convert())) class HttpxStrategy(SPARQLStrategy): diff --git a/rdfproxy/utils/mapper_utils.py b/rdfproxy/utils/mapper_utils.py new file mode 100644 index 0000000..acce14e --- /dev/null +++ b/rdfproxy/utils/mapper_utils.py @@ -0,0 +1,127 @@ +from collections.abc import Callable, Iterable +from typing import Any, TypeGuard, get_args, get_origin + +from pydantic import BaseModel +from pydantic.fields import FieldInfo +from rdfproxy.utils._exceptions import ( + InvalidGroupingKeyException, + MissingModelConfigException, +) +from rdfproxy.utils._types import _TModelInstance +from rdfproxy.utils._types import ModelBoolPredicate, SPARQLBinding, _TModelBoolValue + + +def _is_type(obj: type | None, _type: type) -> bool: + """Check if an obj is type _type or a GenericAlias with origin _type.""" + return (obj is _type) or (get_origin(obj) is _type) + + +def _is_list_type(obj: type | None) -> bool: + """Check if obj is a list type.""" + return _is_type(obj, list) + + +def _is_list_basemodel_type(obj: type | None) -> bool: + """Check if a type is list[pydantic.BaseModel].""" + return (get_origin(obj) is list) and all( + issubclass(cls, BaseModel) for cls in get_args(obj) + ) + + +def _collect_values_from_bindings( + binding_name: str, + bindings: Iterable[dict], + predicate: Callable[[Any], bool] = lambda x: x is not None, +) -> list: + """Scan bindings for a key binding_name and collect unique predicate-compliant values. + + Note that element order is important for testing, so a set cast won't do. + """ + values = dict.fromkeys( + value + for binding in bindings + if predicate(value := binding.get(binding_name, None)) + ) + return list(values) + + +def _get_key_from_metadata(v: FieldInfo, *, default: Any) -> str | Any: + """Try to get a SPARQLBinding object from a field's metadata attribute. + + Helper for _generate_binding_pairs. + """ + return next(filter(lambda x: isinstance(x, SPARQLBinding), v.metadata), default) + + +def _get_applicable_grouping_keys(model: type[_TModelInstance]) -> list[str]: + return [k for k, v in model.model_fields.items() if not _is_list_type(v.annotation)] + + +def _get_group_by(model: type[_TModelInstance]) -> str: + """Get the name of a grouping key from a model Config class.""" + try: + group_by = model.model_config["group_by"] # type: ignore + except KeyError as e: + raise MissingModelConfigException( + "Model config with 'group_by' value required " + "for field-based grouping behavior." + ) from e + else: + applicable_keys = _get_applicable_grouping_keys(model=model) + + if group_by not in applicable_keys: + raise InvalidGroupingKeyException( + f"Invalid grouping key '{group_by}'. " + f"Applicable grouping keys: {', '.join(applicable_keys)}." + ) + + if meta := model.model_fields[group_by].metadata: + if binding := next( + filter(lambda entry: isinstance(entry, SPARQLBinding), meta), None + ): + return binding + return group_by + + +def default_model_bool_predicate(model: BaseModel) -> bool: + """Default predicate for determining model truthiness. + + Adheres to rdfproxy.utils._types.ModelBoolPredicate. + """ + return any(dict(model).values()) + + +def _is_iterable_of_str(iterable: Iterable) -> TypeGuard[Iterable[str]]: + return (not isinstance(iterable, str)) and all( + map(lambda i: isinstance(i, str), iterable) + ) + + +def _get_model_bool_predicate_from_config_value( + model_bool_value: _TModelBoolValue, +) -> ModelBoolPredicate: + """Get a model_bool predicate function given the value of the model_bool config setting.""" + match model_bool_value: + case ModelBoolPredicate(): + return model_bool_value + case str(): + return lambda model: bool(dict(model)[model_bool_value]) + case model_bool_value if _is_iterable_of_str(model_bool_value): + return lambda model: all(map(lambda k: dict(model)[k], model_bool_value)) + case _: # pragma: no cover + raise TypeError( + "Argument for 'model_bool' must be of type ModelBoolPredicate | str | Iterable[str].\n" + f"Received {type(model_bool_value)}" + ) + + +def get_model_bool_predicate(model: BaseModel) -> ModelBoolPredicate: + """Get the applicable model_bool predicate function given a model.""" + if (model_bool_value := model.model_config.get("model_bool", None)) is None: + model_bool_predicate = default_model_bool_predicate + else: + model_bool_predicate = _get_model_bool_predicate_from_config_value( + model_bool_value + ) + + return model_bool_predicate diff --git a/rdfproxy/utils/sparql_utils.py b/rdfproxy/utils/sparql_utils.py index a064898..4eafea3 100644 --- a/rdfproxy/utils/sparql_utils.py +++ b/rdfproxy/utils/sparql_utils.py @@ -1,19 +1,13 @@ """Functionality for dynamic SPARQL query modifcation.""" -from collections.abc import Iterator -from contextlib import contextmanager -from functools import partial +from itertools import chain import re -from typing import cast +from typing import overload -from SPARQLWrapper import QueryResult, SPARQLWrapper +from rdflib import Variable +from rdflib.plugins.sparql.parser import parseQuery +from rdflib.plugins.sparql.parserutils import CompValue, ParseResults from rdfproxy.utils._exceptions import QueryConstructionException -from rdfproxy.utils._types import ItemsQueryConstructor, SPARQLBinding, _TModelInstance - - -def construct_ungrouped_pagination_query(query: str, limit: int, offset: int) -> str: - """Construct an ungrouped pagination query.""" - return f"{query} limit {limit} offset {offset}" def replace_query_select_clause(query: str, repl: str) -> str: @@ -23,7 +17,7 @@ def replace_query_select_clause(query: str, repl: str) -> str: ) if re.search(pattern=pattern, string=query) is None: - raise Exception("Unable to obtain SELECT clause.") + raise QueryConstructionException("Unable to obtain SELECT clause.") modified_query = re.sub( pattern=pattern, @@ -35,148 +29,90 @@ def replace_query_select_clause(query: str, repl: str) -> str: return modified_query -def _remove_sparql_prefixes(query: str) -> str: +def remove_sparql_prefixes(query: str) -> str: """Remove SPARQL prefixes from a query. This is needed for subquery injection, because subqueries cannot have prefixes. - Note that this is not generic, all prefixes are simply ut from the subquery - and do not get appended to the outer query prefixes. + Note that this is not generic, all prefixes are simply cut from the subquery + and are not resolved against the outer query prefixes. """ - prefix_pattern = re.compile(r"PREFIX\s+\w*:\s?<[^>]+>\s*", flags=re.I) + prefix_pattern = re.compile(r"PREFIX\s+\w*:\s?<[^>]+>\s*", flags=re.IGNORECASE) cleaned_query = re.sub(prefix_pattern, "", query).strip() return cleaned_query -def inject_subquery(query: str, subquery: str) -> str: - """Inject a SPARQL query with a subquery.""" +def inject_into_query(query: str, injectant: str) -> str: + """Inject some injectant (e.g. subquery or filter clause) into a query.""" if (tail := re.search(r"}[^}]*\Z", query)) is None: - raise QueryConstructionException("Unable to inject subquery.") + raise QueryConstructionException( + "Unable to inject subquery." + ) # pragma: no cover ; this will be unreachable once query checking runs tail_index: int = tail.start() - injected: str = f"{query[:tail_index]} {{{_remove_sparql_prefixes(subquery)}}} {query[tail_index:]}" - return injected + injected_query: str = f"{query[:tail_index]} {{{injectant}}} {query[tail_index:]}" + return injected_query -def construct_grouped_pagination_query( - query: str, group_by_value: str, limit: int, offset: int +def add_solution_modifier( + query: str, + *, + order_by: str | None = None, + limit: int | None = None, + offset: int | None = None, ) -> str: - """Construct a grouped pagination query.""" - _subquery_base: str = replace_query_select_clause( - query=query, repl=f"select distinct ?{group_by_value}" - ) - subquery: str = construct_ungrouped_pagination_query( - query=_subquery_base, limit=limit, offset=offset - ) - - grouped_pagination_query: str = inject_subquery(query=query, subquery=subquery) - return grouped_pagination_query + """Add optional solution modifiers in SPARQL-conformant order to a query.""" + modifiers = [] + if order_by is not None: + modifiers.append(f"order by ?{order_by}") + if limit is not None: + modifiers.append(f"limit {limit}") + if offset is not None: + modifiers.append(f"offset {offset}") -def get_items_query_constructor( - model: type[_TModelInstance], -) -> ItemsQueryConstructor: - """Get the applicable query constructor function given a model class.""" + return f"{query} {' '.join(modifiers)}".strip() - if (group_by_value := model.model_config.get("group_by", None)) is None: - return construct_ungrouped_pagination_query - elif meta := model.model_fields[group_by_value].metadata: - group_by_value = next( - filter(lambda x: isinstance(x, SPARQLBinding), meta), group_by_value - ) +@overload +def _compvalue_to_dict(comp_value: dict | CompValue) -> dict: ... - return partial(construct_grouped_pagination_query, group_by_value=group_by_value) +@overload +def _compvalue_to_dict(comp_value: list | ParseResults) -> list: ... -def construct_items_query( - query: str, model: type[_TModelInstance], limit: int, offset: int -) -> str: - """Construct a grouped pagination query.""" - items_query_constructor: ItemsQueryConstructor = get_items_query_constructor( - model=model - ) - return items_query_constructor(query=query, limit=limit, offset=offset) - - -def construct_count_query(query: str, model: type[_TModelInstance]) -> str: - """Construct a generic count query from a SELECT query.""" - try: - group_by: str = model.model_config["group_by"] - group_by_binding = next( - filter( - lambda x: isinstance(x, SPARQLBinding), - model.model_fields[group_by].metadata, - ), - group_by, - ) - count_query = construct_grouped_count_query(query, group_by_binding) - except KeyError: - count_query = replace_query_select_clause(query, "select (count(*) as ?cnt)") - - return count_query - - -def calculate_offset(page: int, size: int) -> int: - """Calculate offset value for paginated SPARQL templates.""" - match page: - case 1: - return 0 - case 2: - return size - case _: - return size * (page - 1) - - -def construct_grouped_count_query(query: str, group_by) -> str: - grouped_count_query = replace_query_select_clause( - query, f"select (count(distinct ?{group_by}) as ?cnt)" - ) - return grouped_count_query +def _compvalue_to_dict(comp_value: CompValue): + """Convert a CompValue parsing object into a Python dict/list representation. - -def _get_bindings_from_bindings_dict(bindings_dict: dict) -> Iterator[dict]: - bindings = map( - lambda binding: {k: v["value"] for k, v in binding.items()}, - bindings_dict["results"]["bindings"], - ) - return bindings - - -def get_bindings_from_query_result(query_result: QueryResult) -> Iterator[dict]: - """Extract just the bindings from a SPARQLWrapper.QueryResult.""" - if (result_format := query_result.requestedFormat) != "json": - raise Exception( - "Only QueryResult objects with JSON format are currently supported. " - f"Received object with requestedFormat '{result_format}'." - ) - - query_json: dict = cast(dict, query_result.convert()) - bindings = _get_bindings_from_bindings_dict(query_json) - - return bindings - - -@contextmanager -def temporary_query_override(sparql_wrapper: SPARQLWrapper): - """Context manager that allows to contextually overwrite a query in a SPARQLWrapper object.""" - _query_cache = sparql_wrapper.queryString - - try: - yield sparql_wrapper - finally: - sparql_wrapper.setQuery(_query_cache) + Helper for get_query_projection. + """ + if isinstance(comp_value, dict): + return {key: _compvalue_to_dict(value) for key, value in comp_value.items()} + elif isinstance(comp_value, list | ParseResults): + return [_compvalue_to_dict(item) for item in comp_value] + else: + return comp_value -def query_with_wrapper(query: str, sparql_wrapper: SPARQLWrapper) -> Iterator[dict]: - """Execute a SPARQL query using a predefined sparql_wrapper object. +def get_query_projection(query: str) -> list[Variable]: + """Parse a SPARQL SELECT query and extract the ordered bindings projection. - The query attribute of the wrapper object is temporarily overridden - and gets restored after query execution. + The first case handles explicit/literal binding projections. + The second case handles implicit/* binding projections. + The third case handles implicit/* binding projections with VALUES. """ - with temporary_query_override(sparql_wrapper=sparql_wrapper): - sparql_wrapper.setQuery(query) - result: QueryResult = sparql_wrapper.query() - - bindings: Iterator[dict] = get_bindings_from_query_result(result) - return bindings + _parse_result: CompValue = parseQuery(query)[1] + parsed_query: dict = _compvalue_to_dict(_parse_result) + + match parsed_query: + case {"projection": projection}: + return [i["var"] for i in projection] + case {"where": {"part": [{"triples": triples}]}}: + projection = dict.fromkeys( + i for i in chain.from_iterable(triples) if isinstance(i, Variable) + ) + return list(projection) + case {"where": {"part": [{"var": var}]}}: + return var + case _: # pragma: no cover + raise Exception("Unable to obtain query projection.") diff --git a/rdfproxy/utils/utils.py b/rdfproxy/utils/utils.py index e7e4c12..279d19c 100644 --- a/rdfproxy/utils/utils.py +++ b/rdfproxy/utils/utils.py @@ -1,129 +1,72 @@ """SPARQL/FastAPI utils.""" -from collections.abc import Callable, Iterable -from typing import Any, TypeGuard, get_args, get_origin - -from pydantic import BaseModel -from pydantic.fields import FieldInfo -from rdfproxy.utils._exceptions import ( - InvalidGroupingKeyException, - MissingModelConfigException, -) -from rdfproxy.utils._types import _TModelInstance -from rdfproxy.utils._types import ModelBoolPredicate, SPARQLBinding, _TModelBoolValue - - -def _is_type(obj: type | None, _type: type) -> bool: - """Check if an obj is type _type or a GenericAlias with origin _type.""" - return (obj is _type) or (get_origin(obj) is _type) +from collections import UserDict +from collections.abc import Callable +from functools import partial +from typing import TypeVar +from rdfproxy.utils._types import _TModelInstance +from rdfproxy.utils._types import SPARQLBinding -def _is_list_type(obj: type | None) -> bool: - """Check if obj is a list type.""" - return _is_type(obj, list) +T = TypeVar("T") -def _is_list_basemodel_type(obj: type | None) -> bool: - """Check if a type is list[pydantic.BaseModel].""" - return (get_origin(obj) is list) and all( - issubclass(cls, BaseModel) for cls in get_args(obj) - ) +class FieldsBindingsMap(UserDict): + """Mapping for resolving SPARQLBinding aliases. -def _collect_values_from_bindings( - binding_name: str, - bindings: Iterable[dict], - predicate: Callable[[Any], bool] = lambda x: x is not None, -) -> list: - """Scan bindings for a key binding_name and collect unique predicate-compliant values. + Model field names are mapped to SPARQLBinding names. + The FieldsBindingsMap.reverse allows reverse lookup + (i.e. from SPARQLBindings to model fields). - Note that element order is important for testing, so a set cast won't do. + Note: It might be useful to recursively resolve aliases for nested models. """ - values = dict.fromkeys( - value - for binding in bindings - if predicate(value := binding.get(binding_name, None)) - ) - return list(values) + def __init__(self, model: type[_TModelInstance]) -> None: + self.data = self._get_field_binding_mapping(model) + self._reversed = {v: k for k, v in self.data.items()} -def _get_key_from_metadata(v: FieldInfo, *, default: Any) -> str | Any: - """Try to get a SPARQLBinding object from a field's metadata attribute. + @property + def reverse(self) -> dict[str, str]: + """Reverse lookup map from SPARQL bindings to model fields.""" + return self._reversed - Helper for _generate_binding_pairs. - """ - return next(filter(lambda x: isinstance(x, SPARQLBinding), v.metadata), default) + @staticmethod + def _get_field_binding_mapping(model: type[_TModelInstance]) -> dict[str, str]: + """Resolve model fields against rdfproxy.SPARQLBindings.""" + return { + k: next(filter(lambda x: isinstance(x, SPARQLBinding), v.metadata), k) + for k, v in model.model_fields.items() + } -def _get_applicable_grouping_keys(model: type[_TModelInstance]) -> list[str]: - return [k for k, v in model.model_fields.items() if not _is_list_type(v.annotation)] +def compose_left(*fns: Callable[[T], T]) -> Callable[[T], T]: + """Left associative compose.""" + def _left_wrapper(*fns): + fn, *rest_fns = fns -def _get_group_by(model: type[_TModelInstance]) -> str: - """Get the name of a grouping key from a model Config class.""" - try: - group_by = model.model_config["group_by"] # type: ignore - except KeyError as e: - raise MissingModelConfigException( - "Model config with 'group_by' value required " - "for field-based grouping behavior." - ) from e - else: - applicable_keys = _get_applicable_grouping_keys(model=model) + if rest_fns: + return lambda *args, **kwargs: fn(_left_wrapper(*rest_fns)(*args, **kwargs)) + return fn - if group_by not in applicable_keys: - raise InvalidGroupingKeyException( - f"Invalid grouping key '{group_by}'. " - f"Applicable grouping keys: {', '.join(applicable_keys)}." - ) + return _left_wrapper(*reversed(fns)) - if meta := model.model_fields[group_by].metadata: - if binding := next( - filter(lambda entry: isinstance(entry, SPARQLBinding), meta), None - ): - return binding - return group_by +class QueryConstructorComponent: + """Query modification component factory. -def default_model_bool_predicate(model: BaseModel) -> bool: - """Default predicate for determining model truthiness. + Components either call the wrapped function with non-None value kwargs applied + or (if all kwargs values are None) fall back to the identity function. - Adheres to rdfproxy.utils._types.ModelBoolPredicate. + QueryConstructorComponents are used in QueryConstructor for query modification compose chains. """ - return any(dict(model).values()) - - -def _is_iterable_of_str(iterable: Iterable) -> TypeGuard[Iterable[str]]: - return (not isinstance(iterable, str)) and all( - map(lambda i: isinstance(i, str), iterable) - ) - - -def _get_model_bool_predicate_from_config_value( - model_bool_value: _TModelBoolValue, -) -> ModelBoolPredicate: - """Get a model_bool predicate function given the value of the model_bool config setting.""" - match model_bool_value: - case ModelBoolPredicate(): - return model_bool_value - case str(): - return lambda model: bool(dict(model)[model_bool_value]) - case model_bool_value if _is_iterable_of_str(model_bool_value): - return lambda model: all(map(lambda k: dict(model)[k], model_bool_value)) - case _: - raise TypeError( - "Argument for 'model_bool' must be of type ModelBoolPredicate | str | Iterable[str].\n" - f"Received {type(model_bool_value)}" - ) - - -def get_model_bool_predicate(model: BaseModel) -> ModelBoolPredicate: - """Get the applicable model_bool predicate function given a model.""" - if (model_bool_value := model.model_config.get("model_bool", None)) is None: - model_bool_predicate = default_model_bool_predicate - else: - model_bool_predicate = _get_model_bool_predicate_from_config_value( - model_bool_value - ) - - return model_bool_predicate + + def __init__(self, f: Callable[..., str], **kwargs) -> None: + self.f = f + self.kwargs = kwargs + + def __call__(self, query) -> str: + if tkwargs := {k: v for k, v in self.kwargs.items() if v is not None}: + return partial(self.f, **tkwargs)(query) + return query From 5c643ad872102da7ade48acce04f516a3aa36b56 Mon Sep 17 00:00:00 2001 From: Lukas Plank Date: Mon, 16 Dec 2024 11:59:32 +0100 Subject: [PATCH 3/3] chore(deps): install rdflib --- poetry.lock | 43 +++++++++---------------------------------- pyproject.toml | 1 + 2 files changed, 10 insertions(+), 34 deletions(-) diff --git a/poetry.lock b/poetry.lock index f5405f9..b9320b8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -361,20 +361,6 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] -[[package]] -name = "isodate" -version = "0.6.1" -description = "An ISO 8601 date/time/duration parser and formatter" -optional = false -python-versions = "*" -files = [ - {file = "isodate-0.6.1-py2.py3-none-any.whl", hash = "sha256:0751eece944162659049d35f4f549ed815792b38793f07cf73381c1c87cbed96"}, - {file = "isodate-0.6.1.tar.gz", hash = "sha256:48c5881de7e8b0a0d648cb024c8062dc84e7b840ed81e864c7614fd3c127bde9"}, -] - -[package.dependencies] -six = "*" - [[package]] name = "jinja2" version = "3.1.4" @@ -818,24 +804,24 @@ files = [ [[package]] name = "rdflib" -version = "7.0.0" +version = "7.1.1" description = "RDFLib is a Python library for working with RDF, a simple yet powerful language for representing information." optional = false -python-versions = ">=3.8.1,<4.0.0" +python-versions = "<4.0.0,>=3.8.1" files = [ - {file = "rdflib-7.0.0-py3-none-any.whl", hash = "sha256:0438920912a642c866a513de6fe8a0001bd86ef975057d6962c79ce4771687cd"}, - {file = "rdflib-7.0.0.tar.gz", hash = "sha256:9995eb8569428059b8c1affd26b25eac510d64f5043d9ce8c84e0d0036e995ae"}, + {file = "rdflib-7.1.1-py3-none-any.whl", hash = "sha256:e590fa9a2c34ba33a667818b5a84be3fb8a4d85868f8038f17912ec84f912a25"}, + {file = "rdflib-7.1.1.tar.gz", hash = "sha256:164de86bd3564558802ca983d84f6616a4a1a420c7a17a8152f5016076b2913e"}, ] [package.dependencies] -isodate = ">=0.6.0,<0.7.0" pyparsing = ">=2.1.0,<4" [package.extras] berkeleydb = ["berkeleydb (>=18.1.0,<19.0.0)"] -html = ["html5lib (>=1.0,<2.0)"] -lxml = ["lxml (>=4.3.0,<5.0.0)"] -networkx = ["networkx (>=2.0.0,<3.0.0)"] +html = ["html5rdf (>=1.2,<2)"] +lxml = ["lxml (>=4.3,<6.0)"] +networkx = ["networkx (>=2,<4)"] +orjson = ["orjson (>=3.9.14,<4)"] [[package]] name = "rich" @@ -893,17 +879,6 @@ files = [ {file = "shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de"}, ] -[[package]] -name = "six" -version = "1.16.0" -description = "Python 2 and 3 compatibility utilities" -optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" -files = [ - {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, - {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, -] - [[package]] name = "sniffio" version = "1.3.1" @@ -1246,4 +1221,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "e9ac3d16b289eb2f29bfa0587a17105664b9951db114dfe9d2e9a35b1265e117" +content-hash = "0a2e465322bae2eaee949d9268fb74e79735a2dd33e1d7f7b390f2b21d797124" diff --git a/pyproject.toml b/pyproject.toml index e245608..b2ef8ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ pydantic = "^2.9.2" httpx = "^0.28.1" +rdflib = "^7.1.1" [tool.poetry.group.dev.dependencies] ruff = "^0.7.0" deptry = "^0.20.0"