Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Instantiate iterators #182

Merged
merged 11 commits into from
Aug 14, 2024
1 change: 1 addition & 0 deletions .github/workflows/pull_request_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ jobs:
pip install --upgrade pip
pip install pytest
pip install pytest-cov
pip install pytest-timeout
pip install -e .
- name: Run pytest and Generate coverage report
run: |
Expand Down
20 changes: 2 additions & 18 deletions rex/multi_file_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from rex.renewable_resource import (NSRDB, SolarResource, GeothermalResource,
WindResource, WaveResource,
AbstractInterpolatedResource)
from rex.resource import Resource
from rex.resource import Resource, BaseDatasetIterable
from rex.utilities.exceptions import FileInputError, ResourceRuntimeError
from rex.utilities.utilities import unstupify_path


class MultiH5:
class MultiH5(BaseDatasetIterable):
"""
Class to handle multiple h5 file Resources
"""
Expand All @@ -32,8 +32,6 @@ def __init__(self, h5_files, check_files=False):
self._dset_map = self._map_file_dsets(h5_files)
self._h5_map = self._map_file_instances(set(self._dset_map.values()))

self._i = 0

if check_files:
self._preflight_check()

Expand Down Expand Up @@ -66,19 +64,6 @@ def __getitem__(self, dset):

return ds

def __next__(self):
if self._i >= len(self.datasets):
self._i = 0
raise StopIteration

dset = self.datasets[self._i]
self._i += 1

return dset

def __iter__(self):
return self

def __contains__(self, dset):
return dset in self.datasets

