Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

196 attempt to detect filename when saving #218

Merged
merged 38 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
11ef134
detect filename decorator
EddyCMWF Sep 19, 2023
c75bae2
CDS save tests added
EddyCMWF Sep 19, 2023
28f83ec
CDS save tests added
EddyCMWF Sep 19, 2023
d09f1fd
Merge branch 'develop' into 196-attempt-to-detect-filename-when-saving
EddyCMWF Oct 11, 2023
4b7efce
filename detector
EddyCMWF Oct 12, 2023
21e24eb
filename detector
EddyCMWF Oct 12, 2023
1d72662
filename detector
EddyCMWF Oct 12, 2023
0b77c2b
Merge branch 'develop' into 196-attempt-to-detect-filename-when-saving
EddyCMWF Oct 12, 2023
e247820
merge from develop
EddyCMWF Oct 17, 2023
c649518
qa
EddyCMWF Oct 17, 2023
a9225dd
qa
EddyCMWF Oct 17, 2023
1b86669
qa
EddyCMWF Oct 17, 2023
6169e11
Merge branch 'develop' into 196-attempt-to-detect-filename-when-saving
EddyCMWF Oct 17, 2023
1314550
Do not overwrite file being read
EddyCMWF Oct 17, 2023
e86a242
prevent anyone overwriting the file they are reading
EddyCMWF Oct 17, 2023
abed6af
QA
EddyCMWF Oct 17, 2023
3720f0a
better implementation
EddyCMWF Oct 17, 2023
ba4eac8
even better implementation
EddyCMWF Oct 18, 2023
85ae984
adding source_filename
EddyCMWF Nov 21, 2023
a1e8556
Merge branch 'develop' into 196-attempt-to-detect-filename-when-saving
EddyCMWF Nov 21, 2023
1d4b3d3
use source_filename
EddyCMWF Nov 21, 2023
249b59c
preserve source filename
EddyCMWF Nov 21, 2023
c191be9
preserve source filename
EddyCMWF Nov 21, 2023
58f1ffc
basename only
EddyCMWF Nov 21, 2023
e346b60
qa and warning tests
EddyCMWF Nov 22, 2023
590bf83
typo
EddyCMWF Nov 27, 2023
abb8198
Merge branch 'develop' into 196-attempt-to-detect-filename-when-saving
EddyCMWF Nov 29, 2023
d7b1ab0
Merge branch 'develop' into 196-attempt-to-detect-filename-when-saving
EddyCMWF Dec 13, 2023
37f6c85
merge from develop, with conflict resoved
EddyCMWF Feb 2, 2024
460fbde
remove debug print statement
EddyCMWF Feb 2, 2024
7405fe9
Merge branch 'main' into 196-attempt-to-detect-filename-when-saving
EddyCMWF Feb 8, 2024
beb52aa
save docstring updated
EddyCMWF Feb 8, 2024
acfc548
source_filename as a property
EddyCMWF Feb 8, 2024
21e6041
source_filename as a standard field in Source class
EddyCMWF Feb 8, 2024
05a33bd
Merge branch 'develop' into 196-attempt-to-detect-filename-when-saving
EddyCMWF Feb 8, 2024
f71b895
removing sneaky files
EddyCMWF Feb 8, 2024
5aae0f8
Review comments responded to, test now using the correct filename
EddyCMWF Feb 12, 2024
015e0c2
remove debug prints
EddyCMWF Feb 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions earthkit/data/core/fieldlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from earthkit.data.core import Base
from earthkit.data.core.index import Index
from earthkit.data.decorators import cached_method
from earthkit.data.decorators import cached_method, detect_out_filename
from earthkit.data.utils.metadata import metadata_argument


Expand Down Expand Up @@ -1164,16 +1164,17 @@ def _is_shared_grid(self):
)
return False

@detect_out_filename
def save(self, filename, append=False, **kwargs):
r"""Write all the fields into a file.
sandorkertesz marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
filename: str
The target file path.
append: bool
filename: str, optional
The target file path, if not defined attempts will be made to detect the filename
append: bool, optional
When it is true append data to the target file. Otherwise
the target file be overwritten if already exists.
the target file be overwritten if already exists. Default is False
**kwargs: dict, optional
Other keyword arguments passed to :obj:`write`.
"""
Expand Down
33 changes: 33 additions & 0 deletions earthkit/data/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import os
import re
import threading
import warnings

from earthkit.data.utils import load_json_or_yaml
from earthkit.data.utils.availability import Availability
Expand All @@ -35,6 +36,38 @@ def wrapped(*args, **kwargs):
return wrapped


