Skip to content

Commit

Permalink
refactor: PreparedSearch and RawSearchResult usage (#1191)
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrunato authored Jun 8, 2024
1 parent 8360a41 commit 2c37db3
Show file tree
Hide file tree
Showing 16 changed files with 404 additions and 371 deletions.
23 changes: 15 additions & 8 deletions eodag/api/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
provider_config_init,
)
from eodag.plugins.manager import PluginManager
from eodag.plugins.search import PreparedSearch
from eodag.plugins.search.build_search_result import BuildPostSearchResult
from eodag.types import model_fields_to_annotated
from eodag.types.queryables import CommonQueryables
Expand Down Expand Up @@ -729,6 +730,7 @@ def discover_product_types(
else self.available_providers()
)
]
kwargs: Dict[str, Any] = {}
for provider in providers_to_fetch:
if hasattr(self.providers_config[provider], "search"):
search_plugin_config = self.providers_config[provider].search
Expand All @@ -749,7 +751,7 @@ def discover_product_types(
getattr(auth_plugin, "authenticate", None)
):
try:
search_plugin.auth = auth_plugin.authenticate()
kwargs["auth"] = auth_plugin.authenticate()
except (AuthenticationError, MisconfiguredError) as e:
logger.warning(
f"Could not authenticate on {provider}: {str(e)}"
Expand All @@ -763,9 +765,9 @@ def discover_product_types(
ext_product_types_conf[provider] = None
continue

ext_product_types_conf[
provider
] = search_plugin.discover_product_types()
ext_product_types_conf[provider] = search_plugin.discover_product_types(
**kwargs
)

return ext_product_types_conf

Expand Down Expand Up @@ -1811,16 +1813,21 @@ def _do_search(
total_results: Optional[int] = 0

try:
prep = PreparedSearch(count=count)
if need_auth and auth_plugin and can_authenticate:
search_plugin.auth = auth_plugin.authenticate()
prep.auth = auth_plugin.authenticate()

prep.auth_plugin = auth_plugin
prep.page = kwargs.pop("page", None)
prep.items_per_page = kwargs.pop("items_per_page", None)

res, nb_res = search_plugin.query(count=count, auth=auth_plugin, **kwargs)
res, nb_res = search_plugin.query(prep, **kwargs)

# Only do the pagination computations when it makes sense. For example,
# for a search by id, we can reasonably guess that the provider will return
# At most 1 product, so we don't need such a thing as pagination
page = kwargs.get("page")
items_per_page = kwargs.get("items_per_page")
page = prep.page
items_per_page = prep.items_per_page
if page and items_per_page and count:
# Take into account the fact that a provider may not return the count of
# products (in that case, fallback to using the length of the results it
Expand Down
15 changes: 15 additions & 0 deletions eodag/api/search_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,18 @@ def __geo_interface__(self) -> Dict[str, Any]:
See https://gist.github.com/sgillies/2217756
"""
return self.as_geojson_object()


class RawSearchResult(UserList):
"""An object representing a collection of raw/unparsed search results obtained from a provider.
:param results: A list of raw/unparsed search results
:type results: List[Any]
"""

data: List[Any]
query_params: Dict[str, Any]
product_type_def_params: Dict[str, Any]

def __init__(self, results: List[Any]) -> None:
super(RawSearchResult, self).__init__(results)
12 changes: 3 additions & 9 deletions eodag/plugins/apis/ecmwf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,13 @@
from ecmwfapi.api import APIException, Connection, get_apikey_values

from eodag.plugins.apis.base import Api
from eodag.plugins.search import PreparedSearch
from eodag.plugins.search.base import Search
from eodag.plugins.search.build_search_result import BuildPostSearchResult
from eodag.utils import (
DEFAULT_DOWNLOAD_TIMEOUT,
DEFAULT_DOWNLOAD_WAIT,
DEFAULT_ITEMS_PER_PAGE,
DEFAULT_MISSION_START_DATE,
DEFAULT_PAGE,
get_geometry_from_various,
path_to_uri,
sanitize,
Expand Down Expand Up @@ -90,10 +89,7 @@ def do_search(self, *args: Any, **kwargs: Any) -> List[Dict[str, Any]]:

def query(
self,
product_type: Optional[str] = None,
items_per_page: int = DEFAULT_ITEMS_PER_PAGE,
page: int = DEFAULT_PAGE,
count: bool = True,
prep: PreparedSearch = PreparedSearch(),
**kwargs: Any,
) -> Tuple[List[EOProduct], Optional[int]]:
"""Build ready-to-download SearchResult"""
Expand Down Expand Up @@ -126,9 +122,7 @@ def query(
if "geometry" in kwargs:
kwargs["geometry"] = get_geometry_from_various(geometry=kwargs["geometry"])

return BuildPostSearchResult.query(
self, items_per_page=items_per_page, page=page, count=count, **kwargs
)
return BuildPostSearchResult.query(self, prep, **kwargs)

def authenticate(self) -> Dict[str, Optional[str]]:
"""Check credentials and returns information needed for auth
Expand Down
12 changes: 8 additions & 4 deletions eodag/plugins/apis/usgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
properties_from_json,
)
from eodag.plugins.apis.base import Api
from eodag.plugins.search import PreparedSearch
from eodag.utils import (
DEFAULT_DOWNLOAD_TIMEOUT,
DEFAULT_DOWNLOAD_WAIT,
Expand Down Expand Up @@ -112,13 +113,16 @@ def authenticate(self) -> None:

def query(
self,
product_type: Optional[str] = None,
items_per_page: int = DEFAULT_ITEMS_PER_PAGE,
page: int = DEFAULT_PAGE,
count: bool = True,
prep: PreparedSearch = PreparedSearch(),
**kwargs: Any,
) -> Tuple[List[EOProduct], Optional[int]]:
"""Search for data on USGS catalogues"""
page = prep.page if prep.page is not None else DEFAULT_PAGE
items_per_page = (
prep.items_per_page
if prep.items_per_page is not None
else DEFAULT_ITEMS_PER_PAGE
)
product_type = kwargs.get("productType")
if product_type is None:
raise NoMatchingProductType(
Expand Down
36 changes: 36 additions & 0 deletions eodag/plugins/search/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,39 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""EODAG search package"""
from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING

from eodag.utils import DEFAULT_ITEMS_PER_PAGE, DEFAULT_PAGE

if TYPE_CHECKING:
from typing import Any, Dict, List, Optional, Union

from requests.auth import AuthBase

from eodag.plugins.authentication.base import Authentication


@dataclass
class PreparedSearch:
"""An object collecting needed information for search."""

product_type: Optional[str] = None
page: Optional[int] = DEFAULT_PAGE
items_per_page: Optional[int] = DEFAULT_ITEMS_PER_PAGE
auth: Optional[Union[AuthBase, Dict[str, str]]] = None
auth_plugin: Optional[Authentication] = None
count: bool = True
url: Optional[str] = None
info_message: Optional[str] = None
exception_message: Optional[str] = None

need_count: bool = field(init=False)
query_params: Dict[str, Any] = field(init=False)
query_string: str = field(init=False)
search_urls: List[str] = field(init=False)
product_type_def_params: Dict[str, Any] = field(init=False)
total_items_nb: int = field(init=False)
sort_by_qs: str = field(init=False)
10 changes: 3 additions & 7 deletions eodag/plugins/search/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,11 @@
mtd_cfg_as_conversion_and_querypath,
)
from eodag.plugins.base import PluginTopic
from eodag.plugins.search import PreparedSearch
from eodag.types import model_fields_to_annotated
from eodag.types.queryables import Queryables
from eodag.types.search_args import SortByList
from eodag.utils import (
DEFAULT_ITEMS_PER_PAGE,
DEFAULT_PAGE,
GENERIC_PRODUCT_TYPE,
Annotated,
copy_deepcopy,
Expand Down Expand Up @@ -93,17 +92,14 @@ def clear(self) -> None:

def query(
self,
product_type: Optional[str] = None,
items_per_page: int = DEFAULT_ITEMS_PER_PAGE,
page: int = DEFAULT_PAGE,
count: bool = True,
prep: PreparedSearch = PreparedSearch(),
**kwargs: Any,
) -> Tuple[List[EOProduct], Optional[int]]:
"""Implementation of how the products must be searched goes here.
This method must return a tuple with (1) a list of EOProduct instances (see eodag.api.product module)
which will be processed by a Download plugin (2) and the total number of products matching
the search criteria. If ``count`` is False, the second element returned must be ``None``.
the search criteria. If ``prep.count`` is False, the second element returned must be ``None``.
"""
raise NotImplementedError("A Search plugin must implement a method named query")

Expand Down
48 changes: 21 additions & 27 deletions eodag/plugins/search/build_search_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@
mtd_cfg_as_conversion_and_querypath,
properties_from_json,
)
from eodag.api.search_result import RawSearchResult
from eodag.plugins.search import PreparedSearch
from eodag.plugins.search.base import Search
from eodag.plugins.search.qssearch import PostJsonSearch
from eodag.types import json_field_definition_to_python, model_fields_to_annotated
from eodag.types.queryables import CommonQueryables
from eodag.utils import (
DEFAULT_ITEMS_PER_PAGE,
DEFAULT_MISSION_START_DATE,
DEFAULT_PAGE,
Annotated,
deepcopy,
dict_items_recursive_sort,
Expand Down Expand Up @@ -98,30 +98,29 @@ def count_hits(

def collect_search_urls(
self,
page: Optional[int] = None,
items_per_page: Optional[int] = None,
count: bool = True,
prep: PreparedSearch = PreparedSearch(),
**kwargs: Any,
) -> Tuple[List[str], int]:
"""Wraps PostJsonSearch.collect_search_urls to force product count to 1"""
urls, _ = super(BuildPostSearchResult, self).collect_search_urls(
page=page, items_per_page=items_per_page, count=count, **kwargs
)
urls, _ = super(BuildPostSearchResult, self).collect_search_urls(prep, **kwargs)
return urls, 1

def do_search(self, *args: Any, **kwargs: Any) -> List[Dict[str, Any]]:
def do_search(
self, prep: PreparedSearch = PreparedSearch(items_per_page=None), **kwargs: Any
) -> List[Dict[str, Any]]:
"""Perform the actual search request, and return result in a single element."""
search_url = self.search_urls[0]
response = self._request(
search_url,
info_message=f"Sending search request: {search_url}",
exception_message=f"Skipping error while searching for {self.provider} "
f"{self.__class__.__name__} instance:",
prep.url = prep.search_urls[0]
prep.info_message = f"Sending search request: {prep.url}"
prep.exception_message = (
f"Skipping error while searching for {self.provider} "
f"{self.__class__.__name__} instance:"
)
response = self._request(prep)

return [response.json()]

def normalize_results(
self, results: List[Dict[str, Any]], **kwargs: Any
self, results: RawSearchResult, **kwargs: Any
) -> List[EOProduct]:
"""Build :class:`~eodag.api.product._product.EOProduct` from provider result
Expand All @@ -145,15 +144,15 @@ def normalize_results(
# update result with query parameters without pagination (or search-only params)
if isinstance(
self.config.pagination["next_page_query_obj"], str
) and hasattr(self, "query_params_unpaginated"):
unpaginated_query_params = self.query_params_unpaginated
) and hasattr(results, "query_params_unpaginated"):
unpaginated_query_params = results.query_params_unpaginated
elif isinstance(self.config.pagination["next_page_query_obj"], str):
next_page_query_obj = orjson.loads(
self.config.pagination["next_page_query_obj"].format()
)
unpaginated_query_params = {
k: v[0] if (isinstance(v, list) and len(v) == 1) else v
for k, v in self.query_params.items()
for k, v in results.query_params.items()
if (k, v) not in next_page_query_obj.items()
}
else:
Expand Down Expand Up @@ -181,7 +180,7 @@ def normalize_results(

# update result with product_type_def_params and search args if not None (and not auth)
kwargs.pop("auth", None)
result.update(self.product_type_def_params)
result.update(results.product_type_def_params)
result = dict(result, **{k: v for k, v in kwargs.items() if v is not None})

# parse porperties
Expand Down Expand Up @@ -314,19 +313,14 @@ def do_search(self, *args: Any, **kwargs: Any) -> List[Dict[str, Any]]:

def query(
self,
product_type: Optional[str] = None,
items_per_page: int = DEFAULT_ITEMS_PER_PAGE,
page: int = DEFAULT_PAGE,
count: bool = True,
prep: PreparedSearch = PreparedSearch(),
**kwargs: Any,
) -> Tuple[List[EOProduct], Optional[int]]:
"""Build ready-to-download SearchResult"""

self._preprocess_search_params(kwargs)

return BuildPostSearchResult.query(
self, items_per_page=items_per_page, page=page, count=count, **kwargs
)
return BuildPostSearchResult.query(self, prep, **kwargs)

def clear(self) -> None:
"""Clear search context"""
Expand Down
5 changes: 3 additions & 2 deletions eodag/plugins/search/creodias_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
# limitations under the License.
import logging
from types import MethodType
from typing import Any, Dict, List
from typing import Any, List

import boto3
import botocore
from botocore.exceptions import BotoCoreError

from eodag.api.product import AssetsDict, EOProduct # type: ignore
from eodag.api.search_result import RawSearchResult
from eodag.config import PluginConfig
from eodag.plugins.authentication.aws_auth import AwsAuth
from eodag.plugins.search.qssearch import ODataV4Search
Expand Down Expand Up @@ -119,7 +120,7 @@ def __init__(self, provider, config):
super(CreodiasS3Search, self).__init__(provider, config)

def normalize_results(
self, results: List[Dict[str, Any]], **kwargs: Any
self, results: RawSearchResult, **kwargs: Any
) -> List[EOProduct]:
"""Build EOProducts from provider results"""

Expand Down
10 changes: 4 additions & 6 deletions eodag/plugins/search/csw.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@

from eodag.api.product import EOProduct
from eodag.api.product.metadata_mapping import properties_from_xml
from eodag.plugins.search import PreparedSearch
from eodag.plugins.search.base import Search
from eodag.utils import DEFAULT_ITEMS_PER_PAGE, DEFAULT_PAGE, DEFAULT_PROJ
from eodag.utils import DEFAULT_PROJ
from eodag.utils.import_system import patch_owslib_requests

if TYPE_CHECKING:
Expand Down Expand Up @@ -64,10 +65,7 @@ def clear(self) -> None:

def query(
self,
product_type: Optional[str] = None,
items_per_page: int = DEFAULT_ITEMS_PER_PAGE,
page: int = DEFAULT_PAGE,
count: bool = True,
prep: PreparedSearch = PreparedSearch(),
**kwargs: Any,
) -> Tuple[List[EOProduct], Optional[int]]:
"""Perform a search on a OGC/CSW-like interface"""
Expand Down Expand Up @@ -118,7 +116,7 @@ def query(
)
results.extend(partial_results)
logger.info("Found %s overall results", len(results))
total_results = len(results) if count else None
total_results = len(results) if prep.count else None
return results, total_results

def __init_catalog(
Expand Down
Loading

0 comments on commit 2c37db3

Please sign in to comment.