Expand Down Expand Up @@ -405,7 +390,6 @@ def __init__(self, h5_source, unscale=True, str_decode=True,
self._shapes = None
self._chunks = None
self._dtypes = None
self._i = 0

self._interp_var = None
self._use_lapse = use_lapse_rate
Expand Down
13 changes: 1 addition & 12 deletions rex/multi_res_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def __init__(self, h5_hr, h5_lr, handler_class=Resource,
self._lr_res = handler_class(h5_lr, **handle_kwargs)
self._nn_map = nn_map
self._nn_d = nn_d
self._i = 0

if self._nn_map is None:
self._nn_d, self._nn_map = self.make_nn_map(self._hr_res,
Expand Down Expand Up @@ -237,17 +236,7 @@ def __getitem__(self, keys):
return out

def __iter__(self):
return self

def __next__(self):
if self._i >= len(self.datasets):
self._i = 0
raise StopIteration

dset = self.datasets[self._i]
self._i += 1

return dset
return iter(self.datasets)

def __contains__(self, dset):
return dset in self.datasets
Expand Down
19 changes: 2 additions & 17 deletions rex/multi_time_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
WaveResource,
WindResource,
)
from rex.resource import Resource
from rex.resource import Resource, BaseDatasetIterable
from rex.utilities.exceptions import FileInputError
from rex.utilities.parse_keys import parse_keys, parse_slice

Expand Down Expand Up @@ -58,7 +58,6 @@ def __init__(self, h5_path, res_cls=Resource, hsds=False, hsds_kwargs=None,
self._shape = None
self._time_index = None
self._time_slice_map = []
self._i = 0

def __repr__(self):
msg = ("{} for {}:\n Contains data from {} files"
Expand Down Expand Up @@ -419,7 +418,7 @@ def close(self):
f.close()


class MultiTimeResource:
class MultiTimeResource(BaseDatasetIterable):
"""
Class to handle resource data stored temporally accross multiple
.h5 files
Expand Down Expand Up @@ -519,7 +518,6 @@ def __init__(self, h5_path, unscale=True, str_decode=True,
self._h5 = MultiTimeH5(self.h5_path, res_cls=res_cls, **cls_kwargs)
self.h5_files = self._h5.h5_files
self.h5_file = self.h5_files[0]
self._i = 0

def __repr__(self):
msg = "{} for {}".format(self.__class__.__name__, self.h5_path)
Expand All @@ -537,19 +535,6 @@ def __exit__(self, type, value, traceback):
def __len__(self):
return len(self.h5.time_index)

def __iter__(self):
return self

def __next__(self):
if self._i >= len(self.datasets):
self._i = 0
raise StopIteration

dset = self.datasets[self._i]
self._i += 1

return dset

def __getitem__(self, keys):
ds, ds_slice = parse_keys(keys)

Expand Down
14 changes: 1 addition & 13 deletions rex/multi_year_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def __init__(self, h5_path, years=None, res_cls=Resource, hsds=False,
self._datasets = None
self._shape = None
self._time_index = None
self._i = 0

def __repr__(self):
msg = ("{} for {}:\n Contains data for {} years"
Expand All @@ -82,17 +81,7 @@ def __getitem__(self, year):
return h5

def __iter__(self):
return self

def __next__(self):
if self._i >= len(self.years):
self._i = 0
raise StopIteration

year = self.years[self._i]
self._i += 1

return year
return iter(self.years)

def __contains__(self, year):
return year in self.years
Expand Down Expand Up @@ -451,7 +440,6 @@ def __init__(self, h5_path, years=None, unscale=True, str_decode=True,
**cls_kwargs)
self.h5_files = self._h5.h5_files
self.h5_file = self.h5_files[0]
self._i = 0

@property
def years(self):
Expand Down
30 changes: 14 additions & 16 deletions rex/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Classes to handle resource data
"""
import os
from abc import ABC
from abc import ABC, abstractmethod
from warnings import warn

import dateutil
Expand All @@ -17,6 +17,18 @@
from rex.utilities.utilities import check_tz, get_lat_lon_cols


class BaseDatasetIterable(ABC):
"""Base class for file that is iterable over datasets. """

@property
@abstractmethod
def datasets(self):
"""iterable: Datasets available in file. """

def __iter__(self):
return iter(self.datasets)


class ResourceDataset:
"""
h5py.Dataset wrapper for Resource .h5 files
Expand Down Expand Up @@ -583,7 +595,7 @@ def extract(cls, ds, ds_slice, scale_attr='scale_factor',
return dset[ds_slice]


class BaseResource(ABC):
class BaseResource(BaseDatasetIterable):
"""
Abstract Base class to handle resource .h5 files
"""
Expand Down Expand Up @@ -646,7 +658,6 @@ def __init__(self, h5_file, mode='r', unscale=True, str_decode=True,
self._shapes = None
self._chunks = None
self._dtypes = None
self._i = 0

def __repr__(self):
msg = "{} for {}".format(self.__class__.__name__, self.h5_file)
Expand Down Expand Up @@ -691,19 +702,6 @@ def __getitem__(self, keys):

return out

def __iter__(self):
return self

def __next__(self):
if self._i >= len(self.datasets):
self._i = 0
raise StopIteration

dset = self.datasets[self._i]
self._i += 1

return dset

def __contains__(self, dset):
return dset in self.datasets

Expand Down
21 changes: 2 additions & 19 deletions rex/resource_extraction/resource_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
WaveResource,
WindResource,
)
from rex.resource import Resource, ResourceDataset
from rex.resource import Resource, ResourceDataset, BaseDatasetIterable
from rex.temporal_stats.temporal_stats import TemporalStats
from rex.utilities.exceptions import ResourceValueError, ResourceWarning
from rex.utilities.execution import SpawnProcessPool
Expand All @@ -39,7 +39,7 @@
logger = logging.getLogger(__name__)


class ResourceX:
class ResourceX(BaseDatasetIterable):
"""
Resource data extraction tool
"""
Expand Down Expand Up @@ -88,7 +88,6 @@ def __init__(self, res_h5, res_cls=None, tree=None, unscale=True,
group=group, hsds=hsds, hsds_kwargs=hsds_kwargs)
self._dist_thresh = None
self._tree = tree
self._i = 0

def __repr__(self):
msg = "{} extractor for {}".format(self._res.__class__.__name__,
Expand All @@ -114,19 +113,6 @@ def __getitem__(self, keys):
def __contains__(self, dset):
return dset in self.datasets

def __iter__(self):
return self

def __next__(self):
if self._i >= len(self.datasets):
self._i = 0
raise StopIteration

dset = self.datasets[self._i]
self._i += 1

return dset

@property
def resource(self):
"""
Expand Down Expand Up @@ -1543,7 +1529,6 @@ def __init__(self, resource_path, res_cls=None, tree=None,
str_decode=str_decode, check_files=check_files)
self._dist_thresh = None
self._tree = tree
self._i = 0


class MultiYearResourceX(ResourceX):
Expand Down Expand Up @@ -1590,7 +1575,6 @@ def __init__(self, resource_path, years=None, tree=None, unscale=True,
hsds_kwargs=hsds_kwargs)
self._dist_thresh = None
self._tree = tree
self._i = 0

def get_means_map(self, ds_name, year=None, region=None,
region_col='state', max_workers=None,
Expand Down Expand Up @@ -1676,7 +1660,6 @@ def __init__(self, resource_path, tree=None, unscale=True,
hsds=hsds, hsds_kwargs=hsds_kwargs)
self._dist_thresh = None
self._tree = tree
self._i = 0


class SolarX(ResourceX):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def run(self):
with open("requirements.txt") as f:
install_requires = f.readlines()

test_requires = ["pytest>=5.2", ]
test_requires = ["pytest>=5.2", "pytest-timeout>=2.3.1"]
dev_requires = ["flake8", "pre-commit", "pylint", "hsds>=0.8.4"]
description = ("National Renewable Energy Laboratory's (NREL's) REsource "
"eXtraction tool: rex")
Expand Down
17 changes: 17 additions & 0 deletions tests/test_multi_res_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,20 @@ def test_preload_sam():
assert np.allclose(true, test)

mrr.close()


@pytest.mark.timeout(10)
def test_multi_res_resource_iterator():
"""
test MultiResolutionResource iterator. Incorrect implementation can
cause an infinite loop
"""
with tempfile.TemporaryDirectory() as td:
fp_hr, fp_lr = make_multi_res_files(td)
mrr = MultiResolutionResource(fp_hr, fp_lr, handler_class=WindResource)
dsets_permutation = {(a, b) for a in mrr for b in mrr}
num_dsets = len(mrr.datasets)

mrr.close()

assert len(dsets_permutation) == num_dsets ** 2
19 changes: 17 additions & 2 deletions tests/test_multi_time_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
import pytest

from rex import TESTDATADIR
from rex.multi_time_resource import (MultiTimeH5, MultiTimeNSRDB,
MultiTimeWindResource)
from rex.multi_time_resource import (MultiTimeH5, MultiTimeResource,
MultiTimeNSRDB, MultiTimeWindResource)
from rex.resource import Resource


Expand Down Expand Up @@ -323,6 +323,21 @@ def test_map_hsds_files():
assert not any(wrong), 'Wrong files: {}'.format(wrong)


@pytest.mark.timeout(10)
def test_mt_iterator():
"""
test MultiTimeResource iterator. Incorrect implementation can
cause an infinite loop
"""
path = os.path.join(TESTDATADIR, 'wtk/ri_100_wtk_*.h5')

with MultiTimeResource(path) as res:
dsets_permutation = {(a, b) for a in res for b in res}
num_dsets = len(res.datasets)

assert len(dsets_permutation) == num_dsets ** 2


def execute_pytest(capture='all', flags='-rapP'):
"""Execute module as pytest with detailed summary report.

Expand Down
Loading
Loading