Skip to content

Commit

Permalink
fix: type hints related fixes and refactoring (#1052)
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrunato authored Mar 6, 2024
1 parent 211bdfd commit d1df25e
Show file tree
Hide file tree
Showing 29 changed files with 488 additions and 249 deletions.
39 changes: 24 additions & 15 deletions eodag/api/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
)
from eodag.utils.exceptions import (
AuthenticationError,
EodagError,
MisconfiguredError,
NoMatchingProductType,
PluginImplementationError,
Expand All @@ -105,7 +106,8 @@
from eodag.plugins.crunch.base import Crunch
from eodag.plugins.search.base import Search
from eodag.types import ProviderSortables
from eodag.utils import Annotated, DownloadedCallback, ProgressCallback
from eodag.types.download_args import DownloadConf
from eodag.utils import Annotated, DownloadedCallback, ProgressCallback, Unpack

logger = logging.getLogger("eodag.core")

Expand Down Expand Up @@ -502,10 +504,10 @@ def set_locations_conf(self, locations_conf_path: str) -> None:
locations_config = load_yml_config(locations_conf_path)

main_key = next(iter(locations_config))
locations_config = locations_config[main_key]
main_locations_config = locations_config[main_key]

logger.info("Locations configuration loaded from %s" % locations_conf_path)
self.locations_config: List[Dict[str, Any]] = locations_config
self.locations_config: List[Dict[str, Any]] = main_locations_config
else:
logger.info(
"Could not load locations configuration from %s" % locations_conf_path
Expand Down Expand Up @@ -612,9 +614,8 @@ def fetch_product_types_list(self, provider: Optional[str] = None) -> None:

if not ext_product_types_conf:
# empty ext_product_types conf
discover_kwargs = dict(provider=provider) if provider else {}
ext_product_types_conf = self.discover_product_types(
**discover_kwargs
ext_product_types_conf = (
self.discover_product_types(provider=provider) or {}
)

# update eodag product types list with new conf
Expand Down Expand Up @@ -693,8 +694,8 @@ def fetch_product_types_list(self, provider: Optional[str] = None) -> None:
# or not in ext_product_types_conf (if eodag system conf != eodag conf used for ext_product_types_conf)

# discover product types for user configured provider
provider_ext_product_types_conf = self.discover_product_types(
provider=provider
provider_ext_product_types_conf = (
self.discover_product_types(provider=provider) or {}
)

# update eodag product types list with new conf
Expand Down Expand Up @@ -738,7 +739,9 @@ def discover_product_types(
auth_plugin = self._plugins_manager.get_auth_plugin(
search_plugin.provider
)
if callable(getattr(auth_plugin, "authenticate", None)):
if auth_plugin and callable(
getattr(auth_plugin, "authenticate", None)
):
try:
search_plugin.auth = auth_plugin.authenticate()
except (AuthenticationError, MisconfiguredError) as e:
Expand Down Expand Up @@ -934,6 +937,8 @@ def guess_product_type(self, **kwargs: Any) -> List[str]:
)
if kwargs.get(param, None) is not None
}
if not self._product_types_index:
raise EodagError("Missing product types index")
with self._product_types_index.searcher() as searcher:
results = None
# For each search key, do a guess and then upgrade the result (i.e. when
Expand Down Expand Up @@ -965,7 +970,7 @@ def search(
locations: Optional[Dict[str, str]] = None,
provider: Optional[str] = None,
**kwargs: Any,
) -> Tuple[SearchResult, int]:
) -> Tuple[SearchResult, Optional[int]]:
"""Look for products matching criteria on known providers.
The default behaviour is to look for products on the provider with the
Expand Down Expand Up @@ -1448,7 +1453,7 @@ def _search_by_id(
# parameters which might not given an exact result are used
for result in results:
if result.properties["id"] == uid.split(".")[0]:
return [results[0]], 1
return SearchResult([results[0]]), 1
logger.info(
"Several products found for this id (%s). You may try searching using more selective criteria.",
results,
Expand Down Expand Up @@ -1676,7 +1681,7 @@ def _do_search(
can_authenticate = callable(getattr(auth_plugin, "authenticate", None))

results: List[EOProduct] = []
total_results = 0
total_results: Optional[int] = 0

try:
if need_auth and auth_plugin and can_authenticate:
Expand Down Expand Up @@ -1779,7 +1784,11 @@ def _do_search(
eo_product.register_downloader(download_plugin, auth_plugin)

results.extend(res)
total_results = None if nb_res is None else total_results + nb_res
total_results = (
None
if (nb_res is None or total_results is None)
else total_results + nb_res
)
if count:
logger.info(
"Found %s result(s) on provider '%s'",
Expand Down Expand Up @@ -1868,7 +1877,7 @@ def download_all(
progress_callback: Optional[ProgressCallback] = None,
wait: int = DEFAULT_DOWNLOAD_WAIT,
timeout: int = DEFAULT_DOWNLOAD_TIMEOUT,
**kwargs: Any,
**kwargs: Unpack[DownloadConf],
) -> List[str]:
"""Download all products resulting from a search.
Expand Down Expand Up @@ -2043,7 +2052,7 @@ def download(
progress_callback: Optional[ProgressCallback] = None,
wait: int = DEFAULT_DOWNLOAD_WAIT,
timeout: int = DEFAULT_DOWNLOAD_TIMEOUT,
**kwargs: Any,
**kwargs: Unpack[DownloadConf],
) -> str:
"""Download a single product.
Expand Down
9 changes: 7 additions & 2 deletions eodag/api/product/_assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@

import re
from collections import UserDict
from typing import TYPE_CHECKING, Any, Dict, List
from typing import TYPE_CHECKING, Any, Dict, List, Optional

from eodag.utils.exceptions import NotAvailableError

if TYPE_CHECKING:
from eodag.api.product import EOProduct
from eodag.types.download_args import DownloadConf
from eodag.utils import Unpack


class AssetsDict(UserDict):
Expand Down Expand Up @@ -98,6 +100,9 @@ class Asset(UserDict):
"""

product: EOProduct
size: int
filename: Optional[str]
rel_path: str

def __init__(self, product: EOProduct, key: str, *args: Any, **kwargs: Any) -> None:
self.product = product
Expand All @@ -113,7 +118,7 @@ def as_dict(self) -> Dict[str, Any]:
"""
return self.data

def download(self, **kwargs: Any) -> str:
def download(self, **kwargs: Unpack[DownloadConf]) -> str:
"""Downloads a single asset
:param kwargs: (optional) Additional named-arguments passed to `plugin.download()`
Expand Down
16 changes: 13 additions & 3 deletions eodag/api/product/_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
import logging
import os
import re
import tempfile
import urllib.parse
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union

import requests
from requests import RequestException
from requests.auth import AuthBase
from shapely import geometry, wkb, wkt
from shapely.errors import ShapelyError

Expand All @@ -49,6 +51,8 @@
from eodag.plugins.apis.base import Api
from eodag.plugins.authentication.base import Authentication
from eodag.plugins.download.base import Download
from eodag.types.download_args import DownloadConf
from eodag.utils import Unpack

try:
from shapely.errors import GEOSException
Expand Down Expand Up @@ -308,7 +312,7 @@ def download(
progress_callback: Optional[ProgressCallback] = None,
wait: int = DEFAULT_DOWNLOAD_WAIT,
timeout: int = DEFAULT_DOWNLOAD_TIMEOUT,
**kwargs: Any,
**kwargs: Unpack[DownloadConf],
) -> str:
"""Download the EO product using the provided download plugin and the
authenticator if necessary.
Expand Down Expand Up @@ -469,9 +473,13 @@ def format_quicklook_address() -> None:
if base_dir is not None:
quicklooks_base_dir = os.path.abspath(os.path.realpath(base_dir))
else:
quicklooks_base_dir = os.path.join(
self.downloader.config.outputs_prefix, "quicklooks"
tempdir = tempfile.gettempdir()
outputs_prefix = (
getattr(self.downloader.config, "outputs_prefix", tempdir)
if self.downloader
else tempdir
)
quicklooks_base_dir = os.path.join(outputs_prefix, "quicklooks")
if not os.path.isdir(quicklooks_base_dir):
os.makedirs(quicklooks_base_dir)
quicklook_file = os.path.join(
Expand All @@ -497,6 +505,8 @@ def format_quicklook_address() -> None:
if self.downloader_auth is not None
else None
)
if not isinstance(auth, AuthBase):
auth = None
with requests.get(
self.properties["quicklook"],
stream=True,
Expand Down
43 changes: 29 additions & 14 deletions eodag/api/product/metadata_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from typing import (
TYPE_CHECKING,
Any,
AnyStr,
Callable,
Dict,
Iterator,
List,
Expand All @@ -40,7 +42,7 @@
import pyproj
from dateutil.parser import isoparse
from dateutil.tz import UTC, tzutc
from jsonpath_ng.jsonpath import Child
from jsonpath_ng.jsonpath import Child, JSONPath
from lxml import etree
from lxml.etree import XPathEvalError
from shapely import wkt
Expand Down Expand Up @@ -151,7 +153,7 @@ def get_search_param(map_value: List[str]) -> str:
return map_value[0]


def format_metadata(search_param: str, *args: Tuple[Any], **kwargs: Any) -> str:
def format_metadata(search_param: str, *args: Any, **kwargs: Any) -> str:
"""Format a string of form {<field_name>#<conversion_function>}
The currently understood converters are:
Expand Down Expand Up @@ -203,8 +205,8 @@ class MetadataFormatter(Formatter):
)

def __init__(self) -> None:
self.custom_converter = None
self.custom_args = None
self.custom_converter: Optional[Callable] = None
self.custom_args: Optional[str] = None

def get_field(self, field_name: str, args: Any, kwargs: Any) -> Any:
conversion_func_spec = self.CONVERSION_REGEX.match(field_name)
Expand Down Expand Up @@ -479,12 +481,15 @@ def convert_remove_extension(string: str) -> str:
@staticmethod
def convert_get_group_name(string: str, pattern: str) -> str:
try:
return re.search(pattern, str(string)).lastgroup
match = re.search(pattern, str(string))
if match:
return match.lastgroup or NOT_AVAILABLE
except AttributeError:
logger.warning(
"Could not extract property from %s using %s", string, pattern
)
return NOT_AVAILABLE
pass
logger.warning(
"Could not extract property from %s using %s", string, pattern
)
return NOT_AVAILABLE

@staticmethod
def convert_replace_str(string: str, args: str) -> str:
Expand Down Expand Up @@ -812,7 +817,10 @@ def convert_get_ecmwf_time(date: str) -> List[str]:
@staticmethod
def convert_get_dates_from_string(text: str, split_param="-"):
reg = "[0-9]{8}" + split_param + "[0-9]{8}"
dates_str = re.search(reg, text).group()
match = re.search(reg, text)
if not match:
return NOT_AVAILABLE
dates_str = match.group()
dates = dates_str.split(split_param)
start_date = datetime.strptime(dates[0], "%Y%m%d")
end_date = datetime.strptime(dates[1], "%Y%m%d")
Expand Down Expand Up @@ -926,13 +934,18 @@ def properties_from_json(
discovery_pattern = discovery_config.get("metadata_pattern", None)
discovery_path = discovery_config.get("metadata_path", None)
if discovery_pattern and discovery_path:
discovered_properties = string_to_jsonpath(discovery_path).find(json)
discovery_jsonpath = string_to_jsonpath(discovery_path)
discovered_properties = (
discovery_jsonpath.find(json)
if isinstance(discovery_jsonpath, JSONPath)
else []
)
for found_jsonpath in discovered_properties:
if "metadata_path_id" in discovery_config.keys():
found_key_paths = string_to_jsonpath(
discovery_config["metadata_path_id"], force=True
).find(found_jsonpath.value)
if not found_key_paths:
if not found_key_paths or isinstance(found_key_paths, int):
continue
found_key = found_key_paths[0].value
used_jsonpath = Child(
Expand All @@ -955,7 +968,9 @@ def properties_from_json(
discovery_config["metadata_path_value"], force=True
).find(found_jsonpath.value)
properties[found_key] = (
found_value_path[0].value if found_value_path else NOT_AVAILABLE
found_value_path[0].value
if found_value_path and not isinstance(found_value_path, int)
else NOT_AVAILABLE
)
else:
# default value got from metadata_path
Expand All @@ -971,7 +986,7 @@ def properties_from_json(


def properties_from_xml(
xml_as_text: str,
xml_as_text: AnyStr,
mapping: Any,
empty_ns_prefix: str = "ns",
discovery_config: Optional[Dict[str, Any]] = None,
Expand Down
Loading

0 comments on commit d1df25e

Please sign in to comment.