From 8fbb9ec5e3e98b2a6f3502874ca3888a1cecd69f Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Fri, 3 Feb 2023 14:57:29 -0500 Subject: [PATCH 1/7] Make aggregation_control really optional --- intake_esm/cat.py | 39 +++++++++++++++++++++------------------ intake_esm/core.py | 43 +++++++++++++++++++++++++++++++------------ intake_esm/source.py | 10 +++++----- tests/test_core.py | 17 ++++++++++++++++- tests/utils.py | 1 + 5 files changed, 74 insertions(+), 36 deletions(-) diff --git a/intake_esm/cat.py b/intake_esm/cat.py index a06c8ad5..1991fb36 100644 --- a/intake_esm/cat.py +++ b/intake_esm/cat.py @@ -294,35 +294,38 @@ def df(self) -> pd.DataFrame: @property def has_multiple_variable_assets(self) -> bool: """Return True if the catalog has multiple variable assets.""" - return self.aggregation_control.variable_column_name in self.columns_with_iterables + if self.aggregation_control: + return self.aggregation_control.variable_column_name in self.columns_with_iterables + return False def _cast_agg_columns_with_iterables(self) -> None: """Cast all agg_columns with iterables to tuple values so as to avoid hashing issues (e.g. TypeError: unhashable type: 'list') """ - columns = list( - self.columns_with_iterables.intersection( - set(map(lambda agg: agg.attribute_name, self.aggregation_control.aggregations)) + if self.aggregation_control: + columns = list( + self.columns_with_iterables.intersection( + set(map(lambda agg: agg.attribute_name, self.aggregation_control.aggregations)) + ) ) - ) - if columns: - self._df[columns] = self._df[columns].apply(tuple) + if columns: + self._df[columns] = self._df[columns].apply(tuple) @property def grouped(self) -> typing.Union[pd.core.groupby.DataFrameGroupBy, pd.DataFrame]: - - if self.aggregation_control.groupby_attrs: - self.aggregation_control.groupby_attrs = list( - filter( - functools.partial(_allnan_or_nonan, self.df), - self.aggregation_control.groupby_attrs, + if self.aggregation_control: + if self.aggregation_control.groupby_attrs: + self.aggregation_control.groupby_attrs = list( + filter( + functools.partial(_allnan_or_nonan, self.df), + self.aggregation_control.groupby_attrs, + ) ) - ) - if self.aggregation_control.groupby_attrs and set( - self.aggregation_control.groupby_attrs - ) != set(self.df.columns): - return self.df.groupby(self.aggregation_control.groupby_attrs) + if self.aggregation_control.groupby_attrs and set( + self.aggregation_control.groupby_attrs + ) != set(self.df.columns): + return self.df.groupby(self.aggregation_control.groupby_attrs) return self.df def _construct_group_keys(self, sep: str = '.') -> dict[str, typing.Union[str, tuple[str]]]: diff --git a/intake_esm/core.py b/intake_esm/core.py index 55e0f515..a24d24be 100644 --- a/intake_esm/core.py +++ b/intake_esm/core.py @@ -101,6 +101,10 @@ def __init__( self._validate_derivedcat() def _validate_derivedcat(self) -> None: + if self.esmcat.aggregation_control is None and len(self.derivedcat): + raise ValueError( + 'Variable derivation requires an `aggregation_control` to be specified in the catalog.' + ) for key, entry in self.derivedcat.items(): if self.esmcat.aggregation_control.variable_column_name not in entry.query.keys(): raise ValueError( @@ -149,6 +153,10 @@ def keys_info(self) -> pd.DataFrame: """ results = self.esmcat._construct_group_keys(sep=self.sep) + if self.esmcat.aggregation_control and self.esmcat.aggregation_control.groupby_attrs: + groupby_attrs = self.esmcat.aggregation_control.groupby_attrs + else: + groupby_attrs = self.df.columns data = { key: dict(zip(self.esmcat.aggregation_control.groupby_attrs, results[key])) for key in results @@ -167,7 +175,7 @@ def key_template(self) -> str: str string template used to create catalog entry keys """ - if self.esmcat.aggregation_control.groupby_attrs: + if self.esmcat.aggregation_control and self.esmcat.aggregation_control.groupby_attrs: return self.sep.join(self.esmcat.aggregation_control.groupby_attrs) else: return self.sep.join(self.esmcat.df.columns) @@ -233,15 +241,21 @@ def __getitem__(self, key: str) -> ESMDataSource: else: records = grouped.get_group(internal_key).to_dict(orient='records') + if self.esmcat.aggregation_control: + variable_column_name = self.esmcat.aggregation_control.variable_column_name + aggregations = self.esmcat.aggregation_control.aggregations + else: + variable_column_name = None + aggregations = [] # Create a new entry entry = ESMDataSource( key=key, records=records, - variable_column_name=self.esmcat.aggregation_control.variable_column_name, + variable_column_name=variable_column_name, path_column_name=self.esmcat.assets.column_name, data_format=self.esmcat.assets.format, format_column_name=self.esmcat.assets.format_column_name, - aggregations=self.esmcat.aggregation_control.aggregations, + aggregations=aggregations, intake_kwargs={'metadata': {}}, ) self._entries[key] = entry @@ -366,7 +380,10 @@ def search(self, require_all_on: typing.Union[str, list[str]] = None, **query: t # step 2: Search for entries required to derive variables in the derived catalogs # This requires a bit of a hack i.e. the user has to specify the variable in the query derivedcat_results = [] - variables = query.pop(self.esmcat.aggregation_control.variable_column_name, None) + if self.esmcat.aggregation_control: + variables = query.pop(self.esmcat.aggregation_control.variable_column_name, None) + else: + variables = None dependents = [] derived_cat_subset = {} if variables: @@ -488,9 +505,10 @@ def nunique(self) -> pd.Series: dtype: int64 """ nunique = self.esmcat.nunique() - nunique[f'derived_{self.esmcat.aggregation_control.variable_column_name}'] = len( - self.derivedcat.keys() - ) + if self.esmcat.aggregation_control: + nunique[f'derived_{self.esmcat.aggregation_control.variable_column_name}'] = len( + self.derivedcat.keys() + ) return nunique def unique(self) -> pd.Series: @@ -498,9 +516,10 @@ def unique(self) -> pd.Series: catalog. """ unique = self.esmcat.unique() - unique[f'derived_{self.esmcat.aggregation_control.variable_column_name}'] = list( - self.derivedcat.keys() - ) + if self.esmcat.aggregation_control: + unique[f'derived_{self.esmcat.aggregation_control.variable_column_name}'] = list( + self.derivedcat.keys() + ) return unique @pydantic.validate_arguments @@ -585,7 +604,7 @@ def to_dataset_dict( ) return {} - if ( + if self.esmcat.aggregation_control and ( self.esmcat.aggregation_control.variable_column_name in self.esmcat.aggregation_control.groupby_attrs ) and len(self.derivedcat) > 0: @@ -618,7 +637,7 @@ def to_dataset_dict( storage_options=storage_options, ) - if aggregate is not None and not aggregate: + if aggregate is not None and not aggregate and self.esmcat.aggregation_control: self = deepcopy(self) self.esmcat.aggregation_control.groupby_attrs = [] if progressbar is not None: diff --git a/intake_esm/source.py b/intake_esm/source.py index e7cc12f2..d313def9 100644 --- a/intake_esm/source.py +++ b/intake_esm/source.py @@ -86,7 +86,7 @@ def _open_dataset( ds = ds.set_coords(scalar_variables) ds = ds[variables] ds.attrs[OPTIONS['vars_key']] = variables - else: + elif varname: ds.attrs[OPTIONS['vars_key']] = varname ds = _expand_dims(expand_dims, ds) @@ -126,11 +126,11 @@ def __init__( self, key: pydantic.StrictStr, records: list[dict[str, typing.Any]], - variable_column_name: pydantic.StrictStr, path_column_name: pydantic.StrictStr, data_format: typing.Optional[DataFormat], format_column_name: typing.Optional[pydantic.StrictStr], *, + variable_column_name: typing.Optional[pydantic.StrictStr] = None, aggregations: typing.Optional[list[Aggregation]] = None, requested_variables: list[str] = None, preprocess: typing.Callable = None, @@ -148,12 +148,12 @@ def __init__( records: list of dict A list of records, each of which is a dictionary mapping column names to values. - variable_column_name: str - The column name of the variable name. path_column_name: str The column name of the path. data_format: DataFormat The data format of the data. + variable_column_name: str, optional + The column name of the variable name. aggregations: list of Aggregation, optional A list of aggregations to apply to the data. requested_variables: list of str, optional @@ -220,7 +220,7 @@ def _open_dataset(self): datasets = [ _open_dataset( record[self.path_column_name], - record[self.variable_column_name], + record[self.variable_column_name] if self.variable_column_name else None, xarray_open_kwargs=_get_xarray_open_kwargs( record['_data_format_'], self.xarray_open_kwargs, self.storage_options ), diff --git a/tests/test_core.py b/tests/test_core.py index 4df33013..150a01f0 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -47,6 +47,7 @@ def func_multivar(ds): cdf_cat_sample_cmip6, mixed_cat_sample_cmip6, multi_variable_cat, + noagg_cat, sample_df, sample_esmcat_data, zarr_cat_aws_cesm, @@ -58,6 +59,7 @@ def func_multivar(ds): 'obj, sep, read_csv_kwargs', [ (catalog_dict_records, '.', None), + (noagg_cat, '.', None), (cdf_cat_sample_cmip6, '/', None), (zarr_cat_aws_cesm, '.', None), (zarr_cat_pangeo_cmip6, '*', None), @@ -99,6 +101,18 @@ def func(ds): intake.open_esm_datastore(catalog_dict_records, registry=registry) +def test_impossible_derivedcat(): + registry = intake_esm.DerivedVariableRegistry() + + @registry.register(variable='FOO', query={'variable': ['FLNS', 'FLUT']}) + def func(ds): + ds['FOO'] = ds.FLNS + ds.FLUT + return ds + + with pytest.raises(ValueError, match="Variable derivation requires an `aggregation_control`"): + intake.open_esm_datastore(noagg_cat, registry=registry) + + @pytest.mark.parametrize( 'obj, sep, read_csv_kwargs', [ @@ -107,6 +121,7 @@ def func(ds): (cdf_cat_sample_cmip5, '.', None), (cdf_cat_sample_cmip6, '*', None), (catalog_dict_records, '.', None), + (noagg_cat, '.', None), ({'esmcat': sample_esmcat_data, 'df': sample_df}, '.', None), ], ) @@ -116,7 +131,7 @@ def test_catalog_unique(obj, sep, read_csv_kwargs): nuniques = cat.nunique() assert isinstance(uniques, pd.Series) assert isinstance(nuniques, pd.Series) - assert len(uniques.keys()) == len(cat.df.columns) + 1 # for derived_variable entry + assert len(uniques.keys()) == len(cat.df.columns) + (0 if obj is noagg_cat else 1) # for derived_variable entry def test_catalog_contains(): diff --git a/tests/utils.py b/tests/utils.py index 01380fb3..ae1f1189 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -9,6 +9,7 @@ cdf_cat_sample_cmip5 = os.path.join(here, 'sample-catalogs/cmip5-netcdf.json') cdf_cat_sample_cesmle = os.path.join(here, 'sample-catalogs/cesm1-lens-netcdf.json') catalog_dict_records = os.path.join(here, 'sample-catalogs/catalog-dict-records.json') +noagg_cat = os.path.join(here, 'sample-catalogs/catalog-dict-records-noagg.json') zarr_cat_aws_cesm = ( 'https://raw.githubusercontent.com/NCAR/cesm-lens-aws/master/intake-catalogs/aws-cesm1-le.json' ) From f4f06a02bfc6612d9a1e16cf0fc4d9457ae7c1cf Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Fri, 3 Feb 2023 15:05:09 -0500 Subject: [PATCH 2/7] run precommit --- intake_esm/core.py | 16 ++++++++++------ tests/test_core.py | 6 ++++-- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/intake_esm/core.py b/intake_esm/core.py index a24d24be..26b52ee8 100644 --- a/intake_esm/core.py +++ b/intake_esm/core.py @@ -154,9 +154,9 @@ def keys_info(self) -> pd.DataFrame: """ results = self.esmcat._construct_group_keys(sep=self.sep) if self.esmcat.aggregation_control and self.esmcat.aggregation_control.groupby_attrs: - groupby_attrs = self.esmcat.aggregation_control.groupby_attrs + pass else: - groupby_attrs = self.df.columns + pass data = { key: dict(zip(self.esmcat.aggregation_control.groupby_attrs, results[key])) for key in results @@ -604,10 +604,14 @@ def to_dataset_dict( ) return {} - if self.esmcat.aggregation_control and ( - self.esmcat.aggregation_control.variable_column_name - in self.esmcat.aggregation_control.groupby_attrs - ) and len(self.derivedcat) > 0: + if ( + self.esmcat.aggregation_control + and ( + self.esmcat.aggregation_control.variable_column_name + in self.esmcat.aggregation_control.groupby_attrs + ) + and len(self.derivedcat) > 0 + ): raise NotImplementedError( f'The `{self.esmcat.aggregation_control.variable_column_name}` column name is used as a groupby attribute: {self.esmcat.aggregation_control.groupby_attrs}. ' 'This is not yet supported when computing derived variables.' diff --git a/tests/test_core.py b/tests/test_core.py index 150a01f0..959c8b69 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -109,7 +109,7 @@ def func(ds): ds['FOO'] = ds.FLNS + ds.FLUT return ds - with pytest.raises(ValueError, match="Variable derivation requires an `aggregation_control`"): + with pytest.raises(ValueError, match='Variable derivation requires an `aggregation_control`'): intake.open_esm_datastore(noagg_cat, registry=registry) @@ -131,7 +131,9 @@ def test_catalog_unique(obj, sep, read_csv_kwargs): nuniques = cat.nunique() assert isinstance(uniques, pd.Series) assert isinstance(nuniques, pd.Series) - assert len(uniques.keys()) == len(cat.df.columns) + (0 if obj is noagg_cat else 1) # for derived_variable entry + assert len(uniques.keys()) == len(cat.df.columns) + ( + 0 if obj is noagg_cat else 1 + ) # for derived_variable entry def test_catalog_contains(): From cd658b7ea4d8b548e802f45f8e13e807982fb15f Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Fri, 3 Feb 2023 15:07:58 -0500 Subject: [PATCH 3/7] add new noagg cat --- .../catalog-dict-records-noagg.json | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) create mode 100644 tests/sample-catalogs/catalog-dict-records-noagg.json diff --git a/tests/sample-catalogs/catalog-dict-records-noagg.json b/tests/sample-catalogs/catalog-dict-records-noagg.json new file mode 100644 index 00000000..1afffd48 --- /dev/null +++ b/tests/sample-catalogs/catalog-dict-records-noagg.json @@ -0,0 +1,64 @@ +{ + "esmcat_version": "0.1.0", + "id": "aws-cesm1-le-noagg", + "description": "This is an ESM catalog for CESM1 Large Ensemble Zarr dataset publicly available on Amazon S3 (us-west-2 region), without any aggregation info.", + "catalog_dict": [ + { + "component": "atm", + "frequency": "daily", + "experiment": "20C", + "variable": "FLNS", + "path": "s3://ncar-cesm-lens/atm/daily/cesmLE-20C-FLNS.zarr" + }, + { + "component": "atm", + "frequency": "daily", + "experiment": "20C", + "variable": "FLNSC", + "path": "s3://ncar-cesm-lens/atm/daily/cesmLE-20C-FLNSC.zarr" + }, + { + "component": "atm", + "frequency": "daily", + "experiment": "20C", + "variable": "FLUT", + "path": "s3://ncar-cesm-lens/atm/daily/cesmLE-20C-FLUT.zarr" + }, + { + "component": "atm", + "frequency": "daily", + "experiment": "20C", + "variable": "FSNS", + "path": "s3://ncar-cesm-lens/atm/daily/cesmLE-20C-FSNS.zarr" + }, + { + "component": "atm", + "frequency": "daily", + "experiment": "20C", + "variable": "FSNSC", + "path": "s3://ncar-cesm-lens/atm/daily/cesmLE-20C-FSNSC.zarr" + } + ], + "attributes": [ + { + "column_name": "component", + "vocabulary": "" + }, + { + "column_name": "frequency", + "vocabulary": "" + }, + { + "column_name": "experiment", + "vocabulary": "" + }, + { + "column_name": "variable", + "vocabulary": "" + } + ], + "assets": { + "column_name": "path", + "format": "zarr" + } +} From 642735aa79020d250b9970e6f2ff46c5c3c67e26 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Wed, 10 May 2023 17:04:09 -0400 Subject: [PATCH 4/7] Update intake_esm/core.py Co-authored-by: Max Grover --- intake_esm/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/intake_esm/core.py b/intake_esm/core.py index 0edca7b2..8780c1fc 100644 --- a/intake_esm/core.py +++ b/intake_esm/core.py @@ -117,7 +117,7 @@ def __init__( def _validate_derivedcat(self) -> None: if self.esmcat.aggregation_control is None and len(self.derivedcat): raise ValueError( - 'Variable derivation requires an `aggregation_control` to be specified in the catalog.' + 'Variable derivation requires `aggregation_control` to be specified in the catalog.' ) for key, entry in self.derivedcat.items(): if self.esmcat.aggregation_control.variable_column_name not in entry.query.keys(): From 15c2eb0f2850a939800542efda8c5bc63e255814 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Wed, 10 May 2023 18:01:04 -0400 Subject: [PATCH 5/7] Add tests - change test cat - upd changelog --- CHANGELOG.md | 1 + .../catalog-dict-records-noagg.json | 64 ------------------- tests/sample-catalogs/cmip6-netcdf-noagg.json | 38 +++++++++++ tests/test_cat.py | 2 + tests/test_core.py | 21 ++++-- tests/utils.py | 2 +- 6 files changed, 56 insertions(+), 72 deletions(-) delete mode 100644 tests/sample-catalogs/catalog-dict-records-noagg.json create mode 100644 tests/sample-catalogs/cmip6-netcdf-noagg.json diff --git a/CHANGELOG.md b/CHANGELOG.md index edd13cb6..4a5f2720 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ - Fix the link to documentation build status [#591](https://github.com/intake/intake-esm/pull/591) ([@mgrover1](https://github.com/mgrover1)) - Add optional `columns_with_iterables` argument to `esm_datastore` [#589](https://github.com/intake/intake-esm/pull/589) ([@dougiesquire](https://github.com/dougiesquire)) - Add `opendap` as a possible data format [#570](https://github.com/intake/intake-esm/pull/570) ([@aulemahal](https://github.com/aulemahal)) +- Fix for catalogs without `aggregation_control`, as allowed by the ESM catalog specification [#569](https://github.com/intake/intake-esm/pull/569) ([@aulemahal](https://github.com/aulemahal)) ## v2022.9.18 diff --git a/tests/sample-catalogs/catalog-dict-records-noagg.json b/tests/sample-catalogs/catalog-dict-records-noagg.json deleted file mode 100644 index 1afffd48..00000000 --- a/tests/sample-catalogs/catalog-dict-records-noagg.json +++ /dev/null @@ -1,64 +0,0 @@ -{ - "esmcat_version": "0.1.0", - "id": "aws-cesm1-le-noagg", - "description": "This is an ESM catalog for CESM1 Large Ensemble Zarr dataset publicly available on Amazon S3 (us-west-2 region), without any aggregation info.", - "catalog_dict": [ - { - "component": "atm", - "frequency": "daily", - "experiment": "20C", - "variable": "FLNS", - "path": "s3://ncar-cesm-lens/atm/daily/cesmLE-20C-FLNS.zarr" - }, - { - "component": "atm", - "frequency": "daily", - "experiment": "20C", - "variable": "FLNSC", - "path": "s3://ncar-cesm-lens/atm/daily/cesmLE-20C-FLNSC.zarr" - }, - { - "component": "atm", - "frequency": "daily", - "experiment": "20C", - "variable": "FLUT", - "path": "s3://ncar-cesm-lens/atm/daily/cesmLE-20C-FLUT.zarr" - }, - { - "component": "atm", - "frequency": "daily", - "experiment": "20C", - "variable": "FSNS", - "path": "s3://ncar-cesm-lens/atm/daily/cesmLE-20C-FSNS.zarr" - }, - { - "component": "atm", - "frequency": "daily", - "experiment": "20C", - "variable": "FSNSC", - "path": "s3://ncar-cesm-lens/atm/daily/cesmLE-20C-FSNSC.zarr" - } - ], - "attributes": [ - { - "column_name": "component", - "vocabulary": "" - }, - { - "column_name": "frequency", - "vocabulary": "" - }, - { - "column_name": "experiment", - "vocabulary": "" - }, - { - "column_name": "variable", - "vocabulary": "" - } - ], - "assets": { - "column_name": "path", - "format": "zarr" - } -} diff --git a/tests/sample-catalogs/cmip6-netcdf-noagg.json b/tests/sample-catalogs/cmip6-netcdf-noagg.json new file mode 100644 index 00000000..d1fe2f84 --- /dev/null +++ b/tests/sample-catalogs/cmip6-netcdf-noagg.json @@ -0,0 +1,38 @@ +{ + "esmcat_version": "0.1.0", + "id": "sample-cmip6", + "description": "This is a sample ESM catalog for CMIP6 data in netcdf format", + "catalog_file": "cmip6-netcdf-test.csv", + "attributes": [ + { + "column_name": "activity_id", + "vocabulary": "https://raw.githubusercontent.com/WCRP-CMIP/CMIP6_CVs/master/CMIP6_activity_id.json" + }, + { + "column_name": "source_id", + "vocabulary": "https://raw.githubusercontent.com/WCRP-CMIP/CMIP6_CVs/master/CMIP6_source_id.json" + }, + { + "column_name": "institution_id", + "vocabulary": "https://raw.githubusercontent.com/WCRP-CMIP/CMIP6_CVs/master/CMIP6_institution_id.json" + }, + { + "column_name": "experiment_id", + "vocabulary": "https://raw.githubusercontent.com/WCRP-CMIP/CMIP6_CVs/master/CMIP6_experiment_id.json" + }, + { "column_name": "member_id", "vocabulary": "" }, + { + "column_name": "table_id", + "vocabulary": "https://raw.githubusercontent.com/WCRP-CMIP/CMIP6_CVs/master/CMIP6_table_id.json" + }, + { "column_name": "variable_id", "vocabulary": "" }, + { + "column_name": "grid_label", + "vocabulary": "https://raw.githubusercontent.com/WCRP-CMIP/CMIP6_CVs/master/CMIP6_grid_label.json" + } + ], + "assets": { + "column_name": "path", + "format": "netcdf" + } +} diff --git a/tests/test_cat.py b/tests/test_cat.py index 6726c31d..2190023a 100644 --- a/tests/test_cat.py +++ b/tests/test_cat.py @@ -9,6 +9,7 @@ cdf_cat_sample_cesmle, cdf_cat_sample_cmip5, cdf_cat_sample_cmip6, + cdf_cat_sample_cmip6_noagg, multi_variable_cat, sample_df, sample_esmcat_data, @@ -43,6 +44,7 @@ def test_assets_mutually_exclusive(): zarr_cat_pangeo_cmip6, cdf_cat_sample_cmip5, cdf_cat_sample_cmip6, + cdf_cat_sample_cmip6_noagg, cdf_cat_sample_cesmle, multi_variable_cat, ], diff --git a/tests/test_core.py b/tests/test_core.py index 4fa87801..4065e11a 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -43,7 +43,7 @@ def func_multivar(ds): cdf_cat_sample_cmip6, mixed_cat_sample_cmip6, multi_variable_cat, - noagg_cat, + cdf_cat_sample_cmip6_noagg, opendap_cat_sample_noaa, sample_df, sample_esmcat_data, @@ -56,7 +56,7 @@ def func_multivar(ds): 'obj, sep, read_csv_kwargs, columns_with_iterables', [ (catalog_dict_records, '.', None, None), - (noagg_cat, '.', None, None), + (cdf_cat_sample_cmip6_noagg, '.', None, None), (cdf_cat_sample_cmip6, '/', None, None), (zarr_cat_aws_cesm, '.', None, None), (zarr_cat_pangeo_cmip6, '*', None, None), @@ -142,8 +142,8 @@ def func(ds): ds['FOO'] = ds.FLNS + ds.FLUT return ds - with pytest.raises(ValueError, match='Variable derivation requires an `aggregation_control`'): - intake.open_esm_datastore(noagg_cat, registry=registry) + with pytest.raises(ValueError, match='Variable derivation requires `aggregation_control`'): + intake.open_esm_datastore(cdf_cat_sample_cmip6_noagg, registry=registry) @pytest.mark.parametrize( @@ -154,7 +154,7 @@ def func(ds): (cdf_cat_sample_cmip5, '.', None), (cdf_cat_sample_cmip6, '*', None), (catalog_dict_records, '.', None), - (noagg_cat, '.', None), + (cdf_cat_sample_cmip6_noagg, '.', None), ({'esmcat': sample_esmcat_data, 'df': sample_df}, '.', None), ], ) @@ -165,7 +165,7 @@ def test_catalog_unique(obj, sep, read_csv_kwargs): assert isinstance(uniques, pd.Series) assert isinstance(nuniques, pd.Series) assert len(uniques.keys()) == len(cat.df.columns) + ( - 0 if obj is noagg_cat else 1 + 0 if obj is cdf_cat_sample_cmip6_noagg else 1 ) # for derived_variable entry @@ -341,6 +341,11 @@ def test_multi_variable_catalog_derived_cat(): dict(source_id=['CNRM-ESM2-1', 'CNRM-CM6-1', 'BCC-ESM1'], variable_id=['tasmax']), {'chunks': {'time': 1}}, ), + ( + cdf_cat_sample_cmip6_noagg, + dict(source_id=['CNRM-ESM2-1', 'CNRM-CM6-1', 'BCC-ESM1'], variable_id=['tasmax']), + {'chunks': {'time': 1}}, + ), (mixed_cat_sample_cmip6, dict(institution_id='BCC'), {}), ], ) @@ -348,7 +353,8 @@ def test_to_dataset_dict(path, query, xarray_open_kwargs): cat = intake.open_esm_datastore(path) cat_sub = cat.search(**query) _, ds = cat_sub.to_dataset_dict(xarray_open_kwargs=xarray_open_kwargs).popitem() - assert 'member_id' in ds.dims + if path != cdf_cat_sample_cmip6_noagg: + assert 'member_id' in ds.dims assert len(ds.__dask_keys__()) > 0 assert ds.time.encoding @@ -425,6 +431,7 @@ def test_to_dask(path, query, xarray_open_kwargs): 'path, query', [ (cdf_cat_sample_cmip6, {'experiment_id': ['historical', 'rcp85']}), + (cdf_cat_sample_cmip6_noagg, {'experiment_id': ['historical', 'rcp85']}), (cdf_cat_sample_cmip5, {'experiment': ['historical', 'rcp85']}), ], ) diff --git a/tests/utils.py b/tests/utils.py index 89cd7dc0..661ff5f8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -9,7 +9,7 @@ cdf_cat_sample_cmip5 = os.path.join(here, 'sample-catalogs/cmip5-netcdf.json') cdf_cat_sample_cesmle = os.path.join(here, 'sample-catalogs/cesm1-lens-netcdf.json') catalog_dict_records = os.path.join(here, 'sample-catalogs/catalog-dict-records.json') -noagg_cat = os.path.join(here, 'sample-catalogs/catalog-dict-records-noagg.json') +cdf_cat_sample_cmip6_noagg = os.path.join(here, 'sample-catalogs/cmip6-netcdf-noagg.json') opendap_cat_sample_noaa = os.path.join(here, 'sample-catalogs/noaa-pathfinder-opendap.json') zarr_cat_aws_cesm = ( 'https://raw.githubusercontent.com/NCAR/cesm-lens-aws/master/intake-catalogs/aws-cesm1-le.json' From e830c70d2c8e832091b954d16907964e537f5253 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Wed, 10 May 2023 18:02:30 -0400 Subject: [PATCH 6/7] Ran pre-commit --- tests/test_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_core.py b/tests/test_core.py index 4065e11a..020b41fe 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -41,9 +41,9 @@ def func_multivar(ds): cdf_cat_sample_cesmle, cdf_cat_sample_cmip5, cdf_cat_sample_cmip6, + cdf_cat_sample_cmip6_noagg, mixed_cat_sample_cmip6, multi_variable_cat, - cdf_cat_sample_cmip6_noagg, opendap_cat_sample_noaa, sample_df, sample_esmcat_data, From e5205423b33c871a7deca7f3d9d8cf6f8944b25f Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Thu, 11 May 2023 14:39:35 -0400 Subject: [PATCH 7/7] Simplify and fix grouping with no agg --- intake_esm/cat.py | 28 ++++++++++++---------------- intake_esm/core.py | 9 +++------ tests/test_core.py | 5 +++-- 3 files changed, 18 insertions(+), 24 deletions(-) diff --git a/intake_esm/cat.py b/intake_esm/cat.py index 752d212e..7b276353 100644 --- a/intake_esm/cat.py +++ b/intake_esm/cat.py @@ -327,24 +327,20 @@ def grouped(self) -> typing.Union[pd.core.groupby.DataFrameGroupBy, pd.DataFrame self.aggregation_control.groupby_attrs ) != set(self.df.columns): return self.df.groupby(self.aggregation_control.groupby_attrs) - return self.df - - def _construct_group_keys(self, sep: str = '.') -> dict[str, typing.Union[str, tuple[str]]]: - grouped = self.grouped - if isinstance(grouped, pd.core.groupby.generic.DataFrameGroupBy): - internal_keys = grouped.groups.keys() - public_keys = map( - lambda key: key if isinstance(key, str) else sep.join(str(value) for value in key), - internal_keys, + cols = list( + filter( + functools.partial(_allnan_or_nonan, self.df), + self.df.columns, ) + ) + return self.df.groupby(cols) - else: - internal_keys = grouped.index - public_keys = ( - grouped[grouped.columns.tolist()] - .apply(lambda row: sep.join(str(v) for v in row), axis=1) - .tolist() - ) + def _construct_group_keys(self, sep: str = '.') -> dict[str, typing.Union[str, tuple[str]]]: + internal_keys = self.grouped.groups.keys() + public_keys = map( + lambda key: key if isinstance(key, str) else sep.join(str(value) for value in key), + internal_keys, + ) return dict(zip(public_keys, internal_keys)) diff --git a/intake_esm/core.py b/intake_esm/core.py index 8780c1fc..ce0dd7ac 100644 --- a/intake_esm/core.py +++ b/intake_esm/core.py @@ -168,13 +168,10 @@ def keys_info(self) -> pd.DataFrame: """ results = self.esmcat._construct_group_keys(sep=self.sep) if self.esmcat.aggregation_control and self.esmcat.aggregation_control.groupby_attrs: - pass + groupby_attrs = self.esmcat.aggregation_control.groupby_attrs else: - pass - data = { - key: dict(zip(self.esmcat.aggregation_control.groupby_attrs, results[key])) - for key in results - } + groupby_attrs = list(self.df.columns) + data = {key: dict(zip(groupby_attrs, results[key])) for key in results} data = pd.DataFrame.from_dict(data, orient='index') data.index.name = 'key' return data diff --git a/tests/test_core.py b/tests/test_core.py index 020b41fe..75515108 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -222,8 +222,9 @@ def test_catalog_getitem_error(): cat['foo'] -def test_catalog_keys_info(): - cat = intake.open_esm_datastore(cdf_cat_sample_cesmle) +@pytest.mark.parametrize('cat', [cdf_cat_sample_cesmle, cdf_cat_sample_cmip6_noagg]) +def test_catalog_keys_info(cat): + cat = intake.open_esm_datastore(cat) data = cat.keys_info() assert isinstance(data, pd.DataFrame) assert data.index.name == 'key'