Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement RaggedArray.from_xarray() #51

Merged
merged 6 commits into from
Nov 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 46 additions & 29 deletions clouddrift/dataformat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import Callable
from typing import Tuple, Optional
from tqdm import tqdm
import warnings


class RaggedArray:
Expand Down Expand Up @@ -61,42 +62,17 @@ def from_files(

@classmethod
def from_netcdf(cls, filename: str):
"""Read a ragged arrays archive from a NetCDF file
"""Read a ragged arrays archive from a NetCDF file.

This is a thin wrapper around from_xarray().

Args:
filename (str): filename of NetCDF archive

Returns:
obj: ragged array class object
"""
coords = {}
metadata = {}
data = {}
attrs_global = {}
attrs_variables = {}

with xr.open_dataset(filename) as ds:
nb_traj = ds.dims["traj"]
nb_obs = ds.dims["obs"]

attrs_global = ds.attrs

for var in ds.coords.keys():
coords[var] = ds[var].data
attrs_variables[var] = ds[var].attrs

for var in ds.data_vars.keys():
if len(ds[var]) == nb_traj:
metadata[var] = ds[var].data
elif len(ds[var]) == nb_obs:
data[var] = ds[var].data
else:
print(
f"Error: variable '{var}' has unknown dimension size of {len(ds[var])}, which is not traj={nb_traj} or obs={nb_obs}."
)
attrs_variables[var] = ds[var].attrs

return cls(coords, metadata, data, attrs_global, attrs_variables)
return cls.from_xarray(xr.open_dataset(filename))

@classmethod
def from_parquet(cls, filename: str):
Expand Down Expand Up @@ -132,6 +108,47 @@ def from_parquet(cls, filename: str):

return cls(coords, metadata, data, attrs_global, attrs_variables)

@classmethod
def from_xarray(cls, ds: xr.Dataset, dim_traj: str = "traj", dim_obs: str = "obs"):
"""Populate a RaggedArray instance from an xarray Dataset instance.

Args:
ds (xarray.Dataset): xarray Dataset from which to load the RaggedArray
dim_traj (str, optional): Name of the trajectories dimension in the xarray Dataset
dim_obs (str, optional): Name of the observations dimension in the xarray Dataset

Returns:
res (RaggedArray): A RaggedArray instance
"""
coords = {}
metadata = {}
data = {}
attrs_global = {}
attrs_variables = {}

attrs_global = ds.attrs

for var in ds.coords.keys():
coords[var] = ds[var].data
attrs_variables[var] = ds[var].attrs

for var in ds.data_vars.keys():
if len(ds[var]) == ds.dims[dim_traj]:
metadata[var] = ds[var].data
elif len(ds[var]) == ds.dims[dim_obs]:
data[var] = ds[var].data
else:
warnings.warn(
f"""
Variable '{var}' has unknown dimension size of
{len(ds[var])}, which is not traj={ds.dims[dim_traj]} or
obs={ds.dims[dim_obs]}; skipping.
"""
)
attrs_variables[var] = ds[var].attrs

return cls(coords, metadata, data, attrs_global, attrs_variables)

@staticmethod
def number_of_observations(
rowsize_func: Callable[[int], int], indices: list
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "clouddrift"
version = "0.2.2"
version = "0.3.0"
authors = [
{ name="Shane Elipot", email="selipot@miami.edu" },
{ name="Philippe Miron", email="philippemiron@gmail.com" },
Expand Down
13 changes: 12 additions & 1 deletion tests/dataformat_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import xarray as xr
import numpy as np
from clouddrift.dataformat import RaggedArray
from clouddrift import RaggedArray
import awkward._v2 as ak

if __name__ == "__main__":
Expand Down Expand Up @@ -87,6 +87,17 @@ def tearDown(self):
os.remove("test_archive.nc")
os.remove("test_archive.parquet")

def test_from_xarray(self):
ra = RaggedArray.from_xarray(xr.open_dataset("test_archive.nc"))
self.compare_awkward_array(ra.to_awkward())

def test_from_xarray_dim_names(self):
ds = xr.open_dataset("test_archive.nc")
ra = RaggedArray.from_xarray(
ds.rename_dims({"traj": "t", "obs": "o"}), dim_traj="t", dim_obs="o"
)
self.compare_awkward_array(ra.to_awkward())

def test_length_ragged_arrays(self):
"""
Validate the size of the ragged array variables
Expand Down