Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optional aggregation control #569

Merged
merged 8 commits into from
May 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Fix the link to documentation build status [#591](https://github.com/intake/intake-esm/pull/591) ([@mgrover1](https://github.com/mgrover1))
- Add optional `columns_with_iterables` argument to `esm_datastore` [#589](https://github.com/intake/intake-esm/pull/589) ([@dougiesquire](https://github.com/dougiesquire))
- Add `opendap` as a possible data format [#570](https://github.com/intake/intake-esm/pull/570) ([@aulemahal](https://github.com/aulemahal))
- Fix for catalogs without `aggregation_control`, as allowed by the ESM catalog specification [#569](https://github.com/intake/intake-esm/pull/569) ([@aulemahal](https://github.com/aulemahal))

## v2022.9.18

Expand Down
66 changes: 33 additions & 33 deletions intake_esm/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,52 +295,52 @@ def df(self) -> pd.DataFrame:
@property
def has_multiple_variable_assets(self) -> bool:
"""Return True if the catalog has multiple variable assets."""
return self.aggregation_control.variable_column_name in self.columns_with_iterables
if self.aggregation_control:
return self.aggregation_control.variable_column_name in self.columns_with_iterables
return False

def _cast_agg_columns_with_iterables(self) -> None:
"""Cast all agg_columns with iterables to tuple values so as
to avoid hashing issues (e.g. TypeError: unhashable type: 'list')
"""
columns = list(
self.columns_with_iterables.intersection(
set(map(lambda agg: agg.attribute_name, self.aggregation_control.aggregations))
if self.aggregation_control:
columns = list(
self.columns_with_iterables.intersection(
set(map(lambda agg: agg.attribute_name, self.aggregation_control.aggregations))
)
)
)
if columns:
self._df[columns] = self._df[columns].apply(tuple)
if columns:
self._df[columns] = self._df[columns].apply(tuple)

@property
def grouped(self) -> typing.Union[pd.core.groupby.DataFrameGroupBy, pd.DataFrame]:
if self.aggregation_control.groupby_attrs:
self.aggregation_control.groupby_attrs = list(
filter(
functools.partial(_allnan_or_nonan, self.df),
self.aggregation_control.groupby_attrs,
if self.aggregation_control:
if self.aggregation_control.groupby_attrs:
self.aggregation_control.groupby_attrs = list(
filter(
functools.partial(_allnan_or_nonan, self.df),
self.aggregation_control.groupby_attrs,
)
)
)

if self.aggregation_control.groupby_attrs and set(
self.aggregation_control.groupby_attrs
) != set(self.df.columns):
return self.df.groupby(self.aggregation_control.groupby_attrs)
return self.df

def _construct_group_keys(self, sep: str = '.') -> dict[str, typing.Union[str, tuple[str]]]:
grouped = self.grouped
if isinstance(grouped, pd.core.groupby.generic.DataFrameGroupBy):
internal_keys = grouped.groups.keys()
public_keys = map(
lambda key: key if isinstance(key, str) else sep.join(str(value) for value in key),
internal_keys,
if self.aggregation_control.groupby_attrs and set(
self.aggregation_control.groupby_attrs
) != set(self.df.columns):
return self.df.groupby(self.aggregation_control.groupby_attrs)
cols = list(
filter(
functools.partial(_allnan_or_nonan, self.df),
self.df.columns,
)
)
return self.df.groupby(cols)

else:
internal_keys = grouped.index
public_keys = (
grouped[grouped.columns.tolist()]
.apply(lambda row: sep.join(str(v) for v in row), axis=1)
.tolist()
)
def _construct_group_keys(self, sep: str = '.') -> dict[str, typing.Union[str, tuple[str]]]:
internal_keys = self.grouped.groups.keys()
public_keys = map(
lambda key: key if isinstance(key, str) else sep.join(str(value) for value in key),
internal_keys,
)

return dict(zip(public_keys, internal_keys))

Expand Down
56 changes: 38 additions & 18 deletions intake_esm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ def __init__(
self._validate_derivedcat()

def _validate_derivedcat(self) -> None:
if self.esmcat.aggregation_control is None and len(self.derivedcat):
raise ValueError(
'Variable derivation requires `aggregation_control` to be specified in the catalog.'
)
for key, entry in self.derivedcat.items():
if self.esmcat.aggregation_control.variable_column_name not in entry.query.keys():
raise ValueError(
Expand Down Expand Up @@ -163,10 +167,11 @@ def keys_info(self) -> pd.DataFrame:

"""
results = self.esmcat._construct_group_keys(sep=self.sep)
data = {
key: dict(zip(self.esmcat.aggregation_control.groupby_attrs, results[key]))
for key in results
}
if self.esmcat.aggregation_control and self.esmcat.aggregation_control.groupby_attrs:
groupby_attrs = self.esmcat.aggregation_control.groupby_attrs
else:
groupby_attrs = list(self.df.columns)
data = {key: dict(zip(groupby_attrs, results[key])) for key in results}
data = pd.DataFrame.from_dict(data, orient='index')
data.index.name = 'key'
return data
Expand All @@ -181,7 +186,7 @@ def key_template(self) -> str:
str
string template used to create catalog entry keys
"""
if self.esmcat.aggregation_control.groupby_attrs:
if self.esmcat.aggregation_control and self.esmcat.aggregation_control.groupby_attrs:
return self.sep.join(self.esmcat.aggregation_control.groupby_attrs)
else:
return self.sep.join(self.esmcat.df.columns)
Expand Down Expand Up @@ -247,15 +252,21 @@ def __getitem__(self, key: str) -> ESMDataSource:
else:
records = grouped.get_group(internal_key).to_dict(orient='records')

if self.esmcat.aggregation_control:
variable_column_name = self.esmcat.aggregation_control.variable_column_name
aggregations = self.esmcat.aggregation_control.aggregations
else:
variable_column_name = None
aggregations = []
# Create a new entry
entry = ESMDataSource(
key=key,
records=records,
variable_column_name=self.esmcat.aggregation_control.variable_column_name,
variable_column_name=variable_column_name,
path_column_name=self.esmcat.assets.column_name,
data_format=self.esmcat.assets.format,
format_column_name=self.esmcat.assets.format_column_name,
aggregations=self.esmcat.aggregation_control.aggregations,
aggregations=aggregations,
intake_kwargs={'metadata': {}},
)
self._entries[key] = entry
Expand Down Expand Up @@ -380,7 +391,10 @@ def search(self, require_all_on: typing.Union[str, list[str]] = None, **query: t
# step 2: Search for entries required to derive variables in the derived catalogs
# This requires a bit of a hack i.e. the user has to specify the variable in the query
derivedcat_results = []
variables = query.pop(self.esmcat.aggregation_control.variable_column_name, None)
if self.esmcat.aggregation_control:
variables = query.pop(self.esmcat.aggregation_control.variable_column_name, None)
else:
variables = None
dependents = []
derived_cat_subset = {}
if variables:
Expand Down Expand Up @@ -502,19 +516,21 @@ def nunique(self) -> pd.Series:
dtype: int64
"""
nunique = self.esmcat.nunique()
nunique[f'derived_{self.esmcat.aggregation_control.variable_column_name}'] = len(
self.derivedcat.keys()
)
if self.esmcat.aggregation_control:
nunique[f'derived_{self.esmcat.aggregation_control.variable_column_name}'] = len(
self.derivedcat.keys()
)
return nunique

def unique(self) -> pd.Series:
"""Return unique values for given columns in the
catalog.
"""
unique = self.esmcat.unique()
unique[f'derived_{self.esmcat.aggregation_control.variable_column_name}'] = list(
self.derivedcat.keys()
)
if self.esmcat.aggregation_control:
unique[f'derived_{self.esmcat.aggregation_control.variable_column_name}'] = list(
self.derivedcat.keys()
)
return unique

@pydantic.validate_arguments
Expand Down Expand Up @@ -600,9 +616,13 @@ def to_dataset_dict(
return {}

if (
self.esmcat.aggregation_control.variable_column_name
in self.esmcat.aggregation_control.groupby_attrs
) and len(self.derivedcat) > 0:
self.esmcat.aggregation_control
and (
self.esmcat.aggregation_control.variable_column_name
in self.esmcat.aggregation_control.groupby_attrs
)
and len(self.derivedcat) > 0
):
raise NotImplementedError(
f'The `{self.esmcat.aggregation_control.variable_column_name}` column name is used as a groupby attribute: {self.esmcat.aggregation_control.groupby_attrs}. '
'This is not yet supported when computing derived variables.'
Expand Down Expand Up @@ -632,7 +652,7 @@ def to_dataset_dict(
storage_options=storage_options,
)

if aggregate is not None and not aggregate:
if aggregate is not None and not aggregate and self.esmcat.aggregation_control:
self = deepcopy(self)
self.esmcat.aggregation_control.groupby_attrs = []
if progressbar is not None:
Expand Down
10 changes: 5 additions & 5 deletions intake_esm/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _open_dataset(
ds = ds.set_coords(scalar_variables)
ds = ds[variables]
ds.attrs[OPTIONS['vars_key']] = variables
else:
elif varname:
ds.attrs[OPTIONS['vars_key']] = varname

ds = _expand_dims(expand_dims, ds)
Expand Down Expand Up @@ -125,11 +125,11 @@ def __init__(
self,
key: pydantic.StrictStr,
records: list[dict[str, typing.Any]],
variable_column_name: pydantic.StrictStr,
path_column_name: pydantic.StrictStr,
data_format: typing.Optional[DataFormat],
format_column_name: typing.Optional[pydantic.StrictStr],
*,
variable_column_name: typing.Optional[pydantic.StrictStr] = None,
aggregations: typing.Optional[list[Aggregation]] = None,
requested_variables: list[str] = None,
preprocess: typing.Callable = None,
Expand All @@ -147,12 +147,12 @@ def __init__(
records: list of dict
A list of records, each of which is a dictionary
mapping column names to values.
variable_column_name: str
The column name of the variable name.
path_column_name: str
The column name of the path.
data_format: DataFormat
The data format of the data.
variable_column_name: str, optional
The column name of the variable name.
aggregations: list of Aggregation, optional
A list of aggregations to apply to the data.
requested_variables: list of str, optional
Expand Down Expand Up @@ -218,7 +218,7 @@ def _open_dataset(self):
datasets = [
_open_dataset(
record[self.path_column_name],
record[self.variable_column_name],
record[self.variable_column_name] if self.variable_column_name else None,
xarray_open_kwargs=_get_xarray_open_kwargs(
record['_data_format_'], self.xarray_open_kwargs, self.storage_options
),
Expand Down
38 changes: 38 additions & 0 deletions tests/sample-catalogs/cmip6-netcdf-noagg.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
{
"esmcat_version": "0.1.0",
"id": "sample-cmip6",
"description": "This is a sample ESM catalog for CMIP6 data in netcdf format",
"catalog_file": "cmip6-netcdf-test.csv",
"attributes": [
{
"column_name": "activity_id",
"vocabulary": "https://raw.githubusercontent.com/WCRP-CMIP/CMIP6_CVs/master/CMIP6_activity_id.json"
},
{
"column_name": "source_id",
"vocabulary": "https://raw.githubusercontent.com/WCRP-CMIP/CMIP6_CVs/master/CMIP6_source_id.json"
},
{
"column_name": "institution_id",
"vocabulary": "https://raw.githubusercontent.com/WCRP-CMIP/CMIP6_CVs/master/CMIP6_institution_id.json"
},
{
"column_name": "experiment_id",
"vocabulary": "https://raw.githubusercontent.com/WCRP-CMIP/CMIP6_CVs/master/CMIP6_experiment_id.json"
},
{ "column_name": "member_id", "vocabulary": "" },
{
"column_name": "table_id",
"vocabulary": "https://raw.githubusercontent.com/WCRP-CMIP/CMIP6_CVs/master/CMIP6_table_id.json"
},
{ "column_name": "variable_id", "vocabulary": "" },
{
"column_name": "grid_label",
"vocabulary": "https://raw.githubusercontent.com/WCRP-CMIP/CMIP6_CVs/master/CMIP6_grid_label.json"
}
],
"assets": {
"column_name": "path",
"format": "netcdf"
}
}
2 changes: 2 additions & 0 deletions tests/test_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
cdf_cat_sample_cesmle,
cdf_cat_sample_cmip5,
cdf_cat_sample_cmip6,
cdf_cat_sample_cmip6_noagg,
multi_variable_cat,
sample_df,
sample_esmcat_data,
Expand Down Expand Up @@ -43,6 +44,7 @@ def test_assets_mutually_exclusive():
zarr_cat_pangeo_cmip6,
cdf_cat_sample_cmip5,
cdf_cat_sample_cmip6,
cdf_cat_sample_cmip6_noagg,
cdf_cat_sample_cesmle,
multi_variable_cat,
],
Expand Down
Loading