def detect_out_filename(func):
@functools.wraps(func)
def wrapped(self, *args, **kwargs):
# Detect filename:
if len(args) == 0:
for att in ["source_filename", "path"]:
if hasattr(self, att) and getattr(self, att) is not None:
args = [os.path.basename(getattr(self, att))]
break
else:
raise TypeError("Please provide an output filename")

# Ensure we do not overwrite file that is being read:
if (
args[0] is not None
and os.path.isfile(args[0])
and hasattr(self, "path")
and self.path is not None
and os.path.samefile(args[0], self.path)
):
warnings.warn(
UserWarning(
f"Earhtkit refusing to overwrite the file we are currently reading: {args[0]}"
)
)
return

return func(self, *args, **kwargs)

return wrapped


LOCK = threading.RLock()


Expand Down
4 changes: 3 additions & 1 deletion earthkit/data/readers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from earthkit.data.core import Base
from earthkit.data.core.settings import SETTINGS
from earthkit.data.decorators import locked
from earthkit.data.decorators import detect_out_filename, locked

LOG = logging.getLogger(__name__)

Expand All @@ -32,6 +32,7 @@ def __init__(self, source, path):

self._source = weakref.ref(source)
self.path = path
self.source_filename = self.source.source_filename

@property
def source(self):
Expand Down Expand Up @@ -60,6 +61,7 @@ def ignore(self):
def cache_file(self, *args, **kwargs):
return self.source.cache_file(*args, **kwargs)

@detect_out_filename
def save(self, path, **kwargs):
mode = "wb" if self.binary else "w"
with open(path, mode) as f:
Expand Down
2 changes: 2 additions & 0 deletions earthkit/data/sources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class Source(Base):
_dataset = None
_parent = None

source_filename = None

def __init__(self, **kwargs):
self._kwargs = kwargs

Expand Down
7 changes: 5 additions & 2 deletions earthkit/data/sources/cds.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,16 @@ def __init__(self, dataset, *args, **kwargs):

def _retrieve(self, dataset, request):
def retrieve(target, args):
self.client().retrieve(args[0], args[1], target)
sandorkertesz marked this conversation as resolved.
Show resolved Hide resolved
cds_result = self.client().retrieve(args[0], args[1])
self.source_filename = cds_result.location.split("/")[-1]
cds_result.download(target=target)

