Skip to content
This repository has been archived by the owner on Nov 19, 2024. It is now read-only.

Commit

Permalink
implement wait until complete option (#59)
Browse files Browse the repository at this point in the history
* implement wait until complete option

* warning

* cleanup

* implement logging methods

* cleanup
  • Loading branch information
malmans2 authored Aug 26, 2024
1 parent 6619318 commit 65c7f8f
Show file tree
Hide file tree
Showing 4 changed files with 227 additions and 41 deletions.
7 changes: 7 additions & 0 deletions cads_api_client/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ def retrieve(
**request,
)

def submit(
self, collection_id: str, retry_options: Dict[str, Any] = {}, **request: Any
) -> processing.Remote:
return self.retrieve_api.submit(
collection_id, retry_options=retry_options, **request
)

def submit_and_wait_on_result(
self, collection_id: str, retry_options: Dict[str, Any] = {}, **request: Any
) -> processing.Results:
Expand Down
94 changes: 60 additions & 34 deletions cads_api_client/legacy_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,27 +97,26 @@ def __init__(
self.quiet = quiet
self._debug = debug

self.wait_until_complete = kwargs.pop("wait_until_complete", True)

if kwargs:
warnings.warn(
"This is a beta version."
f" The following parameters have not been implemented yet: {kwargs}.",
UserWarning,
)

with LoggingContext(
logger=LOGGER, quiet=self.quiet, debug=self._debug
) as logger:
logger.debug(
"CDSAPI %s",
{
"url": self.url,
"key": self.key,
"quiet": self.quiet,
"timeout": self.timeout,
"sleep_max": self.sleep_max,
"retry_max": self.retry_max,
},
)
self.debug(
"CDSAPI %s",
{
"url": self.url,
"key": self.key,
"quiet": self.quiet,
"timeout": self.timeout,
"sleep_max": self.sleep_max,
"retry_max": self.retry_max,
},
)

@classmethod
def raise_not_implemented_error(self) -> None:
Expand Down Expand Up @@ -145,39 +144,66 @@ def retrieve(

def retrieve(
self, name: str, request: dict[str, Any], target: str | None = None
) -> str | processing.Results:
result = self.logging_decorator(self.client.submit_and_wait_on_result)(
collection_id=name,
retry_options=self.retry_options,
**request,
)
) -> str | processing.Remote | processing.Results:
submitted: processing.Remote | processing.Results
if self.wait_until_complete:
submitted = self.logging_decorator(self.client.submit_and_wait_on_result)(
collection_id=name,
retry_options=self.retry_options,
**request,
)
else:
submitted = self.logging_decorator(self.client.submit)(
collection_id=name,
retry_options=self.retry_options,
**request,
)

# Assign legacy methods
partial_download: Callable[..., str] = functools.partial(
result.download,
submitted.download,
timeout=self.timeout,
retry_options=self.retry_options,
)
result.download = self.logging_decorator(partial_download) # type: ignore[method-assign]
return result if target is None else result.download(target)
submitted.download = self.logging_decorator(partial_download) # type: ignore[method-assign]
submitted.info = self.logging_decorator(submitted.info) # type: ignore[method-assign]
submitted.warning = self.logging_decorator(submitted.warning) # type: ignore[method-assign]
submitted.error = self.logging_decorator(submitted.error) # type: ignore[method-assign]
submitted.debug = self.logging_decorator(submitted.debug) # type: ignore[method-assign]

def service(self, name, *args, **kwargs): # type: ignore
self.raise_not_implemented_error()
return submitted if target is None else submitted.download(target)

def workflow(self, code, *args, **kwargs): # type: ignore
self.raise_not_implemented_error()
def info(self, *args: Any, **kwargs: Any) -> None:
with LoggingContext(
logger=LOGGER, quiet=self.quiet, debug=self._debug
) as logger:
logger.info(*args, **kwargs)

def status(self, context=None): # type: ignore
self.raise_not_implemented_error()
def warning(self, *args: Any, **kwargs: Any) -> None:
with LoggingContext(
logger=LOGGER, quiet=self.quiet, debug=self._debug
) as logger:
logger.warning(*args, **kwargs)

def info(self, *args, **kwargs): # type: ignore
self.raise_not_implemented_error()
def error(self, *args: Any, **kwargs: Any) -> None:
with LoggingContext(
logger=LOGGER, quiet=self.quiet, debug=self._debug
) as logger:
logger.error(*args, **kwargs)

def debug(self, *args: Any, **kwargs: Any) -> None:
with LoggingContext(
logger=LOGGER, quiet=self.quiet, debug=self._debug
) as logger:
logger.debug(*args, **kwargs)

def warning(self, *args, **kwargs): # type: ignore
def service(self, name, *args, **kwargs): # type: ignore
self.raise_not_implemented_error()

def error(self, *args, **kwargs): # type: ignore
def workflow(self, code, *args, **kwargs): # type: ignore
self.raise_not_implemented_error()

def debug(self, *args, **kwargs): # type: ignore
def status(self, context=None): # type: ignore
self.raise_not_implemented_error()

def download(self, results, targets=None): # type: ignore
Expand Down
85 changes: 79 additions & 6 deletions cads_api_client/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import time
import urllib.parse
import warnings
from typing import Any, Dict, List, Optional, Type, TypeVar

try:
Expand Down Expand Up @@ -202,7 +203,7 @@ def log_metadata(self, metadata: dict[str, Any]) -> None:
def request_uid(self) -> str:
return self.url.rpartition("/")[2]

def _get_status(self, robust: bool, **retry_options: Any) -> str:
def _get_reply(self, robust: bool, **retry_options: Any) -> dict[str, Any]:
# TODO: cache responses for a timeout (possibly reported nby the server)
get = self.session.get
if robust:
Expand All @@ -216,7 +217,10 @@ def _get_status(self, robust: bool, **retry_options: Any) -> str:
requests_response = get(url=self.url, headers=self.headers, params=params)
logger.debug(f"REPLY {requests_response.text}")
requests_response.raise_for_status()
json = requests_response.json()
return dict(requests_response.json())

def _get_status(self, robust: bool, **retry_options: Any) -> str:
json = self._get_reply(robust, **retry_options)
self.log_metadata(json.get("metadata", {}))
return str(json["status"])

Expand Down Expand Up @@ -278,16 +282,73 @@ def make_results(self, url: Optional[str] = None) -> Results:
return results

def _download_result(
self, target: Optional[str] = None, retry_options: Dict[str, Any] = {}
self,
target: str | None = None,
timeout: int = 60,
retry_options: Dict[str, Any] = {},
) -> str:
results: Results = multiurl.robust(self.make_results, **retry_options)(self.url)
return results.download(target, retry_options=retry_options)
return results.download(target, timeout=timeout, retry_options=retry_options)

def download(
self, target: Optional[str] = None, retry_options: Dict[str, Any] = {}
self,
target: str | None = None,
timeout: int = 60,
retry_options: Dict[str, Any] = {},
) -> str:
self.wait_on_result(retry_options=retry_options)
return self._download_result(target, retry_options=retry_options)
return self._download_result(
target, timeout=timeout, retry_options=retry_options
)

def _warn(self) -> None:
message = (
".update and .reply are available for backward compatibility."
" You can now use .download directly without needing to check whether the request is completed."
)
warnings.warn(message, DeprecationWarning)

def update(self, request_id: str | None = None) -> None:
self._warn()
if request_id:
assert request_id == self.request_uid
try:
del self.reply
except AttributeError:
pass
self.reply

@functools.cached_property
def reply(self) -> dict[str, Any]:
self._warn()

reply = self._get_reply(True)

reply.setdefault("state", reply["status"])
if reply["state"] == "successful":
reply["state"] = "completed"
elif reply["state"] == "queued":
reply["state"] = "accepted"
elif reply["state"] == "failed":
results = multiurl.robust(self.make_results)(self.url)
message = error_json_to_message(results.json)
reply.setdefault("error", {})
reply["error"].setdefault("message", message)

reply.setdefault("request_id", self.request_uid)
return reply

def info(self, *args: Any, **kwargs: Any) -> None:
logger.info(*args, **kwargs)

def warning(self, *args: Any, **kwargs: Any) -> None:
logger.warning(*args, **kwargs)

def error(self, *args: Any, **kwargs: Any) -> None:
logger.error(*args, **kwargs)

def debug(self, *args: Any, **kwargs: Any) -> None:
logger.debug(*args, **kwargs)


@attrs.define
Expand Down Expand Up @@ -370,6 +431,18 @@ def download(
)
return target

def info(self, *args: Any, **kwargs: Any) -> None:
logger.info(*args, **kwargs)

def warning(self, *args: Any, **kwargs: Any) -> None:
logger.warning(*args, **kwargs)

def error(self, *args: Any, **kwargs: Any) -> None:
logger.error(*args, **kwargs)

def debug(self, *args: Any, **kwargs: Any) -> None:
logger.debug(*args, **kwargs)


class Processing:
supported_api_version = "v1"
Expand Down
82 changes: 81 additions & 1 deletion tests/integration_test_50_legacy_api_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import pathlib
import time

import pytest
import requests

from cads_api_client import legacy_api_client
from cads_api_client import legacy_api_client, processing


def test_retrieve(tmp_path: pathlib.Path, api_root_url: str, api_anon_key: str) -> None:
Expand Down Expand Up @@ -52,3 +53,82 @@ def test_debug(
legacy_api_client.LegacyApiClient(url=api_root_url, key=api_anon_key, debug=debug)
records = [record for record in caplog.records if record.levelname == "DEBUG"]
assert records if debug else not records


@pytest.mark.parametrize(
"wait_until_complete,expected_type",
[(True, processing.Results), (False, processing.Remote)],
)
def test_wait_until_complete(
tmp_path: pathlib.Path,
api_root_url: str,
api_anon_key: str,
wait_until_complete: bool,
expected_type: type,
) -> None:
client = legacy_api_client.LegacyApiClient(
url=api_root_url,
key=api_anon_key,
wait_until_complete=wait_until_complete,
)

collection_id = "test-adaptor-dummy"
request = {"size": 1}

result = client.retrieve(collection_id, request)
assert isinstance(result, expected_type)

target = tmp_path / "test.grib"
result.download(str(target))
assert target.stat().st_size == 1


def test_legacy_update(
tmp_path: pathlib.Path,
api_root_url: str,
api_anon_key: str,
) -> None:
client = legacy_api_client.LegacyApiClient(
url=api_root_url,
key=api_anon_key,
wait_until_complete=False,
)
collection_id = "test-adaptor-dummy"
request = {"size": 1}
remote = client.retrieve(collection_id, request)
assert isinstance(remote, processing.Remote)

# See https://github.com/ecmwf/cdsapi/blob/master/examples/example-era5-update.py
sleep = 1
while True:
with pytest.deprecated_call():
remote.update()

reply = remote.reply
remote.info("Request ID: %s, state: %s" % (reply["request_id"], reply["state"]))

if reply["state"] == "completed":
break
elif reply["state"] in ("queued", "running"):
remote.info("Request ID: %s, sleep: %s", reply["request_id"], sleep)
time.sleep(sleep)
elif reply["state"] in ("failed",):
remote.error("Message: %s", reply["error"].get("message"))
remote.error("Reason: %s", reply["error"].get("reason"))
for n in (
reply.get("error", {})
.get("context", {})
.get("traceback", "")
.split("\n")
):
if n.strip() == "":
break
remote.error(" %s", n)
raise Exception(
"%s. %s."
% (reply["error"].get("message"), reply["error"].get("reason"))
)

target = tmp_path / "test.grib"
remote.download(str(target))
assert target.stat().st_size == 1

0 comments on commit 65c7f8f

Please sign in to comment.