Skip to content

Commit

Permalink
Fix validate_api_response config
Browse files Browse the repository at this point in the history
  • Loading branch information
ml-evs committed Jul 21, 2023
1 parent f36eeeb commit 1a64f86
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 42 deletions.
20 changes: 5 additions & 15 deletions optimade/server/entry_collections/entry_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,22 +137,15 @@ def count(self, **kwargs: Any) -> int:

def find(
self, params: Union[EntryListingQueryParams, SingleEntryQueryParams]
) -> Tuple[
Union[None, List[EntryResource], EntryResource, List[Dict]],
int,
bool,
Set[str],
Set[str],
]:
) -> Tuple[Union[None, Dict, List[Dict]], int, bool, Set[str], Set[str],]:
"""
Fetches results and indicates if more data is available.
Also gives the total number of data available in the absence of `page_limit`.
See [`EntryListingQueryParams`][optimade.server.query_params.EntryListingQueryParams]
for more information.
Returns either the list of validated pydantic models matching the query, or simply the
mapped database reponse, depending on the value of `CONFIG.validate_api_response`.
Returns a list of the mapped database reponse.
If no results match the query, then `results` is set to `None`.
Expand Down Expand Up @@ -202,13 +195,10 @@ def find(
detail=f"Unrecognised OPTIMADE field(s) in requested `response_fields`: {bad_optimade_fields}."
)

results: Union[None, List[EntryResource], EntryResource, List[Dict]] = None
results: Union[None, List[Dict], Dict] = None

if raw_results:
if CONFIG.validate_api_response:
results = self.resource_mapper.deserialize(raw_results)
else:
results = [self.resource_mapper.map_back(doc) for doc in raw_results]
results = [self.resource_mapper.map_back(doc) for doc in raw_results]

if single_entry:
results = results[0] # type: ignore[assignment]
Expand Down Expand Up @@ -468,7 +458,7 @@ def parse_sort_params(self, sort_params: str) -> Iterable[Tuple[str, int]]:
def get_next_query_params(
self,
params: EntryListingQueryParams,
results: Union[None, List[EntryResource], EntryResource, List[Dict]],
results: Union[None, Dict, List[Dict]],
) -> Dict[str, List[str]]:
"""Provides url query pagination parameters that will be used in the next
link.
Expand Down
4 changes: 2 additions & 2 deletions optimade/server/routers/links.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Dict

from fastapi import APIRouter, Depends, Request

Expand All @@ -21,7 +21,7 @@

@router.get(
"/links",
response_model=LinksResponse,
response_model=LinksResponse if CONFIG.validate_api_response else Dict,
response_model_exclude_unset=True,
tags=["Links"],
responses=ERROR_RESPONSES,
Expand Down
6 changes: 3 additions & 3 deletions optimade/server/routers/references.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Dict

from fastapi import APIRouter, Depends, Request

Expand All @@ -25,7 +25,7 @@

@router.get(
"/references",
response_model=ReferenceResponseMany,
response_model=ReferenceResponseMany if CONFIG.validate_api_response else Dict,
response_model_exclude_unset=True,
tags=["References"],
responses=ERROR_RESPONSES,
Expand All @@ -43,7 +43,7 @@ def get_references(

@router.get(
"/references/{entry_id:path}",
response_model=ReferenceResponseOne,
response_model=ReferenceResponseOne if CONFIG.validate_api_response else Dict,
response_model_exclude_unset=True,
tags=["References"],
responses=ERROR_RESPONSES,
Expand Down
6 changes: 3 additions & 3 deletions optimade/server/routers/structures.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Dict

from fastapi import APIRouter, Depends, Request

Expand All @@ -25,7 +25,7 @@

@router.get(
"/structures",
response_model=StructureResponseMany,
response_model=StructureResponseMany if CONFIG.validate_api_response else Dict,
response_model_exclude_unset=True,
tags=["Structures"],
responses=ERROR_RESPONSES,
Expand All @@ -43,7 +43,7 @@ def get_structures(

@router.get(
"/structures/{entry_id:path}",
response_model=StructureResponseOne,
response_model=StructureResponseOne if CONFIG.validate_api_response else Dict,
response_model_exclude_unset=True,
tags=["Structures"],
responses=ERROR_RESPONSES,
Expand Down
2 changes: 1 addition & 1 deletion optimade/server/routers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def get_base_url(

def get_entries(
collection: EntryCollection,
response: Type[EntryResponseMany],
response: Type[EntryResponseMany], # noqa
request: Request,
params: EntryListingQueryParams,
) -> Dict:
Expand Down
34 changes: 16 additions & 18 deletions tests/server/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from optimade.client.cli import _get
from optimade.server.config import CONFIG, SupportedBackend
from optimade.warnings import MissingExpectedField

try:
from optimade.client import OptimadeClient as OptimadeTestClient
Expand Down Expand Up @@ -115,24 +114,23 @@ def test_filter_validation(async_http_client, http_client, use_async):

@pytest.mark.parametrize("use_async", [True, False])
def test_client_response_fields(async_http_client, http_client, use_async):
with pytest.warns(MissingExpectedField):
cli = OptimadeClient(
base_urls=[TEST_URL],
use_async=use_async,
http_client=async_http_client if use_async else http_client,
)
results = cli.get(response_fields=["chemical_formula_reduced"])
for d in results["structures"][""][TEST_URL]["data"]:
assert "chemical_formula_reduced" in d["attributes"]
assert len(d["attributes"]) == 1
cli = OptimadeClient(
base_urls=[TEST_URL],
use_async=use_async,
http_client=async_http_client if use_async else http_client,
)
results = cli.get(response_fields=["chemical_formula_reduced"])
for d in results["structures"][""][TEST_URL]["data"]:
assert "chemical_formula_reduced" in d["attributes"]
assert len(d["attributes"]) == 1

results = cli.get(
response_fields=["chemical_formula_reduced", "cartesian_site_positions"]
)
for d in results["structures"][""][TEST_URL]["data"]:
assert "chemical_formula_reduced" in d["attributes"]
assert "cartesian_site_positions" in d["attributes"]
assert len(d["attributes"]) == 2
results = cli.get(
response_fields=["chemical_formula_reduced", "cartesian_site_positions"]
)
for d in results["structures"][""][TEST_URL]["data"]:
assert "chemical_formula_reduced" in d["attributes"]
assert "cartesian_site_positions" in d["attributes"]
assert len(d["attributes"]) == 2


@pytest.mark.parametrize("use_async", [True, False])
Expand Down

0 comments on commit 1a64f86

Please sign in to comment.