diff --git a/pint_xarray/compat.py b/pint_xarray/compat.py new file mode 100644 index 00000000..dcabd47f --- /dev/null +++ b/pint_xarray/compat.py @@ -0,0 +1,18 @@ +import xarray as xr + +try: + from xarray import call_on_dataset +except ImportError: + + def call_on_dataset(func, obj, name, *args, **kwargs): + if isinstance(obj, xr.DataArray): + ds = obj.to_dataset(name=name) + else: + ds = obj + + result = func(ds, *args, **kwargs) + + if isinstance(obj, xr.DataArray) and isinstance(result, xr.Dataset): + result = result.get(name).rename(obj.name) + + return result diff --git a/pint_xarray/conversion.py b/pint_xarray/conversion.py index 26ba6afd..1940e492 100644 --- a/pint_xarray/conversion.py +++ b/pint_xarray/conversion.py @@ -3,11 +3,13 @@ import pint from xarray import DataArray, Dataset, IndexVariable, Variable +from .compat import call_on_dataset from .errors import format_error_message no_unit_values = ("none", None) unit_attribute_name = "units" slice_attributes = ("start", "stop", "step") +temporary_name = "" def array_attach_units(data, unit): @@ -107,40 +109,49 @@ def attach_units_variable(variable, units): return new_obj +def dataset_from_variables(variables, coords, attrs): + data_vars = {name: var for name, var in variables.items() if name not in coords} + coords = {name: var for name, var in variables.items() if name in coords} + + return Dataset(data_vars=data_vars, coords=coords, attrs=attrs) + + +def attach_units_dataset(obj, units): + attached = {} + rejected_vars = {} + for name, var in obj.variables.items(): + unit = units.get(name) + try: + converted = attach_units_variable(var, unit) + attached[name] = converted + except ValueError as e: + rejected_vars[name] = (unit, e) + + if rejected_vars: + raise ValueError(rejected_vars) + + return dataset_from_variables(attached, obj._coord_names, obj.attrs) + + def attach_units(obj, units): + if not isinstance(obj, (DataArray, Dataset)): + raise ValueError(f"cannot attach units to {obj!r}: unknown type") + if isinstance(obj, DataArray): - old_name = obj.name - new_name = old_name if old_name is not None else "" - ds = obj.rename(new_name).to_dataset() units = units.copy() - units[new_name] = units.get(old_name) + if obj.name in units: + units[temporary_name] = units.get(obj.name) - new_ds = attach_units(ds, units) - new_obj = new_ds.get(new_name).rename(old_name) - elif isinstance(obj, Dataset): - attached = {} - rejected_vars = {} - for name, var in obj.variables.items(): - unit = units.get(name) - try: - converted = attach_units_variable(var, unit) - attached[name] = converted - except ValueError as e: - rejected_vars[name] = (unit, e) - - if rejected_vars: - raise ValueError(format_error_message(rejected_vars, "attach")) - - data_vars = { - name: var for name, var in attached.items() if name not in obj._coord_names - } - coords = { - name: var for name, var in attached.items() if name in obj._coord_names - } - - new_obj = Dataset(data_vars=data_vars, coords=coords, attrs=obj.attrs) - else: - raise ValueError(f"cannot attach units to {obj!r}: unknown type") + try: + new_obj = call_on_dataset( + attach_units_dataset, obj, name=temporary_name, units=units + ) + except ValueError as e: + (rejected_vars,) = e.args + if temporary_name in rejected_vars: + rejected_vars[obj.name] = rejected_vars.pop(temporary_name) + + raise ValueError(format_error_message(rejected_vars, "attach")) from e return new_obj @@ -192,87 +203,81 @@ def convert_units_variable(variable, units): return new_obj +def convert_units_dataset(obj, units): + converted = {} + failed = {} + for name, var in obj.variables.items(): + unit = units.get(name) + try: + converted[name] = convert_units_variable(var, unit) + except (ValueError, pint.errors.PintTypeError) as e: + failed[name] = e + + if failed: + raise ValueError(failed) + + return dataset_from_variables(converted, obj._coord_names, obj.attrs) + + def convert_units(obj, units): - if isinstance(obj, DataArray): - original_name = obj.name - name = obj.name if obj.name is not None else "" + if not isinstance(obj, (DataArray, Dataset)): + raise ValueError(f"cannot convert object: {obj!r}: unknown type") - units_ = units.copy() - if obj.name in units_: - units_[name] = units_[obj.name] + if isinstance(obj, DataArray): + units = units.copy() + if obj.name in units: + units[temporary_name] = units.pop(obj.name) - ds = obj.rename(name).to_dataset() - converted = convert_units(ds, units_) + try: + new_obj = call_on_dataset( + convert_units_dataset, obj, name=temporary_name, units=units + ) + except ValueError as e: + (failed,) = e.args + if temporary_name in failed: + failed[obj.name] = failed.pop(temporary_name) - new_obj = converted[name].rename(original_name) - elif isinstance(obj, Dataset): - converted = {} - failed = {} - for name, var in obj.variables.items(): - unit = units.get(name) - try: - converted[name] = convert_units_variable(var, unit) - except (ValueError, pint.errors.PintTypeError) as e: - failed[name] = e - - if failed: - raise ValueError(format_error_message(failed, "convert")) - - coords = { - name: var for name, var in converted.items() if name in obj._coord_names - } - data_vars = { - name: var for name, var in converted.items() if name not in obj._coord_names - } - - new_obj = Dataset(data_vars=data_vars, coords=coords, attrs=obj.attrs) - else: - raise ValueError(f"cannot convert object: {obj}") + raise ValueError(format_error_message(failed, "convert")) from e return new_obj -def extract_units(obj): - if isinstance(obj, Dataset): - units = extract_unit_attributes(obj) - dims = obj.dims - units.update( - { - name: array_extract_units(value.data) - for name, value in obj.variables.items() - if name not in dims - } - ) - elif isinstance(obj, DataArray): - original_name = obj.name - name = obj.name if obj.name is not None else "" +def extract_units_dataset(obj): + return {name: array_extract_units(var.data) for name, var in obj.variables.items()} - ds = obj.rename(name).to_dataset() - units = extract_units(ds) - units[original_name] = units.pop(name) - else: +def extract_units(obj): + if not isinstance(obj, (DataArray, Dataset)): raise ValueError(f"unknown type: {type(obj)}") - return units + unit_attributes = extract_unit_attributes(obj) + units = call_on_dataset(extract_units_dataset, obj, name=temporary_name) + if temporary_name in units: + units[obj.name] = units.pop(temporary_name) -def extract_unit_attributes(obj, attr="units"): - if isinstance(obj, DataArray): - original_name = obj.name - name = obj.name if obj.name is not None else "" + units_ = unit_attributes.copy() + units_.update({k: v for k, v in units.items() if v is not None}) - ds = obj.rename(name).to_dataset() + return units_ - units = extract_unit_attributes(ds) - units[original_name] = units.pop(name) - elif isinstance(obj, Dataset): - units = {name: var.attrs.get(attr, None) for name, var in obj.variables.items()} - else: + +def extract_unit_attributes_dataset(obj, attr="units"): + return {name: var.attrs.get(attr, None) for name, var in obj.variables.items()} + + +def extract_unit_attributes(obj, attr="units"): + if not isinstance(obj, (DataArray, Dataset)): raise ValueError( f"cannot retrieve unit attributes from unknown type: {type(obj)}" ) + units = call_on_dataset( + extract_unit_attributes_dataset, obj, name=temporary_name, attr=attr + ) + if temporary_name in units: + units[obj.name] = units.pop(temporary_name) + return units @@ -281,51 +286,34 @@ def strip_units_variable(var): return var.copy(data=data) -def strip_units(obj): - if isinstance(obj, DataArray): - original_name = obj.name - name = obj.name if obj.name is not None else "" - ds = obj.rename(name).to_dataset() - stripped = strip_units(ds) +def strip_units_dataset(obj): + variables = {name: strip_units_variable(var) for name, var in obj.variables.items()} - new_obj = stripped[name].rename(original_name) - elif isinstance(obj, Dataset): - data_vars = { - name: strip_units_variable(variable) - for name, variable in obj.variables.items() - if name not in obj._coord_names - } - coords = { - name: strip_units_variable(variable) - for name, variable in obj.variables.items() - if name in obj._coord_names - } - - new_obj = Dataset(data_vars=data_vars, coords=coords, attrs=obj.attrs) - else: + return dataset_from_variables(variables, obj._coord_names, obj.attrs) + + +def strip_units(obj): + if not isinstance(obj, (DataArray, Dataset)): raise ValueError("cannot strip units from {obj!r}: unknown type") - return new_obj + return call_on_dataset(strip_units_dataset, obj, name=temporary_name) -def strip_unit_attributes(obj, attr="units"): - if isinstance(obj, DataArray): - original_name = obj.name - name = obj.name if obj.name is not None else "" +def strip_unit_attributes_dataset(obj, attr="units"): + new_obj = obj.copy() + for var in new_obj.variables.values(): + var.attrs.pop(attr, None) - ds = obj.rename(name).to_dataset() + return new_obj - stripped = strip_unit_attributes(ds) - new_obj = stripped[name].rename(original_name) - elif isinstance(obj, Dataset): - new_obj = obj.copy() - for var in new_obj.variables.values(): - var.attrs.pop(attr, None) - else: +def strip_unit_attributes(obj, attr="units"): + if not isinstance(obj, (DataArray, Dataset)): raise ValueError(f"cannot strip unit attributes from unknown type: {type(obj)}") - return new_obj + return call_on_dataset( + strip_unit_attributes_dataset, obj, name=temporary_name, attr=attr + ) def slice_extract_units(indexer): diff --git a/pint_xarray/tests/test_conversion.py b/pint_xarray/tests/test_conversion.py index 0919f12a..6386b60c 100644 --- a/pint_xarray/tests/test_conversion.py +++ b/pint_xarray/tests/test_conversion.py @@ -208,6 +208,7 @@ class TestXarrayFunctions: @pytest.mark.parametrize( "units", ( + pytest.param({}, id="empty units"), pytest.param({"a": None, "b": None, "u": None, "x": None}, id="no units"), pytest.param( {"a": unit_registry.m, "b": unit_registry.m, "u": None, "x": None},