From f8ad7988c11e78d3ddc05ba325444bad67aa8403 Mon Sep 17 00:00:00 2001 From: Yann Forget Date: Wed, 6 Nov 2024 23:34:35 +0100 Subject: [PATCH 01/18] feat(era5): use cads_api_client instead of cadsapi --- openhexa/toolbox/era5/cds.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/openhexa/toolbox/era5/cds.py b/openhexa/toolbox/era5/cds.py index 4ce46b0..b192eb1 100644 --- a/openhexa/toolbox/era5/cds.py +++ b/openhexa/toolbox/era5/cds.py @@ -11,9 +11,8 @@ from math import ceil from pathlib import Path -import cads_api_client -import cdsapi import geopandas as gpd +from cads_api_client import ApiClient, Remote from dateutil.relativedelta import relativedelta with importlib.resources.open_text("openhexa.toolbox.era5", "variables.json") as f: @@ -57,25 +56,39 @@ def bounds_from_file(fp: Path, buffer: float = 0.5) -> list[float]: class Client: def __init__(self, key: str): - self.client = cdsapi.Client(url=URL, key=key, wait_until_complete=True, quiet=True, progress=False) - self.cads_api_client = cads_api_client.ApiClient(key=key, url=URL) + self.client = ApiClient(key=key, url=URL) + self.check_authentication() @cached_property def latest(self) -> datetime: """Get date of latest available product.""" - collection = self.cads_api_client.get_collection(DATASET) + collection = self.client.get_collection(DATASET) dt = collection.end_datetime # make datetime unaware of timezone for comparability with other datetimes dt = datetime(dt.year, dt.month, dt.day) return dt + def get_jobs(self) -> list[dict]: + """Get list of current jobs for the account in the CDS.""" + r = self.client.get_jobs() + return "jobs" in r.json.get("jobs") + + def get_remote(self, request_id: str) -> Remote: + """Get remote object from request uid.""" + return self.client.get_remote(request_id) + + def submit(self, request: dict) -> str: + """Submit an async data request to the CDS API.""" + r = self.client.submit(DATASET, **request) + return r.request_uid + @staticmethod def build_request( variable: str, year: int, month: int, days: list[int] = None, - time: list[str] = None, + time: list[int] = None, data_format: str = "grib", area: list[float] = None, ) -> dict: From 70b0696d0d44167342c33548d8197013897cb153 Mon Sep 17 00:00:00 2001 From: Yann Forget Date: Wed, 6 Nov 2024 23:36:43 +0100 Subject: [PATCH 02/18] feat(era5): use int for days when building payload --- openhexa/toolbox/era5/cds.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/openhexa/toolbox/era5/cds.py b/openhexa/toolbox/era5/cds.py index b192eb1..8192a04 100644 --- a/openhexa/toolbox/era5/cds.py +++ b/openhexa/toolbox/era5/cds.py @@ -104,8 +104,8 @@ def build_request( Month of interest. days : list[int] Days of interest. Defauls to None (all days). - time : list[str] - Hours of interest (ex: ["01:00", "06:00", "18:00"]). Defaults to None (all hours). + time : list[int] + Hours of interest (ex: [1, 6, 18]). Defaults to None (all hours). data_format : str Output data format ("grib" or "netcdf"). Defaults to "grib". area : list[float] @@ -139,7 +139,8 @@ def build_request( days = [day for day in range(1, dmax + 1)] if not time: - time = [f"{hour:02}:00" for hour in range(0, 24)] + time = [hour for hour in range(0, 24)] + time = [f"{hour:02}:00" for hour in time] year = str(year) month = f"{month:02}" From 8f89f4da0cd38678818ce6364016a2263433a7cd Mon Sep 17 00:00:00 2001 From: Yann Forget Date: Thu, 7 Nov 2024 00:04:17 +0100 Subject: [PATCH 03/18] feat(era5): add async data requests methods --- openhexa/toolbox/era5/cds.py | 44 +++++++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/openhexa/toolbox/era5/cds.py b/openhexa/toolbox/era5/cds.py index 8192a04..fae4f0f 100644 --- a/openhexa/toolbox/era5/cds.py +++ b/openhexa/toolbox/era5/cds.py @@ -10,9 +10,10 @@ from functools import cached_property from math import ceil from pathlib import Path +from typing import Optional, Union import geopandas as gpd -from cads_api_client import ApiClient, Remote +from cads_api_client import ApiClient, Remote, Results from dateutil.relativedelta import relativedelta with importlib.resources.open_text("openhexa.toolbox.era5", "variables.json") as f: @@ -77,11 +78,52 @@ def get_remote(self, request_id: str) -> Remote: """Get remote object from request uid.""" return self.client.get_remote(request_id) + def get_remote_from_request(self, request: dict, max_age: int = 1) -> Optional[Remote]: + """Look for a remote object that matches the provided request payload. + + Parameters + ---------- + request : dict + Request payload. + max_age : int, optional + Maximum age of the remote object in days (default=1). + + Returns + ------- + Optional[Remote] + Remote object if found, None otherwise. + """ + jobs = self.get_jobs() + if not jobs: + return None + for job in jobs: + remote = self.get_remote(job["jobID"]) + if remote.request == request: + age = datetime.now() - remote.creation_datetime + if age.days <= max_age: + return remote + return None + def submit(self, request: dict) -> str: """Submit an async data request to the CDS API.""" r = self.client.submit(DATASET, **request) + log.debug("Submitted data request %s", r.request_uid) return r.request_uid + def submit_and_wait(self, request: dict) -> Results: + """Submit a data request and wait for completion.""" + result = self.client.submit_and_wait_on_results(DATASET, **request) + return result + + def download(self, request: dict, dst_file: Union[str, Path]): + """Submit a data request and wait for completion before download.""" + if isinstance(dst_file, str): + dst_file = Path(dst_file) + dst_file.parent.mkdir(parents=True, exist_ok=True) + result = self.submit_and_wait(request) + result.download(dst_file.as_posix()) + log.debug("Downloaded %s", dst_file.name) + @staticmethod def build_request( variable: str, From 2e76c9c53c0cdd4ba9f57308fefafe30702e3260 Mon Sep 17 00:00:00 2001 From: Yann Forget Date: Thu, 7 Nov 2024 01:54:47 +0100 Subject: [PATCH 04/18] feat(era5): add async download_between method --- openhexa/toolbox/era5/cds.py | 321 +++++++++++++++++++++-------------- 1 file changed, 198 insertions(+), 123 deletions(-) diff --git a/openhexa/toolbox/era5/cds.py b/openhexa/toolbox/era5/cds.py index fae4f0f..8ed2197 100644 --- a/openhexa/toolbox/era5/cds.py +++ b/openhexa/toolbox/era5/cds.py @@ -3,18 +3,18 @@ import importlib.resources import json import logging -import shutil -import tempfile from calendar import monthrange -from datetime import datetime, timedelta +from datetime import date, datetime, timedelta from functools import cached_property from math import ceil from pathlib import Path +from time import sleep from typing import Optional, Union import geopandas as gpd +import numpy as np +import xarray as xr from cads_api_client import ApiClient, Remote, Results -from dateutil.relativedelta import relativedelta with importlib.resources.open_text("openhexa.toolbox.era5", "variables.json") as f: VARIABLES = json.load(f) @@ -55,6 +55,138 @@ def bounds_from_file(fp: Path, buffer: float = 0.5) -> list[float]: return ymax, xmin, ymin, xmax +def get_period_chunk(dtimes: list[datetime]) -> dict: + """Get the period chunk for a list of datetimes. + + The period chunk is a dictionary with the "year", "month", "day" and "time" keys as expected by + the CDS API. A period chunk cannot contain more than 1 year and 1 month. However, it can + contain any number of days and times. + + Parameters + ---------- + dtimes : list[datetime] + A list of datetimes for which we want data + + Returns + ------- + dict + The period chunk, in other words the temporal part of the request payload + + Raises + ------ + ValueError + If the list of datetimes contains more than 1 year or more than 1 month + """ + years = list(set([dtime.year for dtime in dtimes])) + if len(set(years)) > 1: + raise ValueError("Cannot create a period chunk for multiple years") + months = list(set([dtime.month for dtime in dtimes])) + if len(months) > 1: + raise ValueError("Cannot create a period chunk for multiple months") + + year = years[0] + month = months[0] + days = [] + + for dtime in sorted(dtimes): + if dtime.day not in days: + days.append(dtime.day) + + return { + "year": year, + "month": month, + "day": days, + "time": [h for h in range(0, 24)], + } + + +def get_period_chunks(dtimes: list[datetime]) -> list[dict]: + """Get the period chunks for a list of datetimes. + + The period chunks are a list of dictionaries with the "year", "month", "day" and "time" keys as + expected by the CDS API. A period chunk cannot contain more than 1 year and 1 month. However, + it can contain any number of days and times. + + The function tries its best to generate the minimum amount of chunks to minize the amount of requests. + + Parameters + ---------- + dtimes : list[datetime] + A list of datetimes for which we want data + + Returns + ------- + list[dict] + The period chunks (one per month max) + """ + chunks = [] + for year in range(min(dtimes).year, max(dtimes).year + 1): + for month in range(1, 13): + dtimes_month = [dtime for dtime in dtimes if dtime.year == year and dtime.month == month] + if dtimes_month: + chunk = get_period_chunk(dtimes_month) + chunks.append(chunk) + return chunks + + +def _np_to_datetime(dt64: np.datetime64) -> datetime: + epoch = np.datetime64(0, "s") + one_second = np.timedelta64(1, "s") + seconds_since_epoch = (dt64 - epoch) / one_second + return datetime.fromtimestamp(seconds_since_epoch) + + +def available_datetimes(data_dir: Path) -> list[date]: + """Get available datetimes from a directory of ERA5 data files. + + Dates are considered as available if data for all 24 hours of the day are found in the file. + + Parameters + ---------- + data_dir : Path + Directory containing the ERA5 data files. + + Returns + ------- + list[date] + List of available dates. + """ + dtimes = [] + + for f in data_dir.glob("*.grib"): + ds = xr.open_dataset(f, engine="cfgrib") + var = [v for v in ds.data_vars][0] + + for time in ds.time: + dtime = _np_to_datetime(time.values).date() + if dtime in dtimes: + continue + + is_complete = True + for hour in range(1, 25): + step = timedelta(hours=hour) + if not ds.sel(time=time, step=step).get(var).notnull().any().values.item(): + is_complete = False + break + + if is_complete: + dtimes.append(dtime) + + log.debug(f"Scanned {data_dir.as_posix()}, found {len(dtimes)} available dates") + + return dtimes + + +def date_range(start: date, end: date) -> list[date]: + """Get a range of dates with a 1-day step.""" + drange = [] + dt = start + while dt <= end: + drange.append(dt) + dt += timedelta(days=1) + return drange + + class Client: def __init__(self, key: str): self.client = ApiClient(key=key, url=URL) @@ -124,6 +256,68 @@ def download(self, request: dict, dst_file: Union[str, Path]): result.download(dst_file.as_posix()) log.debug("Downloaded %s", dst_file.name) + def download_between(self, start: date, end: date, variable: str, area: list[float], dst_dir: Union[str, Path]): + """Download all ERA5 data files needed to cover the period. + + Data requests are sent asynchronously (max one per month) to the CDS API and fetched when + they are completed. + + Parameters + ---------- + start : date + Start date. + end : date + End date. + variable : str + Climate data store variable name (ex: "2m_temperature"). + area : list[float] + Area of interest (north, west, south, east). + dst_dir : Path + Output directory. + """ + if isinstance(dst_dir, str): + dst_dir = Path(dst_dir) + dst_dir.mkdir(parents=True, exist_ok=True) + + if end > self.latest: + end = self.latest.date() + log.debug(f"End date is after latest available product, setting end date to {end.strftime('%Y-%m-%d')}") + + drange = date_range(start, end) + available = available_datetimes(dst_dir) + dates = [d for d in drange if d not in available] + + chunks = get_period_chunks(dates) + requests = [] + remotes = [] + + for chunk in chunks: + request = self.build_request(variable=variable, data_format="grib", area=area, **chunk) + + # has a similar request been submitted recently? if yes, use it + remote = self.get_remote_from_request(request) + if remote: + remotes.append(remote) + log.debug(f"Found existing request for date {request["year"]}-{request["month"]}") + continue + + requests.append(self.submit(request)) + sleep(3) + + remotes = [self.get_remote(request) for request in requests] + done = [] + + while not all([remote.request_uid in done for remote in remotes]): + for remote in remotes: + if remote.results_ready: + fname = f"{date.year}{date.month:02}_{remote.request_uid}.grib" + dst_file = Path(dst_dir, fname) + remote.download(dst_file.as_posix()) + log.debug(f"Downloaded {dst_file.name}") + done.append(remote.request_uid) + remote.delete() + sleep(60) + @staticmethod def build_request( variable: str, @@ -201,122 +395,3 @@ def build_request( payload["area"] = area return payload - - @staticmethod - def _filename(variable: str, year: int, month: int, day: int = None, data_format: str = "grib") -> str: - """Get filename from variable name and date.""" - EXTENSION = {"grib": "grib", "netcdf": "nc"} - if day is not None: - return f"{variable}_{year}-{month:02}-{day:02}.{EXTENSION[data_format]}" - else: - return f"{variable}_{year}-{month:02}.{EXTENSION[data_format]}" - - def download(self, request: dict, dst_file: str | Path, overwrite: bool = False): - """Download Era5 product. - - Parameters - ---------- - request : dict - Request payload as returned by the build_request() method. - dst_file : Path - Output file path. - overwrite : bool, optional - Overwrite existing file (default=False). - """ - dst_file = Path(dst_file) - dst_file.parent.mkdir(parents=True, exist_ok=True) - - if dst_file.exists() and not overwrite: - log.debug("File %s already exists, skipping download", str(dst_file.absolute())) - return - - # if we request daily data while a monthly file is already present, also skip download - if len(request["day"]) == 1: - dst_file_monthly = Path( - dst_file.parent, self._filename(request["variable"], request["year"], request["month"]) - ) - if dst_file_monthly.exists() and not overwrite: - log.debug("Monthly file `{}` already exists, skipping download".format(dst_file_monthly.name)) - - with tempfile.NamedTemporaryFile() as tmp: - self.client.retrieve(name=DATASET, request=request, target=tmp.name) - shutil.copy(tmp.name, dst_file) - - log.debug("Downloaded Era5 product to %s", str(dst_file.absolute())) - - @staticmethod - def _period_chunks(start: datetime, end: datetime) -> list[dict]: - """Generate list of period chunks to prepare CDS API requests. - - If we can, prepare requests for full months to optimize wait times. If we can't, prepare - daily requests. - - Parameters - ---------- - start : datetime - Start date. - end : datetime - End date. - - Returns - ------- - list[dict] - List of period chunks as dicts with `year`, `month` and `days` keys. - """ - chunks = [] - date = start - while date <= end: - last_day_in_month = datetime(date.year, date.month, monthrange(date.year, date.month)[1]) - if last_day_in_month <= end: - chunks.append( - {"year": date.year, "month": date.month, "days": [day for day in range(1, last_day_in_month.day)]} - ) - date += relativedelta(months=1) - else: - chunks.append({"year": date.year, "month": date.month, "days": [date.day]}) - date += timedelta(days=1) - return chunks - - def download_between( - self, - variable: str, - start: datetime, - end: datetime, - dst_dir: str | Path, - area: list[float] = None, - overwrite: bool = False, - ): - """Download all ERA5 products between two dates. - - Parameters - ---------- - variable : str - Climate data store variable name (ex: "2m_temperature"). - start : datetime - Start date. - end : datetime - End date. - dst_dir : Path - Output directory. - area : list[float], optional - Area of interest (north, west, south, east). Defaults to None (world). - overwrite : bool, optional - Overwrite existing files (default=False). - """ - if end > self.latest: - end = self.latest - log.debug("End date is after latest available product, setting end date to %s", end) - - chunks = self._period_chunks(start, end) - - for chunk in chunks: - request = self.build_request( - variable=variable, year=chunk["year"], month=chunk["month"], days=chunk["days"], area=area - ) - - if len(chunk["days"]) == 1: - dst_file = Path(dst_dir, self._filename(variable, chunk["year"], chunk["month"], chunk["days"][0])) - else: - dst_file = Path(dst_dir, self._filename(variable, chunk["year"], chunk["month"])) - - self.download(request=request, dst_file=dst_file, overwrite=overwrite) From 1787618c45b8880ffc08fdf138f8232945327953 Mon Sep 17 00:00:00 2001 From: Yann Forget Date: Thu, 7 Nov 2024 01:58:45 +0100 Subject: [PATCH 05/18] fix(era5): check authentication --- openhexa/toolbox/era5/cds.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/openhexa/toolbox/era5/cds.py b/openhexa/toolbox/era5/cds.py index 8ed2197..c7d02ed 100644 --- a/openhexa/toolbox/era5/cds.py +++ b/openhexa/toolbox/era5/cds.py @@ -190,7 +190,7 @@ def date_range(start: date, end: date) -> list[date]: class Client: def __init__(self, key: str): self.client = ApiClient(key=key, url=URL) - self.check_authentication() + self.client.check_authentication() @cached_property def latest(self) -> datetime: @@ -298,7 +298,7 @@ def download_between(self, start: date, end: date, variable: str, area: list[flo remote = self.get_remote_from_request(request) if remote: remotes.append(remote) - log.debug(f"Found existing request for date {request["year"]}-{request["month"]}") + log.debug(f"Found existing request for date {request['year']}-{request['month']}") continue requests.append(self.submit(request)) From 5c9aa100cc6eaa0403b537a4eb9e490556bbce14 Mon Sep 17 00:00:00 2001 From: Yann Forget Date: Thu, 7 Nov 2024 02:01:48 +0100 Subject: [PATCH 06/18] fix(era5): cannot compare date and datetimes --- openhexa/toolbox/era5/cds.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openhexa/toolbox/era5/cds.py b/openhexa/toolbox/era5/cds.py index c7d02ed..5d9784e 100644 --- a/openhexa/toolbox/era5/cds.py +++ b/openhexa/toolbox/era5/cds.py @@ -279,7 +279,7 @@ def download_between(self, start: date, end: date, variable: str, area: list[flo dst_dir = Path(dst_dir) dst_dir.mkdir(parents=True, exist_ok=True) - if end > self.latest: + if end > self.latest.date(): end = self.latest.date() log.debug(f"End date is after latest available product, setting end date to {end.strftime('%Y-%m-%d')}") From 213b7c4adc26c82efbe6ebc48e2cc246d3b5f6c9 Mon Sep 17 00:00:00 2001 From: Yann Forget Date: Thu, 7 Nov 2024 02:02:49 +0100 Subject: [PATCH 07/18] fix(era5): wrong field name --- openhexa/toolbox/era5/cds.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openhexa/toolbox/era5/cds.py b/openhexa/toolbox/era5/cds.py index 5d9784e..588bd05 100644 --- a/openhexa/toolbox/era5/cds.py +++ b/openhexa/toolbox/era5/cds.py @@ -95,7 +95,7 @@ def get_period_chunk(dtimes: list[datetime]) -> dict: return { "year": year, "month": month, - "day": days, + "days": days, "time": [h for h in range(0, 24)], } From db5dd787138d873002f5cbe35d3dca8d18224aff Mon Sep 17 00:00:00 2001 From: Yann Forget Date: Thu, 7 Nov 2024 02:08:59 +0100 Subject: [PATCH 08/18] fix(era5): string format --- openhexa/toolbox/era5/cds.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openhexa/toolbox/era5/cds.py b/openhexa/toolbox/era5/cds.py index 588bd05..44c7b8c 100644 --- a/openhexa/toolbox/era5/cds.py +++ b/openhexa/toolbox/era5/cds.py @@ -310,7 +310,7 @@ def download_between(self, start: date, end: date, variable: str, area: list[flo while not all([remote.request_uid in done for remote in remotes]): for remote in remotes: if remote.results_ready: - fname = f"{date.year}{date.month:02}_{remote.request_uid}.grib" + fname = f"{request["year"]}{request["month"]}_{remote.request_uid}.grib" dst_file = Path(dst_dir, fname) remote.download(dst_file.as_posix()) log.debug(f"Downloaded {dst_file.name}") From 365d2c6f78645e4e181ecf3e2df8c4b6f9d35c6b Mon Sep 17 00:00:00 2001 From: Yann Forget Date: Thu, 7 Nov 2024 02:51:40 +0100 Subject: [PATCH 09/18] fix(era5): check successful jobs 1st --- openhexa/toolbox/era5/cds.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/openhexa/toolbox/era5/cds.py b/openhexa/toolbox/era5/cds.py index 44c7b8c..a82bb79 100644 --- a/openhexa/toolbox/era5/cds.py +++ b/openhexa/toolbox/era5/cds.py @@ -201,10 +201,10 @@ def latest(self) -> datetime: dt = datetime(dt.year, dt.month, dt.day) return dt - def get_jobs(self) -> list[dict]: + def get_jobs(self, **kwargs) -> Optional[list[dict]]: """Get list of current jobs for the account in the CDS.""" - r = self.client.get_jobs() - return "jobs" in r.json.get("jobs") + r = self.client.get_jobs(limit=100, **kwargs) + return r.json.get("jobs") def get_remote(self, request_id: str) -> Remote: """Get remote object from request uid.""" @@ -228,6 +228,9 @@ def get_remote_from_request(self, request: dict, max_age: int = 1) -> Optional[R jobs = self.get_jobs() if not jobs: return None + + jobs = sorted(jobs, key=lambda job: job["status"], reverse=True) + for job in jobs: remote = self.get_remote(job["jobID"]) if remote.request == request: @@ -299,18 +302,18 @@ def download_between(self, start: date, end: date, variable: str, area: list[flo if remote: remotes.append(remote) log.debug(f"Found existing request for date {request['year']}-{request['month']}") - continue - - requests.append(self.submit(request)) + else: + requests.append(self.submit(request)) sleep(3) - remotes = [self.get_remote(request) for request in requests] + for request in requests: + remotes.append(self.get_remote(request)) done = [] while not all([remote.request_uid in done for remote in remotes]): for remote in remotes: if remote.results_ready: - fname = f"{request["year"]}{request["month"]}_{remote.request_uid}.grib" + fname = f"{request['year']}{request['month']}_{remote.request_uid}.grib" dst_file = Path(dst_dir, fname) remote.download(dst_file.as_posix()) log.debug(f"Downloaded {dst_file.name}") From 5f955eafcddc3ebbb80d397204513728b9b488ea Mon Sep 17 00:00:00 2001 From: Yann Forget Date: Thu, 7 Nov 2024 02:57:17 +0100 Subject: [PATCH 10/18] fix(era5): use remote request for filenames --- openhexa/toolbox/era5/cds.py | 1 + 1 file changed, 1 insertion(+) diff --git a/openhexa/toolbox/era5/cds.py b/openhexa/toolbox/era5/cds.py index a82bb79..81aa039 100644 --- a/openhexa/toolbox/era5/cds.py +++ b/openhexa/toolbox/era5/cds.py @@ -313,6 +313,7 @@ def download_between(self, start: date, end: date, variable: str, area: list[flo while not all([remote.request_uid in done for remote in remotes]): for remote in remotes: if remote.results_ready: + request = remote.request fname = f"{request['year']}{request['month']}_{remote.request_uid}.grib" dst_file = Path(dst_dir, fname) remote.download(dst_file.as_posix()) From 1d27c460feadafb63d059ca7c202712f6c54766f Mon Sep 17 00:00:00 2001 From: Yann Forget Date: Fri, 13 Dec 2024 12:41:18 +0100 Subject: [PATCH 11/18] feat(era5): upgrade to datapi cds client --- openhexa/toolbox/era5/cds.py | 353 ++++++++++++++++++----------------- pyproject.toml | 2 + 2 files changed, 185 insertions(+), 170 deletions(-) diff --git a/openhexa/toolbox/era5/cds.py b/openhexa/toolbox/era5/cds.py index 81aa039..92d96e0 100644 --- a/openhexa/toolbox/era5/cds.py +++ b/openhexa/toolbox/era5/cds.py @@ -4,31 +4,40 @@ import json import logging from calendar import monthrange -from datetime import date, datetime, timedelta +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone from functools import cached_property from math import ceil from pathlib import Path from time import sleep -from typing import Optional, Union +from typing import Iterator import geopandas as gpd -import numpy as np import xarray as xr -from cads_api_client import ApiClient, Remote, Results +from datapi import ApiClient, Remote +from requests.exceptions import HTTPError with importlib.resources.open_text("openhexa.toolbox.era5", "variables.json") as f: VARIABLES = json.load(f) DATASET = "reanalysis-era5-land" -logging.basicConfig(level=logging.DEBUG, format="%(name)s %(asctime)s %(levelname)s %(message)s") log = logging.getLogger(__name__) URL = "https://cds-beta.climate.copernicus.eu/api" -class ParameterError(ValueError): - pass +@dataclass +class DataRequest: + """CDS data request as expected by the API.""" + + variable: list[str] + year: str + month: str + day: list[str] + time: list[str] + data_format: str = "grib" + area: list[float] | None = None def bounds_from_file(fp: Path, buffer: float = 0.5) -> list[float]: @@ -48,10 +57,10 @@ def bounds_from_file(fp: Path, buffer: float = 0.5) -> list[float]: """ boundaries = gpd.read_parquet(fp) xmin, ymin, xmax, ymax = boundaries.total_bounds - xmin = ceil(xmin - 0.5) - ymin = ceil(ymin - 0.5) - xmax = ceil(xmax + 0.5) - ymax = ceil(ymax + 0.5) + xmin = ceil(xmin - buffer) + ymin = ceil(ymin - buffer) + xmax = ceil(xmax + buffer) + ymax = ceil(ymax + buffer) return ymax, xmin, ymin, xmax @@ -62,6 +71,8 @@ def get_period_chunk(dtimes: list[datetime]) -> dict: the CDS API. A period chunk cannot contain more than 1 year and 1 month. However, it can contain any number of days and times. + This is the temporal part of a CDS data request. + Parameters ---------- dtimes : list[datetime] @@ -77,15 +88,17 @@ def get_period_chunk(dtimes: list[datetime]) -> dict: ValueError If the list of datetimes contains more than 1 year or more than 1 month """ - years = list(set([dtime.year for dtime in dtimes])) - if len(set(years)) > 1: - raise ValueError("Cannot create a period chunk for multiple years") - months = list(set([dtime.month for dtime in dtimes])) + years = {dtime.year for dtime in dtimes} + if len(years) > 1: + msg = "Cannot create a period chunk for multiple years" + raise ValueError(msg) + months = {dtime.month for dtime in dtimes} if len(months) > 1: - raise ValueError("Cannot create a period chunk for multiple months") + msg = "Cannot create a period chunk for multiple months" + raise ValueError(msg) - year = years[0] - month = months[0] + year = next(iter(years)) + month = next(iter(months)) days = [] for dtime in sorted(dtimes): @@ -93,21 +106,21 @@ def get_period_chunk(dtimes: list[datetime]) -> dict: days.append(dtime.day) return { - "year": year, - "month": month, - "days": days, - "time": [h for h in range(0, 24)], + "year": str(year), + "month": f"{month:02}", + "day": [f"{day:02}" for day in days], + "time": [f"{hour:02}:00" for hour in range(24)], } -def get_period_chunks(dtimes: list[datetime]) -> list[dict]: +def iter_chunks(dtimes: list[datetime]) -> Iterator[dict]: """Get the period chunks for a list of datetimes. The period chunks are a list of dictionaries with the "year", "month", "day" and "time" keys as expected by the CDS API. A period chunk cannot contain more than 1 year and 1 month. However, it can contain any number of days and times. - The function tries its best to generate the minimum amount of chunks to minize the amount of requests. + The function tries its best to generate the minimum amount of chunks to minimize the amount of requests. Parameters ---------- @@ -116,30 +129,21 @@ def get_period_chunks(dtimes: list[datetime]) -> list[dict]: Returns ------- - list[dict] + Iterator[dict] The period chunks (one per month max) """ - chunks = [] for year in range(min(dtimes).year, max(dtimes).year + 1): - for month in range(1, 13): - dtimes_month = [dtime for dtime in dtimes if dtime.year == year and dtime.month == month] + for month in range(12): + dtimes_month = [dtime for dtime in dtimes if dtime.year == year and dtime.month == month + 1] if dtimes_month: - chunk = get_period_chunk(dtimes_month) - chunks.append(chunk) - return chunks - + yield get_period_chunk(dtimes_month) -def _np_to_datetime(dt64: np.datetime64) -> datetime: - epoch = np.datetime64(0, "s") - one_second = np.timedelta64(1, "s") - seconds_since_epoch = (dt64 - epoch) / one_second - return datetime.fromtimestamp(seconds_since_epoch) - -def available_datetimes(data_dir: Path) -> list[date]: +def available_datetimes(data_dir: Path) -> list[datetime]: """Get available datetimes from a directory of ERA5 data files. - Dates are considered as available if data for all 24 hours of the day are found in the file. + Dates are considered as available if non-null values are present for the day for more than 1 step. + Assumes data files are stored as .grib files. Parameters ---------- @@ -148,36 +152,32 @@ def available_datetimes(data_dir: Path) -> list[date]: Returns ------- - list[date] + list[datetime] List of available dates. """ dtimes = [] for f in data_dir.glob("*.grib"): ds = xr.open_dataset(f, engine="cfgrib") - var = [v for v in ds.data_vars][0] + data_vars = list(ds.data_vars) + var = data_vars[0] - for time in ds.time: - dtime = _np_to_datetime(time.values).date() + for time in ds.time.values: + dtime = datetime.fromtimestamp(time.astype(int) / 1e9, tz=timezone.utc) if dtime in dtimes: continue - - is_complete = True - for hour in range(1, 25): - step = timedelta(hours=hour) - if not ds.sel(time=time, step=step).get(var).notnull().any().values.item(): - is_complete = False - break - - if is_complete: + non_null = ds.sel(time=time)[var].notnull().sum().values.item() + non_null /= len(ds.latitude) * len(ds.longitude) + if non_null > 1: dtimes.append(dtime) - log.debug(f"Scanned {data_dir.as_posix()}, found {len(dtimes)} available dates") + msg = f"Scanned {data_dir.as_posix()}, found data for {len(dtimes)} dates" + log.info(msg) return dtimes -def date_range(start: date, end: date) -> list[date]: +def date_range(start: datetime, end: datetime) -> list[datetime]: """Get a range of dates with a 1-day step.""" drange = [] dt = start @@ -187,79 +187,71 @@ def date_range(start: date, end: date) -> list[date]: return drange -class Client: - def __init__(self, key: str): +class CDS: + """Climate data store API client based on datapi.""" + + def __init__(self, key: str) -> None: + """Initialize CDS client.""" self.client = ApiClient(key=key, url=URL) self.client.check_authentication() + msg = f"Sucessfully authenticated to {URL}" + log.info(msg) @cached_property def latest(self) -> datetime: """Get date of latest available product.""" collection = self.client.get_collection(DATASET) - dt = collection.end_datetime - # make datetime unaware of timezone for comparability with other datetimes - dt = datetime(dt.year, dt.month, dt.day) - return dt - - def get_jobs(self, **kwargs) -> Optional[list[dict]]: - """Get list of current jobs for the account in the CDS.""" - r = self.client.get_jobs(limit=100, **kwargs) - return r.json.get("jobs") + return collection.end_datetime - def get_remote(self, request_id: str) -> Remote: - """Get remote object from request uid.""" - return self.client.get_remote(request_id) + def get_remote_requests(self) -> list[dict]: + """Fetch list of the last 100 data requests in the CDS account.""" + requests = [] + jobs = self.client.get_jobs(limit=100) + for request_id in jobs.request_uids: + try: + remote = self.client.get_remote(request_id) + if remote.status in ["failed", "dismissed", "deleted"]: + continue + requests.append({"request_id": request_id, "request": remote.request}) + except HTTPError: + continue + return requests - def get_remote_from_request(self, request: dict, max_age: int = 1) -> Optional[Remote]: + def get_remote_from_request(self, request: DataRequest, existing_requests: list[dict]) -> Remote | None: """Look for a remote object that matches the provided request payload. Parameters ---------- - request : dict - Request payload. - max_age : int, optional - Maximum age of the remote object in days (default=1). + request : DataRequest + Data request payload to look for. + existing_requests : list[dict] + List of existing data requests (as returned by self.get_remote_requests()). Returns ------- - Optional[Remote] + Remote | None Remote object if found, None otherwise. """ - jobs = self.get_jobs() - if not jobs: + if not existing_requests: return None - jobs = sorted(jobs, key=lambda job: job["status"], reverse=True) + for remote_request in existing_requests: + if remote_request["request"] == request.__dict__: + return self.client.get_remote(remote_request["request_id"]) - for job in jobs: - remote = self.get_remote(job["jobID"]) - if remote.request == request: - age = datetime.now() - remote.creation_datetime - if age.days <= max_age: - return remote return None - def submit(self, request: dict) -> str: - """Submit an async data request to the CDS API.""" - r = self.client.submit(DATASET, **request) - log.debug("Submitted data request %s", r.request_uid) - return r.request_uid - - def submit_and_wait(self, request: dict) -> Results: - """Submit a data request and wait for completion.""" - result = self.client.submit_and_wait_on_results(DATASET, **request) - return result - - def download(self, request: dict, dst_file: Union[str, Path]): - """Submit a data request and wait for completion before download.""" - if isinstance(dst_file, str): - dst_file = Path(dst_file) - dst_file.parent.mkdir(parents=True, exist_ok=True) - result = self.submit_and_wait(request) - result.download(dst_file.as_posix()) - log.debug("Downloaded %s", dst_file.name) - - def download_between(self, start: date, end: date, variable: str, area: list[float], dst_dir: Union[str, Path]): + def submit(self, request: DataRequest) -> Remote: + """Submit an async data request to the CDS API. + + If an identical data request has already been submitted, the Remote object corresponding to + the existing data request is returned instead of submitting a new one. + """ + return self.client.submit(DATASET, **request.__dict__) + + def download_between( + self, start: datetime, end: datetime, variable: str, area: list[float], dst_dir: str | Path + ) -> None: """Download all ERA5 data files needed to cover the period. Data requests are sent asynchronously (max one per month) to the CDS API and fetched when @@ -267,71 +259,83 @@ def download_between(self, start: date, end: date, variable: str, area: list[flo Parameters ---------- - start : date + start : datetime Start date. - end : date + end : datetime End date. variable : str Climate data store variable name (ex: "2m_temperature"). area : list[float] Area of interest (north, west, south, east). - dst_dir : Path + dst_dir : str | Path Output directory. """ - if isinstance(dst_dir, str): - dst_dir = Path(dst_dir) + dst_dir = Path(dst_dir) dst_dir.mkdir(parents=True, exist_ok=True) - if end > self.latest.date(): - end = self.latest.date() - log.debug(f"End date is after latest available product, setting end date to {end.strftime('%Y-%m-%d')}") + if not start.tzinfo: + start = start.astimezone(tz=timezone.utc) + if not end.tzinfo: + end = end.astimezone(tz=timezone.utc) + + if end > self.latest: + end = self.latest + msg = "End date is after latest available product, setting end date to {}".format(end.strftime("%Y-%m-%d")) + log.info(msg) + # get the list of dates for which we will want to download data, which is the difference + # between the available (already downloaded) and the requested dates drange = date_range(start, end) - available = available_datetimes(dst_dir) - dates = [d for d in drange if d not in available] + available = [dtime.date() for dtime in available_datetimes(dst_dir)] + dates = [d for d in drange if d.date() not in available] + msg = f"Will request data for {len(dates)} dates" + log.info(msg) - chunks = get_period_chunks(dates) - requests = [] - remotes = [] + existing_requests = self.get_remote_requests() + remotes: list[Remote] = [] - for chunk in chunks: + for chunk in iter_chunks(dates): request = self.build_request(variable=variable, data_format="grib", area=area, **chunk) - # has a similar request been submitted recently? if yes, use it - remote = self.get_remote_from_request(request) + # has a similar request been submitted recently? if yes, use it instead of submitting + # a new one + remote = self.get_remote_from_request(request, existing_requests) if remote: remotes.append(remote) - log.debug(f"Found existing request for date {request['year']}-{request['month']}") + msg = f"Found existing request for date {request.year}-{request.month}" + log.info(msg) else: - requests.append(self.submit(request)) - sleep(3) - - for request in requests: - remotes.append(self.get_remote(request)) - done = [] + remote = self.submit(request) + remotes.append(remote) + msg = f"Submitted new data request {remote.request_uid} for {request.year}-{request.month}" - while not all([remote.request_uid in done for remote in remotes]): + while remotes: for remote in remotes: if remote.results_ready: request = remote.request fname = f"{request['year']}{request['month']}_{remote.request_uid}.grib" dst_file = Path(dst_dir, fname) remote.download(dst_file.as_posix()) - log.debug(f"Downloaded {dst_file.name}") - done.append(remote.request_uid) + msg = f"Downloaded {dst_file.name}" + log.info(msg) + remotes.remove(remote) remote.delete() - sleep(60) + + if remotes: + msg = f"Still {len(remotes)} files to download. Waiting 30s before retrying..." + log.info(msg) + sleep(30) @staticmethod def build_request( variable: str, year: int, month: int, - days: list[int] = None, - time: list[int] = None, + day: list[int] | list[str] | None = None, + time: list[int] | list[str] | None = None, data_format: str = "grib", - area: list[float] = None, - ) -> dict: + area: list[float] | None = None, + ) -> DataRequest: """Build request payload. Parameters @@ -342,60 +346,69 @@ def build_request( Year of interest. month : int Month of interest. - days : list[int] + day : list[int] | list[str] | None, optional Days of interest. Defauls to None (all days). - time : list[int] + time : list[int] | list[str] | None, optional Hours of interest (ex: [1, 6, 18]). Defaults to None (all hours). - data_format : str + data_format : str, optional Output data format ("grib" or "netcdf"). Defaults to "grib". - area : list[float] + area : list[float] | None, optional Area of interest (north, west, south, east). Defaults to None (world). Returns ------- - dict - Request payload. + DataRequest + CDS data equest payload. Raises ------ - ParameterError + ValueError Request parameters are not valid. """ if variable not in VARIABLES: - raise ParameterError("Variable %s not supported", variable) + msg = f"Variable {variable} not supported" + raise ValueError(msg) if data_format not in ["grib", "netcdf"]: - raise ParameterError("Data format %s not supported", data_format) + msg = f"Data format {data_format} not supported" + raise ValueError(msg) + # in the CDS data request, area is an array of float or int in the following order: + # [north, west, south, east] if area: n, w, s, e = area - if ((abs(n) > 90) or (abs(s) > 90)) or ((abs(w) > 180) or (abs(e) > 180)): - raise ParameterError("Invalid area of interest") + msg = "Invalid area of interest" + max_lat = 90 + max_lon = 180 + if ((abs(n) > max_lat) or (abs(s) > max_lat)) or ((abs(w) > max_lon) or (abs(e) > max_lon)): + raise ValueError(msg) if (n < s) or (e < w): - raise ParameterError("Invalid area of interest") + raise ValueError(msg) - if not days: + # in the CDS data request, days must be an array of strings (one string per day) + # ex: ["01", "02", "03"] + if not day: dmax = monthrange(year, month)[1] - days = [day for day in range(1, dmax + 1)] + day = list(range(1, dmax + 1)) - if not time: - time = [hour for hour in range(0, 24)] - time = [f"{hour:02}:00" for hour in time] - - year = str(year) - month = f"{month:02}" - days = [f"{day:02}" for day in days] - - payload = { - "variable": [variable], - "year": year, - "month": month, - "day": days, - "time": time, - "data_format": data_format, - } + if isinstance(day[0], int): + day = [f"{d:02}" for d in day] - if area: - payload["area"] = area - - return payload + # in the CDS data request, time must be an array of strings (one string per hour) + # only hours between 00:00 and 23:00 are valid + # ex: ["00:00", "03:00", "06:00"] + if not time: + time = range(24) + + if isinstance(time[0], int): + time = [f"{hour:02}:00" for hour in time] + + return DataRequest( + variable=[variable], + year=str(year), + month=f"{month:02}", + day=day, + time=time, + data_format="grib", + area=list(area) if area else None, + ) diff --git a/pyproject.toml b/pyproject.toml index 50dd906..38ff5f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,8 @@ dependencies = [ "cfgrib", "xarray", "epiweeks", + "datapi >=0.1.1", + "multiurl >=0.3.2" ] [project.optional-dependencies] From ddc1d565603ddd3eef905e8524df4194353cd383 Mon Sep 17 00:00:00 2001 From: Yann Forget Date: Fri, 13 Dec 2024 12:42:52 +0100 Subject: [PATCH 12/18] docs(era5): missing docstring --- openhexa/toolbox/era5/cds.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/openhexa/toolbox/era5/cds.py b/openhexa/toolbox/era5/cds.py index 92d96e0..387ce03 100644 --- a/openhexa/toolbox/era5/cds.py +++ b/openhexa/toolbox/era5/cds.py @@ -1,3 +1,8 @@ +"""Client to download ERA5-Land data products from the climate data store. + +See . +""" + from __future__ import annotations import importlib.resources From 68bf773690e66a33d7f2f338ae4fb50ba4f234c9 Mon Sep 17 00:00:00 2001 From: Yann Forget Date: Fri, 13 Dec 2024 15:01:48 +0100 Subject: [PATCH 13/18] refactor(era5): move build_request out of class --- openhexa/toolbox/era5/cds.py | 176 +++++++++++++++++------------------ 1 file changed, 88 insertions(+), 88 deletions(-) diff --git a/openhexa/toolbox/era5/cds.py b/openhexa/toolbox/era5/cds.py index 387ce03..98bf79b 100644 --- a/openhexa/toolbox/era5/cds.py +++ b/openhexa/toolbox/era5/cds.py @@ -192,6 +192,93 @@ def date_range(start: datetime, end: datetime) -> list[datetime]: return drange +def build_request( + variable: str, + year: int, + month: int, + day: list[int] | list[str] | None = None, + time: list[int] | list[str] | None = None, + data_format: str = "grib", + area: list[float] | None = None, +) -> DataRequest: + """Build request payload. + + Parameters + ---------- + variable : str + Climate data store variable name (ex: "2m_temperature"). + year : int + Year of interest. + month : int + Month of interest. + day : list[int] | list[str] | None, optional + Days of interest. Defauls to None (all days). + time : list[int] | list[str] | None, optional + Hours of interest (ex: [1, 6, 18]). Defaults to None (all hours). + data_format : str, optional + Output data format ("grib" or "netcdf"). Defaults to "grib". + area : list[float] | None, optional + Area of interest (north, west, south, east). Defaults to None (world). + + Returns + ------- + DataRequest + CDS data equest payload. + + Raises + ------ + ValueError + Request parameters are not valid. + """ + if variable not in VARIABLES: + msg = f"Variable {variable} not supported" + raise ValueError(msg) + + if data_format not in ["grib", "netcdf"]: + msg = f"Data format {data_format} not supported" + raise ValueError(msg) + + # in the CDS data request, area is an array of float or int in the following order: + # [north, west, south, east] + if area: + n, w, s, e = area + msg = "Invalid area of interest" + max_lat = 90 + max_lon = 180 + if ((abs(n) > max_lat) or (abs(s) > max_lat)) or ((abs(w) > max_lon) or (abs(e) > max_lon)): + raise ValueError(msg) + if (n < s) or (e < w): + raise ValueError(msg) + + # in the CDS data request, days must be an array of strings (one string per day) + # ex: ["01", "02", "03"] + if not day: + dmax = monthrange(year, month)[1] + day = list(range(1, dmax + 1)) + + if isinstance(day[0], int): + day = [f"{d:02}" for d in day] + + # in the CDS data request, time must be an array of strings (one string per hour) + # only hours between 00:00 and 23:00 are valid + # ex: ["00:00", "03:00", "06:00"] + if not time: + time = range(24) + + if isinstance(time[0], int): + time = [f"{hour:02}:00" for hour in time] + + return DataRequest( + variable=[variable], + year=str(year), + month=f"{month:02}", + day=day, + time=time, + data_format="grib", + area=list(area) if area else None, + ) + + class CDS: """Climate data store API client based on datapi.""" @@ -300,7 +387,7 @@ def download_between( remotes: list[Remote] = [] for chunk in iter_chunks(dates): - request = self.build_request(variable=variable, data_format="grib", area=area, **chunk) + request = build_request(variable=variable, data_format="grib", area=area, **chunk) # has a similar request been submitted recently? if yes, use it instead of submitting # a new one @@ -330,90 +417,3 @@ def download_between( msg = f"Still {len(remotes)} files to download. Waiting 30s before retrying..." log.info(msg) sleep(30) - - @staticmethod - def build_request( - variable: str, - year: int, - month: int, - day: list[int] | list[str] | None = None, - time: list[int] | list[str] | None = None, - data_format: str = "grib", - area: list[float] | None = None, - ) -> DataRequest: - """Build request payload. - - Parameters - ---------- - variable : str - Climate data store variable name (ex: "2m_temperature"). - year : int - Year of interest. - month : int - Month of interest. - day : list[int] | list[str] | None, optional - Days of interest. Defauls to None (all days). - time : list[int] | list[str] | None, optional - Hours of interest (ex: [1, 6, 18]). Defaults to None (all hours). - data_format : str, optional - Output data format ("grib" or "netcdf"). Defaults to "grib". - area : list[float] | None, optional - Area of interest (north, west, south, east). Defaults to None (world). - - Returns - ------- - DataRequest - CDS data equest payload. - - Raises - ------ - ValueError - Request parameters are not valid. - """ - if variable not in VARIABLES: - msg = f"Variable {variable} not supported" - raise ValueError(msg) - - if data_format not in ["grib", "netcdf"]: - msg = f"Data format {data_format} not supported" - raise ValueError(msg) - - # in the CDS data request, area is an array of float or int in the following order: - # [north, west, south, east] - if area: - n, w, s, e = area - msg = "Invalid area of interest" - max_lat = 90 - max_lon = 180 - if ((abs(n) > max_lat) or (abs(s) > max_lat)) or ((abs(w) > max_lon) or (abs(e) > max_lon)): - raise ValueError(msg) - if (n < s) or (e < w): - raise ValueError(msg) - - # in the CDS data request, days must be an array of strings (one string per day) - # ex: ["01", "02", "03"] - if not day: - dmax = monthrange(year, month)[1] - day = list(range(1, dmax + 1)) - - if isinstance(day[0], int): - day = [f"{d:02}" for d in day] - - # in the CDS data request, time must be an array of strings (one string per hour) - # only hours between 00:00 and 23:00 are valid - # ex: ["00:00", "03:00", "06:00"] - if not time: - time = range(24) - - if isinstance(time[0], int): - time = [f"{hour:02}:00" for hour in time] - - return DataRequest( - variable=[variable], - year=str(year), - month=f"{month:02}", - day=day, - time=time, - data_format="grib", - area=list(area) if area else None, - ) From 9136114af44550668a5d93d846f4a55899f8ad04 Mon Sep 17 00:00:00 2001 From: Yann Forget Date: Fri, 13 Dec 2024 17:09:17 +0100 Subject: [PATCH 14/18] refactor(era5): add api url as parameter --- openhexa/toolbox/era5/cds.py | 57 ++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/openhexa/toolbox/era5/cds.py b/openhexa/toolbox/era5/cds.py index 98bf79b..7fb63fa 100644 --- a/openhexa/toolbox/era5/cds.py +++ b/openhexa/toolbox/era5/cds.py @@ -29,8 +29,6 @@ log = logging.getLogger(__name__) -URL = "https://cds-beta.climate.copernicus.eu/api" - @dataclass class DataRequest: @@ -144,37 +142,38 @@ def iter_chunks(dtimes: list[datetime]) -> Iterator[dict]: yield get_period_chunk(dtimes_month) -def available_datetimes(data_dir: Path) -> list[datetime]: - """Get available datetimes from a directory of ERA5 data files. +def list_datetimes_in_dataset(ds: xr.Dataset) -> list[datetime]: + """List datetimes in input dataset for which data is available. - Dates are considered as available if non-null values are present for the day for more than 1 step. - Assumes data files are stored as .grib files. + It is assumed that the dataset has a `time` dimension, in addition to `latitude` and `longitude` + dimensions. We consider that a datetime is available in a dataset if non-null data values are + present for more than 1 step. + """ + dtimes = [] + data_vars = list(ds.data_vars) + var = data_vars[0] + + for time in ds.time.values: + dtime = datetime.fromtimestamp(time.astype(int) / 1e9, tz=timezone.utc) + if dtime in dtimes: + continue + non_null = ds.sel(time=time)[var].notnull().sum().values.item() + non_null /= len(ds.latitude) * len(ds.longitude) + if non_null >= len(ds.step): + dtimes.append(dtime) - Parameters - ---------- - data_dir : Path - Directory containing the ERA5 data files. + return dtimes - Returns - ------- - list[datetime] - List of available dates. - """ + +def list_datetimes_in_dir(data_dir: Path) -> list[datetime]: + """List datetimes in datasets that can be found in an input directory.""" dtimes = [] for f in data_dir.glob("*.grib"): ds = xr.open_dataset(f, engine="cfgrib") - data_vars = list(ds.data_vars) - var = data_vars[0] + dtimes += list_datetimes_in_dataset(ds) - for time in ds.time.values: - dtime = datetime.fromtimestamp(time.astype(int) / 1e9, tz=timezone.utc) - if dtime in dtimes: - continue - non_null = ds.sel(time=time)[var].notnull().sum().values.item() - non_null /= len(ds.latitude) * len(ds.longitude) - if non_null > 1: - dtimes.append(dtime) + dtimes = sorted(set(dtimes)) msg = f"Scanned {data_dir.as_posix()}, found data for {len(dtimes)} dates" log.info(msg) @@ -282,11 +281,11 @@ def build_request( class CDS: """Climate data store API client based on datapi.""" - def __init__(self, key: str) -> None: + def __init__(self, key: str, url: str = "https://cds-beta.climate.copernicus.eu/api") -> None: """Initialize CDS client.""" - self.client = ApiClient(key=key, url=URL) + self.client = ApiClient(key=key, url=url) self.client.check_authentication() - msg = f"Sucessfully authenticated to {URL}" + msg = f"Sucessfully authenticated to {url}" log.info(msg) @cached_property @@ -378,7 +377,7 @@ def download_between( # get the list of dates for which we will want to download data, which is the difference # between the available (already downloaded) and the requested dates drange = date_range(start, end) - available = [dtime.date() for dtime in available_datetimes(dst_dir)] + available = [dtime.date() for dtime in list_datetimes_in_dir(dst_dir)] dates = [d for d in drange if d.date() not in available] msg = f"Will request data for {len(dates)} dates" log.info(msg) From 76250ec282d166f62ca57499f51e6545e842fddc Mon Sep 17 00:00:00 2001 From: Yann Forget Date: Fri, 13 Dec 2024 18:30:32 +0100 Subject: [PATCH 15/18] test(era5): add tests for cds module --- tests/era5/test_cds.py | 165 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 165 insertions(+) create mode 100644 tests/era5/test_cds.py diff --git a/tests/era5/test_cds.py b/tests/era5/test_cds.py new file mode 100644 index 0000000..1d48356 --- /dev/null +++ b/tests/era5/test_cds.py @@ -0,0 +1,165 @@ +"""Unit tests for the ERA5 Climate Data Store client.""" + +from __future__ import annotations + +import tempfile +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from unittest.mock import MagicMock, Mock, patch + +import datapi +import pytest + +from openhexa.toolbox.era5.cds import ( + CDS, + DataRequest, + bounds_from_file, + date_range, + get_period_chunk, + iter_chunks, +) + + +class TestCollection(datapi.catalogue.Collection): + """Datapi Collection object with mocked end_datetime property.""" + + __test__ = False + + def __init__(self, end_datetime: datetime) -> None: + self._end_datetime = end_datetime + + @property + def end_datetime(self): + return self._end_datetime + + +@patch("datapi.ApiClient.check_authentication") +def test_cds_init(mock_check_authentication: Mock): + """Test CDS class initialization.""" + mock_check_authentication.return_value = True + CDS(key="xxx") + + +@pytest.fixture +@patch("datapi.ApiClient.check_authentication") +def fake_cds(mock_check_authentication: Mock): + mock_check_authentication.return_value = True + return CDS(key="xxx") + + +@patch("datapi.ApiClient.get_collection") +def test_latest(mock_get_collection: Mock, fake_cds: CDS): + mock_get_collection.return_value = TestCollection(end_datetime=datetime(2023, 1, 1).astimezone()) + assert fake_cds.latest == datetime(2023, 1, 1).astimezone() + + +class TestJobs(datapi.processing.Jobs): + """Datapi Jobs class with mocked request_uids property.""" + + __test__ = False + + def __init__(self, request_uids: list[str]) -> None: + self._request_uids = request_uids + + @property + def request_uids(self): + return self._request_uids + + +class TestRemote(datapi.processing.Remote): + """Datapi Remote class with mocked properties.""" + + __test__ = False + + def __init__(self, request_uid: str, status: str, results_ready: bool, request: dict) -> None: + self._request_uid = request_uid + self._status = status + self._results_ready = results_ready + self._request = request + self.cleanup = False + + @property + def status(self): + return self._status + + @property + def request_uid(self): + return self._request_uid + + @property + def results_ready(self): + return self._results_ready + + @property + def request(self): + return self._request + + +@patch("datapi.ApiClient.get_jobs") +@patch("datapi.ApiClient.get_remote") +def test_cds_get_remote_requests(mock_get_remote: Mock, mock_get_jobs: Mock, fake_cds: CDS): + mock_get_jobs.return_value = TestJobs( + request_uids=[ + "73dc0d2d-8288-4041-a84d-87e70772d5a8", + "3973ec55-4b38-449b-b7f1-5edd1034f663", + "a5c7093d-56d9-40a4-a363-c60cd242ce66", + ] + ) + + mock_get_remote.return_value = TestRemote( + request_uid="73dc0d2d-8288-4041-a84d-87e70772d5a8", status="successful", results_ready=True, request={} + ) + + remote_requests = fake_cds.get_remote_requests() + + assert len(remote_requests) == 3 + assert remote_requests[0]["request_id"] == "73dc0d2d-8288-4041-a84d-87e70772d5a8" + assert isinstance(remote_requests[0]["request"], dict) + + +@pytest.fixture +def tp_request() -> DataRequest: + return DataRequest( + variable=["total_precipitation"], + year="2024", + month="12", + day=["01", "02", "03", "04", "05"], + time=["01:00", "06:00", "18:00"], + data_format="grib", + area=[16, -6, 9, 3], + ) + + +@pytest.fixture +def tp_request_remote() -> dict: + return { + "request_id": "73dc0d2d-8288-4041-a84d-87e70772d5a8", + "request": { + "day": ["01", "02", "03", "04", "05"], + "area": [16, -6, 9, 3], + "time": ["01:00", "06:00", "18:00"], + "year": "2024", + "month": "12", + "variable": ["total_precipitation"], + "data_format": "grib", + }, + } + + +@patch("datapi.ApiClient.get_remote") +def test_cds_get_remote_from_request( + mock_get_remote: Mock, fake_cds: CDS, tp_request: DataRequest, tp_request_remote: dict +): + mock_get_remote.return_value = TestRemote( + request_uid="73dc0d2d-8288-4041-a84d-87e70772d5a8", + status="successful", + results_ready=True, + request=tp_request_remote, + ) + + existing_requests = [tp_request_remote] + remote = fake_cds.get_remote_from_request(tp_request, existing_requests=existing_requests) + assert remote + assert remote.request_uid == "73dc0d2d-8288-4041-a84d-87e70772d5a8" + assert remote.request["request"] == tp_request.__dict__ From a6ec824cdf9769d3d476c7abb67a2af19935bdf7 Mon Sep 17 00:00:00 2001 From: Yann Forget Date: Fri, 13 Dec 2024 18:38:02 +0100 Subject: [PATCH 16/18] feat(era5): download single product --- openhexa/toolbox/era5/cds.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/openhexa/toolbox/era5/cds.py b/openhexa/toolbox/era5/cds.py index 7fb63fa..6947b8e 100644 --- a/openhexa/toolbox/era5/cds.py +++ b/openhexa/toolbox/era5/cds.py @@ -340,6 +340,12 @@ def submit(self, request: DataRequest) -> Remote: """ return self.client.submit(DATASET, **request.__dict__) + def retrieve(self, request: DataRequest, dst_file: Path | str) -> None: + """Submit and download a data request to the CDS API.""" + dst_file = Path(dst_file) + dst_file.parent.mkdir(parents=True, exist_ok=True) + self.client.retrieve(collection_id=DATASET, target=dst_file, **request.__dict__) + def download_between( self, start: datetime, end: datetime, variable: str, area: list[float], dst_dir: str | Path ) -> None: From 6ebc1e72f213a1d364ec7d2f46069275a42f8e2f Mon Sep 17 00:00:00 2001 From: Yann Forget Date: Fri, 13 Dec 2024 18:38:27 +0100 Subject: [PATCH 17/18] docs(era5): update examples in README --- openhexa/toolbox/era5/README.md | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/openhexa/toolbox/era5/README.md b/openhexa/toolbox/era5/README.md index 0b0e651..7e6ec0c 100644 --- a/openhexa/toolbox/era5/README.md +++ b/openhexa/toolbox/era5/README.md @@ -31,19 +31,21 @@ The package contains 3 modules: To download products from the Climate Data Store, you will need to create an account and generate an API key in ECMWF (see [CDS](https://cds.climate.copernicus.eu/)). ```python -from openhexa.toolbox.era5.cds import Client +from openhexa.toolbox.era5.cds import CDS, build_request, bounds_from_file -cds = Client(key="") +cds = CDS(key="") -request = cds.build_request( +request = build_request( variable="2m_temperature", year=2024, - month=4 + month=4, + day=[1, 2, 3], + time=[1, 6, 12, 18] ) -cds.download( +cds.retrieve( request=request, - dst_file="data/product.grib" + dst_file="data/t2m.grib" ) ``` @@ -54,7 +56,7 @@ downloaded. ```python bounds = bounds_from_file(fp=Path("data/districts.parquet"), buffer=0.5) -request = cds.build_request( +request = build_request( variable="total_precipitation", year=2023, month=10, @@ -62,7 +64,7 @@ request = cds.build_request( area=bounds ) -cds.download( +cds.retrieve( request=request, dst_file="data/product.grib" ) @@ -73,8 +75,8 @@ To download multiple products for a given period, use `Client.download_between() ```python cds.download_between( variable="2m_temperature", - start=datetime(2020, 1, 1), - end=datetime(2021, 6, 1), + start=datetime(2020, 1, 1, tzinfo=timezone.utc), + end=datetime(2021, 6, 1, tzinfo=timezone.utc), dst_dir="data/raw/2m_temperature", area=bounds ) @@ -83,7 +85,7 @@ cds.download_between( Checking latest available date in the ERA5-Land dataset: ```python -cds = Client("") +cds = CDS("") cds.latest ``` From 2618f2dba4037401994ca15b8eb06dbd33fa90c7 Mon Sep 17 00:00:00 2001 From: Yann Forget Date: Fri, 13 Dec 2024 18:44:31 +0100 Subject: [PATCH 18/18] style(era5): remove unused imports --- tests/era5/test_cds.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/tests/era5/test_cds.py b/tests/era5/test_cds.py index 1d48356..ae35e20 100644 --- a/tests/era5/test_cds.py +++ b/tests/era5/test_cds.py @@ -2,11 +2,8 @@ from __future__ import annotations -import tempfile -from dataclasses import dataclass -from datetime import datetime, timezone -from pathlib import Path -from unittest.mock import MagicMock, Mock, patch +from datetime import datetime +from unittest.mock import Mock, patch import datapi import pytest @@ -14,10 +11,6 @@ from openhexa.toolbox.era5.cds import ( CDS, DataRequest, - bounds_from_file, - date_range, - get_period_chunk, - iter_chunks, )