return self.cache_file(
return_object = self.cache_file(
retrieve,
(dataset, request),
extension=EXTENSIONS.get(request.get("format"), ".cache"),
)
return return_object

@staticmethod
@normalize("date", "date-list(%Y-%m-%d)")
Expand Down
2 changes: 2 additions & 0 deletions earthkit/data/sources/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from earthkit.data import from_source
from earthkit.data.core.caching import CACHE
from earthkit.data.decorators import detect_out_filename
from earthkit.data.readers import reader
from earthkit.data.utils.parts import check_urls_and_parts, ensure_urls_and_parts

Expand Down Expand Up @@ -125,6 +126,7 @@ def to_numpy(self, **kwargs):
def values(self):
return self._reader.values

@detect_out_filename
def save(self, path, **kwargs):
return self._reader.save(path, **kwargs)

Expand Down
66 changes: 66 additions & 0 deletions tests/sources/test_cds.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
# nor does it submit to any jurisdiction.
#

import os

import pytest

from earthkit.data import from_source
from earthkit.data.core.temporary import temp_directory
from earthkit.data.testing import NO_CDS


Expand Down Expand Up @@ -120,6 +123,42 @@ def test_cds_grib_multi_var_date(date, expected_date):
assert s.metadata("date") == expected_date


@pytest.mark.long_test
@pytest.mark.download
@pytest.mark.skipif(NO_CDS, reason="No access to CDS")
def test_cds_grib_save():
s = from_source(
"cds",
"reanalysis-era5-single-levels",
variable=["2t", "msl"],
product_type="reanalysis",
area=[50, -50, 20, 50],
date="2012-12-12",
time="12:00",
)
with temp_directory() as tmpdir:
# Check file save to assigned filename
s.save(os.path.join(tmpdir, "test.grib"))
assert os.path.isfile(os.path.join(tmpdir, "test.grib"))

s = from_source(
"cds",
"reanalysis-era5-single-levels",
variable=["2t", "msl"],
product_type="reanalysis",
area=[50, -50, 20, 50],
date="2012-12-12",
time="12:00",
)
with temp_directory() as tmpdir:
# Check file can be saved in current dir with detected filename:
here = os.curdir
os.chdir(tmpdir)
s.save()
assert os.path.isfile(os.path.basename(s.source_filename))
os.chdir(here)


@pytest.mark.long_test
@pytest.mark.download
@pytest.mark.skipif(NO_CDS, reason="No access to CDS")
Expand Down Expand Up @@ -212,6 +251,33 @@ def test_cds_netcdf():
assert s.metadata("variable") == ["t2m", "msl"]


@pytest.mark.long_test
@pytest.mark.download
@pytest.mark.skipif(NO_CDS, reason="No access to CDS")
def test_cds_netcdf_save():
s = from_source(
"cds",
"reanalysis-era5-single-levels",
variable=["2t", "msl"],
product_type="reanalysis",
area=[50, -50, 20, 50],
date="2012-12-12",
time="12:00",
format="netcdf",
)
with temp_directory() as tmpdir:
# Check file save to assigned filename
s.save(os.path.join(tmpdir, "test.nc"))
assert os.path.isfile(os.path.join(tmpdir, "test.nc"))

# Check file can be saved in current dir with detected filename:
here = os.curdir
os.chdir(tmpdir)
s.save()
assert os.path.isfile(os.path.basename(s.path))
os.chdir(here)


@pytest.mark.long_test
@pytest.mark.download
@pytest.mark.skipif(NO_CDS, reason="No access to CDS")
Expand Down
70 changes: 70 additions & 0 deletions tests/sources/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import logging
import os

import pytest

from earthkit.data import from_source
from earthkit.data.core.temporary import temp_directory
from earthkit.data.testing import earthkit_examples_file
Expand All @@ -28,11 +30,79 @@ def test_file_source_grib():
assert len(s) == 2


def test_file_source_grib_save():
s = from_source("file", earthkit_examples_file("test.grib"))
with temp_directory() as tmpdir:
# Check file save to assigned filename
s.save(os.path.join(tmpdir, "test2.grib"))
assert os.path.isfile(os.path.join(tmpdir, "test2.grib"))
# Check file can be saved in current dir with detected filename:
here = os.curdir
os.chdir(tmpdir)
s.save()
assert os.path.isfile("test.grib")
os.chdir(here)


def test_file_source_grib_no_overwrite():
_s = from_source("file", earthkit_examples_file("test.grib"))
with temp_directory() as tmpdir:
os.chdir(tmpdir)
# Save the file locally
_s.save("test.grib")
# Open the local file
s = from_source("file", "test.grib")
with pytest.warns(
UserWarning,
match="Earhtkit refusing to overwrite the file we are currently reading",
):
s.save("test.grib")
with pytest.warns(
UserWarning,
match="Earhtkit refusing to overwrite the file we are currently reading",
):
s.save()


def test_file_source_netcdf():
s = from_source("file", earthkit_examples_file("test.nc"))
assert len(s) == 2


def test_file_source_netcdf_save():
s = from_source("file", earthkit_examples_file("test.nc"))
with temp_directory() as tmpdir:
# Check file save to assigned filename
s.save(os.path.join(tmpdir, "test2.nc"))
assert os.path.isfile(os.path.join(tmpdir, "test2.nc"))
# Check file can be saved in current dir with detected filename:
here = os.curdir
os.chdir(tmpdir)
s.save()
assert os.path.isfile("test.nc")
os.chdir(here)


def test_file_source_netcdf_no_overwrite():
_s = from_source("file", earthkit_examples_file("test.nc"))
with temp_directory() as tmpdir:
os.chdir(tmpdir)
# Save the file locally
_s.save("test.nc")
# Open the local file
s = from_source("file", "test.nc")
with pytest.warns(
UserWarning,
match="Earhtkit refusing to overwrite the file we are currently reading",
):
s.save("test.nc")
with pytest.warns(
UserWarning,
match="Earhtkit refusing to overwrite the file we are currently reading",
):
s.save()


def test_file_source_odb():
s = from_source("file", earthkit_examples_file("test.odb"))
assert s.path == earthkit_examples_file("test.odb")
Expand Down
3 changes: 2 additions & 1 deletion tests/translators/test_translators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import xarray as xr

from earthkit.data import from_source, transform, translators, wrappers
from earthkit.data.testing import earthkit_test_data_file
from earthkit.data.translators import ndarray as ndtranslator
from earthkit.data.translators import pandas as pdtranslator
from earthkit.data.translators import string as strtranslator
Expand Down Expand Up @@ -133,7 +134,7 @@ def test_gpd_dataframe_translator():

def test_transform_from_grib_file():
# transform grib-based data object
f = from_source("file", "tests/data/test_single.grib")
f = from_source("file", earthkit_test_data_file("test_single.grib"))
EddyCMWF marked this conversation as resolved.
Show resolved Hide resolved

# np.ndarray
transformed = transform(f, np.ndarray)
Expand Down
Loading
Loading