diff --git a/doc/source/analyzing/time_series_analysis.rst b/doc/source/analyzing/time_series_analysis.rst index 9f4db0fc2a..f0ac68d178 100644 --- a/doc/source/analyzing/time_series_analysis.rst +++ b/doc/source/analyzing/time_series_analysis.rst @@ -77,6 +77,29 @@ see: * The cookbook recipe for :ref:`cookbook-time-series-analysis` * :class:`~yt.data_objects.time_series.DatasetSeries` +In addition, the :class:`~yt.data_objects.time_series.DatasetSeries` object allows to +select an output based on its time or by its redshift (if defined) as follows: + +.. code-block:: python + + import yt + + ts = yt.load("*/*.index") + # Get output at 3 Gyr + ds = ts.get_by_time((3, "Gyr")) + # This will fail if no output is found within 100 Myr + ds = ts.get_by_time((3, "Gyr"), tolerance=(100, "Myr")) + # Get the output at the time right before and after 3 Gyr + ds_before = ts.get_by_time((3, "Gyr"), prefer="smaller") + ds_after = ts.get_by_time((3, "Gyr"), prefer="larger") + + # For cosmological simulations, you can also select an output by its redshift + # with the same options as above + ds = ts.get_by_redshift(0.5) + +For more information, see :meth:`~yt.data_objects.time_series.DatasetSeries.get_by_time` and +:meth:`~yt.data_objects.time_series.DatasetSeries.get_by_redshift`. + .. _analyzing-an-entire-simulation: Analyzing an Entire Simulation diff --git a/nose_ignores.txt b/nose_ignores.txt index b5529a6ecc..61d20fab0b 100644 --- a/nose_ignores.txt +++ b/nose_ignores.txt @@ -47,3 +47,4 @@ --ignore-file=test_disks\.py --ignore-file=test_offaxisprojection_pytestonly\.py --ignore-file=test_sph_pixelization_pytestonly\.py +--ignore-file=test_time_series\.py diff --git a/tests/tests.yaml b/tests/tests.yaml index 267c1a95a8..fc7a880b3c 100644 --- a/tests/tests.yaml +++ b/tests/tests.yaml @@ -223,6 +223,7 @@ other_tests: - "--ignore-file=test_gadget_pytest\\.py" - "--ignore-file=test_vr_orientation\\.py" - "--ignore-file=test_particle_trajectories_pytest\\.py" + - "--ignore-file=test_time_series\\.py" - "--exclude-test=yt.frontends.gdf.tests.test_outputs.TestGDF" - "--exclude-test=yt.frontends.adaptahop.tests.test_outputs" - "--exclude-test=yt.frontends.stream.tests.test_stream_particles.test_stream_non_cartesian_particles" diff --git a/yt/data_objects/tests/test_time_series.py b/yt/data_objects/tests/test_time_series.py index 99c5472d61..9c28d24ca8 100644 --- a/yt/data_objects/tests/test_time_series.py +++ b/yt/data_objects/tests/test_time_series.py @@ -1,7 +1,7 @@ import tempfile from pathlib import Path -from numpy.testing import assert_raises +import pytest from yt.data_objects.static_output import Dataset from yt.data_objects.time_series import DatasetSeries @@ -29,13 +29,49 @@ def test_pattern_expansion(): def test_no_match_pattern(): with tempfile.TemporaryDirectory() as tmpdir: pattern = Path(tmpdir).joinpath("fake_data_file_*") - assert_raises( - FileNotFoundError, DatasetSeries._get_filenames_from_glob_pattern, pattern - ) + with pytest.raises(FileNotFoundError): + DatasetSeries._get_filenames_from_glob_pattern(pattern) -def test_init_fake_dataseries(): - file_list = [f"fake_data_file_{str(i).zfill(4)}" for i in range(10)] +@pytest.fixture +def FakeDataset(): + class _FakeDataset(Dataset): + """A minimal loadable fake dataset subclass""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @classmethod + def _is_valid(cls, *args, **kwargs): + return True + + def _parse_parameter_file(self): + return + + def _set_code_unit_attributes(self): + return + + def set_code_units(self): + i = int(Path(self.filename).name.split("_")[-1]) + self.current_time = i + self.current_redshift = 1 / (i + 1) + return + + def _hash(self): + return + + def _setup_classes(self): + return + + try: + yield _FakeDataset + finally: + output_type_registry.pop("_FakeDataset") + + +@pytest.fixture +def fake_datasets(): + file_list = [f"fake_data_file_{i:04d}" for i in range(10)] with tempfile.TemporaryDirectory() as tmpdir: pfile_list = [Path(tmpdir) / file for file in file_list] sfile_list = [str(file) for file in pfile_list] @@ -43,62 +79,77 @@ def test_init_fake_dataseries(): file.touch() pattern = Path(tmpdir) / "fake_data_file_*" - # init from str pattern - ts = DatasetSeries(pattern) - assert ts._pre_outputs == sfile_list + yield file_list, pfile_list, sfile_list, pattern + + +def test_init_fake_dataseries(fake_datasets): + file_list, pfile_list, sfile_list, pattern = fake_datasets + + # init from str pattern + ts = DatasetSeries(pattern) + assert ts._pre_outputs == sfile_list + + # init from Path pattern + ppattern = Path(pattern) + ts = DatasetSeries(ppattern) + assert ts._pre_outputs == sfile_list + + # init form str list + ts = DatasetSeries(sfile_list) + assert ts._pre_outputs == sfile_list + + # init form Path list + ts = DatasetSeries(pfile_list) + assert ts._pre_outputs == pfile_list + + # rejected input type (str repr of a list) "[file1, file2, ...]" + with pytest.raises(FileNotFoundError): + DatasetSeries(str(file_list)) - # init from Path pattern - ppattern = Path(pattern) - ts = DatasetSeries(ppattern) - assert ts._pre_outputs == sfile_list + # finally, check that ts[0] fails to actually load + with pytest.raises(YTUnidentifiedDataType): + ts[0] - # init form str list - ts = DatasetSeries(sfile_list) - assert ts._pre_outputs == sfile_list - # init form Path list - ts = DatasetSeries(pfile_list) - assert ts._pre_outputs == pfile_list +def test_init_fake_dataseries2(FakeDataset, fake_datasets): + _file_list, _pfile_list, _sfile_list, pattern = fake_datasets + ds = DatasetSeries(pattern)[0] + assert isinstance(ds, FakeDataset) - # rejected input type (str repr of a list) "[file1, file2, ...]" - assert_raises(FileNotFoundError, DatasetSeries, str(file_list)) + ts = DatasetSeries(pattern, my_unsupported_kwarg=None) - # finally, check that ts[0] fails to actually load - assert_raises(YTUnidentifiedDataType, ts.__getitem__, 0) + with pytest.raises(TypeError): + ts[0] - class FakeDataset(Dataset): - """A minimal loadable fake dataset subclass""" - @classmethod - def _is_valid(cls, *args, **kwargs): - return True +def test_get_by_key(FakeDataset, fake_datasets): + _file_list, _pfile_list, sfile_list, pattern = fake_datasets + ts = DatasetSeries(pattern) - def _parse_parameter_file(self): - return + Ntot = len(sfile_list) - def _set_code_unit_attributes(self): - return + t = ts[0].quan(1, "code_time") - def set_code_units(self): - self.current_time = 0 - return + assert sfile_list[0] == ts.get_by_time(-t).filename + assert sfile_list[0] == ts.get_by_time(t - t).filename + assert sfile_list[1] == ts.get_by_time((0.8, "code_time")).filename + assert sfile_list[1] == ts.get_by_time((1.2, "code_time")).filename + assert sfile_list[Ntot - 1] == ts.get_by_time(t * (Ntot - 1)).filename + assert sfile_list[Ntot - 1] == ts.get_by_time(t * Ntot).filename - def _hash(self): - return + with pytest.raises(ValueError): + ts.get_by_time(-2 * t, tolerance=0.1 * t) + with pytest.raises(ValueError): + ts.get_by_time(1000 * t, tolerance=0.1 * t) - def _setup_classes(self): - return + assert sfile_list[1] == ts.get_by_redshift(1 / 2.2).filename + assert sfile_list[1] == ts.get_by_redshift(1 / 2).filename + assert sfile_list[1] == ts.get_by_redshift(1 / 1.6).filename - try: - ds = DatasetSeries(pattern)[0] - assert isinstance(ds, FakeDataset) + with pytest.raises(ValueError): + ts.get_by_redshift(1000, tolerance=0.1) - ts = DatasetSeries(pattern, my_unsupported_kwarg=None) + zmid = (ts[0].current_redshift + ts[1].current_redshift) / 2 - assert_raises(TypeError, ts.__getitem__, 0) - # the exact error message is supposed to be this - # """__init__() got an unexpected keyword argument 'my_unsupported_kwarg'""" - # but it's hard to check for within the framework - finally: - # tear down to avoid possible breakage in following tests - output_type_registry.pop("FakeDataset") + assert sfile_list[1] == ts.get_by_redshift(zmid, prefer="smaller").filename + assert sfile_list[0] == ts.get_by_redshift(zmid, prefer="larger").filename diff --git a/yt/data_objects/time_series.py b/yt/data_objects/time_series.py index b642626e67..6edc5aa5fb 100644 --- a/yt/data_objects/time_series.py +++ b/yt/data_objects/time_series.py @@ -5,10 +5,11 @@ import weakref from abc import ABC, abstractmethod from functools import wraps -from typing import Optional +from typing import TYPE_CHECKING, Literal, Optional, Union import numpy as np from more_itertools import always_iterable +from unyt import Unit, unyt_quantity from yt.config import ytcfg from yt.data_objects.analyzer_objects import AnalysisTask, create_quantity_proxy @@ -28,6 +29,9 @@ parallel_root_only, ) +if TYPE_CHECKING: + from yt.data_objects.static_output import Dataset + class AnalysisTaskProxy: def __init__(self, time_series): @@ -144,10 +148,7 @@ class DatasetSeries: """ - # this annotation should really be Optional[Type[Dataset]] - # but we cannot import the yt.data_objects.static_output.Dataset - # class here without creating a circular import for now - _dataset_cls: Optional[type] = None + _dataset_cls: Optional[type["Dataset"]] = None def __init_subclass__(cls, *args, **kwargs): super().__init_subclass__(*args, **kwargs) @@ -436,6 +437,173 @@ def particle_trajectories( self, indices, fields=fields, suppress_logging=suppress_logging, ptype=ptype ) + def _get_by_attribute( + self, + attribute: str, + value: Union[unyt_quantity, tuple[float, Union[Unit, str]]], + tolerance: Union[None, unyt_quantity, tuple[float, Union[Unit, str]]] = None, + prefer: Literal["nearest", "smaller", "larger"] = "nearest", + ) -> "Dataset": + r""" + Get a dataset at or near to a given value. + + Parameters + ---------- + attribute : str + The key by which to retrieve an output, usually 'current_time' or + 'current_redshift'. The key must be an attribute of the dataset + and monotonic. + value : unyt_quantity or (value, unit) + The value to search for. + tolerance : unyt_quantity or (value, unit), optional + If not None, do not return a dataset unless the value is + within the tolerance value. If None, simply return the + nearest dataset. + Default: None. + prefer : str + The side of the value to return. Can be 'nearest', 'smaller' or 'larger'. + Default: 'nearest'. + """ + + if prefer not in ("nearest", "smaller", "larger"): + raise ValueError( + f"Side must be 'nearest', 'smaller' or 'larger', got {prefer}." + ) + + # Use a binary search to find the closest value + iL = 0 + iR = len(self._pre_outputs) - 1 + + if iL == iR: + ds = self[0] + if ( + tolerance is not None + and abs(getattr(ds, attribute) - value) > tolerance + ): + raise ValueError( + f"No dataset found with {attribute} within {tolerance} of {value}." + ) + return ds + + # Check signedness + dsL = self[iL] + dsR = self[iR] + vL = getattr(dsL, attribute) + vR = getattr(dsR, attribute) + + if vL < vR: + sign = 1 + elif vL > vR: + sign = -1 + else: + raise ValueError( + f"{dsL} and {dsR} have both {attribute}={vL}, cannot perform search." + ) + + if isinstance(value, tuple): + value = dsL.quan(*value) + if isinstance(tolerance, tuple): + tolerance = dsL.quan(*tolerance) + + # Short-circuit if value is out-of-range + if not (vL * sign < value * sign < vR * sign): + iL = iR = 0 + + while iR - iL > 1: + iM = (iR + iL) // 2 + dsM = self[iM] + vM = getattr(dsM, attribute) + + if sign * value < sign * vM: + iR = iM + dsR = dsM + elif sign * value > sign * vM: + iL = iM + dsL = dsM + else: # Exact match + dsL = dsR = dsM + break + + if prefer == "smaller": + ds_best = dsL if sign > 0 else dsR + elif prefer == "larger": + ds_best = dsR if sign > 0 else dsL + elif abs(value - getattr(dsL, attribute)) < abs( + value - getattr(dsR, attribute) + ): + ds_best = dsL + else: + ds_best = dsR + + if tolerance is not None: + if abs(value - getattr(ds_best, attribute)) > tolerance: + raise ValueError( + f"No dataset found with {attribute} within {tolerance} of {value}." + ) + return ds_best + + def get_by_time( + self, + time: Union[unyt_quantity, tuple[float, Union[Unit, str]]], + tolerance: Union[None, unyt_quantity, tuple[float, Union[Unit, str]]] = None, + prefer: Literal["nearest", "smaller", "larger"] = "nearest", + ) -> "Dataset": + """ + Get a dataset at or near to a given time. + + Parameters + ---------- + time : unyt_quantity or (value, unit) + The time to search for. + tolerance : unyt_quantity or (value, unit) + If not None, do not return a dataset unless the time is + within the tolerance value. If None, simply return the + nearest dataset. + Default: None. + prefer : str + The side of the value to return. Can be 'nearest', 'smaller' or 'larger'. + Default: 'nearest'. + + Examples + -------- + >>> ds = ts.get_by_time((12, "Gyr")) + >>> t = ts[0].quan(12, "Gyr") + ... ds = ts.get_by_time(t, tolerance=(100, "Myr")) + """ + return self._get_by_attribute( + "current_time", time, tolerance=tolerance, prefer=prefer + ) + + def get_by_redshift( + self, + redshift: float, + tolerance: Optional[float] = None, + prefer: Literal["nearest", "smaller", "larger"] = "nearest", + ) -> "Dataset": + """ + Get a dataset at or near to a given time. + + Parameters + ---------- + redshift : float + The redshift to search for. + tolerance : float + If not None, do not return a dataset unless the redshift is + within the tolerance value. If None, simply return the + nearest dataset. + Default: None. + prefer : str + The side of the value to return. Can be 'nearest', 'smaller' or 'larger'. + Default: 'nearest'. + + Examples + -------- + >>> ds = ts.get_by_redshift(0.0) + """ + return self._get_by_attribute( + "current_redshift", redshift, tolerance=tolerance, prefer=prefer + ) + class TimeSeriesQuantitiesContainer: def __init__(self, data_object, quantities):