From e70a8aa91ac0d97a45167c04b74325e4641bb3b2 Mon Sep 17 00:00:00 2001 From: Corentin Cadiou Date: Fri, 27 Oct 2023 11:11:47 +0100 Subject: [PATCH 01/20] Add functionality to find a dataset by its closest value --- yt/data_objects/time_series.py | 85 +++++++++++++++++++++++++++++++++- 1 file changed, 84 insertions(+), 1 deletion(-) diff --git a/yt/data_objects/time_series.py b/yt/data_objects/time_series.py index b642626e67..3e9a6fe330 100644 --- a/yt/data_objects/time_series.py +++ b/yt/data_objects/time_series.py @@ -2,6 +2,7 @@ import glob import inspect import os +import typing import weakref from abc import ABC, abstractmethod from functools import wraps @@ -28,6 +29,9 @@ parallel_root_only, ) +if typing.TYPE_CHECKING: + from yt.data_objects.static_output import Dataset + class AnalysisTaskProxy: def __init__(self, time_series): @@ -147,7 +151,7 @@ class DatasetSeries: # this annotation should really be Optional[Type[Dataset]] # but we cannot import the yt.data_objects.static_output.Dataset # class here without creating a circular import for now - _dataset_cls: Optional[type] = None + _dataset_cls: Optional[type["Dataset"]] = None def __init_subclass__(cls, *args, **kwargs): super().__init_subclass__(*args, **kwargs) @@ -436,6 +440,85 @@ def particle_trajectories( self, indices, fields=fields, suppress_logging=suppress_logging, ptype=ptype ) + def get_by_key(self, key: str, value, tolerance=None) -> "Dataset": + r""" + Get a dataset at or near to a given value. + + Parameters + ---------- + key : str + The key by which to retrieve an output, usually 'current_time' or + 'current_redshift'. The key must be an attribute of the dataset + and monotonic. + value : float + The value to search for. + tolerance : float + If not None, do not return a dataset unless the value is + within the tolerance value. If None, simply return the + nearest dataset. + Default: None. + + Examples + -------- + >>> ds = ts.get_by_key("current_redshift", 0.0) + """ + + # Use a binary search to find the closest value + iL = 0 + iH = len(self._pre_outputs) - 1 + + if iL == iH: + ds = self[0] + if tolerance is not None and abs(getattr(ds, key) - value) > tolerance: + raise ValueError( + f"No dataset found with {key} within {tolerance} of {value}." + ) + return ds + + # Check signedness + dsL = self[iL] + dsH = self[iH] + vL = getattr(dsL, key) + vH = getattr(dsH, key) + + if vL < vH: + sign = 1 + elif vL > vH: + sign = -1 + else: + raise ValueError( + f"{dsL} and {dsH} have both {key}={vL}, cannot perform search. " + "Try with another key." + ) + + if sign * value < sign * vL: + return dsL + elif sign * value > sign * vH: + return dsH + + while iH - iL > 1: + iM = (iH + iL) // 2 + dsM = self[iM] + vM = getattr(dsM, key) + if sign * value < sign * vM: + iH = iM + dsH = dsM + elif sign * value > sign * vM: + iL = iM + dsL = dsM + + if abs(value - getattr(dsL, key)) < abs(value - getattr(dsH, key)): + ds_best = dsL + else: + ds_best = dsH + + if tolerance is not None: + if abs(value - getattr(ds_best, key)) > tolerance: + raise ValueError( + f"No dataset found with {key} within {tolerance} of {value}." + ) + return ds_best + class TimeSeriesQuantitiesContainer: def __init__(self, data_object, quantities): From a967fce48e2a431314350e38f62363e1a911aff3 Mon Sep 17 00:00:00 2001 From: Corentin Cadiou Date: Fri, 27 Oct 2023 16:25:41 +0100 Subject: [PATCH 02/20] Test the ability to get output from key --- yt/data_objects/tests/test_time_series.py | 129 ++++++++++++++-------- 1 file changed, 84 insertions(+), 45 deletions(-) diff --git a/yt/data_objects/tests/test_time_series.py b/yt/data_objects/tests/test_time_series.py index 99c5472d61..2d7980988c 100644 --- a/yt/data_objects/tests/test_time_series.py +++ b/yt/data_objects/tests/test_time_series.py @@ -1,6 +1,7 @@ import tempfile from pathlib import Path +import pytest from numpy.testing import assert_raises from yt.data_objects.static_output import Dataset @@ -34,7 +35,48 @@ def test_no_match_pattern(): ) -def test_init_fake_dataseries(): +@pytest.fixture +def FakeDataset(): + i = 0 + + class __FakeDataset(Dataset): + """A minimal loadable fake dataset subclass""" + + def __init__(self, *args, **kwargs): + nonlocal i + super().__init__(*args, **kwargs) + self.current_time = i + self.current_opposite_time = -i + i += 1 + + @classmethod + def _is_valid(cls, *args, **kwargs): + return True + + def _parse_parameter_file(self): + return + + def _set_code_unit_attributes(self): + return + + def set_code_units(self): + self.current_time = 0 + return + + def _hash(self): + return + + def _setup_classes(self): + return + + try: + yield __FakeDataset + finally: + output_type_registry.pop("__FakeDataset") + + +@pytest.fixture +def fake_datasets(): file_list = [f"fake_data_file_{str(i).zfill(4)}" for i in range(10)] with tempfile.TemporaryDirectory() as tmpdir: pfile_list = [Path(tmpdir) / file for file in file_list] @@ -43,62 +85,59 @@ def test_init_fake_dataseries(): file.touch() pattern = Path(tmpdir) / "fake_data_file_*" - # init from str pattern - ts = DatasetSeries(pattern) - assert ts._pre_outputs == sfile_list + yield file_list, pfile_list, sfile_list, pattern + + +def test_init_fake_dataseries(fake_datasets): + file_list, pfile_list, sfile_list, pattern = fake_datasets - # init from Path pattern - ppattern = Path(pattern) - ts = DatasetSeries(ppattern) - assert ts._pre_outputs == sfile_list + # init from str pattern + ts = DatasetSeries(pattern) + assert ts._pre_outputs == sfile_list - # init form str list - ts = DatasetSeries(sfile_list) - assert ts._pre_outputs == sfile_list + # init from Path pattern + ppattern = Path(pattern) + ts = DatasetSeries(ppattern) + assert ts._pre_outputs == sfile_list - # init form Path list - ts = DatasetSeries(pfile_list) - assert ts._pre_outputs == pfile_list + # init form str list + ts = DatasetSeries(sfile_list) + assert ts._pre_outputs == sfile_list - # rejected input type (str repr of a list) "[file1, file2, ...]" - assert_raises(FileNotFoundError, DatasetSeries, str(file_list)) + # init form Path list + ts = DatasetSeries(pfile_list) + assert ts._pre_outputs == pfile_list - # finally, check that ts[0] fails to actually load - assert_raises(YTUnidentifiedDataType, ts.__getitem__, 0) + # rejected input type (str repr of a list) "[file1, file2, ...]" + assert_raises(FileNotFoundError, DatasetSeries, str(file_list)) - class FakeDataset(Dataset): - """A minimal loadable fake dataset subclass""" + # finally, check that ts[0] fails to actually load + assert_raises(YTUnidentifiedDataType, ts.__getitem__, 0) - @classmethod - def _is_valid(cls, *args, **kwargs): - return True - def _parse_parameter_file(self): - return +def test_init_fake_dataseries2(FakeDataset, fake_datasets): + _file_list, _pfile_list, _sfile_list, pattern = fake_datasets + ds = DatasetSeries(pattern)[0] + assert isinstance(ds, FakeDataset) - def _set_code_unit_attributes(self): - return + ts = DatasetSeries(pattern, my_unsupported_kwarg=None) - def set_code_units(self): - self.current_time = 0 - return + assert_raises(TypeError, ts.__getitem__, 0) - def _hash(self): - return - def _setup_classes(self): - return +def test_get_by_key(FakeDataset, fake_datasets): + file_list, _pfile_list, _sfile_list, pattern = fake_datasets + ts = DatasetSeries(pattern) - try: - ds = DatasetSeries(pattern)[0] - assert isinstance(ds, FakeDataset) + Ntot = len(file_list) - ts = DatasetSeries(pattern, my_unsupported_kwarg=None) + assert ts[0] == ts.get_by_key("current_time", -1) + assert ts[0] == ts.get_by_key("current_time", 0) + assert ts[1] == ts.get_by_key("current_time", 0.8) + assert ts[1] == ts.get_by_key("current_time", 1.2) + assert ts[Ntot - 1] == ts.get_by_key("current_time", Ntot - 1) + assert ts[Ntot - 1] == ts.get_by_key("current_time", Ntot) - assert_raises(TypeError, ts.__getitem__, 0) - # the exact error message is supposed to be this - # """__init__() got an unexpected keyword argument 'my_unsupported_kwarg'""" - # but it's hard to check for within the framework - finally: - # tear down to avoid possible breakage in following tests - output_type_registry.pop("FakeDataset") + assert ts[1] == ts.get_by_key("current_opposite_time", -1.2) + assert ts[1] == ts.get_by_key("current_opposite_time", -1) + assert ts[1] == ts.get_by_key("current_opposite_time", -0.6) From 42fd1ae58a8f56ca106f2b9956942f1cadf9f1a7 Mon Sep 17 00:00:00 2001 From: Corentin Cadiou Date: Mon, 30 Oct 2023 09:13:50 +0000 Subject: [PATCH 03/20] Only 'get_by_time' and '_by_redshift' are user-facing --- yt/data_objects/tests/test_time_series.py | 29 +++++--- yt/data_objects/time_series.py | 88 ++++++++++++++++++----- 2 files changed, 89 insertions(+), 28 deletions(-) diff --git a/yt/data_objects/tests/test_time_series.py b/yt/data_objects/tests/test_time_series.py index 2d7980988c..b8e38037b7 100644 --- a/yt/data_objects/tests/test_time_series.py +++ b/yt/data_objects/tests/test_time_series.py @@ -46,7 +46,7 @@ def __init__(self, *args, **kwargs): nonlocal i super().__init__(*args, **kwargs) self.current_time = i - self.current_opposite_time = -i + self.current_opposite_time = 1 / (i + 1) i += 1 @classmethod @@ -60,7 +60,6 @@ def _set_code_unit_attributes(self): return def set_code_units(self): - self.current_time = 0 return def _hash(self): @@ -131,13 +130,21 @@ def test_get_by_key(FakeDataset, fake_datasets): Ntot = len(file_list) - assert ts[0] == ts.get_by_key("current_time", -1) - assert ts[0] == ts.get_by_key("current_time", 0) - assert ts[1] == ts.get_by_key("current_time", 0.8) - assert ts[1] == ts.get_by_key("current_time", 1.2) - assert ts[Ntot - 1] == ts.get_by_key("current_time", Ntot - 1) - assert ts[Ntot - 1] == ts.get_by_key("current_time", Ntot) + assert ts[0] == ts.get_by_time(-1) + assert ts[0] == ts.get_by_time(0) + assert ts[1] == ts.get_by_time(0.8) + assert ts[1] == ts.get_by_time(1.2) + assert ts[Ntot - 1] == ts.get_by_time(Ntot - 1) + assert ts[Ntot - 1] == ts.get_by_time(Ntot) - assert ts[1] == ts.get_by_key("current_opposite_time", -1.2) - assert ts[1] == ts.get_by_key("current_opposite_time", -1) - assert ts[1] == ts.get_by_key("current_opposite_time", -0.6) + with pytest.raises(ValueError): + ts.get_by_time(-2, tolerance=0.1) + with pytest.raises(ValueError): + ts.get_by_time(1000, tolerance=0.1) + + assert ts[1] == ts.get_by_redshift(1 / 2.2) + assert ts[1] == ts.get_by_redshift(1 / 2) + assert ts[1] == ts.get_by_redshift(1 / 1.6) + + with pytest.raises(ValueError): + ts.get_by_redshift(1000, tolerance=0.1) diff --git a/yt/data_objects/time_series.py b/yt/data_objects/time_series.py index 3e9a6fe330..7931100541 100644 --- a/yt/data_objects/time_series.py +++ b/yt/data_objects/time_series.py @@ -6,10 +6,11 @@ import weakref from abc import ABC, abstractmethod from functools import wraps -from typing import Optional +from typing import Optional, Union import numpy as np from more_itertools import always_iterable +from unyt import unyt_quantity from yt.config import ytcfg from yt.data_objects.analyzer_objects import AnalysisTask, create_quantity_proxy @@ -440,7 +441,12 @@ def particle_trajectories( self, indices, fields=fields, suppress_logging=suppress_logging, ptype=ptype ) - def get_by_key(self, key: str, value, tolerance=None) -> "Dataset": + def _get_by_attribute( + self, + attribute: str, + value: Union[unyt_quantity, tuple[float, str]], + tolerance: Union[None, unyt_quantity, tuple[float, str]] = None, + ) -> "Dataset": r""" Get a dataset at or near to a given value. @@ -450,17 +456,13 @@ def get_by_key(self, key: str, value, tolerance=None) -> "Dataset": The key by which to retrieve an output, usually 'current_time' or 'current_redshift'. The key must be an attribute of the dataset and monotonic. - value : float + value : unyt_array or (value, unit) The value to search for. - tolerance : float + tolerance : unyt_array or (value, unit) If not None, do not return a dataset unless the value is within the tolerance value. If None, simply return the nearest dataset. Default: None. - - Examples - -------- - >>> ds = ts.get_by_key("current_redshift", 0.0) """ # Use a binary search to find the closest value @@ -469,17 +471,20 @@ def get_by_key(self, key: str, value, tolerance=None) -> "Dataset": if iL == iH: ds = self[0] - if tolerance is not None and abs(getattr(ds, key) - value) > tolerance: + if ( + tolerance is not None + and abs(getattr(ds, attribute) - value) > tolerance + ): raise ValueError( - f"No dataset found with {key} within {tolerance} of {value}." + f"No dataset found with {attribute} within {tolerance} of {value}." ) return ds # Check signedness dsL = self[iL] dsH = self[iH] - vL = getattr(dsL, key) - vH = getattr(dsH, key) + vL = getattr(dsL, attribute) + vH = getattr(dsH, attribute) if vL < vH: sign = 1 @@ -487,10 +492,13 @@ def get_by_key(self, key: str, value, tolerance=None) -> "Dataset": sign = -1 else: raise ValueError( - f"{dsL} and {dsH} have both {key}={vL}, cannot perform search. " + f"{dsL} and {dsH} have both {attribute}={vL}, cannot perform search. " "Try with another key." ) + if isinstance(value, tuple): + value = dsL.quan(*value) + if sign * value < sign * vL: return dsL elif sign * value > sign * vH: @@ -499,7 +507,7 @@ def get_by_key(self, key: str, value, tolerance=None) -> "Dataset": while iH - iL > 1: iM = (iH + iL) // 2 dsM = self[iM] - vM = getattr(dsM, key) + vM = getattr(dsM, attribute) if sign * value < sign * vM: iH = iM dsH = dsM @@ -507,18 +515,64 @@ def get_by_key(self, key: str, value, tolerance=None) -> "Dataset": iL = iM dsL = dsM - if abs(value - getattr(dsL, key)) < abs(value - getattr(dsH, key)): + if abs(value - getattr(dsL, attribute)) < abs(value - getattr(dsH, attribute)): ds_best = dsL else: ds_best = dsH if tolerance is not None: - if abs(value - getattr(ds_best, key)) > tolerance: + if abs(value - getattr(ds_best, attribute)) > tolerance: raise ValueError( - f"No dataset found with {key} within {tolerance} of {value}." + f"No dataset found with {attribute} within {tolerance} of {value}." ) return ds_best + def get_by_time( + self, + time: Union[unyt_quantity, tuple], + tolerance: Union[None, unyt_quantity, tuple] = None, + ): + """ + Get a dataset at or near to a given time. + + Parameters + ---------- + time : unyt_quantity or (value, unit) + The time to search for. + tolerance : unyt_quantity or (value, unit) + If not None, do not return a dataset unless the time is + within the tolerance value. If None, simply return the + nearest dataset. + Default: None. + + Examples + -------- + >>> ds = ts.get_by_time((12, "Gyr")) + >>> t = ts[0].quan(12, "Gyr") + ... ds = ts.get_by_time(t, tolerance=(100, "Myr")) + """ + return self._get_by_attribute("current_time", time, tolerance=tolerance) + + def get_by_redshift(self, redshift: float, tolerance: Optional[float] = None): + """ + Get a dataset at or near to a given time. + + Parameters + ---------- + redshift : unyt_quantity or (value, unit) + The redshift to search for. + tolerance : unyt_quantity or (value, unit) + If not None, do not return a dataset unless the time is + within the tolerance value. If None, simply return the + nearest dataset. + Default: None. + + Examples + -------- + >>> ds = ts.get_redshift_time(0.0) + """ + return self._get_by_attribute("current_redshift", redshift, tolerance=tolerance) + class TimeSeriesQuantitiesContainer: def __init__(self, data_object, quantities): From 49c691038ac33a5e13ffcdb05e927a1188dd5d8f Mon Sep 17 00:00:00 2001 From: Corentin Cadiou Date: Mon, 30 Oct 2023 09:32:31 +0000 Subject: [PATCH 04/20] Fix typos in docstring --- yt/data_objects/time_series.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/yt/data_objects/time_series.py b/yt/data_objects/time_series.py index 7931100541..b6a0bbedb4 100644 --- a/yt/data_objects/time_series.py +++ b/yt/data_objects/time_series.py @@ -559,17 +559,17 @@ def get_by_redshift(self, redshift: float, tolerance: Optional[float] = None): Parameters ---------- - redshift : unyt_quantity or (value, unit) + redshift : float The redshift to search for. - tolerance : unyt_quantity or (value, unit) - If not None, do not return a dataset unless the time is + tolerance : float + If not None, do not return a dataset unless the redshift is within the tolerance value. If None, simply return the nearest dataset. Default: None. Examples -------- - >>> ds = ts.get_redshift_time(0.0) + >>> ds = ts.get_by_redshift(0.0) """ return self._get_by_attribute("current_redshift", redshift, tolerance=tolerance) From 4a83b764fc465ab85a9494478b971d80a7424399 Mon Sep 17 00:00:00 2001 From: Corentin Cadiou Date: Mon, 30 Oct 2023 09:53:41 +0000 Subject: [PATCH 05/20] Match by filename --- yt/data_objects/tests/test_time_series.py | 39 +++++++++++------------ 1 file changed, 19 insertions(+), 20 deletions(-) diff --git a/yt/data_objects/tests/test_time_series.py b/yt/data_objects/tests/test_time_series.py index b8e38037b7..cefb17fb9f 100644 --- a/yt/data_objects/tests/test_time_series.py +++ b/yt/data_objects/tests/test_time_series.py @@ -37,17 +37,11 @@ def test_no_match_pattern(): @pytest.fixture def FakeDataset(): - i = 0 - class __FakeDataset(Dataset): """A minimal loadable fake dataset subclass""" def __init__(self, *args, **kwargs): - nonlocal i super().__init__(*args, **kwargs) - self.current_time = i - self.current_opposite_time = 1 / (i + 1) - i += 1 @classmethod def _is_valid(cls, *args, **kwargs): @@ -60,6 +54,9 @@ def _set_code_unit_attributes(self): return def set_code_units(self): + i = int(Path(self.filename).name.split("_")[-1]) + self.current_time = i + self.current_redshift = 1 / (i + 1) return def _hash(self): @@ -76,7 +73,7 @@ def _setup_classes(self): @pytest.fixture def fake_datasets(): - file_list = [f"fake_data_file_{str(i).zfill(4)}" for i in range(10)] + file_list = [f"fake_data_file_{i:04d}" for i in range(10)] with tempfile.TemporaryDirectory() as tmpdir: pfile_list = [Path(tmpdir) / file for file in file_list] sfile_list = [str(file) for file in pfile_list] @@ -125,26 +122,28 @@ def test_init_fake_dataseries2(FakeDataset, fake_datasets): def test_get_by_key(FakeDataset, fake_datasets): - file_list, _pfile_list, _sfile_list, pattern = fake_datasets + _file_list, _pfile_list, sfile_list, pattern = fake_datasets ts = DatasetSeries(pattern) - Ntot = len(file_list) + Ntot = len(sfile_list) + + t = ts[0].quan(1, "code_time") - assert ts[0] == ts.get_by_time(-1) - assert ts[0] == ts.get_by_time(0) - assert ts[1] == ts.get_by_time(0.8) - assert ts[1] == ts.get_by_time(1.2) - assert ts[Ntot - 1] == ts.get_by_time(Ntot - 1) - assert ts[Ntot - 1] == ts.get_by_time(Ntot) + assert sfile_list[0] == ts.get_by_time(-t).filename + assert sfile_list[0] == ts.get_by_time(t - t).filename + assert sfile_list[1] == ts.get_by_time(0.8 * t).filename + assert sfile_list[1] == ts.get_by_time(1.2 * t).filename + assert sfile_list[Ntot - 1] == ts.get_by_time(t * (Ntot - 1)).filename + assert sfile_list[Ntot - 1] == ts.get_by_time(t * Ntot).filename with pytest.raises(ValueError): - ts.get_by_time(-2, tolerance=0.1) + ts.get_by_time(-2 * t, tolerance=0.1 * t) with pytest.raises(ValueError): - ts.get_by_time(1000, tolerance=0.1) + ts.get_by_time(1000 * t, tolerance=0.1 * t) - assert ts[1] == ts.get_by_redshift(1 / 2.2) - assert ts[1] == ts.get_by_redshift(1 / 2) - assert ts[1] == ts.get_by_redshift(1 / 1.6) + assert sfile_list[1] == ts.get_by_redshift(1 / 2.2).filename + assert sfile_list[1] == ts.get_by_redshift(1 / 2).filename + assert sfile_list[1] == ts.get_by_redshift(1 / 1.6).filename with pytest.raises(ValueError): ts.get_by_redshift(1000, tolerance=0.1) From 4883f5701de2f08289c880404f49ee515c86c6e9 Mon Sep 17 00:00:00 2001 From: Corentin Cadiou Date: Mon, 30 Oct 2023 09:54:06 +0000 Subject: [PATCH 06/20] Short-circuit iterations for out-of-bounds and exact matches --- yt/data_objects/time_series.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/yt/data_objects/time_series.py b/yt/data_objects/time_series.py index b6a0bbedb4..6394d839b8 100644 --- a/yt/data_objects/time_series.py +++ b/yt/data_objects/time_series.py @@ -498,11 +498,12 @@ def _get_by_attribute( if isinstance(value, tuple): value = dsL.quan(*value) + if isinstance(tolerance, tuple): + tolerance = dsL.quan(*tolerance) - if sign * value < sign * vL: - return dsL - elif sign * value > sign * vH: - return dsH + # Short-circuit if value is out-of-range + if not (vL * sign < value * sign < vH * sign): + iL = iH = 0 while iH - iL > 1: iM = (iH + iL) // 2 @@ -514,6 +515,9 @@ def _get_by_attribute( elif sign * value > sign * vM: iL = iM dsL = dsM + else: # Exact match + dsL = dsH = dsM + break if abs(value - getattr(dsL, attribute)) < abs(value - getattr(dsH, attribute)): ds_best = dsL From 87a2686401ed4144c76eee7d1449798b09c5d026 Mon Sep 17 00:00:00 2001 From: Corentin Cadiou Date: Mon, 30 Oct 2023 09:55:26 +0000 Subject: [PATCH 07/20] Make sure we can pass tuples for unitfull quantities --- yt/data_objects/tests/test_time_series.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/yt/data_objects/tests/test_time_series.py b/yt/data_objects/tests/test_time_series.py index cefb17fb9f..35b342fecf 100644 --- a/yt/data_objects/tests/test_time_series.py +++ b/yt/data_objects/tests/test_time_series.py @@ -131,8 +131,8 @@ def test_get_by_key(FakeDataset, fake_datasets): assert sfile_list[0] == ts.get_by_time(-t).filename assert sfile_list[0] == ts.get_by_time(t - t).filename - assert sfile_list[1] == ts.get_by_time(0.8 * t).filename - assert sfile_list[1] == ts.get_by_time(1.2 * t).filename + assert sfile_list[1] == ts.get_by_time((0.8, "code_time")).filename + assert sfile_list[1] == ts.get_by_time((1.2, "code_time")).filename assert sfile_list[Ntot - 1] == ts.get_by_time(t * (Ntot - 1)).filename assert sfile_list[Ntot - 1] == ts.get_by_time(t * Ntot).filename From 559cab00dcf9d05354eaf1f25a77417178656017 Mon Sep 17 00:00:00 2001 From: Corentin Cadiou Date: Mon, 30 Oct 2023 10:48:05 +0000 Subject: [PATCH 08/20] Skip time_series for nose testing --- nose_ignores.txt | 1 + tests/tests.yaml | 1 + 2 files changed, 2 insertions(+) diff --git a/nose_ignores.txt b/nose_ignores.txt index b5529a6ecc..61d20fab0b 100644 --- a/nose_ignores.txt +++ b/nose_ignores.txt @@ -47,3 +47,4 @@ --ignore-file=test_disks\.py --ignore-file=test_offaxisprojection_pytestonly\.py --ignore-file=test_sph_pixelization_pytestonly\.py +--ignore-file=test_time_series\.py diff --git a/tests/tests.yaml b/tests/tests.yaml index 267c1a95a8..fc7a880b3c 100644 --- a/tests/tests.yaml +++ b/tests/tests.yaml @@ -223,6 +223,7 @@ other_tests: - "--ignore-file=test_gadget_pytest\\.py" - "--ignore-file=test_vr_orientation\\.py" - "--ignore-file=test_particle_trajectories_pytest\\.py" + - "--ignore-file=test_time_series\\.py" - "--exclude-test=yt.frontends.gdf.tests.test_outputs.TestGDF" - "--exclude-test=yt.frontends.adaptahop.tests.test_outputs" - "--exclude-test=yt.frontends.stream.tests.test_stream_particles.test_stream_non_cartesian_particles" From 77e705bfde6caf22850e1a0ab1e1832b0f945334 Mon Sep 17 00:00:00 2001 From: Corentin Cadiou Date: Wed, 1 Nov 2023 17:52:38 +0000 Subject: [PATCH 09/20] Provide 'side' to pick whether we want the closest, smaller or larger value --- yt/data_objects/time_series.py | 47 ++++++++++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 5 deletions(-) diff --git a/yt/data_objects/time_series.py b/yt/data_objects/time_series.py index 6394d839b8..ef0cba525d 100644 --- a/yt/data_objects/time_series.py +++ b/yt/data_objects/time_series.py @@ -6,7 +6,7 @@ import weakref from abc import ABC, abstractmethod from functools import wraps -from typing import Optional, Union +from typing import Literal, Optional, Union import numpy as np from more_itertools import always_iterable @@ -446,6 +446,9 @@ def _get_by_attribute( attribute: str, value: Union[unyt_quantity, tuple[float, str]], tolerance: Union[None, unyt_quantity, tuple[float, str]] = None, + side: Union[ + Literal["nearest"], Literal["smaller"], Literal["larger"] + ] = "nearest", ) -> "Dataset": r""" Get a dataset at or near to a given value. @@ -463,8 +466,16 @@ def _get_by_attribute( within the tolerance value. If None, simply return the nearest dataset. Default: None. + side : str + The side of the value to return. Can be 'nearest', 'smaller' or 'larger'. + Default: 'nearest'. """ + if side not in ("nearest", "smaller", "larger"): + raise ValueError( + f"side must be 'nearest', 'smaller' or 'larger', got {side}" + ) + # Use a binary search to find the closest value iL = 0 iH = len(self._pre_outputs) - 1 @@ -519,7 +530,13 @@ def _get_by_attribute( dsL = dsH = dsM break - if abs(value - getattr(dsL, attribute)) < abs(value - getattr(dsH, attribute)): + if side == "smaller": + ds_best = dsL if sign > 0 else dsH + elif side == "larger": + ds_best = dsH if sign > 0 else dsL + elif abs(value - getattr(dsL, attribute)) < abs( + value - getattr(dsH, attribute) + ): ds_best = dsL else: ds_best = dsH @@ -535,6 +552,9 @@ def get_by_time( self, time: Union[unyt_quantity, tuple], tolerance: Union[None, unyt_quantity, tuple] = None, + side: Union[ + Literal["nearest"], Literal["smaller"], Literal["larger"] + ] = "nearest", ): """ Get a dataset at or near to a given time. @@ -548,6 +568,9 @@ def get_by_time( within the tolerance value. If None, simply return the nearest dataset. Default: None. + side : str + The side of the value to return. Can be 'nearest', 'smaller' or 'larger'. + Default: 'nearest'. Examples -------- @@ -555,9 +578,18 @@ def get_by_time( >>> t = ts[0].quan(12, "Gyr") ... ds = ts.get_by_time(t, tolerance=(100, "Myr")) """ - return self._get_by_attribute("current_time", time, tolerance=tolerance) + return self._get_by_attribute( + "current_time", time, tolerance=tolerance, side=side + ) - def get_by_redshift(self, redshift: float, tolerance: Optional[float] = None): + def get_by_redshift( + self, + redshift: float, + tolerance: Optional[float] = None, + side: Union[ + Literal["nearest"], Literal["smaller"], Literal["larger"] + ] = "nearest", + ): """ Get a dataset at or near to a given time. @@ -570,12 +602,17 @@ def get_by_redshift(self, redshift: float, tolerance: Optional[float] = None): within the tolerance value. If None, simply return the nearest dataset. Default: None. + side : str + The side of the value to return. Can be 'nearest', 'smaller' or 'larger'. + Default: 'nearest'. Examples -------- >>> ds = ts.get_by_redshift(0.0) """ - return self._get_by_attribute("current_redshift", redshift, tolerance=tolerance) + return self._get_by_attribute( + "current_redshift", redshift, tolerance=tolerance, side=side + ) class TimeSeriesQuantitiesContainer: From ad5a6c976247594a0e773b6b0d45047c8aca0566 Mon Sep 17 00:00:00 2001 From: Corentin Cadiou Date: Thu, 2 Nov 2023 09:20:58 +0000 Subject: [PATCH 10/20] Add documentation --- doc/source/analyzing/time_series_analysis.rst | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/doc/source/analyzing/time_series_analysis.rst b/doc/source/analyzing/time_series_analysis.rst index 9f4db0fc2a..6e3241f382 100644 --- a/doc/source/analyzing/time_series_analysis.rst +++ b/doc/source/analyzing/time_series_analysis.rst @@ -77,6 +77,29 @@ see: * The cookbook recipe for :ref:`cookbook-time-series-analysis` * :class:`~yt.data_objects.time_series.DatasetSeries` +In addition, the :class:`~yt.data_objects.time_series.DatasetSeries` object allows to +select an output based on its time or by its redshift (if defined) as follows: + +.. code-block:: python + + import yt + + ts = yt.load("*/*.index") + # Get output at 3 Gyr + ds = ts.get_by_time((3, "Gyr")) + # This will fail if no output is found within 100 Myr + ds = ts.get_by_time((3, "Gyr"), tolerance=(100, "Myr")) + # Get the output at the time right before and after 3 Gyr + ds_before = ts.get_by_time((3, "Gyr"), side="smaller") + ds_after = ts.get_by_time((3, "Gyr"), side="larger") + + # For cosmological simulations, you can also select an output by its redshift + # with the same options as above + ds = ts.get_by_redshift(0.5) + +For more information, see :meth:`~yt.data_objects.time_series.DatasetSeries.get_by_time` and +:meth:`~yt.data_objects.time_series.DatasetSeries.get_by_redshift`. + .. _analyzing-an-entire-simulation: Analyzing an Entire Simulation From 652b685cced358f964b6677ceac6b797a011de55 Mon Sep 17 00:00:00 2001 From: Corentin Cadiou Date: Mon, 27 Nov 2023 16:48:28 +0100 Subject: [PATCH 11/20] Update yt/data_objects/tests/test_time_series.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Clément Robert --- yt/data_objects/tests/test_time_series.py | 6 +++--- yt/data_objects/time_series.py | 26 +++++++++-------------- 2 files changed, 13 insertions(+), 19 deletions(-) diff --git a/yt/data_objects/tests/test_time_series.py b/yt/data_objects/tests/test_time_series.py index 35b342fecf..56b9c80539 100644 --- a/yt/data_objects/tests/test_time_series.py +++ b/yt/data_objects/tests/test_time_series.py @@ -37,7 +37,7 @@ def test_no_match_pattern(): @pytest.fixture def FakeDataset(): - class __FakeDataset(Dataset): + class _FakeDataset(Dataset): """A minimal loadable fake dataset subclass""" def __init__(self, *args, **kwargs): @@ -66,9 +66,9 @@ def _setup_classes(self): return try: - yield __FakeDataset + yield _FakeDataset finally: - output_type_registry.pop("__FakeDataset") + output_type_registry.pop("_FakeDataset") @pytest.fixture diff --git a/yt/data_objects/time_series.py b/yt/data_objects/time_series.py index ef0cba525d..5d23ef52f7 100644 --- a/yt/data_objects/time_series.py +++ b/yt/data_objects/time_series.py @@ -10,7 +10,7 @@ import numpy as np from more_itertools import always_iterable -from unyt import unyt_quantity +from unyt import Unit, unyt_quantity from yt.config import ytcfg from yt.data_objects.analyzer_objects import AnalysisTask, create_quantity_proxy @@ -444,11 +444,9 @@ def particle_trajectories( def _get_by_attribute( self, attribute: str, - value: Union[unyt_quantity, tuple[float, str]], - tolerance: Union[None, unyt_quantity, tuple[float, str]] = None, - side: Union[ - Literal["nearest"], Literal["smaller"], Literal["larger"] - ] = "nearest", + value: Union[unyt_quantity, tuple[float, Union[Unit, str]]], + tolerance: Union[None, unyt_quantity, tuple[float, Union[Unit, str]]] = None, + side: Literal["nearest", "smaller", "larger"] = "nearest", ) -> "Dataset": r""" Get a dataset at or near to a given value. @@ -550,12 +548,10 @@ def _get_by_attribute( def get_by_time( self, - time: Union[unyt_quantity, tuple], - tolerance: Union[None, unyt_quantity, tuple] = None, - side: Union[ - Literal["nearest"], Literal["smaller"], Literal["larger"] - ] = "nearest", - ): + time: Union[unyt_quantity, tuple[float, Union[Unit, str]]], + tolerance: Union[None, unyt_quantity, tuple[float, Union[Unit, str]]] = None, + side: Literal["nearest", "smaller", "larger"] = "nearest", + ) -> Dataset: """ Get a dataset at or near to a given time. @@ -586,10 +582,8 @@ def get_by_redshift( self, redshift: float, tolerance: Optional[float] = None, - side: Union[ - Literal["nearest"], Literal["smaller"], Literal["larger"] - ] = "nearest", - ): + side: Literal["nearest", "smaller", "larger"] = "nearest", + ) -> Dataset: """ Get a dataset at or near to a given time. From e5e609a76326791484cebabd4162315d70983a20 Mon Sep 17 00:00:00 2001 From: Corentin Cadiou Date: Mon, 27 Nov 2023 16:43:34 +0100 Subject: [PATCH 12/20] Replace numpy.assert_raises with pytest's version --- yt/data_objects/tests/test_time_series.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/yt/data_objects/tests/test_time_series.py b/yt/data_objects/tests/test_time_series.py index 56b9c80539..babd349df7 100644 --- a/yt/data_objects/tests/test_time_series.py +++ b/yt/data_objects/tests/test_time_series.py @@ -2,7 +2,6 @@ from pathlib import Path import pytest -from numpy.testing import assert_raises from yt.data_objects.static_output import Dataset from yt.data_objects.time_series import DatasetSeries @@ -30,9 +29,8 @@ def test_pattern_expansion(): def test_no_match_pattern(): with tempfile.TemporaryDirectory() as tmpdir: pattern = Path(tmpdir).joinpath("fake_data_file_*") - assert_raises( - FileNotFoundError, DatasetSeries._get_filenames_from_glob_pattern, pattern - ) + with pytest.raises(FileNotFoundError): + DatasetSeries._get_filenames_from_glob_pattern(pattern) @pytest.fixture @@ -105,10 +103,12 @@ def test_init_fake_dataseries(fake_datasets): assert ts._pre_outputs == pfile_list # rejected input type (str repr of a list) "[file1, file2, ...]" - assert_raises(FileNotFoundError, DatasetSeries, str(file_list)) + with pytest.raises(FileNotFoundError): + DatasetSeries(str(file_list)) # finally, check that ts[0] fails to actually load - assert_raises(YTUnidentifiedDataType, ts.__getitem__, 0) + with pytest.raises(YTUnidentifiedDataType): + ts[0] def test_init_fake_dataseries2(FakeDataset, fake_datasets): @@ -118,7 +118,8 @@ def test_init_fake_dataseries2(FakeDataset, fake_datasets): ts = DatasetSeries(pattern, my_unsupported_kwarg=None) - assert_raises(TypeError, ts.__getitem__, 0) + with pytest.raises(TypeError): + ts[0] def test_get_by_key(FakeDataset, fake_datasets): From bc101e32e0a6e236a5a78cb74f646f3b04e8bdac Mon Sep 17 00:00:00 2001 From: Corentin Cadiou Date: Mon, 27 Nov 2023 16:45:05 +0100 Subject: [PATCH 13/20] Import TYPE_CHECKING from typing --- yt/data_objects/time_series.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/yt/data_objects/time_series.py b/yt/data_objects/time_series.py index 5d23ef52f7..6c31015b8b 100644 --- a/yt/data_objects/time_series.py +++ b/yt/data_objects/time_series.py @@ -2,11 +2,10 @@ import glob import inspect import os -import typing import weakref from abc import ABC, abstractmethod from functools import wraps -from typing import Literal, Optional, Union +from typing import TYPE_CHECKING, Literal, Optional, Union import numpy as np from more_itertools import always_iterable @@ -30,7 +29,7 @@ parallel_root_only, ) -if typing.TYPE_CHECKING: +if TYPE_CHECKING: from yt.data_objects.static_output import Dataset From cd8970e582f1088f4f3145d8ddfeccb1fcbcb9aa Mon Sep 17 00:00:00 2001 From: Corentin Cadiou Date: Mon, 27 Nov 2023 16:45:28 +0100 Subject: [PATCH 14/20] Remove outdated comment --- yt/data_objects/time_series.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/yt/data_objects/time_series.py b/yt/data_objects/time_series.py index 6c31015b8b..2a79a43281 100644 --- a/yt/data_objects/time_series.py +++ b/yt/data_objects/time_series.py @@ -147,10 +147,6 @@ class DatasetSeries: ... SlicePlot(ds, "x", ("gas", "density")).save() """ - - # this annotation should really be Optional[Type[Dataset]] - # but we cannot import the yt.data_objects.static_output.Dataset - # class here without creating a circular import for now _dataset_cls: Optional[type["Dataset"]] = None def __init_subclass__(cls, *args, **kwargs): From 7fba4c8710a03a604ff5c45fdb525ea774c8802f Mon Sep 17 00:00:00 2001 From: Corentin Cadiou Date: Mon, 27 Nov 2023 16:58:53 +0100 Subject: [PATCH 15/20] Rename iH(igh) to the more canonical iR(ight) --- yt/data_objects/time_series.py | 36 +++++++++++++++++----------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/yt/data_objects/time_series.py b/yt/data_objects/time_series.py index 2a79a43281..6b98b52b0b 100644 --- a/yt/data_objects/time_series.py +++ b/yt/data_objects/time_series.py @@ -471,9 +471,9 @@ def _get_by_attribute( # Use a binary search to find the closest value iL = 0 - iH = len(self._pre_outputs) - 1 + iR = len(self._pre_outputs) - 1 - if iL == iH: + if iL == iR: ds = self[0] if ( tolerance is not None @@ -486,17 +486,17 @@ def _get_by_attribute( # Check signedness dsL = self[iL] - dsH = self[iH] + dsR = self[iR] vL = getattr(dsL, attribute) - vH = getattr(dsH, attribute) + vR = getattr(dsR, attribute) - if vL < vH: + if vL < vR: sign = 1 - elif vL > vH: + elif vL > vR: sign = -1 else: raise ValueError( - f"{dsL} and {dsH} have both {attribute}={vL}, cannot perform search. " + f"{dsL} and {dsR} have both {attribute}={vL}, cannot perform search." "Try with another key." ) @@ -506,33 +506,33 @@ def _get_by_attribute( tolerance = dsL.quan(*tolerance) # Short-circuit if value is out-of-range - if not (vL * sign < value * sign < vH * sign): - iL = iH = 0 + if not (vL * sign < value * sign < vR * sign): + iL = iR = 0 - while iH - iL > 1: - iM = (iH + iL) // 2 + while iR - iL > 1: + iM = (iR + iL) // 2 dsM = self[iM] vM = getattr(dsM, attribute) if sign * value < sign * vM: - iH = iM - dsH = dsM + iR = iM + dsR = dsM elif sign * value > sign * vM: iL = iM dsL = dsM else: # Exact match - dsL = dsH = dsM + dsL = dsR = dsM break if side == "smaller": - ds_best = dsL if sign > 0 else dsH + ds_best = dsL if sign > 0 else dsR elif side == "larger": - ds_best = dsH if sign > 0 else dsL + ds_best = dsR if sign > 0 else dsL elif abs(value - getattr(dsL, attribute)) < abs( - value - getattr(dsH, attribute) + value - getattr(dsR, attribute) ): ds_best = dsL else: - ds_best = dsH + ds_best = dsR if tolerance is not None: if abs(value - getattr(ds_best, attribute)) > tolerance: From 8f05febf4ca20d4a950b9e1a60a0d1c62241040f Mon Sep 17 00:00:00 2001 From: Corentin Cadiou Date: Mon, 27 Nov 2023 17:03:34 +0100 Subject: [PATCH 16/20] Rename side to more explicit 'prefer' kwa --- doc/source/analyzing/time_series_analysis.rst | 4 +-- yt/data_objects/time_series.py | 30 +++++++++---------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/doc/source/analyzing/time_series_analysis.rst b/doc/source/analyzing/time_series_analysis.rst index 6e3241f382..f0ac68d178 100644 --- a/doc/source/analyzing/time_series_analysis.rst +++ b/doc/source/analyzing/time_series_analysis.rst @@ -90,8 +90,8 @@ select an output based on its time or by its redshift (if defined) as follows: # This will fail if no output is found within 100 Myr ds = ts.get_by_time((3, "Gyr"), tolerance=(100, "Myr")) # Get the output at the time right before and after 3 Gyr - ds_before = ts.get_by_time((3, "Gyr"), side="smaller") - ds_after = ts.get_by_time((3, "Gyr"), side="larger") + ds_before = ts.get_by_time((3, "Gyr"), prefer="smaller") + ds_after = ts.get_by_time((3, "Gyr"), prefer="larger") # For cosmological simulations, you can also select an output by its redshift # with the same options as above diff --git a/yt/data_objects/time_series.py b/yt/data_objects/time_series.py index 6b98b52b0b..681aa3c243 100644 --- a/yt/data_objects/time_series.py +++ b/yt/data_objects/time_series.py @@ -441,32 +441,32 @@ def _get_by_attribute( attribute: str, value: Union[unyt_quantity, tuple[float, Union[Unit, str]]], tolerance: Union[None, unyt_quantity, tuple[float, Union[Unit, str]]] = None, - side: Literal["nearest", "smaller", "larger"] = "nearest", + prefer: Literal["nearest", "smaller", "larger"] = "nearest", ) -> "Dataset": r""" Get a dataset at or near to a given value. Parameters ---------- - key : str + attribute : str The key by which to retrieve an output, usually 'current_time' or 'current_redshift'. The key must be an attribute of the dataset and monotonic. - value : unyt_array or (value, unit) + value : unyt_quantity or (value, unit) The value to search for. - tolerance : unyt_array or (value, unit) + tolerance : unyt_quantity or (value, unit), optional If not None, do not return a dataset unless the value is within the tolerance value. If None, simply return the nearest dataset. Default: None. - side : str + prefer : str The side of the value to return. Can be 'nearest', 'smaller' or 'larger'. Default: 'nearest'. """ - if side not in ("nearest", "smaller", "larger"): + if prefer not in ("nearest", "smaller", "larger"): raise ValueError( - f"side must be 'nearest', 'smaller' or 'larger', got {side}" + f"side must be 'nearest', 'smaller' or 'larger', got {prefer}" ) # Use a binary search to find the closest value @@ -523,9 +523,9 @@ def _get_by_attribute( dsL = dsR = dsM break - if side == "smaller": + if prefer == "smaller": ds_best = dsL if sign > 0 else dsR - elif side == "larger": + elif prefer == "larger": ds_best = dsR if sign > 0 else dsL elif abs(value - getattr(dsL, attribute)) < abs( value - getattr(dsR, attribute) @@ -545,7 +545,7 @@ def get_by_time( self, time: Union[unyt_quantity, tuple[float, Union[Unit, str]]], tolerance: Union[None, unyt_quantity, tuple[float, Union[Unit, str]]] = None, - side: Literal["nearest", "smaller", "larger"] = "nearest", + prefer: Literal["nearest", "smaller", "larger"] = "nearest", ) -> Dataset: """ Get a dataset at or near to a given time. @@ -559,7 +559,7 @@ def get_by_time( within the tolerance value. If None, simply return the nearest dataset. Default: None. - side : str + prefer : str The side of the value to return. Can be 'nearest', 'smaller' or 'larger'. Default: 'nearest'. @@ -570,14 +570,14 @@ def get_by_time( ... ds = ts.get_by_time(t, tolerance=(100, "Myr")) """ return self._get_by_attribute( - "current_time", time, tolerance=tolerance, side=side + "current_time", time, tolerance=tolerance, prefer=prefer ) def get_by_redshift( self, redshift: float, tolerance: Optional[float] = None, - side: Literal["nearest", "smaller", "larger"] = "nearest", + prefer: Literal["nearest", "smaller", "larger"] = "nearest", ) -> Dataset: """ Get a dataset at or near to a given time. @@ -591,7 +591,7 @@ def get_by_redshift( within the tolerance value. If None, simply return the nearest dataset. Default: None. - side : str + prefer : str The side of the value to return. Can be 'nearest', 'smaller' or 'larger'. Default: 'nearest'. @@ -600,7 +600,7 @@ def get_by_redshift( >>> ds = ts.get_by_redshift(0.0) """ return self._get_by_attribute( - "current_redshift", redshift, tolerance=tolerance, side=side + "current_redshift", redshift, tolerance=tolerance, prefer=prefer ) From 199cfd7acd42d50e9e3808bc83984d74fef7c14a Mon Sep 17 00:00:00 2001 From: Corentin Cadiou Date: Mon, 27 Nov 2023 17:32:24 +0100 Subject: [PATCH 17/20] Fix typing --- yt/data_objects/time_series.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/yt/data_objects/time_series.py b/yt/data_objects/time_series.py index 681aa3c243..46f1b7fc85 100644 --- a/yt/data_objects/time_series.py +++ b/yt/data_objects/time_series.py @@ -546,7 +546,7 @@ def get_by_time( time: Union[unyt_quantity, tuple[float, Union[Unit, str]]], tolerance: Union[None, unyt_quantity, tuple[float, Union[Unit, str]]] = None, prefer: Literal["nearest", "smaller", "larger"] = "nearest", - ) -> Dataset: + ) -> "Dataset": """ Get a dataset at or near to a given time. @@ -578,7 +578,7 @@ def get_by_redshift( redshift: float, tolerance: Optional[float] = None, prefer: Literal["nearest", "smaller", "larger"] = "nearest", - ) -> Dataset: + ) -> "Dataset": """ Get a dataset at or near to a given time. From 385aeb2230c2a0e8185dd7914f632b15de108d52 Mon Sep 17 00:00:00 2001 From: Corentin Cadiou Date: Mon, 27 Nov 2023 17:38:37 +0100 Subject: [PATCH 18/20] Do not suggest very vague fix --- yt/data_objects/time_series.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/yt/data_objects/time_series.py b/yt/data_objects/time_series.py index 46f1b7fc85..0c24edcc66 100644 --- a/yt/data_objects/time_series.py +++ b/yt/data_objects/time_series.py @@ -466,7 +466,7 @@ def _get_by_attribute( if prefer not in ("nearest", "smaller", "larger"): raise ValueError( - f"side must be 'nearest', 'smaller' or 'larger', got {prefer}" + f"Side must be 'nearest', 'smaller' or 'larger', got {prefer}." ) # Use a binary search to find the closest value @@ -497,7 +497,6 @@ def _get_by_attribute( else: raise ValueError( f"{dsL} and {dsR} have both {attribute}={vL}, cannot perform search." - "Try with another key." ) if isinstance(value, tuple): From 70a57148f60324be9a916fdd5b24b7c625bbbb1a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 13 Apr 2024 06:12:01 +0000 Subject: [PATCH 19/20] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- yt/data_objects/time_series.py | 1 + 1 file changed, 1 insertion(+) diff --git a/yt/data_objects/time_series.py b/yt/data_objects/time_series.py index 0c24edcc66..6137976a70 100644 --- a/yt/data_objects/time_series.py +++ b/yt/data_objects/time_series.py @@ -147,6 +147,7 @@ class DatasetSeries: ... SlicePlot(ds, "x", ("gas", "density")).save() """ + _dataset_cls: Optional[type["Dataset"]] = None def __init_subclass__(cls, *args, **kwargs): From fe4642d5ccee680761f10dc6f52fbc18e1645140 Mon Sep 17 00:00:00 2001 From: Corentin Cadiou Date: Thu, 3 Oct 2024 10:14:31 +0200 Subject: [PATCH 20/20] Add extra test to check halfpoint Co-authored-by: Chris Havlin --- yt/data_objects/tests/test_time_series.py | 5 +++++ yt/data_objects/time_series.py | 1 + 2 files changed, 6 insertions(+) diff --git a/yt/data_objects/tests/test_time_series.py b/yt/data_objects/tests/test_time_series.py index babd349df7..9c28d24ca8 100644 --- a/yt/data_objects/tests/test_time_series.py +++ b/yt/data_objects/tests/test_time_series.py @@ -148,3 +148,8 @@ def test_get_by_key(FakeDataset, fake_datasets): with pytest.raises(ValueError): ts.get_by_redshift(1000, tolerance=0.1) + + zmid = (ts[0].current_redshift + ts[1].current_redshift) / 2 + + assert sfile_list[1] == ts.get_by_redshift(zmid, prefer="smaller").filename + assert sfile_list[0] == ts.get_by_redshift(zmid, prefer="larger").filename diff --git a/yt/data_objects/time_series.py b/yt/data_objects/time_series.py index 6137976a70..6edc5aa5fb 100644 --- a/yt/data_objects/time_series.py +++ b/yt/data_objects/time_series.py @@ -513,6 +513,7 @@ def _get_by_attribute( iM = (iR + iL) // 2 dsM = self[iM] vM = getattr(dsM, attribute) + if sign * value < sign * vM: iR = iM dsR = dsM