Skip to content

Commit

Permalink
194 write method added to NetCDFFieldList method (#195)
Browse files Browse the repository at this point in the history
* fixed save method for netCDF reader

---------

Co-authored-by: Sandor Kertesz <Sandor.Kertesz@ecmwf.int>
  • Loading branch information
EddyCMWF and sandorkertesz authored Sep 25, 2023
1 parent bdb85a8 commit 683c80e
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 5 deletions.
18 changes: 15 additions & 3 deletions earthkit/data/readers/netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,9 @@ def to_xarray_multi_from_paths(cls, paths, **kwargs):
**options,
)

def write(self, *args, **kwargs):
return self.to_netcdf(*args, **kwargs)


class NetCDFFieldListInFiles(NetCDFFieldList):
pass
Expand All @@ -632,15 +635,24 @@ class NetCDFMaskFieldList(NetCDFFieldList, MaskIndex):
def __init__(self, *args, **kwargs):
MaskIndex.__init__(self, *args, **kwargs)

# TODO: Implement this, but discussion required
def to_xarray(self, *args, **kwargs):
self._not_implemented()


class NetCDFMultiFieldList(NetCDFFieldList, MultiIndex):
def __init__(self, *args, **kwargs):
MultiIndex.__init__(self, *args, **kwargs)

def to_xarray(self, **kwargs):
return NetCDFFieldList.to_xarray_multi_from_paths(
[x.path for x in self.indexes], **kwargs
)
try:
return NetCDFFieldList.to_xarray_multi_from_paths(
[x.path for x in self.indexes], **kwargs
)
except AttributeError:
# TODO: Implement this, but discussion required
# This catches Multi-MaskFieldLists which cannot be openned in xarray
self._not_implemented()


class NetCDFFieldListReader(NetCDFFieldListInOneFile, Reader):
Expand Down
65 changes: 65 additions & 0 deletions tests/netcdf/test_netcdf_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#!/usr/bin/env python3

# (C) Copyright 2020 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

import os

import pytest

from earthkit.data import from_source
from earthkit.data.core.temporary import temp_file
from earthkit.data.testing import earthkit_examples_file


def test_netcdf_fieldlist_save():
ds = from_source("file", earthkit_examples_file("test.nc"))
assert len(ds) == 2

tmp = temp_file()
ds.save(tmp.path)
assert os.path.exists(tmp.path)
r_tmp = from_source("file", tmp.path)
assert len(r_tmp) == 2


def test_netcdf_fieldlist_subset_save():
ds = from_source("file", earthkit_examples_file("test.nc"))
assert len(ds) == 2
r = ds[1]

tmp = temp_file()
with pytest.raises(NotImplementedError):
r.save(tmp.path)


def test_netcdf_fieldlist_multi_subset_save():
ds1 = from_source("file", earthkit_examples_file("test.nc"))
ds2 = from_source("file", earthkit_examples_file("tuv_pl.nc"))

ds = ds1 + ds2
assert len(ds) == 20

tmp = temp_file()
ds.save(tmp.path)
assert os.path.exists(tmp.path)
r_tmp = from_source("file", tmp.path)
assert len(r_tmp) == 20


def test_netcdf_fieldlist_multi_subset_save_bad():
ds1 = from_source("file", earthkit_examples_file("test.nc"))
ds2 = from_source("file", earthkit_examples_file("tuv_pl.nc"))

ds = ds1 + ds2[1:5]
assert len(ds) == 6

tmp = temp_file()
with pytest.raises(NotImplementedError):
ds.save(tmp.path)
19 changes: 17 additions & 2 deletions tests/sources/test_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
import pytest

from earthkit.data import from_source, settings
from earthkit.data.core.temporary import temp_directory
from earthkit.data.testing import earthkit_file, network_off
from earthkit.data.core.temporary import temp_directory, temp_file
from earthkit.data.testing import (
earthkit_file,
earthkit_remote_test_data_file,
network_off,
)


@pytest.mark.skipif( # TODO: fix
Expand Down Expand Up @@ -129,6 +133,17 @@ def test_url_part_file_source():
assert f.read() == b"GRIB7777GRIB7777"


def test_url_netcdf_source_save():
ds = from_source(
"url",
earthkit_remote_test_data_file("examples/test.nc"),
)

tmp = temp_file()
ds.save(tmp.path)
assert os.path.exists(tmp.path)


if __name__ == "__main__":
test_part_url()
# from earthkit.data.testing import main
Expand Down

0 comments on commit 683c80e

Please sign in to comment.