Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

override units for datetime64/timedelta64 variables to preserve integer dtype #8201

Merged
merged 22 commits into from
Sep 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
37db2b0
remove `dtype` from encoding for datetime64/timedelta64 variables to …
kmuehlbauer Sep 18, 2023
490b186
Merge branch 'main' into drop-time-dtype
kmuehlbauer Sep 19, 2023
aff4e63
adapt tests
kmuehlbauer Sep 19, 2023
f6a954c
add whats-new.rst entry
kmuehlbauer Sep 19, 2023
e8de234
Update xarray/coding/times.py
kmuehlbauer Sep 19, 2023
b19a0e2
Update doc/whats-new.rst
kmuehlbauer Sep 19, 2023
7c19267
Merge branch 'main' into drop-time-dtype
kmuehlbauer Sep 19, 2023
fc63ed9
add test per review suggestion, replace .kind-check with np.issubdtyp…
kmuehlbauer Sep 19, 2023
35b77d6
align timedelta64 check with datetime64 check
kmuehlbauer Sep 20, 2023
ebd2e9a
Merge branch 'main' into drop-time-dtype
kmuehlbauer Sep 20, 2023
6691bfa
Merge branch 'main' into drop-time-dtype
kmuehlbauer Sep 21, 2023
64e2087
override units instead of dtype
kmuehlbauer Sep 22, 2023
f3ca8bf
remove print statement
kmuehlbauer Sep 22, 2023
90d4e0c
warn in case of serialization to floating point, too
kmuehlbauer Sep 22, 2023
4340884
align if-else
kmuehlbauer Sep 22, 2023
a7ba341
Add instructions to warnings
kmuehlbauer Sep 24, 2023
9e25c3a
Merge branch 'main' into drop-time-dtype
kmuehlbauer Sep 24, 2023
f4b3153
Fix test
kmuehlbauer Sep 24, 2023
32461c0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 24, 2023
410b9de
Merge branch 'main' into drop-time-dtype
kmuehlbauer Sep 24, 2023
43664c9
use warnings.catch_warnings
kmuehlbauer Sep 24, 2023
33a3e1a
Update doc/whats-new.rst
kmuehlbauer Sep 24, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ Bug fixes
- ``.rolling_exp`` functions no longer mistakenly lose non-dimensioned coords
(:issue:`6528`, :pull:`8114`)
By `Maximilian Roos <https://github.com/max-sixty>`_.
- In the event that user-provided datetime64/timedelta64 units and integer dtype encoding parameters conflict with each other, override the units to preserve an integer dtype for most faithful serialization to disk (:issue:`1064`, :pull:`8201`).
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
110 changes: 76 additions & 34 deletions xarray/coding/times.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,8 +656,22 @@ def cast_to_int_if_safe(num) -> np.ndarray:
return num


def _division(deltas, delta, floor):
if floor:
# calculate int64 floor division
# to preserve integer dtype if possible (GH 4045, GH7817).
num = deltas // delta.astype(np.int64)
num = num.astype(np.int64, copy=False)
else:
num = deltas / delta
return num


