diff --git a/.github/workflows/pull_request_tests.yml b/.github/workflows/pull_request_tests.yml index cbb8f29d..56517637 100644 --- a/.github/workflows/pull_request_tests.yml +++ b/.github/workflows/pull_request_tests.yml @@ -28,6 +28,7 @@ jobs: pip install --upgrade pip pip install pytest pip install pytest-cov + pip install pytest-timeout pip install -e . - name: Run pytest and Generate coverage report run: | diff --git a/rex/multi_file_resource.py b/rex/multi_file_resource.py index 239144c9..613fa1a0 100644 --- a/rex/multi_file_resource.py +++ b/rex/multi_file_resource.py @@ -10,12 +10,12 @@ from rex.renewable_resource import (NSRDB, SolarResource, GeothermalResource, WindResource, WaveResource, AbstractInterpolatedResource) -from rex.resource import Resource +from rex.resource import Resource, BaseDatasetIterable from rex.utilities.exceptions import FileInputError, ResourceRuntimeError from rex.utilities.utilities import unstupify_path -class MultiH5: +class MultiH5(BaseDatasetIterable): """ Class to handle multiple h5 file Resources """ @@ -32,8 +32,6 @@ def __init__(self, h5_files, check_files=False): self._dset_map = self._map_file_dsets(h5_files) self._h5_map = self._map_file_instances(set(self._dset_map.values())) - self._i = 0 - if check_files: self._preflight_check() @@ -66,19 +64,6 @@ def __getitem__(self, dset): return ds - def __next__(self): - if self._i >= len(self.datasets): - self._i = 0 - raise StopIteration - - dset = self.datasets[self._i] - self._i += 1 - - return dset - - def __iter__(self): - return self - def __contains__(self, dset): return dset in self.datasets @@ -405,7 +390,6 @@ def __init__(self, h5_source, unscale=True, str_decode=True, self._shapes = None self._chunks = None self._dtypes = None - self._i = 0 self._interp_var = None self._use_lapse = use_lapse_rate diff --git a/rex/multi_res_resource.py b/rex/multi_res_resource.py index 75272b4d..1e55fd4e 100644 --- a/rex/multi_res_resource.py +++ b/rex/multi_res_resource.py @@ -70,7 +70,6 @@ def __init__(self, h5_hr, h5_lr, handler_class=Resource, self._lr_res = handler_class(h5_lr, **handle_kwargs) self._nn_map = nn_map self._nn_d = nn_d - self._i = 0 if self._nn_map is None: self._nn_d, self._nn_map = self.make_nn_map(self._hr_res, @@ -237,17 +236,7 @@ def __getitem__(self, keys): return out def __iter__(self): - return self - - def __next__(self): - if self._i >= len(self.datasets): - self._i = 0 - raise StopIteration - - dset = self.datasets[self._i] - self._i += 1 - - return dset + return iter(self.datasets) def __contains__(self, dset): return dset in self.datasets diff --git a/rex/multi_time_resource.py b/rex/multi_time_resource.py index 8b679070..2fcaae73 100644 --- a/rex/multi_time_resource.py +++ b/rex/multi_time_resource.py @@ -15,7 +15,7 @@ WaveResource, WindResource, ) -from rex.resource import Resource +from rex.resource import Resource, BaseDatasetIterable from rex.utilities.exceptions import FileInputError from rex.utilities.parse_keys import parse_keys, parse_slice @@ -58,7 +58,6 @@ def __init__(self, h5_path, res_cls=Resource, hsds=False, hsds_kwargs=None, self._shape = None self._time_index = None self._time_slice_map = [] - self._i = 0 def __repr__(self): msg = ("{} for {}:\n Contains data from {} files" @@ -419,7 +418,7 @@ def close(self): f.close() -class MultiTimeResource: +class MultiTimeResource(BaseDatasetIterable): """ Class to handle resource data stored temporally accross multiple .h5 files @@ -519,7 +518,6 @@ def __init__(self, h5_path, unscale=True, str_decode=True, self._h5 = MultiTimeH5(self.h5_path, res_cls=res_cls, **cls_kwargs) self.h5_files = self._h5.h5_files self.h5_file = self.h5_files[0] - self._i = 0 def __repr__(self): msg = "{} for {}".format(self.__class__.__name__, self.h5_path) @@ -537,19 +535,6 @@ def __exit__(self, type, value, traceback): def __len__(self): return len(self.h5.time_index) - def __iter__(self): - return self - - def __next__(self): - if self._i >= len(self.datasets): - self._i = 0 - raise StopIteration - - dset = self.datasets[self._i] - self._i += 1 - - return dset - def __getitem__(self, keys): ds, ds_slice = parse_keys(keys) diff --git a/rex/multi_year_resource.py b/rex/multi_year_resource.py index 21ba59d1..b6b04f1f 100644 --- a/rex/multi_year_resource.py +++ b/rex/multi_year_resource.py @@ -58,7 +58,6 @@ def __init__(self, h5_path, years=None, res_cls=Resource, hsds=False, self._datasets = None self._shape = None self._time_index = None - self._i = 0 def __repr__(self): msg = ("{} for {}:\n Contains data for {} years" @@ -82,17 +81,7 @@ def __getitem__(self, year): return h5 def __iter__(self): - return self - - def __next__(self): - if self._i >= len(self.years): - self._i = 0 - raise StopIteration - - year = self.years[self._i] - self._i += 1 - - return year + return iter(self.years) def __contains__(self, year): return year in self.years @@ -451,7 +440,6 @@ def __init__(self, h5_path, years=None, unscale=True, str_decode=True, **cls_kwargs) self.h5_files = self._h5.h5_files self.h5_file = self.h5_files[0] - self._i = 0 @property def years(self): diff --git a/rex/resource.py b/rex/resource.py index c078ba0d..64b7ccf5 100644 --- a/rex/resource.py +++ b/rex/resource.py @@ -3,7 +3,7 @@ Classes to handle resource data """ import os -from abc import ABC +from abc import ABC, abstractmethod from warnings import warn import dateutil @@ -17,6 +17,18 @@ from rex.utilities.utilities import check_tz, get_lat_lon_cols +class BaseDatasetIterable(ABC): + """Base class for file that is iterable over datasets. """ + + @property + @abstractmethod + def datasets(self): + """iterable: Datasets available in file. """ + + def __iter__(self): + return iter(self.datasets) + + class ResourceDataset: """ h5py.Dataset wrapper for Resource .h5 files @@ -583,7 +595,7 @@ def extract(cls, ds, ds_slice, scale_attr='scale_factor', return dset[ds_slice] -class BaseResource(ABC): +class BaseResource(BaseDatasetIterable): """ Abstract Base class to handle resource .h5 files """ @@ -646,7 +658,6 @@ def __init__(self, h5_file, mode='r', unscale=True, str_decode=True, self._shapes = None self._chunks = None self._dtypes = None - self._i = 0 def __repr__(self): msg = "{} for {}".format(self.__class__.__name__, self.h5_file) @@ -691,19 +702,6 @@ def __getitem__(self, keys): return out - def __iter__(self): - return self - - def __next__(self): - if self._i >= len(self.datasets): - self._i = 0 - raise StopIteration - - dset = self.datasets[self._i] - self._i += 1 - - return dset - def __contains__(self, dset): return dset in self.datasets diff --git a/rex/resource_extraction/resource_extraction.py b/rex/resource_extraction/resource_extraction.py index 223bdeea..d28939e9 100644 --- a/rex/resource_extraction/resource_extraction.py +++ b/rex/resource_extraction/resource_extraction.py @@ -27,7 +27,7 @@ WaveResource, WindResource, ) -from rex.resource import Resource, ResourceDataset +from rex.resource import Resource, ResourceDataset, BaseDatasetIterable from rex.temporal_stats.temporal_stats import TemporalStats from rex.utilities.exceptions import ResourceValueError, ResourceWarning from rex.utilities.execution import SpawnProcessPool @@ -39,7 +39,7 @@ logger = logging.getLogger(__name__) -class ResourceX: +class ResourceX(BaseDatasetIterable): """ Resource data extraction tool """ @@ -88,7 +88,6 @@ def __init__(self, res_h5, res_cls=None, tree=None, unscale=True, group=group, hsds=hsds, hsds_kwargs=hsds_kwargs) self._dist_thresh = None self._tree = tree - self._i = 0 def __repr__(self): msg = "{} extractor for {}".format(self._res.__class__.__name__, @@ -114,19 +113,6 @@ def __getitem__(self, keys): def __contains__(self, dset): return dset in self.datasets - def __iter__(self): - return self - - def __next__(self): - if self._i >= len(self.datasets): - self._i = 0 - raise StopIteration - - dset = self.datasets[self._i] - self._i += 1 - - return dset - @property def resource(self): """ @@ -1543,7 +1529,6 @@ def __init__(self, resource_path, res_cls=None, tree=None, str_decode=str_decode, check_files=check_files) self._dist_thresh = None self._tree = tree - self._i = 0 class MultiYearResourceX(ResourceX): @@ -1590,7 +1575,6 @@ def __init__(self, resource_path, years=None, tree=None, unscale=True, hsds_kwargs=hsds_kwargs) self._dist_thresh = None self._tree = tree - self._i = 0 def get_means_map(self, ds_name, year=None, region=None, region_col='state', max_workers=None, @@ -1676,7 +1660,6 @@ def __init__(self, resource_path, tree=None, unscale=True, hsds=hsds, hsds_kwargs=hsds_kwargs) self._dist_thresh = None self._tree = tree - self._i = 0 class SolarX(ResourceX): diff --git a/setup.py b/setup.py index 93f15176..4dc10e80 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ def run(self): with open("requirements.txt") as f: install_requires = f.readlines() -test_requires = ["pytest>=5.2", ] +test_requires = ["pytest>=5.2", "pytest-timeout>=2.3.1"] dev_requires = ["flake8", "pre-commit", "pylint", "hsds>=0.8.4"] description = ("National Renewable Energy Laboratory's (NREL's) REsource " "eXtraction tool: rex") diff --git a/tests/test_multi_res_resource.py b/tests/test_multi_res_resource.py index 0ca3438a..018d8285 100644 --- a/tests/test_multi_res_resource.py +++ b/tests/test_multi_res_resource.py @@ -170,3 +170,20 @@ def test_preload_sam(): assert np.allclose(true, test) mrr.close() + + +@pytest.mark.timeout(10) +def test_multi_res_resource_iterator(): + """ + test MultiResolutionResource iterator. Incorrect implementation can + cause an infinite loop + """ + with tempfile.TemporaryDirectory() as td: + fp_hr, fp_lr = make_multi_res_files(td) + mrr = MultiResolutionResource(fp_hr, fp_lr, handler_class=WindResource) + dsets_permutation = {(a, b) for a in mrr for b in mrr} + num_dsets = len(mrr.datasets) + + mrr.close() + + assert len(dsets_permutation) == num_dsets ** 2 diff --git a/tests/test_multi_time_resource.py b/tests/test_multi_time_resource.py index acca2233..12944bf1 100644 --- a/tests/test_multi_time_resource.py +++ b/tests/test_multi_time_resource.py @@ -9,8 +9,8 @@ import pytest from rex import TESTDATADIR -from rex.multi_time_resource import (MultiTimeH5, MultiTimeNSRDB, - MultiTimeWindResource) +from rex.multi_time_resource import (MultiTimeH5, MultiTimeResource, + MultiTimeNSRDB, MultiTimeWindResource) from rex.resource import Resource @@ -323,6 +323,21 @@ def test_map_hsds_files(): assert not any(wrong), 'Wrong files: {}'.format(wrong) +@pytest.mark.timeout(10) +def test_mt_iterator(): + """ + test MultiTimeResource iterator. Incorrect implementation can + cause an infinite loop + """ + path = os.path.join(TESTDATADIR, 'wtk/ri_100_wtk_*.h5') + + with MultiTimeResource(path) as res: + dsets_permutation = {(a, b) for a in res for b in res} + num_dsets = len(res.datasets) + + assert len(dsets_permutation) == num_dsets ** 2 + + def execute_pytest(capture='all', flags='-rapP'): """Execute module as pytest with detailed summary report. diff --git a/tests/test_multi_year_resource.py b/tests/test_multi_year_resource.py index 60df156a..0376a846 100644 --- a/tests/test_multi_year_resource.py +++ b/tests/test_multi_year_resource.py @@ -358,6 +358,36 @@ def test_multi_file_year(): assert test_meta.equals(f.meta) +@pytest.mark.timeout(10) +def test_my_resource_iterator(): + """ + test MultiYearWindResource iterator. Incorrect implementation can + cause an infinite loop + """ + path = os.path.join(TESTDATADIR, 'wtk/ri_100_wtk_*.h5') + + with MultiYearWindResource(path) as res: + dsets_permutation = {(a, b) for a in res for b in res} + num_dsets = len(res.datasets) + + assert len(dsets_permutation) == num_dsets ** 2 + + +@pytest.mark.timeout(10) +def test_myh5_iterator(): + """ + test MultiYearH5 iterator. Incorrect implementation can + cause an infinite loop + """ + path = os.path.join(TESTDATADIR, 'nsrdb/ri_100_nsrdb_*.h5') + + myh5 = MultiYearH5(path) + dsets_permutation = {(a, b) for a in myh5 for b in myh5} + num_dsets = len(myh5.years) + + assert len(dsets_permutation) == num_dsets ** 2 + + def execute_pytest(capture='all', flags='-rapP'): """Execute module as pytest with detailed summary report. diff --git a/tests/test_resource.py b/tests/test_resource.py index 3c3c86c9..37fe759a 100644 --- a/tests/test_resource.py +++ b/tests/test_resource.py @@ -949,6 +949,32 @@ def test_1D_dataset_slicing_spatial_repeat(): assert res['dset3', 55, 79].dtype == np.float32 +@pytest.mark.timeout(10) +def test_resource_iterator(): + """ + test Resource iterator. Incorrect implementation can cause an infinite loop + """ + h5_file = os.path.join(TESTDATADIR, 'nsrdb', 'nsrdb_irradiance_2018.h5') + with Resource(h5_file) as res: + dsets_permutation = {(a, b) for a in res for b in res} + num_dsets = len(res.datasets) + + assert len(dsets_permutation) == num_dsets ** 2 + + +@pytest.mark.timeout(10) +def test_mh5_iterator(): + """ + test MultiH5 iterator. Incorrect implementation can cause an infinite loop + """ + h5_files = [os.path.join(TESTDATADIR, 'nsrdb', 'nsrdb_irradiance_2018.h5'), + os.path.join(TESTDATADIR, 'wtk', 'wtk_2010_100m.h5')] + + mh5 = MultiH5(h5_files) + dsets_permutation = {(a, b) for a in mh5 for b in mh5} + assert len(dsets_permutation) == len(mh5.datasets) ** 2 + + def execute_pytest(capture='all', flags='-rapP'): """Execute module as pytest with detailed summary report. diff --git a/tests/test_resource_extraction.py b/tests/test_resource_extraction.py index 20393a1f..94a17111 100644 --- a/tests/test_resource_extraction.py +++ b/tests/test_resource_extraction.py @@ -1130,6 +1130,19 @@ def test_get_bad_raster_index(): ext.get_raster_index(target, shape, meta=meta) +@pytest.mark.timeout(10) +def test_resourcex_iterable(NSRDBX_cls): + """ + test ResourceX iterator. Incorrect implementation can cause an + infinite loop + """ + with NSRDBX_cls as res: + dsets_permutation = {(a, b) for a in res for b in res} + num_dsets = len(res.datasets) + + assert len(dsets_permutation) == num_dsets ** 2 + + def execute_pytest(capture='all', flags='-rapP'): """Execute module as pytest with detailed summary report.