Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More flexible gdp adapter #237

Merged
merged 11 commits into from
Aug 29, 2023
5 changes: 4 additions & 1 deletion clouddrift/adapters/gdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
"""

import numpy as np
import os
import pandas as pd
import xarray as xr
import urllib.request
import os
import warnings

GDP_VERSION = "2.00"

Expand Down Expand Up @@ -161,6 +162,8 @@ def fetch_netcdf(url: str, file: str):
"""
if not os.path.isfile(file):
urllib.request.urlretrieve(url, file)
else:
warnings.warn(f"{file} already exists; skip download.")


def decode_date(t):
Expand Down
22 changes: 16 additions & 6 deletions clouddrift/adapters/gdp1h.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@


def download(
drifter_ids: list = None, n_random_id: int = None, url: str = GDP_DATA_URL
drifter_ids: list = None,
n_random_id: int = None,
url: str = GDP_DATA_URL,
tmp_path: str = GDP_TMP_PATH,
):
"""Download individual NetCDF files from the AOML server.

Expand All @@ -60,17 +63,20 @@ def download(
Randomly select n_random_id drifter IDs to download (Default: None)
url : str
URL from which to download the data (Default: GDP_DATA_URL). Alternatively, it can be GDP_DATA_URL_EXPERIMENTAL.
tmp_path : str, optional
Path to the directory where the individual NetCDF files are stored
(default varies depending on operating system; /tmp/clouddrift/gdp on Linux)

Returns
-------
out : list
List of retrived drifters
"""

print(f"Downloading GDP hourly data to {GDP_TMP_PATH}...")
print(f"Downloading GDP hourly data from {url} to {tmp_path}...")

# Create a temporary directory if doesn't already exists.
os.makedirs(GDP_TMP_PATH, exist_ok=True)
os.makedirs(tmp_path, exist_ok=True)

