Skip to content

Commit

Permalink
refactor: mix queryables kwargs and defaults
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrunato committed Feb 19, 2024
1 parent a4dc65d commit 5388042
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 105 deletions.
25 changes: 10 additions & 15 deletions eodag/api/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
HTTP_REQ_TIMEOUT,
MockResponse,
_deprecated,
copy_deepcopy,
deepcopy,
get_args,
get_geometry_from_various,
Expand Down Expand Up @@ -2171,7 +2172,9 @@ def list_queryables(
if k in queryables_keys
}

all_queryables = model_fields_to_annotated(Queryables.model_fields)
all_queryables = copy_deepcopy(
model_fields_to_annotated(Queryables.model_fields)
)

try:
plugin = next(
Expand Down Expand Up @@ -2209,15 +2212,7 @@ def list_queryables(
getattr(plugin.config, "products", {}).get(product_type, {})
)
default_values.pop("metadata_mapping", None)
removed_defaults = []
for param in kwargs:
if not kwargs[param]:
default_values.pop(param, None)
removed_defaults.append(param)
else:
default_values[param] = kwargs[param]
kwargs = {key: kwargs[key] for key in kwargs if key not in removed_defaults}
kwargs["defaults"] = default_values
kwargs = dict(default_values, **kwargs)

# remove not mapped parameters or non-queryables
for param in list(metadata_mapping.keys()):
Expand All @@ -2233,8 +2228,8 @@ def list_queryables(
field_info = annotated_args[1]
if not isinstance(field_info, FieldInfo):
continue
if key in default_values:
field_info.default = default_values[key]
if key in kwargs:
field_info.default = kwargs[key]
if field_info.is_required() or (
(field_info.alias or key) in metadata_mapping
):
Expand All @@ -2245,10 +2240,10 @@ def list_queryables(
provider_queryables.update(providers_available_queryables[provider])

# always keep at least CommonQueryables
common_queryables = deepcopy(CommonQueryables.model_fields)
common_queryables = copy_deepcopy(CommonQueryables.model_fields)
for key, queryable in common_queryables.items():
if key in default_values:
queryable.default = default_values[key]
if key in kwargs:
queryable.default = kwargs[key]

provider_queryables.update(model_fields_to_annotated(common_queryables))

Expand Down
30 changes: 25 additions & 5 deletions eodag/plugins/apis/cds.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@
logger = logging.getLogger("eodag.apis.cds")

CDS_KNOWN_FORMATS = {"grib": "grib", "netcdf": "nc"}
# always available queryables (needed as not available in constraints)
CDS_ALLOWED_QUERYABLES = ["format"]


class CdsApi(HTTPDownload, Api, BuildPostSearchResult):
Expand Down Expand Up @@ -457,9 +459,21 @@ def discover_queryables(
product_type = kwargs.pop("productType", None)
if not product_type:
return {}

provider_product_type = self.config.products.get(product_type, {}).get(
"dataset", None
)
user_provider_product_type = kwargs.pop("dataset", None)
if (
user_provider_product_type
and user_provider_product_type != provider_product_type
):
raise ValidationError(
f"Cannot change dataset from {provider_product_type} to {user_provider_product_type}"
)

non_empty_kwargs = {k: v for k, v in kwargs.items() if v}

if "{" in constraints_file_url:
constraints_file_url = constraints_file_url.format(
dataset=provider_product_type
Expand All @@ -468,7 +482,7 @@ def discover_queryables(
if not constraints:
return {}
constraint_params: Dict[str, Dict[str, Set[Any]]] = {}
if len(kwargs) == 0 or (len(kwargs) == 1 and len(kwargs["defaults"]) == 0):
if len(kwargs) == 0:
# get values from constraints without additional filters
for constraint in constraints:
for key in constraint.keys():
Expand All @@ -479,26 +493,32 @@ def discover_queryables(
constraint_params[key]["enum"] = set(constraint[key])
else:
# get values from constraints with additional filters
constraints_input_params = {
k: v
for k, v in non_empty_kwargs.items()
if k not in CDS_ALLOWED_QUERYABLES
}
constraint_params = get_constraint_queryables_with_additional_params(
constraints, kwargs, self, product_type
constraints, constraints_input_params, self, product_type
)
# query params that are not in constraints but might be default queryables
if len(constraint_params) == 1 and "not_available" in constraint_params:
not_queryables = set()
for constraint_param in constraint_params["not_available"]["enum"]:
param = CommonQueryables.get_queryable_from_alias(constraint_param)
if param in CommonQueryables.model_fields:
kwargs.pop(constraint_param)
non_empty_kwargs.pop(constraint_param)
else:
not_queryables.add(constraint_param)
if not_queryables:
raise ValidationError(
f"parameter(s) {str(not_queryables)} not queryable"
)
else:
# get constraints again without common queryables
constraint_params = (
get_constraint_queryables_with_additional_params(
constraints, kwargs, self, product_type
constraints, non_empty_kwargs, self, product_type
)
)

Expand All @@ -508,7 +528,7 @@ def discover_queryables(
get_queryable_from_provider(json_param, self.config.metadata_mapping)
or json_param
)
default = kwargs.get("defaults", {}).get(param, None)
default = kwargs.get(param, None)
annotated_def = json_field_definition_to_python(
json_mtd, default_value=default, required=True
)
Expand Down
2 changes: 1 addition & 1 deletion eodag/plugins/search/qssearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1353,7 +1353,7 @@ def discover_queryables(
or json_param
)

default = kwargs.get("defaults", {}).get(param, None)
default = kwargs.get(param, None)
annotated_def = json_field_definition_to_python(
json_mtd, default_value=default
)
Expand Down
25 changes: 16 additions & 9 deletions eodag/utils/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def get_constraint_queryables_with_additional_params(
:returns: dict containing queryable data
:rtype: Dict[str, Dict[str, Set[Any]]]
"""
params = copy.deepcopy(input_params)
defaults = copy.deepcopy(input_params)
constraint_matches = {}
defaults = params.pop("defaults", {})
params = {k: v for k, v in defaults.items() if v}
for p in params.keys():
defaults.pop(p, None)
params_available = {k: False for k in params.keys()}
Expand Down Expand Up @@ -101,13 +101,20 @@ def get_constraint_queryables_with_additional_params(
# add values of constraints matching params
queryables: Dict[str, Dict[str, Set[Any]]] = {}
for num, matches in constraint_matches.items():
if False not in matches.values():
for key in constraints[num]:
if key in queryables:
queryables[key]["enum"].update(constraints[num][key])
else:
queryables[key] = {}
queryables[key]["enum"] = set(constraints[num][key])
for key in constraints[num]:
other_keys_matching = [v for k, v in matches.items() if k != key]
key_matches_a_constraint = any(
v.get(key, False) for v in constraint_matches.values()
)
if False in other_keys_matching or (
not key_matches_a_constraint and key in matches
):
continue
if key in queryables:
queryables[key]["enum"].update(constraints[num][key])
else:
queryables[key] = {}
queryables[key]["enum"] = set(constraints[num][key])

other_values = _get_other_possible_values_for_values_with_defaults(
defaults, params, constraints, metadata_mapping
Expand Down
12 changes: 12 additions & 0 deletions tests/resources/constraints.json
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
],
"type": [
"A", "B"
],
"product_type": [
"ensemble_mean", "reanalysis"
]
},
{
Expand All @@ -40,6 +43,9 @@
],
"type": [
"C", "B"
],
"product_type": [
"ensemble_mean", "reanalysis"
]
},
{
Expand All @@ -57,6 +63,9 @@
],
"variable": [
"b", "c"
],
"product_type": [
"ensemble_mean", "reanalysis"
]
},
{
Expand All @@ -74,6 +83,9 @@
],
"variable": [
"e", "f"
],
"product_type": [
"ensemble_mean", "reanalysis"
]
}
]
5 changes: 2 additions & 3 deletions tests/units/test_apis_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,15 +925,14 @@ def test_plugins_apis_cds_discover_queryables(self, mock_requests_constraints):
queryables = self.api_plugin.discover_queryables(
productType="CAMS_EU_AIR_QUALITY_RE"
)
self.assertEqual(7, len(queryables))
self.assertEqual(8, len(queryables))
self.assertIn("variable", queryables)
# with additional param
queryables = self.api_plugin.discover_queryables(
productType="CAMS_EU_AIR_QUALITY_RE",
variable="a",
defaults={"variable": "a"},
)
self.assertEqual(7, len(queryables))
self.assertEqual(8, len(queryables))
queryable = queryables.get("variable")
self.assertEqual("a", queryable.__metadata__[0].get_default())
queryable = queryables.get("month")
Expand Down
66 changes: 27 additions & 39 deletions tests/units/test_constraints.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,20 @@
# -*- coding: utf-8 -*-
# Copyright 2024, CS GROUP - France, https://www.csgroup.eu/
#
# This file is part of EODAG project
# https://www.github.com/CS-SI/EODAG
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import unittest
Expand All @@ -21,48 +38,38 @@ def test_get_constraint_queryables_with_additional_params(self):
constraints = json.load(f)
plugins = self.plugins_manager.get_search_plugins("ERA5_SL", "cop_cds")
plugin = next(plugins)

# filter on one parameter
queryables = get_constraint_queryables_with_additional_params(
constraints, {"variable": "f"}, plugin, "ERA5_SL"
)
self.assertEqual(5, len(queryables))
self.assertEqual(6, len(queryables))
self.assertIn("year", queryables)
queryable = queryables.get("year")
self.assertSetEqual({"2000", "2001"}, queryable["enum"])
self.assertIn("variable", queryables)

# not existing parameter
queryables = get_constraint_queryables_with_additional_params(
constraints, {"param": "f"}, plugin, "ERA5_SL"
)
self.assertIn("not_available", queryables)
self.assertEqual("param", queryables["not_available"]["enum"].pop())

# not existing value of parameter
with self.assertRaises(ValidationError):
get_constraint_queryables_with_additional_params(
constraints, {"variable": "g"}, plugin, "ERA5_SL"
)
# 2 parameters
queryables = get_constraint_queryables_with_additional_params(
constraints, {"variable": "c", "year": "2000"}, plugin, "ERA5_SL"
)
self.assertEqual(6, len(queryables))
self.assertIn("year", queryables)
self.assertIn("variable", queryables)
self.assertIn("month", queryables)
self.assertIn("day", queryables)
self.assertIn("time", queryables)
queryable = queryables.get("time")
self.assertSetEqual({"01:00", "12:00", "18:00", "22:00"}, queryable["enum"])
# with param and defaults

# with params/defaults
queryables = get_constraint_queryables_with_additional_params(
constraints,
{
"variable": "c",
"defaults": {"type": "B", "year": "2000", "field": "test"},
},
{"variable": "c", "type": "B", "year": "2000"},
plugin,
"ERA5_SL",
)
self.assertEqual(6, len(queryables))
self.assertEqual(7, len(queryables))
self.assertIn("year", queryables)
self.assertIn("variable", queryables)
self.assertIn("month", queryables)
Expand All @@ -74,23 +81,4 @@ def test_get_constraint_queryables_with_additional_params(self):
queryable = queryables.get("type")
self.assertSetEqual({"C", "B"}, queryable["enum"])
queryable = queryables.get("year")
self.assertSetEqual(
{"2000", "2001", "2002", "2003", "2004", "2005"}, queryable["enum"]
)
# only with defaults
queryables = get_constraint_queryables_with_additional_params(
constraints,
{"defaults": {"type": "A", "year": "2000", "field": "test"}},
plugin,
"ERA5_SL",
)
self.assertEqual(7, len(queryables))
self.assertIn("year", queryables)
self.assertIn("variable", queryables)
self.assertIn("type", queryables)
queryable = queryables.get("time")
self.assertSetEqual({"01:00", "12:00", "18:00", "22:00"}, queryable["enum"])
queryable = queryables.get("variable")
self.assertSetEqual({"a", "b", "e", "f"}, queryable["enum"])
queryable = queryables.get("type")
self.assertSetEqual({"A", "B", "C"}, queryable["enum"])
self.assertSetEqual({"2000", "2001"}, queryable["enum"])
Loading

0 comments on commit 5388042

Please sign in to comment.