Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify the conversion functions using call_on_dataset #110

Merged
merged 6 commits into from
Jul 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions pint_xarray/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import xarray as xr

try:
from xarray import call_on_dataset
except ImportError:

def call_on_dataset(func, obj, name, *args, **kwargs):
if isinstance(obj, xr.DataArray):
ds = obj.to_dataset(name=name)
else:
ds = obj

result = func(ds, *args, **kwargs)

if isinstance(obj, xr.DataArray) and isinstance(result, xr.Dataset):
result = result.get(name).rename(obj.name)

return result
240 changes: 114 additions & 126 deletions pint_xarray/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
import pint
from xarray import DataArray, Dataset, IndexVariable, Variable

from .compat import call_on_dataset
from .errors import format_error_message

no_unit_values = ("none", None)
unit_attribute_name = "units"
slice_attributes = ("start", "stop", "step")
temporary_name = "<this-array>"


def array_attach_units(data, unit):
Expand Down Expand Up @@ -107,40 +109,49 @@ def attach_units_variable(variable, units):
return new_obj


def dataset_from_variables(variables, coords, attrs):
data_vars = {name: var for name, var in variables.items() if name not in coords}
coords = {name: var for name, var in variables.items() if name in coords}

return Dataset(data_vars=data_vars, coords=coords, attrs=attrs)


def attach_units_dataset(obj, units):
attached = {}
rejected_vars = {}
for name, var in obj.variables.items():
unit = units.get(name)
try:
converted = attach_units_variable(var, unit)
attached[name] = converted
except ValueError as e:
rejected_vars[name] = (unit, e)

if rejected_vars:
raise ValueError(rejected_vars)

return dataset_from_variables(attached, obj._coord_names, obj.attrs)


def attach_units(obj, units):
if not isinstance(obj, (DataArray, Dataset)):
raise ValueError(f"cannot attach units to {obj!r}: unknown type")

if isinstance(obj, DataArray):
old_name = obj.name
new_name = old_name if old_name is not None else "<this-array>"
ds = obj.rename(new_name).to_dataset()
units = units.copy()
units[new_name] = units.get(old_name)
if obj.name in units:
units[temporary_name] = units.get(obj.name)

new_ds = attach_units(ds, units)
new_obj = new_ds.get(new_name).rename(old_name)
elif isinstance(obj, Dataset):
attached = {}
rejected_vars = {}
for name, var in obj.variables.items():
unit = units.get(name)
try:
converted = attach_units_variable(var, unit)
attached[name] = converted
except ValueError as e:
rejected_vars[name] = (unit, e)

if rejected_vars:
raise ValueError(format_error_message(rejected_vars, "attach"))

data_vars = {
name: var for name, var in attached.items() if name not in obj._coord_names
}
coords = {
name: var for name, var in attached.items() if name in obj._coord_names
}

new_obj = Dataset(data_vars=data_vars, coords=coords, attrs=obj.attrs)
else:
raise ValueError(f"cannot attach units to {obj!r}: unknown type")
try:
new_obj = call_on_dataset(
attach_units_dataset, obj, name=temporary_name, units=units
)
except ValueError as e:
(rejected_vars,) = e.args
if temporary_name in rejected_vars:
rejected_vars[obj.name] = rejected_vars.pop(temporary_name)

raise ValueError(format_error_message(rejected_vars, "attach")) from e

return new_obj

Expand Down Expand Up @@ -192,87 +203,81 @@ def convert_units_variable(variable, units):
return new_obj


def convert_units_dataset(obj, units):
converted = {}
failed = {}
for name, var in obj.variables.items():
unit = units.get(name)
try:
converted[name] = convert_units_variable(var, unit)
except (ValueError, pint.errors.PintTypeError) as e:
failed[name] = e

if failed:
raise ValueError(failed)

return dataset_from_variables(converted, obj._coord_names, obj.attrs)


def convert_units(obj, units):
if isinstance(obj, DataArray):
original_name = obj.name
name = obj.name if obj.name is not None else "<this-array>"
if not isinstance(obj, (DataArray, Dataset)):
raise ValueError(f"cannot convert object: {obj!r}: unknown type")

units_ = units.copy()
if obj.name in units_:
units_[name] = units_[obj.name]
if isinstance(obj, DataArray):
units = units.copy()
if obj.name in units:
units[temporary_name] = units.pop(obj.name)

ds = obj.rename(name).to_dataset()
converted = convert_units(ds, units_)
try:
new_obj = call_on_dataset(
convert_units_dataset, obj, name=temporary_name, units=units
)
except ValueError as e:
(failed,) = e.args
if temporary_name in failed:
failed[obj.name] = failed.pop(temporary_name)

new_obj = converted[name].rename(original_name)
elif isinstance(obj, Dataset):
converted = {}
failed = {}
for name, var in obj.variables.items():
unit = units.get(name)
try:
converted[name] = convert_units_variable(var, unit)
except (ValueError, pint.errors.PintTypeError) as e:
failed[name] = e

if failed:
raise ValueError(format_error_message(failed, "convert"))