def encode_cf_datetime(
dates, units: str | None = None, calendar: str | None = None
dates,
units: str | None = None,
calendar: str | None = None,
dtype: np.dtype | None = None,
) -> tuple[np.ndarray, str, str]:
"""Given an array of datetime objects, returns the tuple `(num, units,
calendar)` suitable for a CF compliant time variable.
Expand Down Expand Up @@ -689,6 +703,12 @@ def encode_cf_datetime(
time_units, ref_date = _unpack_time_units_and_ref_date(units)
time_delta = _time_units_to_timedelta64(time_units)

# Wrap the dates in a DatetimeIndex to do the subtraction to ensure
# an OverflowError is raised if the ref_date is too far away from
# dates to be encoded (GH 2272).
dates_as_index = pd.DatetimeIndex(dates.ravel())
time_deltas = dates_as_index - ref_date

# retrieve needed units to faithfully encode to int64
needed_units, data_ref_date = _unpack_time_units_and_ref_date(data_units)
if data_units != units:
Expand All @@ -697,26 +717,32 @@ def encode_cf_datetime(
if ref_delta > np.timedelta64(0, "ns"):
needed_units = _infer_time_units_from_diff(ref_delta)

# Wrap the dates in a DatetimeIndex to do the subtraction to ensure
# an OverflowError is raised if the ref_date is too far away from
# dates to be encoded (GH 2272).
dates_as_index = pd.DatetimeIndex(dates.ravel())
time_deltas = dates_as_index - ref_date

# needed time delta to encode faithfully to int64
needed_time_delta = _time_units_to_timedelta64(needed_units)
if time_delta <= needed_time_delta:
# calculate int64 floor division
# to preserve integer dtype if possible (GH 4045, GH7817).
num = time_deltas // time_delta.astype(np.int64)
num = num.astype(np.int64, copy=False)
else:
emit_user_level_warning(
f"Times can't be serialized faithfully with requested units {units!r}. "
f"Resolution of {needed_units!r} needed. "
f"Serializing timeseries to floating point."
)
num = time_deltas / time_delta

floor_division = True
if time_delta > needed_time_delta:
floor_division = False
if dtype is None:
emit_user_level_warning(
f"Times can't be serialized faithfully to int64 with requested units {units!r}. "
f"Resolution of {needed_units!r} needed. Serializing times to floating point instead. "
f"Set encoding['dtype'] to integer dtype to serialize to int64. "
f"Set encoding['dtype'] to floating point dtype to silence this warning."
)
elif np.issubdtype(dtype, np.integer):
new_units = f"{needed_units} since {format_timestamp(ref_date)}"
emit_user_level_warning(
f"Times can't be serialized faithfully to int64 with requested units {units!r}. "
f"Serializing with units {new_units!r} instead. "
f"Set encoding['dtype'] to floating point dtype to serialize with units {units!r}. "
f"Set encoding['units'] to {new_units!r} to silence this warning ."
)
units = new_units
time_delta = needed_time_delta
floor_division = True

num = _division(time_deltas, time_delta, floor_division)
num = num.values.reshape(dates.shape)

except (OutOfBoundsDatetime, OverflowError, ValueError):
Expand All @@ -728,7 +754,9 @@ def encode_cf_datetime(
return (num, units, calendar)


def encode_cf_timedelta(timedeltas, units: str | None = None) -> tuple[np.ndarray, str]:
def encode_cf_timedelta(
timedeltas, units: str | None = None, dtype: np.dtype | None = None
) -> tuple[np.ndarray, str]:
data_units = infer_timedelta_units(timedeltas)

if units is None:
Expand All @@ -744,18 +772,29 @@ def encode_cf_timedelta(timedeltas, units: str | None = None) -> tuple[np.ndarra

# needed time delta to encode faithfully to int64
needed_time_delta = _time_units_to_timedelta64(needed_units)
if time_delta <= needed_time_delta:
# calculate int64 floor division
# to preserve integer dtype if possible
num = time_deltas // time_delta.astype(np.int64)
num = num.astype(np.int64, copy=False)
else:
emit_user_level_warning(
f"Timedeltas can't be serialized faithfully with requested units {units!r}. "
f"Resolution of {needed_units!r} needed. "
f"Serializing timedeltas to floating point."
)
num = time_deltas / time_delta

floor_division = True
if time_delta > needed_time_delta:
floor_division = False
if dtype is None:
emit_user_level_warning(
f"Timedeltas can't be serialized faithfully to int64 with requested units {units!r}. "
f"Resolution of {needed_units!r} needed. Serializing timeseries to floating point instead. "
f"Set encoding['dtype'] to integer dtype to serialize to int64. "
f"Set encoding['dtype'] to floating point dtype to silence this warning."
)
elif np.issubdtype(dtype, np.integer):
emit_user_level_warning(
f"Timedeltas can't be serialized faithfully with requested units {units!r}. "
f"Serializing with units {needed_units!r} instead. "
f"Set encoding['dtype'] to floating point dtype to serialize with units {units!r}. "
f"Set encoding['units'] to {needed_units!r} to silence this warning ."
)
units = needed_units
time_delta = needed_time_delta
floor_division = True

num = _division(time_deltas, time_delta, floor_division)
num = num.values.reshape(timedeltas.shape)
return (num, units)

Expand All @@ -772,7 +811,8 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable:

units = encoding.pop("units", None)
calendar = encoding.pop("calendar", None)
(data, units, calendar) = encode_cf_datetime(data, units, calendar)
dtype = encoding.get("dtype", None)
(data, units, calendar) = encode_cf_datetime(data, units, calendar, dtype)

safe_setitem(attrs, "units", units, name=name)
safe_setitem(attrs, "calendar", calendar, name=name)
Expand Down Expand Up @@ -807,7 +847,9 @@ def encode(self, variable: Variable, name: T_Name = None) -> Variable:
if np.issubdtype(variable.data.dtype, np.timedelta64):
dims, data, attrs, encoding = unpack_for_encoding(variable)

data, units = encode_cf_timedelta(data, encoding.pop("units", None))
data, units = encode_cf_timedelta(
data, encoding.pop("units", None), encoding.get("dtype", None)
)
safe_setitem(attrs, "units", units, name=name)

return Variable(dims, data, attrs, encoding, fastpath=True)
Expand Down
65 changes: 54 additions & 11 deletions xarray/tests/test_coding_times.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from xarray.coding.variables import SerializationWarning
from xarray.conventions import _update_bounds_attributes, cf_encoder
from xarray.core.common import contains_cftime_datetimes
from xarray.testing import assert_allclose, assert_equal, assert_identical
from xarray.testing import assert_equal, assert_identical
from xarray.tests import (
FirstElementAccessibleArray,
arm_xfail,
Expand Down Expand Up @@ -1036,7 +1036,7 @@ def test_encode_cf_datetime_defaults_to_correct_dtype(
pytest.skip("Nanosecond frequency is not valid for cftime dates.")
times = date_range("2000", periods=3, freq=freq)
units = f"{encoding_units} since 2000-01-01"
encoded, _, _ = coding.times.encode_cf_datetime(times, units)
encoded, _units, _ = coding.times.encode_cf_datetime(times, units)

numpy_timeunit = coding.times._netcdf_to_numpy_timeunit(encoding_units)
encoding_units_as_timedelta = np.timedelta64(1, numpy_timeunit)
Expand Down Expand Up @@ -1212,6 +1212,7 @@ def test_contains_cftime_lazy() -> None:
("1677-09-21T00:12:43.145224193", "ns", np.int64, None, False),
("1677-09-21T00:12:43.145225", "us", np.int64, None, False),
("1970-01-01T00:00:01.000001", "us", np.int64, None, False),
("1677-09-21T00:21:52.901038080", "ns", np.float32, 20.0, True),
],
)
def test_roundtrip_datetime64_nanosecond_precision(
Expand Down Expand Up @@ -1261,14 +1262,52 @@ def test_roundtrip_datetime64_nanosecond_precision_warning() -> None:
]
units = "days since 1970-01-10T01:01:00"
needed_units = "hours"
encoding = dict(_FillValue=20, units=units)
new_units = f"{needed_units} since 1970-01-10T01:01:00"

encoding = dict(dtype=None, _FillValue=20, units=units)
var = Variable(["time"], times, encoding=encoding)
wmsg = (
f"Times can't be serialized faithfully with requested units {units!r}. "
f"Resolution of {needed_units!r} needed. "
)
with pytest.warns(UserWarning, match=wmsg):
with pytest.warns(UserWarning, match=f"Resolution of {needed_units!r} needed."):
encoded_var = conventions.encode_cf_variable(var)
assert encoded_var.dtype == np.float64
assert encoded_var.attrs["units"] == units
assert encoded_var.attrs["_FillValue"] == 20.0

decoded_var = conventions.decode_cf_variable("foo", encoded_var)
assert_identical(var, decoded_var)

encoding = dict(dtype="int64", _FillValue=20, units=units)
var = Variable(["time"], times, encoding=encoding)
with pytest.warns(
UserWarning, match=f"Serializing with units {new_units!r} instead."
):
encoded_var = conventions.encode_cf_variable(var)
assert encoded_var.dtype == np.int64
assert encoded_var.attrs["units"] == new_units
assert encoded_var.attrs["_FillValue"] == 20

decoded_var = conventions.decode_cf_variable("foo", encoded_var)
assert_identical(var, decoded_var)

encoding = dict(dtype="float64", _FillValue=20, units=units)
var = Variable(["time"], times, encoding=encoding)
with warnings.catch_warnings():
warnings.simplefilter("error")
encoded_var = conventions.encode_cf_variable(var)
assert encoded_var.dtype == np.float64
assert encoded_var.attrs["units"] == units
assert encoded_var.attrs["_FillValue"] == 20.0

decoded_var = conventions.decode_cf_variable("foo", encoded_var)
assert_identical(var, decoded_var)

encoding = dict(dtype="int64", _FillValue=20, units=new_units)
var = Variable(["time"], times, encoding=encoding)
with warnings.catch_warnings():
warnings.simplefilter("error")
encoded_var = conventions.encode_cf_variable(var)
assert encoded_var.dtype == np.int64
assert encoded_var.attrs["units"] == new_units
assert encoded_var.attrs["_FillValue"] == 20

decoded_var = conventions.decode_cf_variable("foo", encoded_var)
assert_identical(var, decoded_var)
Expand Down Expand Up @@ -1309,14 +1348,18 @@ def test_roundtrip_timedelta64_nanosecond_precision_warning() -> None:
needed_units = "hours"
wmsg = (
f"Timedeltas can't be serialized faithfully with requested units {units!r}. "
f"Resolution of {needed_units!r} needed. "
f"Serializing with units {needed_units!r} instead."
)
encoding = dict(_FillValue=20, units=units)
encoding = dict(dtype=np.int64, _FillValue=20, units=units)
kmuehlbauer marked this conversation as resolved.
Show resolved Hide resolved
var = Variable(["time"], timedelta_values, encoding=encoding)
with pytest.warns(UserWarning, match=wmsg):
encoded_var = conventions.encode_cf_variable(var)
assert encoded_var.dtype == np.int64
assert encoded_var.attrs["units"] == needed_units
assert encoded_var.attrs["_FillValue"] == 20
decoded_var = conventions.decode_cf_variable("foo", encoded_var)
assert_allclose(var, decoded_var)
assert_identical(var, decoded_var)
assert decoded_var.encoding["dtype"] == np.int64


def test_roundtrip_float_times() -> None:
Expand Down
Loading