Skip to content

Commit

Permalink
simplify the conversion functions using call_on_dataset (#110)
Browse files Browse the repository at this point in the history
* vendor a short helper function

* only attempt to convert to DataArray if the result is a Dataset

* also try attaching a empty dict

* refactor the conversion functions to use call_on_dataset

* catch the ImportError
  • Loading branch information
keewis authored Jul 24, 2021
1 parent 451b639 commit 5902933
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 126 deletions.
18 changes: 18 additions & 0 deletions pint_xarray/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import xarray as xr

try:
from xarray import call_on_dataset
except ImportError:

def call_on_dataset(func, obj, name, *args, **kwargs):
if isinstance(obj, xr.DataArray):
ds = obj.to_dataset(name=name)
else:
ds = obj

result = func(ds, *args, **kwargs)

if isinstance(obj, xr.DataArray) and isinstance(result, xr.Dataset):
result = result.get(name).rename(obj.name)

return result
240 changes: 114 additions & 126 deletions pint_xarray/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import pint
from xarray import DataArray, Dataset, IndexVariable, Variable

from .compat import call_on_dataset
from .errors import format_error_message

no_unit_values = ("none", None)
unit_attribute_name = "units"
slice_attributes = ("start", "stop", "step")
temporary_name = "<this-array>"


def array_attach_units(data, unit):
Expand Down Expand Up @@ -107,40 +109,49 @@ def attach_units_variable(variable, units):
return new_obj


def dataset_from_variables(variables, coords, attrs):
data_vars = {name: var for name, var in variables.items() if name not in coords}
coords = {name: var for name, var in variables.items() if name in coords}

return Dataset(data_vars=data_vars, coords=coords, attrs=attrs)


def attach_units_dataset(obj, units):
attached = {}
rejected_vars = {}
for name, var in obj.variables.items():
unit = units.get(name)
try:
converted = attach_units_variable(var, unit)
attached[name] = converted
except ValueError as e:
rejected_vars[name] = (unit, e)

if rejected_vars:
raise ValueError(rejected_vars)

return dataset_from_variables(attached, obj._coord_names, obj.attrs)


def attach_units(obj, units):
if not isinstance(obj, (DataArray, Dataset)):
raise ValueError(f"cannot attach units to {obj!r}: unknown type")

if isinstance(obj, DataArray):
old_name = obj.name
new_name = old_name if old_name is not None else "<this-array>"
ds = obj.rename(new_name).to_dataset()
units = units.copy()
units[new_name] = units.get(old_name)
if obj.name in units:
units[temporary_name] = units.get(obj.name)

new_ds = attach_units(ds, units)
new_obj = new_ds.get(new_name).rename(old_name)
elif isinstance(obj, Dataset):
attached = {}
rejected_vars = {}
for name, var in obj.variables.items():
unit = units.get(name)
try:
converted = attach_units_variable(var, unit)
attached[name] = converted
except ValueError as e:
rejected_vars[name] = (unit, e)

if rejected_vars:
raise ValueError(format_error_message(rejected_vars, "attach"))

data_vars = {
name: var for name, var in attached.items() if name not in obj._coord_names
}
coords = {
name: var for name, var in attached.items() if name in obj._coord_names
}

new_obj = Dataset(data_vars=data_vars, coords=coords, attrs=obj.attrs)
else:
raise ValueError(f"cannot attach units to {obj!r}: unknown type")
try:
new_obj = call_on_dataset(
attach_units_dataset, obj, name=temporary_name, units=units
)
except ValueError as e:
(rejected_vars,) = e.args
if temporary_name in rejected_vars:
rejected_vars[obj.name] = rejected_vars.pop(temporary_name)

raise ValueError(format_error_message(rejected_vars, "attach")) from e

return new_obj

Expand Down Expand Up @@ -192,87 +203,81 @@ def convert_units_variable(variable, units):
return new_obj


def convert_units_dataset(obj, units):
converted = {}
failed = {}
for name, var in obj.variables.items():
unit = units.get(name)
try:
converted[name] = convert_units_variable(var, unit)
except (ValueError, pint.errors.PintTypeError) as e:
failed[name] = e

if failed:
raise ValueError(failed)

return dataset_from_variables(converted, obj._coord_names, obj.attrs)


def convert_units(obj, units):
if isinstance(obj, DataArray):
original_name = obj.name
name = obj.name if obj.name is not None else "<this-array>"
if not isinstance(obj, (DataArray, Dataset)):
raise ValueError(f"cannot convert object: {obj!r}: unknown type")

units_ = units.copy()
if obj.name in units_:
units_[name] = units_[obj.name]
if isinstance(obj, DataArray):
units = units.copy()
if obj.name in units:
units[temporary_name] = units.pop(obj.name)

ds = obj.rename(name).to_dataset()
converted = convert_units(ds, units_)
try:
new_obj = call_on_dataset(
convert_units_dataset, obj, name=temporary_name, units=units
)
except ValueError as e:
(failed,) = e.args
if temporary_name in failed:
failed[obj.name] = failed.pop(temporary_name)

new_obj = converted[name].rename(original_name)
elif isinstance(obj, Dataset):
converted = {}
failed = {}
for name, var in obj.variables.items():
unit = units.get(name)
try:
converted[name] = convert_units_variable(var, unit)
except (ValueError, pint.errors.PintTypeError) as e:
failed[name] = e

if failed:
raise ValueError(format_error_message(failed, "convert"))

coords = {
name: var for name, var in converted.items() if name in obj._coord_names
}
data_vars = {
name: var for name, var in converted.items() if name not in obj._coord_names
}

new_obj = Dataset(data_vars=data_vars, coords=coords, attrs=obj.attrs)
else:
raise ValueError(f"cannot convert object: {obj}")
raise ValueError(format_error_message(failed, "convert")) from e

return new_obj


def extract_units(obj):
if isinstance(obj, Dataset):
units = extract_unit_attributes(obj)
dims = obj.dims
units.update(
{
name: array_extract_units(value.data)
for name, value in obj.variables.items()
if name not in dims
}
)
elif isinstance(obj, DataArray):
original_name = obj.name
name = obj.name if obj.name is not None else "<this-array>"
def extract_units_dataset(obj):
return {name: array_extract_units(var.data) for name, var in obj.variables.items()}

ds = obj.rename(name).to_dataset()

units = extract_units(ds)
units[original_name] = units.pop(name)
else:
def extract_units(obj):
if not isinstance(obj, (DataArray, Dataset)):
raise ValueError(f"unknown type: {type(obj)}")

return units
unit_attributes = extract_unit_attributes(obj)

units = call_on_dataset(extract_units_dataset, obj, name=temporary_name)
if temporary_name in units:
units[obj.name] = units.pop(temporary_name)

def extract_unit_attributes(obj, attr="units"):
if isinstance(obj, DataArray):
original_name = obj.name
name = obj.name if obj.name is not None else "<this-array>"
units_ = unit_attributes.copy()
units_.update({k: v for k, v in units.items() if v is not None})

ds = obj.rename(name).to_dataset()
return units_

units = extract_unit_attributes(ds)
units[original_name] = units.pop(name)
elif isinstance(obj, Dataset):
units = {name: var.attrs.get(attr, None) for name, var in obj.variables.items()}
else:

def extract_unit_attributes_dataset(obj, attr="units"):
return {name: var.attrs.get(attr, None) for name, var in obj.variables.items()}


def extract_unit_attributes(obj, attr="units"):
if not isinstance(obj, (DataArray, Dataset)):
raise ValueError(
f"cannot retrieve unit attributes from unknown type: {type(obj)}"
)

units = call_on_dataset(
extract_unit_attributes_dataset, obj, name=temporary_name, attr=attr
)
if temporary_name in units:
units[obj.name] = units.pop(temporary_name)

return units


Expand All @@ -281,51 +286,34 @@ def strip_units_variable(var):
return var.copy(data=data)


def strip_units(obj):
if isinstance(obj, DataArray):
original_name = obj.name
name = obj.name if obj.name is not None else "<this-array>"
ds = obj.rename(name).to_dataset()
stripped = strip_units(ds)
def strip_units_dataset(obj):
variables = {name: strip_units_variable(var) for name, var in obj.variables.items()}

new_obj = stripped[name].rename(original_name)
elif isinstance(obj, Dataset):
data_vars = {
name: strip_units_variable(variable)
for name, variable in obj.variables.items()
if name not in obj._coord_names
}
coords = {
name: strip_units_variable(variable)
for name, variable in obj.variables.items()
if name in obj._coord_names
}

new_obj = Dataset(data_vars=data_vars, coords=coords, attrs=obj.attrs)
else:
return dataset_from_variables(variables, obj._coord_names, obj.attrs)


def strip_units(obj):
if not isinstance(obj, (DataArray, Dataset)):
raise ValueError("cannot strip units from {obj!r}: unknown type")

return new_obj
return call_on_dataset(strip_units_dataset, obj, name=temporary_name)


def strip_unit_attributes(obj, attr="units"):
if isinstance(obj, DataArray):
original_name = obj.name
name = obj.name if obj.name is not None else "<this-array>"
def strip_unit_attributes_dataset(obj, attr="units"):
new_obj = obj.copy()
for var in new_obj.variables.values():
var.attrs.pop(attr, None)

ds = obj.rename(name).to_dataset()
return new_obj

stripped = strip_unit_attributes(ds)

new_obj = stripped[name].rename(original_name)
elif isinstance(obj, Dataset):
new_obj = obj.copy()
for var in new_obj.variables.values():
var.attrs.pop(attr, None)
else:
def strip_unit_attributes(obj, attr="units"):
if not isinstance(obj, (DataArray, Dataset)):
raise ValueError(f"cannot strip unit attributes from unknown type: {type(obj)}")

return new_obj
return call_on_dataset(
strip_unit_attributes_dataset, obj, name=temporary_name, attr=attr
)


def slice_extract_units(indexer):
Expand Down
1 change: 1 addition & 0 deletions pint_xarray/tests/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ class TestXarrayFunctions:
@pytest.mark.parametrize(
"units",
(
pytest.param({}, id="empty units"),
pytest.param({"a": None, "b": None, "u": None, "x": None}, id="no units"),
pytest.param(
{"a": unit_registry.m, "b": unit_registry.m, "u": None, "x": None},
Expand Down

0 comments on commit 5902933

Please sign in to comment.