coords = {
name: var for name, var in converted.items() if name in obj._coord_names
}
data_vars = {
name: var for name, var in converted.items() if name not in obj._coord_names
}

new_obj = Dataset(data_vars=data_vars, coords=coords, attrs=obj.attrs)
else:
raise ValueError(f"cannot convert object: {obj}")
raise ValueError(format_error_message(failed, "convert")) from e

return new_obj


def extract_units(obj):
if isinstance(obj, Dataset):
units = extract_unit_attributes(obj)
dims = obj.dims
units.update(
{
name: array_extract_units(value.data)
for name, value in obj.variables.items()
if name not in dims
}
)
elif isinstance(obj, DataArray):
original_name = obj.name
name = obj.name if obj.name is not None else "<this-array>"
def extract_units_dataset(obj):
return {name: array_extract_units(var.data) for name, var in obj.variables.items()}

ds = obj.rename(name).to_dataset()

units = extract_units(ds)
units[original_name] = units.pop(name)
else:
def extract_units(obj):
if not isinstance(obj, (DataArray, Dataset)):
raise ValueError(f"unknown type: {type(obj)}")

return units
unit_attributes = extract_unit_attributes(obj)

units = call_on_dataset(extract_units_dataset, obj, name=temporary_name)
if temporary_name in units:
units[obj.name] = units.pop(temporary_name)

def extract_unit_attributes(obj, attr="units"):
if isinstance(obj, DataArray):
original_name = obj.name
name = obj.name if obj.name is not None else "<this-array>"
units_ = unit_attributes.copy()
units_.update({k: v for k, v in units.items() if v is not None})

ds = obj.rename(name).to_dataset()
return units_

units = extract_unit_attributes(ds)
units[original_name] = units.pop(name)
elif isinstance(obj, Dataset):
units = {name: var.attrs.get(attr, None) for name, var in obj.variables.items()}
else:

def extract_unit_attributes_dataset(obj, attr="units"):
return {name: var.attrs.get(attr, None) for name, var in obj.variables.items()}


def extract_unit_attributes(obj, attr="units"):
if not isinstance(obj, (DataArray, Dataset)):
raise ValueError(
f"cannot retrieve unit attributes from unknown type: {type(obj)}"
)

units = call_on_dataset(
extract_unit_attributes_dataset, obj, name=temporary_name, attr=attr
)
if temporary_name in units:
units[obj.name] = units.pop(temporary_name)

return units


Expand All @@ -281,51 +286,34 @@ def strip_units_variable(var):
return var.copy(data=data)


def strip_units(obj):
if isinstance(obj, DataArray):
original_name = obj.name
name = obj.name if obj.name is not None else "<this-array>"
ds = obj.rename(name).to_dataset()
stripped = strip_units(ds)
def strip_units_dataset(obj):
variables = {name: strip_units_variable(var) for name, var in obj.variables.items()}

new_obj = stripped[name].rename(original_name)
elif isinstance(obj, Dataset):
data_vars = {
name: strip_units_variable(variable)
for name, variable in obj.variables.items()
if name not in obj._coord_names
}
coords = {
name: strip_units_variable(variable)
for name, variable in obj.variables.items()
if name in obj._coord_names
}

new_obj = Dataset(data_vars=data_vars, coords=coords, attrs=obj.attrs)
else:
return dataset_from_variables(variables, obj._coord_names, obj.attrs)


def strip_units(obj):
if not isinstance(obj, (DataArray, Dataset)):
raise ValueError("cannot strip units from {obj!r}: unknown type")

return new_obj
return call_on_dataset(strip_units_dataset, obj, name=temporary_name)


def strip_unit_attributes(obj, attr="units"):
if isinstance(obj, DataArray):
original_name = obj.name
name = obj.name if obj.name is not None else "<this-array>"
def strip_unit_attributes_dataset(obj, attr="units"):
new_obj = obj.copy()
for var in new_obj.variables.values():
var.attrs.pop(attr, None)

ds = obj.rename(name).to_dataset()
return new_obj

stripped = strip_unit_attributes(ds)

new_obj = stripped[name].rename(original_name)
elif isinstance(obj, Dataset):
new_obj = obj.copy()
for var in new_obj.variables.values():
var.attrs.pop(attr, None)
else:
def strip_unit_attributes(obj, attr="units"):
if not isinstance(obj, (DataArray, Dataset)):
raise ValueError(f"cannot strip unit attributes from unknown type: {type(obj)}")

return new_obj
return call_on_dataset(
strip_unit_attributes_dataset, obj, name=temporary_name, attr=attr
)


def slice_extract_units(indexer):
Expand Down
1 change: 1 addition & 0 deletions pint_xarray/tests/test_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ class TestXarrayFunctions:
@pytest.mark.parametrize(
"units",
(
pytest.param({}, id="empty units"),
pytest.param({"a": None, "b": None, "u": None, "x": None}, id="no units"),
pytest.param(
{"a": unit_registry.m, "b": unit_registry.m, "u": None, "x": None},
Expand Down