if url == GDP_DATA_URL:
pattern = "drifter_[0-9]*.nc"
Expand Down Expand Up @@ -103,7 +109,7 @@ def download(
for i in drifter_ids:
file = filename_pattern.format(id=i)
urls.append(os.path.join(url, file))
files.append(os.path.join(GDP_TMP_PATH, file))
files.append(os.path.join(tmp_path, file))

# parallel retrieving of individual netCDF files
list(
Expand Down Expand Up @@ -493,6 +499,7 @@ def to_raggedarray(
drifter_ids: Optional[list[int]] = None,
n_random_id: Optional[int] = None,
url: Optional[str] = GDP_DATA_URL,
tmp_path: Optional[str] = GDP_TMP_PATH,
) -> RaggedArray:
"""Download and process individual GDP hourly files and return a RaggedArray
instance with the data.
Expand All @@ -506,6 +513,9 @@ def to_raggedarray(
url : str, optional
URL from which to download the data (Default: GDP_DATA_URL).
Alternatively, it can be GDP_DATA_URL_EXPERIMENTAL.
tmp_path : str, optional
Path to the directory where the individual NetCDF files are stored
(default varies depending on operating system; /tmp/clouddrift/gdp on Linux)

Returns
-------
Expand Down Expand Up @@ -551,7 +561,7 @@ def to_raggedarray(
>>> arr = ra.to_awkward()
>>> arr.to_parquet("gdp1h.parquet")
"""
ids = download(drifter_ids, n_random_id, url)
ids = download(drifter_ids, n_random_id, url, tmp_path)

if url == GDP_DATA_URL:
filename_pattern = "drifter_{id}.nc"
Expand All @@ -568,5 +578,5 @@ def to_raggedarray(
name_data=GDP_DATA,
rowsize_func=gdp.rowsize,
filename_pattern=filename_pattern,
tmp_path=GDP_TMP_PATH,
tmp_path=tmp_path,
)
25 changes: 16 additions & 9 deletions clouddrift/adapters/gdp6h.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@


def download(
drifter_ids: list = None, n_random_id: int = None, url: str = GDP_DATA_URL
drifter_ids: list = None,
n_random_id: int = None,
url: str = GDP_DATA_URL,
tmp_path: str = GDP_TMP_PATH,
):
"""Download individual NetCDF files from the AOML server.

Expand All @@ -47,17 +50,20 @@ def download(
Randomly select n_random_id drifter IDs to download (Default: None)
url : str
URL from which to download the data (Default: GDP_DATA_URL). Alternatively, it can be GDP_DATA_URL_EXPERIMENTAL.
tmp_path : str, optional
Path to the directory where the individual NetCDF files are stored
(default varies depending on operating system; /tmp/clouddrift/gdp6h on Linux)

Returns
-------
out : list
List of retrived drifters
"""

print(f"Downloading GDP 6-hourly data to {GDP_TMP_PATH}...")
print(f"Downloading GDP 6-hourly data to {tmp_path}...")

# Create a temporary directory if doesn't already exists.
os.makedirs(GDP_TMP_PATH, exist_ok=True)
os.makedirs(tmp_path, exist_ok=True)

pattern = "drifter_[0-9]*.nc"
directory_list = [
Expand Down Expand Up @@ -95,10 +101,7 @@ def download(
executor.map(
gdp.fetch_netcdf,
drifter_urls,
[
os.path.join(GDP_TMP_PATH, os.path.basename(f))
for f in drifter_urls
],
[os.path.join(tmp_path, os.path.basename(f)) for f in drifter_urls],
),
total=len(drifter_urls),
desc="Downloading files",
Expand Down Expand Up @@ -424,6 +427,7 @@ def preprocess(index: int, **kwargs) -> xr.Dataset:
def to_raggedarray(
drifter_ids: Optional[list[int]] = None,
n_random_id: Optional[int] = None,
tmp_path: Optional[str] = GDP_TMP_PATH,
) -> RaggedArray:
"""Download and process individual GDP 6-hourly files and return a
RaggedArray instance with the data.
Expand All @@ -434,6 +438,9 @@ def to_raggedarray(
List of drifters to retrieve (Default: all)
n_random_id : list[int], optional
Randomly select n_random_id drifter NetCDF files
tmp_path : str, optional
Path to the directory where the individual NetCDF files are stored
(default varies depending on operating system; /tmp/clouddrift/gdp6h on Linux)

Returns
-------
Expand Down Expand Up @@ -473,7 +480,7 @@ def to_raggedarray(
>>> arr = ra.to_awkward()
>>> arr.to_parquet("gdp6h.parquet")
"""
ids = download(drifter_ids, n_random_id, GDP_DATA_URL)
ids = download(drifter_ids, n_random_id, GDP_DATA_URL, tmp_path)

return RaggedArray.from_files(
indices=ids,
Expand All @@ -483,5 +490,5 @@ def to_raggedarray(
name_data=GDP_DATA,
rowsize_func=gdp.rowsize,
filename_pattern="drifter_{id}.nc",
tmp_path=GDP_TMP_PATH,
tmp_path=tmp_path,
)
2 changes: 1 addition & 1 deletion clouddrift/raggedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def to_awkward(self):
ak.Array
Awkward Array containing the ragged array and its attributes
"""
index_traj = np.insert(np.cumsum(self.metadata["count"]), 0, 0)
index_traj = np.insert(np.cumsum(self.metadata["rowsize"]), 0, 0)
offset = ak.index.Index64(index_traj)

data = []
Expand Down
6 changes: 3 additions & 3 deletions tests/raggedarray_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ def setUpClass(self):
[self.drifter_id[i]],
{"long_name": f"variable ID", "units": "-"},
)
xr_data["count"] = (
xr_data["rowsize"] = (
["traj"],
[self.count[i]],
{"long_name": f"variable count", "units": "-"},
{"long_name": f"variable rowsize", "units": "-"},
)
xr_data["temp"] = (
["obs"],
Expand All @@ -71,7 +71,7 @@ def setUpClass(self):
[0, 1, 2],
lambda i: list_ds[i],
self.variables_coords,
["ID", "count"],
["ID", "rowsize"],
["temp"],
)

Expand Down