diff --git a/.vscode/settings.json b/.vscode/settings.json index ef7bb4b0..63974d0e 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -11,6 +11,7 @@ "python.linting.flake8Enabled": true, "python.linting.flake8Args": ["--config=setup.cfg"], "python.linting.mypyEnabled": true, + "python.linting.mypyArgs": ["--config=setup.cfg"], "python.pythonPath": "/opt/miniconda3/envs/xcdat_dev/bin/python", "python.testing.unittestEnabled": false, "python.testing.nosetestsEnabled": false, diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index db3772dc..64005df5 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -183,14 +183,14 @@ Local Development plugins: anyio-2.2.0, cov-2.11.1 collected 3 items - tests/test_utils.py .. + tests/test_dataset.py .. tests/test_xcdat.py . ---------- coverage: platform darwin, python 3.8.8-final-0 ----------- Name Stmts Miss Cover --------------------------------------- xcdat/__init__.py 3 0 100% - xcdat/utils.py 18 0 100% + xcdat/dataset.py 18 0 100% xcdat/xcdat.py 0 0 100% --------------------------------------- TOTAL 21 0 100% diff --git a/conda-env/dev.yml b/conda-env/dev.yml index 7928e8ed..0ed4f1b8 100644 --- a/conda-env/dev.yml +++ b/conda-env/dev.yml @@ -9,11 +9,10 @@ dependencies: - python=3.8.8 - pip=21.0.1 - typing_extensions=3.7.4 # Required to make use of Python >=3.8 backported types - - cartopy=0.18.0 - - matplotlib=3.3.4 - netcdf4=1.5.6 - xarray=0.17.0 - cf_xarray=0.6.0 + - dask=2021.7.0 # Additional # ================== - bump2version==1.0.1 diff --git a/conda-env/readthedocs.yml b/conda-env/readthedocs.yml index 9a0f0e7b..fd15c285 100644 --- a/conda-env/readthedocs.yml +++ b/conda-env/readthedocs.yml @@ -9,10 +9,10 @@ dependencies: - python=3.8.8 - pip=21.0.1 - typing_extensions=3.7.4 # Required to make use of Python >=3.8 backported types - - cartopy=0.18.0 - - matplotlib=3.3.4 - netcdf4=1.5.6 - xarray=0.17.0 + - dask=2021.7.0 + - cf_xarray=0.6.0 # Documentation # ================== - sphinx=3.5.1 diff --git a/conda-recipe/meta.yaml b/conda-recipe/meta.yaml index 2a66ede7..205dd746 100644 --- a/conda-recipe/meta.yaml +++ b/conda-recipe/meta.yaml @@ -21,19 +21,18 @@ requirements: run: - python - typing_extensions - - cartopy - - matplotlib - netcdf4 - xarray - - typing_extensions + - dask - cf_xarray test: imports: - xcdat - - xcdat.logs - - xcdat.coord - - xcdat.utils + - xcdat.bounds + - xcdat.dataset + - xcdat.logger + - xcdat.variable about: home: https://github.com/tomvothecoder/xcdat diff --git a/docs/api.rst b/docs/api.rst index 542e9adc..b9dd123a 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -4,6 +4,7 @@ API Reference .. autosummary:: :toctree: generated/ - xcdat.coord - xcdat.log - xcdat.utils + xcdat.bounds + xcdat.dataset + xcdat.logger + xcdat.variable diff --git a/setup.py b/setup.py index f490114b..0e33786d 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ history = history_file.read() # https://packaging.python.org/discussions/install-requires-vs-requirements/#install-requires -requirements: List[str] = [] +requirements: List[str] = ["xarray", "pandas", "numpy"] setup_requirements = [ "pytest-runner", diff --git a/tests/fixtures.py b/tests/fixtures.py new file mode 100644 index 00000000..c93d59cf --- /dev/null +++ b/tests/fixtures.py @@ -0,0 +1,188 @@ +"""This module stores reusable test fixtures.""" +from datetime import datetime + +import numpy as np +import xarray as xr + +# If the fixture is an xarray object, make sure to use .copy() to create a +# shallow copy of the object. Otherwise, you might run into unintentional +# side-effects caused by reference assignment. +# https://xarray.pydata.org/en/stable/generated/xarray.DataArray.copy.html + +# Dataset coordinates +time = xr.DataArray( + data=[ + datetime(2000, 1, 1), + datetime(2000, 2, 1), + datetime(2000, 3, 1), + datetime(2000, 4, 1), + datetime(2000, 5, 1), + datetime(2000, 6, 1), + datetime(2000, 7, 1), + datetime(2000, 8, 1), + datetime(2000, 9, 1), + datetime(2000, 10, 1), + datetime(2000, 11, 1), + datetime(2000, 12, 1), + ], + dims=["time"], +) +time_non_cf_compliant = xr.DataArray( + data=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + dims=["time"], + attrs={"units": "months since 2000-01-01"}, +) + +lat = xr.DataArray( + data=np.array([-90, -88.75, 88.75, 90]), + dims=["lat"], + attrs={"units": "degrees_north", "axis": "Y"}, +) +lon = xr.DataArray( + data=np.array([0, 1.875, 356.25, 358.125]), + dims=["lon"], + attrs={"units": "degrees_east", "axis": "X"}, +) + +# Dataset data variables (bounds) +time_bnds = xr.DataArray( + name="time_bnds", + data=[ + [datetime(1999, 12, 16, 12), datetime(2000, 1, 16, 12)], + [datetime(2000, 1, 16, 12), datetime(2000, 2, 15, 12)], + [datetime(2000, 2, 15, 12), datetime(2000, 3, 16, 12)], + [datetime(2000, 3, 16, 12), datetime(2000, 4, 16)], + [datetime(2000, 4, 16), datetime(2000, 5, 16, 12)], + [datetime(2000, 5, 16, 12), datetime(2000, 6, 16)], + [datetime(2000, 6, 16), datetime(2000, 7, 16, 12)], + [datetime(2000, 7, 16, 12), datetime(2000, 8, 16, 12)], + [datetime(2000, 8, 16, 12), datetime(2000, 9, 16)], + [datetime(2000, 9, 16), datetime(2000, 10, 16, 12)], + [datetime(2000, 10, 16, 12), datetime(2000, 11, 16)], + [datetime(2000, 11, 16), datetime(2000, 12, 16)], + ], + coords={"time": time}, + dims=["time", "bnds"], + attrs={"is_generated": "True"}, +) + +time_bnds_non_cf_compliant = xr.DataArray( + name="time_bnds", + data=[ + [datetime(1999, 12, 16, 12), datetime(2000, 1, 16, 12)], + [datetime(2000, 1, 16, 12), datetime(2000, 2, 15, 12)], + [datetime(2000, 2, 15, 12), datetime(2000, 3, 16, 12)], + [datetime(2000, 3, 16, 12), datetime(2000, 4, 16)], + [datetime(2000, 4, 16), datetime(2000, 5, 16, 12)], + [datetime(2000, 5, 16, 12), datetime(2000, 6, 16)], + [datetime(2000, 6, 16), datetime(2000, 7, 16, 12)], + [datetime(2000, 7, 16, 12), datetime(2000, 8, 16, 12)], + [datetime(2000, 8, 16, 12), datetime(2000, 9, 16)], + [datetime(2000, 9, 16), datetime(2000, 10, 16, 12)], + [datetime(2000, 10, 16, 12), datetime(2000, 11, 16)], + [datetime(2000, 11, 16), datetime(2000, 12, 16)], + ], + coords={"time": time.data}, + dims=["time", "bnds"], + attrs={"is_generated": "True"}, +) +lat_bnds = xr.DataArray( + name="lat_bnds", + data=np.array([[-90, -89.375], [-89.375, 0.0], [0.0, 89.375], [89.375, 90]]), + coords={"lat": lat.data}, + dims=["lat", "bnds"], + attrs={"units": "degrees_north", "axis": "Y", "is_generated": "True"}, +) +lon_bnds = xr.DataArray( + name="lon_bnds", + data=np.array( + [ + [-0.9375, 0.9375], + [0.9375, 179.0625], + [179.0625, 357.1875], + [357.1875, 359.0625], + ] + ), + coords={"lon": lon.data}, + dims=["lon", "bnds"], + attrs={"units": "degrees_east", "axis": "X", "is_generated": "True"}, +) + +# Dataset data variables (variables) +ts = xr.DataArray( + name="ts", + data=np.ones((12, 4, 4)), + coords={"time": time, "lat": lat, "lon": lon}, + dims=["time", "lat", "lon"], +) +ts_non_cf_compliant = xr.DataArray( + name="ts", + data=np.ones((12, 4, 4)), + coords={"time": time_non_cf_compliant, "lat": lat, "lon": lon}, + dims=["time", "lat", "lon"], +) + +ts_with_bnds = xr.DataArray( + name="ts", + data=np.ones((2, 12, 4, 4)), + coords={ + "bnds": np.array([0, 1]), + "time": time.assign_attrs(bounds="time_bnds"), + "lat": lat.assign_attrs(bounds="lat_bnds"), + "lon": lon.assign_attrs(bounds="lon_bnds"), + "lat_bnds": lat_bnds, + "lon_bnds": lon_bnds, + "time_bnds": time_bnds, + }, + dims=[ + "bnds", + "time", + "lat", + "lon", + ], +) + + +def generate_dataset(cf_compliant=True, has_bounds: bool = True) -> xr.Dataset: + """Generates a dataset using coordinate and data variable fixtures. + + NOTE: Using ``.assign()`` to add data variables to an existing dataset will + remove attributes from existing coordinates. The workaround is to update a + data_vars dict then create the dataset. https://github.com/pydata/xarray/issues/2245 + + Parameters + ---------- + cf_compliant : bool, optional + CF compliant time units, by default True + has_bounds : bool, optional + Include bounds for coordinates, by default True + + Returns + ------- + xr.Dataset + Test dataset. + """ + data_vars = {} + coords = { + "lat": lat.copy(), + "lon": lon.copy(), + } + + if cf_compliant: + coords.update({"time": time.copy()}) + data_vars.update({"ts": ts.copy()}) + else: + coords.update({"time": time_non_cf_compliant.copy()}) + data_vars.update({"ts": ts_non_cf_compliant.copy()}) + + if has_bounds: + data_vars.update( + { + "time_bnds": time_bnds.copy(), + "lat_bnds": lat_bnds.copy(), + "lon_bnds": lon_bnds.copy(), + } + ) + + ds = xr.Dataset(data_vars=data_vars, coords=coords) + return ds diff --git a/tests/test_coord.py b/tests/test_bounds.py similarity index 59% rename from tests/test_coord.py rename to tests/test_bounds.py index 25c04d39..2b92fa5a 100644 --- a/tests/test_coord.py +++ b/tests/test_bounds.py @@ -2,10 +2,11 @@ import pytest import xarray as xr -from xcdat.coord import CoordAccessor +from tests.fixtures import generate_dataset, ts_with_bnds +from xcdat.bounds import DataArrayBoundsAccessor, DatasetBoundsAccessor -class TestCoordAccessor: +class TestDatasetBoundsAccessor: @pytest.fixture(autouse=True) def setup(self): # Coordinate information @@ -64,61 +65,60 @@ def setup(self): self.ds = xr.Dataset(coords={"lat": lat, "lon": lon}) def test__init__(self): - obj = CoordAccessor(self.ds) + obj = DatasetBoundsAccessor(self.ds) assert obj._dataset.identical(self.ds) def test_decorator_call(self): - assert self.ds.coord._dataset.identical(self.ds) + assert self.ds.bounds._dataset.identical(self.ds) def test_get_bounds_when_bounds_exist_in_dataset(self): - obj = CoordAccessor(self.ds) - obj._dataset = obj._dataset.assign( + ds = self.ds.copy() + ds = ds.assign( lat_bnds=self.lat_bnds, lon_bnds=self.lon_bnds, ) - lat_bnds = obj.get_bounds("lat") + lat_bnds = ds.bounds.get_bounds("lat") assert lat_bnds is not None and lat_bnds.identical(self.lat_bnds) assert lat_bnds.is_generated - lon_bnds = obj.get_bounds("lon") + lon_bnds = ds.bounds.get_bounds("lon") assert lon_bnds is not None and lon_bnds.identical(self.lon_bnds) assert lon_bnds.is_generated def test_get_bounds_when_bounds_do_not_exist_in_dataset(self): - # Check bounds generated if bounds do not exist. - obj = CoordAccessor(self.ds) + ds = self.ds.copy() - lat_bnds = obj.get_bounds("lat") + lat_bnds = ds.bounds.get_bounds("lat") assert lat_bnds is not None assert lat_bnds.identical(self.lat_bnds) assert lat_bnds.is_generated - lon_bnds = obj.get_bounds("lon") + lon_bnds = ds.bounds.get_bounds("lon") assert lon_bnds is not None assert lon_bnds.identical(self.lon_bnds) assert lon_bnds.is_generated # Check raises error when bounds do not exist and not allowing generated bounds. with pytest.raises(ValueError): - obj._dataset = obj._dataset.drop_vars(["lat_bnds"]) - obj.get_bounds("lat", allow_generating=False) + ds = ds.drop_vars(["lat_bnds"]) + ds.bounds.get_bounds("lat", allow_generating=False) def test_get_bounds_raises_error_with_incorrect_axis_argument(self): - obj = CoordAccessor(self.ds) + ds = self.ds.copy() with pytest.raises(ValueError): - obj.get_bounds("incorrect_axis_argument") # type: ignore + ds.bounds.get_bounds("incorrect_axis_argument") def test__get_bounds_does_not_drop_attrs_of_existing_coords_when_generating_bounds( self, ): ds = self.ds.copy() - lat_bnds = ds.coord.get_bounds("lat", allow_generating=True) + lat_bnds = ds.bounds.get_bounds("lat", allow_generating=True) assert lat_bnds.identical(self.lat_bnds) - ds = ds.drop("lat_bnds") + ds = ds.drop_vars("lat_bnds") assert ds.identical(self.ds) def test__generate_bounds_raises_errors_for_data_dim_and_length(self): @@ -135,40 +135,83 @@ def test__generate_bounds_raises_errors_for_data_dim_and_length(self): attrs={"units": "degrees_east", "axis": "X"}, ) ds = xr.Dataset(coords={"lat": lat, "lon": lon}) - obj = CoordAccessor(ds) # If coords dimensions does not equal 1. with pytest.raises(ValueError): - obj._generate_bounds("lat") + ds.bounds._generate_bounds("lat") # If coords are length of <=1. with pytest.raises(ValueError): - obj._generate_bounds("lon") + ds.bounds._generate_bounds("lon") def test__generate_bounds_returns_bounds(self): - obj = CoordAccessor(self.ds) + ds = self.ds.copy() - lat_bnds = obj._generate_bounds("lat") + lat_bnds = ds.bounds._generate_bounds("lat") assert lat_bnds.equals(self.lat_bnds) - assert obj._dataset.lat_bnds.is_generated + assert ds.lat_bnds.is_generated - lon_bnds = obj._generate_bounds("lon") + lon_bnds = ds.bounds._generate_bounds("lon") assert lon_bnds.equals(self.lon_bnds) - assert obj._dataset.lon_bnds.is_generated + assert ds.lon_bnds.is_generated def test__get_coord(self): - obj = CoordAccessor(self.ds) + ds = self.ds.copy() # Check lat axis coordinates exist - lat = obj._get_coord("lat") + lat = ds.bounds._get_coord("lat") assert lat is not None # Check lon axis coordinates exist - lon = obj._get_coord("lon") + lon = ds.bounds._get_coord("lon") assert lon is not None def test__get_coord_raises_error_if_coord_does_not_exist(self): - obj = CoordAccessor(self.ds) + ds = self.ds.copy() with pytest.raises(KeyError): - obj._dataset = obj._dataset.drop_vars("lat") - obj._get_coord("lat") + ds = ds.drop_vars("lat") + ds.bounds._get_coord("lat") + + +class TestDataArrayBoundsAccessor: + @pytest.fixture(autouse=True) + def setUp(self): + self.ds = generate_dataset(has_bounds=True) + self.ds.lat.attrs["bounds"] = "lat_bnds" + self.ds.lon.attrs["bounds"] = "lon_bnds" + self.ds.time.attrs["bounds"] = "time_bnds" + + def test__init__(self): + obj = DataArrayBoundsAccessor(self.ds.ts) + assert obj._dataarray.identical(self.ds.ts) + + def test_decorator_call(self): + expected = self.ds.ts + result = self.ds["ts"].bounds._dataarray + assert result.identical(expected) + + def test__copy_from_parent_copies_bounds(self): + ts_expected = ts_with_bnds.copy() + + ts = self.ds["ts"].copy() + ts_result = ts.bounds.copy_from_parent(self.ds) + + assert ts_result.identical(ts_expected) + + def test__set_bounds_dim_adds_bnds(self): + ts = self.ds["ts"].copy() + ts_expected = ts.copy() + ts_expected = ts_expected.expand_dims(bnds=np.array([0, 1])) + ts_result = ts.bounds._set_bounds_dim(self.ds) + + assert ts_result.identical(ts_expected) + + def test__set_bounds_dim_adds_bounds(self): + ds = self.ds.swap_dims({"bnds": "bounds"}).copy() + ts = ds["ts"].copy() + + ts_expected = ts.copy() + ts_expected = ts_expected.expand_dims(bounds=np.array([0, 1])) + + ts_result = ts.bounds._set_bounds_dim(ds) + assert ts_result.identical(ts_expected) diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 00000000..96d32cfb --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,286 @@ +import numpy as np +import pytest +import xarray as xr + +from tests.fixtures import generate_dataset +from xcdat.dataset import decode_time_units, open_dataset, open_mfdataset + + +class TestOpenDataset: + @pytest.fixture(autouse=True) + def setUp(self, tmp_path): + # Create temporary directory to save files. + self.dir = tmp_path / "input_data" + self.dir.mkdir() + + # Paths to the dummy datasets. + self.file_path = f"{self.dir}/file.nc" + + def test_non_cf_compliant_time_is_decoded(self): + # Generate two dummy datasets with non-CF compliant time units. + ds = generate_dataset(cf_compliant=False, has_bounds=False) + ds.to_netcdf(self.file_path) + + # Generate an expected dataset, which is a combination of both datasets + # with decoded time units and coordinate bounds. + expected_ds = generate_dataset(cf_compliant=True, has_bounds=True) + expected_ds.time.attrs["units"] = "months since 2000-01-01" + expected_ds.time_bnds.attrs["units"] = "months since 2000-01-01" + expected_ds.time.encoding = { + "source": None, + "dtype": np.dtype(np.int64), + "original_shape": expected_ds.time.data.shape, + "units": "months since 2000-01-01", + "calendar": "proleptic_gregorian", + } + + # Check that non-cf compliant time was decoded and bounds were generated. + result_ds = open_dataset(self.file_path) + assert result_ds.identical(expected_ds) + + def test_preserves_lat_and_lon_bounds_if_they_exist(self): + # Create expected dataset which includes bounds. + expected_ds = generate_dataset(has_bounds=True) + expected_ds.to_netcdf(self.file_path) + + # Need to add time units since decode time units is also called, + # which adds this attribute. + expected_ds.time.attrs["units"] = "days since 2000-01-01 00:00:00" + expected_ds.time_bnds.attrs["units"] = "days since 2000-01-01 00:00:00" + + # Check resulting dataset and expected are identical + result_ds = open_dataset(self.file_path) + assert result_ds.identical(expected_ds) + + def test_generates_lat_and_lon_bounds_if_they_dont_exist(self): + # Create expected dataset without bounds. + ds = generate_dataset(has_bounds=False) + ds.to_netcdf(self.file_path) + + # Make sure bounds don't exist + data_vars = list(ds.data_vars.keys()) + assert "lat_bnds" not in data_vars + assert "lon_bnds" not in data_vars + + # Check bounds were generated. + result = open_dataset(self.file_path) + result_data_vars = list(result.data_vars.keys()) + assert "lat_bnds" in result_data_vars + assert "lon_bnds" in result_data_vars + + +class TestOpenMfDataset: + @pytest.fixture(autouse=True) + def setUp(self, tmp_path): + # Create temporary directory to save files. + self.dir = tmp_path / "input_data" + self.dir.mkdir() + + # Paths to the dummy datasets. + self.file_path1 = f"{self.dir}/file1.nc" + self.file_path2 = f"{self.dir}/file2.nc" + + def test_non_cf_compliant_time_is_decoded(self): + # Generate two dummy datasets with non-CF compliant time units. + ds1 = generate_dataset(cf_compliant=False, has_bounds=False) + ds1.to_netcdf(self.file_path1) + + ds2 = generate_dataset(cf_compliant=False, has_bounds=False) + ds2 = ds2.rename_vars({"ts": "tas"}) + ds2.to_netcdf(self.file_path2) + + # Generate an expected dataset, which is a combination of both datasets + # with decoded time units and coordinate bounds. + expected_ds = generate_dataset(cf_compliant=True, has_bounds=True) + expected_ds["tas"] = expected_ds["ts"].copy() + expected_ds.time.attrs["units"] = "months since 2000-01-01" + expected_ds.time_bnds.attrs["units"] = "months since 2000-01-01" + expected_ds.time.encoding = { + "source": None, + "dtype": np.dtype(np.int64), + "original_shape": expected_ds.time.data.shape, + "units": "months since 2000-01-01", + "calendar": "proleptic_gregorian", + } + + # Check that non-cf compliant time was decoded and bounds were generated. + result_ds = open_mfdataset([self.file_path1, self.file_path2]) + assert result_ds.identical(expected_ds) + + def test_preserves_lat_and_lon_bounds_if_they_exist(self): + # Generate two dummy datasets. + ds1 = generate_dataset(cf_compliant=True, has_bounds=True) + ds1.to_netcdf(self.file_path1) + + ds2 = generate_dataset(cf_compliant=True, has_bounds=True) + ds2 = ds2.rename_vars({"ts": "tas"}) + ds2.to_netcdf(self.file_path2) + + # Generate expected dataset, which is a combination of the two datasets. + expected_ds = generate_dataset(has_bounds=True) + expected_ds["tas"] = expected_ds["ts"].copy() + expected_ds.time.attrs["units"] = "days since 2000-01-01 00:00:00" + expected_ds.time_bnds.attrs["units"] = "days since 2000-01-01 00:00:00" + + # Check that the result is identical to the expected. + result_ds = open_mfdataset([self.file_path1, self.file_path2]) + assert result_ds.identical(expected_ds) + + def test_generates_lat_and_lon_bounds_if_they_dont_exist(self): + # Generate two dummy datasets. + ds1 = generate_dataset(cf_compliant=True, has_bounds=False) + ds1.to_netcdf(self.file_path1) + + ds2 = generate_dataset(cf_compliant=True, has_bounds=False) + ds2 = ds2.rename_vars({"ts": "tas"}) + ds2.to_netcdf(self.file_path2) + + # Make sure no bounds exist in the input file. + data_vars1 = list(ds1.data_vars.keys()) + data_vars2 = list(ds2.data_vars.keys()) + assert "lat_bnds" not in data_vars1 + data_vars2 + assert "lon_bnds" not in data_vars1 + data_vars2 + + # Check that bounds were generated. + result = open_dataset(self.file_path1) + result_data_vars = list(result.data_vars.keys()) + assert "lat_bnds" in result_data_vars + assert "lon_bnds" in result_data_vars + + +class TestDecodeTimeUnits: + @pytest.fixture(autouse=True) + def setup(self): + # Common attributes for the time coordinate. Units are overriden based + # on the unit that needs to be tested (days (CF compliant) or months + # (non-CF compliant). + self.time_attrs = { + "units": None, + "bounds": "time_bnds", + "axis": "T", + "long_name": "time", + "standard_name": "time", + } + + def test_throws_error_if_function_is_called_on_already_decoded_cf_compliant_dataset( + self, + ): + ds = generate_dataset(cf_compliant=True) + + with pytest.raises(KeyError): + decode_time_units(ds) + + def test_decodes_cf_compliant_time_units(self): + # Create a dummy dataset with CF compliant time units. + time_attrs = self.time_attrs + time_attrs.update({"units": "days since 2000-01-01"}) + time_coord = xr.DataArray( + name="time", data=[0, 1, 2], dims=["time"], attrs=time_attrs + ) + input_ds = xr.Dataset({"time": time_coord}) + + # Create an expected dataset with properly decoded time units. + expected_ds = xr.Dataset( + { + "time": xr.DataArray( + name="time", + data=[ + np.datetime64("2000-01-01"), + np.datetime64("2000-01-02"), + np.datetime64("2000-01-03"), + ], + dims=["time"], + attrs=time_attrs, + ) + } + ) + + # Check the resulting dataset is identical to the expected. + result_ds = decode_time_units(input_ds) + assert result_ds.identical(expected_ds) + + # Check the encodings are the same. + expected_ds.time.encoding = { + # Default entries when `decode_times=True` + "dtype": np.dtype(np.int64), + "units": time_attrs["units"], + } + assert result_ds.time.encoding == expected_ds.time.encoding + + def test_decodes_non_cf_compliant_time_units_months(self): + # Create a dummy dataset with non-CF compliant time units. + time_attrs = self.time_attrs + time_attrs.update({"units": "months since 2000-01-01"}) + time_coord = xr.DataArray( + name="time", data=[0, 1, 2], dims=["time"], attrs=time_attrs + ) + input_ds = xr.Dataset({"time": time_coord}) + + # Create an expected dataset with properly decoded time units. + expected_ds = xr.Dataset( + { + "time": xr.DataArray( + name="time", + data=[ + np.datetime64("2000-01-01"), + np.datetime64("2000-02-01"), + np.datetime64("2000-03-01"), + ], + dims=["time"], + attrs=time_attrs, + ) + } + ) + + # Check the resulting dataset is identical to the expected. + result_ds = decode_time_units(input_ds) + assert result_ds.identical(expected_ds) + + # Check result and expected time coordinate encodings are the same. + expected_ds.time.encoding = { + "source": None, + "dtype": np.dtype(np.int64), + "original_shape": expected_ds.time.data.shape, + "units": time_attrs["units"], + "calendar": "proleptic_gregorian", + } + assert result_ds.time.encoding == expected_ds.time.encoding + + def test_decodes_non_cf_compliant_time_units_years(self): + # Create a dummy dataset with non-CF compliant time units. + time_attrs = self.time_attrs + time_attrs.update({"units": "years since 2000-01-01"}) + time_coord = xr.DataArray( + name="time", data=[0, 1, 2], dims=["time"], attrs=time_attrs + ) + input_ds = xr.Dataset({"time": time_coord}) + + # Create an expected dataset with properly decoded time units. + expected_ds = xr.Dataset( + { + "time": xr.DataArray( + name="time", + data=[ + np.datetime64("2000-01-01"), + np.datetime64("2001-01-01"), + np.datetime64("2002-01-01"), + ], + dims=["time"], + attrs=time_attrs, + ) + } + ) + + # Check the resulting dataset is identical to the expected. + result_ds = decode_time_units(input_ds) + assert result_ds.identical(expected_ds) + + # Check result and expected time coordinate encodings are the same. + expected_ds.time.encoding = { + "source": None, + "dtype": np.dtype(np.int64), + "original_shape": expected_ds.time.data.shape, + "units": time_attrs["units"], + "calendar": "proleptic_gregorian", + } + assert result_ds.time.encoding == expected_ds.time.encoding diff --git a/tests/test_utils.py b/tests/test_utils.py deleted file mode 100644 index 1f3b5227..00000000 --- a/tests/test_utils.py +++ /dev/null @@ -1,35 +0,0 @@ -import numpy as np -import pytest -import xarray as xr - -from xcdat.utils import open_datasets - - -class TestOpenDatasets: - @pytest.fixture(autouse=True) - def setUp(self, tmp_path): - # Create temporary directory - self.dir = tmp_path / "input_data" - self.dir.mkdir() - - # Create dummy dataset - self.ds = xr.Dataset( - {"longitude": np.linspace(0, 10), "latitude": np.linspace(0, 20)} - ) - self.ds.to_netcdf(f"{self.dir}/file.nc") - - def test_returns_files_with_specific_extension(self): - # Compare expected and result - expected = {"file.nc": self.ds} - result = open_datasets(self.dir, extension="nc") - - for filename, dataset in result.items(): - assert dataset.equals(expected[filename]) - - def test_returns_all_files(self): - # Compare expected and result - expected = {"file.nc": self.ds} - result = open_datasets(self.dir) - - for filename, dataset in result.items(): - assert dataset.equals(expected[filename]) diff --git a/tests/test_variable.py b/tests/test_variable.py new file mode 100644 index 00000000..062a9d1b --- /dev/null +++ b/tests/test_variable.py @@ -0,0 +1,38 @@ +import pytest + +from tests.fixtures import generate_dataset, lat_bnds, ts_with_bnds +from xcdat.variable import open_variable + + +class TestOpenVariable: + def test_raises_error_if_variable_does_not_exist(self): + ds = generate_dataset(has_bounds=False) + + with pytest.raises(KeyError): + open_variable(ds, "invalid_var") + + def test_raises_error_if_bounds_dim_is_missing(self): + ds = generate_dataset(has_bounds=False) + + with pytest.raises(KeyError): + open_variable(ds, "ts") + + def test_raises_error_if_bounds_are_missing_for_coordinates(self): + ds = generate_dataset(has_bounds=False) + + # By adding bounds to the parent dataset, it will initiate copying bounds + # and find that bounds are missing for the other coords (lon and time). + ds["lat_bnds"] = lat_bnds.copy() + with pytest.raises(ValueError): + open_variable(ds, "ts") + + def test_returns_variable_with_bounds(self): + ds = generate_dataset(has_bounds=True) + ds.lat.attrs["bounds"] = "lat_bnds" + ds.lon.attrs["bounds"] = "lon_bnds" + ds.time.attrs["bounds"] = "time_bnds" + + ts_expected = ts_with_bnds.copy() + + ts_result = open_variable(ds, "ts") + assert ts_result.identical(ts_expected) diff --git a/xcdat/coord.py b/xcdat/bounds.py similarity index 53% rename from xcdat/coord.py rename to xcdat/bounds.py index 11b97a53..47951ec7 100644 --- a/xcdat/coord.py +++ b/xcdat/bounds.py @@ -1,4 +1,4 @@ -"""Coordinate module that contains coordinate related functions (e.g., getting bounds).""" +"""Bounds module for functions related to coordinate bounds.""" from typing import Optional, Tuple, get_args import cf_xarray as cfxr # noqa: F401 @@ -6,41 +6,55 @@ import xarray as xr from typing_extensions import Literal -from xcdat.log import setup_custom_logger +from xcdat.logger import setup_custom_logger logger = setup_custom_logger("root") +Coord = Literal["lat", "latitude", "lon", "longitude", "time"] +#: Tuple of supported coordinates in xCDAT functions and methods. +SUPPORTED_COORDS: Tuple[Coord, ...] = get_args(Coord) -@xr.register_dataset_accessor("coord") -class CoordAccessor: - """A class to represent the CoordAccessor xarray extension. + +@xr.register_dataset_accessor("bounds") +class DatasetBoundsAccessor: + """A class to represent the DatasetBoundsAccessor. Examples --------- Import: - >>> from xcdat import coord + >>> from xcdat import bounds Get coordinate bounds if they exist, otherwise return generated bounds: >>> ds = xr.open_dataset("file_path") - >>> lat_bnds = ds.coord.get_bounds("lat") # or pass "latitude" - >>> lon_bnds = ds.coord.get_bounds("lon") # or pass "longitude" - >>> time_bnds = ds.coord.get_bounds("time") + >>> lat_bounds = ds.bounds.get_bounds("lat") # or pass "latitude" + >>> lon_bounds = ds.bounds.get_bounds("lon") # or pass "longitude" + >>> time_bounds = ds.bounds.get_bounds("time") Get coordinate bounds and don't generate if they don't exist: >>> ds = xr.open_dataset("file_path") >>> # Throws error if no bounds exist - >>> lat_bnds = ds.coord.get_bounds("lat", allow_generating=False) + >>> lat_bounds = ds.bounds.get_bounds("lat", allow_generating=False) """ - Coord = Literal["lat", "latitude", "lon", "longitude", "time"] - COORD_ARGUMENTS: Tuple[Coord, ...] = get_args(Coord) - def __init__(self, dataset: xr.Dataset): self._dataset: xr.Dataset = dataset + def get_bounds_for_all_coords(self, allow_generating=True) -> xr.Dataset: + """Gets existing bounds or generates new ones for supported coordinates. + + Returns + ------- + xr.Dataset + """ + for coord in [*self._dataset.coords]: + if coord in SUPPORTED_COORDS: + self.get_bounds(coord, allow_generating) + + return self._dataset + def get_bounds( self, coord: Coord, @@ -75,10 +89,10 @@ def get_bounds( ValueError If ``allow_generating=False`` and no bounds were found in the dataset. """ - if coord not in CoordAccessor.COORD_ARGUMENTS: + if coord not in SUPPORTED_COORDS: raise ValueError( "Incorrect `coord` argument. Supported coordinates include: Supported " - f"arguments include: {', '.join(CoordAccessor.COORD_ARGUMENTS)}." + f"arguments include: {', '.join(SUPPORTED_COORDS)}." ) try: @@ -109,8 +123,8 @@ def _generate_bounds( coord : Coord The coordinate. width : float, optional - Width of the bounds relative to the position of the nearest points, - by default 0.5. + Width of the bounds relative to the position of the nearest + points, by default 0.5. Returns ------- @@ -201,3 +215,121 @@ def _get_coord(self, coord: Coord) -> xr.DataArray: raise KeyError(f"No matching coordinates for coord: {coord}") return matching_coord + + +@xr.register_dataarray_accessor("bounds") +class DataArrayBoundsAccessor: + """A class representing the DataArrayBoundsAccessor. + + Examples + -------- + Import module: + + >>> from xcdat import bounds + >>> from xcdat.dataset import open_dataset + + Return attribute (refer to ``tas.bounds__dict__`` for attributes): + + >>> tas.bounds. + + Copy axis bounds from parent Dataset to data variable: + + >>> ds = open_dataset("file_path") # Auto-generates bounds if missing + >>> tas = ds["tas"] + >>> tas.bounds._copy_from_dataset(ds) + """ + + def __init__(self, dataarray: xr.DataArray): + self._dataarray = dataarray + + def copy_from_parent(self, dataset: xr.Dataset) -> xr.DataArray: + """Copies coordinate bounds from the parent Dataset to the DataArray. + + In an xarray.Dataset, variables (e.g., "tas") and coordinate bounds + (e.g., "lat_bnds") are stored in the Dataset's data variables as + independent DataArrays that have no link between one another [3]_. As a + result, this creates an issue when you need to reference coordinate + bounds after extracting a variable to work on it independently. + + This function works around this issue by copying the coordinate bounds + from the parent Dataset to the DataArray variable. + + Parameters + ---------- + dataset : xr.Dataset + The parent Dataset. + + Returns + ------- + xr.DataArray + The data variable with bounds coordinates in the list of coordinates. + + Notes + ----- + + .. [3] https://github.com/pydata/xarray/issues/1475 + + """ + da = self._dataarray.copy() + da = self._set_bounds_dim(dataset) + + coords = [*dataset.coords] + boundless_coords = [] + for coord in coords: + if coord in SUPPORTED_COORDS: + try: + bounds = dataset.cf.get_bounds(coord) + da[bounds.name] = bounds.copy() + except KeyError: + boundless_coords.append(coord) + + if boundless_coords: + raise ValueError( + "The dataset is missing bounds for the following coords: " + f"{', '.join(boundless_coords)}. Pass the dataset to" + "`xcdat.dataset.open_dataset` to auto-generate missing bounds first" + ) + + self._dataarray = da + return self._dataarray + + def _set_bounds_dim(self, dataset: xr.Dataset) -> xr.DataArray: + """Sets the bounds dimension in the DataArray using the parent Dataset. + + The bounds dimension must be set before adding bounds to the DataArray + coordinates, otherwise the error below will be thrown: + ``ValueError: cannot add coordinates with new dimensions to a DataArray``. + + Parameters + ---------- + dataset : xr.Dataset + The parent Dataset. + + Returns + ------- + xr.DataArray + The data variable with a bounds dimension. + + Raises + ------ + KeyError + When no bounds dimension exists in the parent Dataset. + """ + da = self._dataarray.copy() + dims = dataset.dims.keys() + + if "bnds" in dims: + bounds_dim = "bnds" + elif "bounds" in dims: + bounds_dim = "bounds" + else: + raise KeyError( + "No bounds dimension in the parent dataset. This indicates that there " + "are probably no coordinate bounds in the dataset. Pass the " + "dataset to `xcdat.dataset.open_dataset` to auto-generate bounds." + ) + + da = da.expand_dims(dim={bounds_dim: np.array([0, 1])}) + + self._dataarray = da + return self._dataarray diff --git a/xcdat/dataset.py b/xcdat/dataset.py new file mode 100644 index 00000000..e8c2752f --- /dev/null +++ b/xcdat/dataset.py @@ -0,0 +1,221 @@ +"""Dataset module for functions related to an xarray.Dataset.""" +from typing import Any, Dict, List, Union + +import pandas as pd +import xarray as xr + +from xcdat import bounds # noqa: F401 + + +def open_dataset(path: str, **kwargs: Dict[str, Any]) -> xr.Dataset: + """Wrapper for ``xarray.open_dataset`` that applies common operations. + + Operations include: + + - Decode all time units, including non-CF compliant units (months and years). + - Generate bounds for supported coordinates if they don't exist. + + Parameters + ---------- + path : str + Path to Dataset. + kwargs : Dict[str, Any] + Additional arguments passed on to ``xarray.open_dataset``. + + - Visit the xarray docs for accepted arguments [1]_. + - ``decode_times`` defaults to ``False`` to allow for the manual + decoding of non-CF time units. + + Returns + ------- + xr.Dataset + Dataset after applying operations. + + Notes + ----- + ``xarray.open_dataset`` opens the file with read-only access. When you + modify values of a Dataset, even one linked to files on disk, only the + in-memory copy you are manipulating in xarray is modified: the original file + on disk is never touched. + + References + ---------- + + .. [1] https://xarray.pydata.org/en/stable/generated/xarray.open_dataset.html + + Examples + -------- + Import and call module: + + >>> from xcdat.dataset import open_dataset + >>> ds = open_dataset("file_path") + """ + ds = xr.open_dataset(path, decode_times=False, **kwargs) + ds = decode_time_units(ds) + ds = ds.bounds.get_bounds_for_all_coords() + + return ds + + +def open_mfdataset(paths: Union[str, List[str]], **kwargs) -> xr.Dataset: + """Wrapper for ``xarray.open_mfdataset`` that applies common operations. + + Operations include: + + - Decode all time units, including non-CF compliant units (months and years). + - Generate bounds for supported coordinates if they don't exist. + + Parameters + ---------- + path : Union[str, List[str]] + Either a string glob in the form ``"path/to/my/files/*.nc"`` or an + explicit list of files to open. Paths can be given as strings or as + pathlib Paths. If concatenation along more than one dimension is desired, + then ``paths`` must be a nested list-of-lists (see ``combine_nested`` + for details). (A string glob will be expanded to a 1-dimensional list.) + + kwargs : Dict[str, Any] + Additional arguments passed on to ``xarray.open_mfdataset`` and/or + ``xarray.open_dataset``. + + - Visit the xarray docs for accepted arguments, [2]_ and [3]_. + - ``decode_times`` defaults to ``False`` to allow for the manual + decoding of non-CF time units. + + Returns + ------- + xr.Dataset + Dataset after applying operations. + + Notes + ----- + ``xarray.open_mfdataset`` opens the file with read-only access. When you + modify values of a Dataset, even one linked to files on disk, only the + in-memory copy you are manipulating in xarray is modified: the original file + on disk is never touched. + + References + ---------- + + .. [2] https://xarray.pydata.org/en/stable/generated/xarray.open_mfdataset.html + .. [3] https://xarray.pydata.org/en/stable/generated/xarray.open_dataset.html + + Examples + -------- + Import and call module: + + >>> from xcdat.dataset import open_mfdataset + >>> ds = open_mfdataset(["file_path1", "file_path2"]) + """ + ds = xr.open_mfdataset(paths, decode_times=False, **kwargs) + ds = decode_time_units(ds) + ds = ds.bounds.get_bounds_for_all_coords() + + return ds + + +def decode_time_units(dataset: xr.Dataset): + """Decodes both CF and non-CF compliant time units. + + ``xarray`` uses the ``cftime`` module, which only supports CF compliant + time units [4]_. As a result, opening datasets with non-CF compliant + time units (months and years) will throw an error if ``decode_times=True``. + + This function works around this issue by first checking if the time units + are CF or non-CF compliant. Datasets with CF compliant time units are passed + to ``xarray.decode_cf``. Datasets with non-CF compliant time units are + manually decoded by extracting the units and reference date, which are used + to generate an array of datetime values. + + Parameters + ---------- + dataset : xr.Dataset + Dataset with non-decoded CF/non-CF compliant time units. + + Returns + ------- + xr.Dataset + Dataset with decoded time units. + + Notes + ----- + .. [4] https://unidata.github.io/cftime/api.html#cftime.num2date + + Examples + -------- + + Decode non-CF compliant time units in a Dataset: + + >>> from xcdat.dataset import decode_time_units + >>> ds = xr.open_dataset("file_path", decode_times=False) + >>> ds.time + + array([0, 1, 2]) + Coordinates: + * time (time) int64 0 1 2 + Attributes: + units: years since 2000-01-01 + bounds: time_bnds + axis: T + long_name: time + standard_name: time + >>> ds = decode_time_units(ds) + >>> ds.time + + array(['2000-01-01T00:00:00.000000000', '2001-01-01T00:00:00.000000000', + '2002-01-01T00:00:00.000000000'], dtype='datetime64[ns]') + Coordinates: + * time (time) datetime64[ns] 2000-01-01 2001-01-01 2002-01-01 + Attributes: + units: years since 2000-01-01 + bounds: time_bnds + axis: T + long_name: time + standard_name: time + + View time coordinate encoding information: + + >>> ds.time.encoding + {'source': None, 'dtype': dtype('int64'), 'original_shape': (3,), 'units': + 'years since 2000-01-01', 'calendar': 'proleptic_gregorian'} + """ + time = dataset["time"] + units_attr = time.attrs.get("units") + + if units_attr is None: + raise KeyError( + "No 'units' attribute found for time coordinate. Make sure to open " + "the dataset with `decode_times=False`." + ) + + units, reference_date = units_attr.split(" since ") + non_cf_units_to_freq = {"months": "MS", "years": "YS"} + + cf_compliant = units not in non_cf_units_to_freq.keys() + if cf_compliant: + dataset = xr.decode_cf(dataset, decode_times=True) + dataset.time.attrs["units"] = units_attr + else: + # NOTE: Calendar type for "months" or years" does not matter because the + # number of days in a month does not play a factor when generating date + # ranges for these units. + decoded_time = xr.DataArray( + data=pd.date_range( + start=reference_date, + periods=time.size, + freq=non_cf_units_to_freq[units], + ), + dims=["time"], + attrs=dataset["time"].attrs, + ) + decoded_time.encoding = { + "source": dataset.encoding.get("source"), + "dtype": time.dtype, + "original_shape": decoded_time.shape, + "units": units_attr, + # pandas.date_range() returns "proleptic_gregorian" by default + "calendar": time.attrs.get("calendar", "proleptic_gregorian"), + } + + dataset = dataset.assign_coords({"time": decoded_time}) + return dataset diff --git a/xcdat/log.py b/xcdat/logger.py similarity index 96% rename from xcdat/log.py rename to xcdat/logger.py index b29cf7b3..827e39dc 100644 --- a/xcdat/log.py +++ b/xcdat/logger.py @@ -1,4 +1,4 @@ -"""Log module for setting up a logger.""" +"""Logger module for setting up a logger.""" import logging import logging.handlers diff --git a/xcdat/utils.py b/xcdat/utils.py deleted file mode 100644 index 38723981..00000000 --- a/xcdat/utils.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Utility functions that might be helpful to use.""" -import glob -import os -from typing import Dict, Tuple, get_args - -import xarray as xr -from typing_extensions import Literal - -# Add supported extensions on as-need basis -# https://xarray.pydata.org/en/stable/io.html# -SupportedExtensions = Literal["nc"] - -SUPPORTED_EXTENSIONS: Tuple[SupportedExtensions, ...] = get_args(SupportedExtensions) - - -def open_datasets( - path: str, - extension: SupportedExtensions = None, -) -> Dict[str, xr.Dataset]: - """Lazily loads datasets from a specified path. - - Parameters - ---------- - path : str - The relative or absolute path of the input files. - extension : SupportedExtensions, optional - The file extension to look for. Refer to ``SupportedExtensions``, by - default None. - - Returns - ------- - Dict[str, xr.Dataset] - A dictionary of datasets, key is file name and value is dataset object. - """ - datasets: Dict[str, xr.Dataset] = dict() - files_grabbed = [] - - if extension: - files_grabbed.extend(glob.glob(os.path.join(path, f"*.{extension}"))) - else: - for extension in SUPPORTED_EXTENSIONS: - files_grabbed.extend(glob.glob(os.path.join(path, f"*.{extension}"))) - - for file in files_grabbed: - key = file.replace(f"{path}/", "") - datasets[key] = xr.open_dataset(file) - - return datasets diff --git a/xcdat/variable.py b/xcdat/variable.py new file mode 100644 index 00000000..b2e970de --- /dev/null +++ b/xcdat/variable.py @@ -0,0 +1,65 @@ +"""Variable module for functions related to an xarray.Dataset data variable""" +import xarray as xr + +from xcdat.bounds import DataArrayBoundsAccessor # noqa: F401 + + +def open_variable(dataset: xr.Dataset, name: str) -> xr.DataArray: + """Opens a Dataset data variable and applies additional operations. + + Operations include: + + - Propagate coordinate bounds from the parent ``Dataset`` to the + ``DataArray`` data variable. + + Parameters + ---------- + dataset : xr.Dataset + The parent Dataset. + name : str + The name of the data variable to be opened. + + Returns + ------- + xr.DataArray + The data variable. + + Notes + ----- + If you are familiar with CDAT, the ``DataArray`` data variable output is + similar to a ``TransientVariable``, which stores coordinate bounds as object + attributes. + + Examples + -------- + Import module: + + >>> from xcdat.dataset import open_dataset + >>> from xcdat.variable import open_variable + + Open a variable from a Dataset: + + >>> ds = open_dataset("file_path") # Auto-generate bounds if missing + >>> ts = open_variable(ds, "ts") + + List coordinate bounds: + + >>> ts.coords + * bnds (bnds) int64 0 1 + * time (time) datetime64[ns] 1850-01-16T12:00:00 ... 2005-12-16T12:00:00 + * lat (lat) float64 -90.0 -88.75 -87.5 -86.25 ... 86.25 87.5 88.75 90.0 + * lon (lon) float64 0.0 1.875 3.75 5.625 ... 352.5 354.4 356.2 358.1 + lon_bnds (lon, bnds) float64 ... + lat_bnds (lat, bnds) float64 ... + time_bnds (time, bnds) datetime64[ns] ... + + Return coordinate bounds: + + >>> ts.lon_bnds + >>> ts.lat_bnds + >>> ts.time_bnds + """ + data_var = dataset[name].copy() + data_var = data_var.bounds.copy_from_parent(dataset) + + return data_var