Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

custom wx_unit validation behavior #739

Merged
merged 6 commits into from
Apr 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ ASDF

- update the default sorting order of ``select_tag`` for ``WeldxConverter`` [:pull:`733`]

- add custom validation behavior to ``wx_unit`` [:pull:`739`]

deprecations
============

Expand Down
31 changes: 31 additions & 0 deletions weldx/asdf/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from warnings import warn

import asdf
import pint
from asdf.asdf import SerializationContext
from asdf.config import AsdfConfig, get_config
from asdf.extension import Extension
Expand All @@ -18,6 +19,7 @@

from weldx.asdf.constants import SCHEMA_PATH, WELDX_EXTENSION_URI
from weldx.asdf.types import WeldxConverter
from weldx.constants import U_, UNITS_KEY, WELDX_UNIT_REGISTRY
from weldx.types import (
SupportsFileReadWrite,
types_file_like,
Expand Down Expand Up @@ -544,6 +546,35 @@ def _get_instance_shape(
return None


def _get_instance_units(
instance_dict: Union[TaggedDict, dict[str, Any]]
) -> Union[pint.Unit, None]:
"""Get the units of an ASDF instance from its tagged dict form.

Parameters
----------
instance_dict
The yaml node to evaluate.

Returns
-------
pint.Unit
The pint unit or `None` if no unit information is present.
"""
if isinstance(instance_dict, (float, int)): # base types
return WELDX_UNIT_REGISTRY.dimensionless
elif isinstance(instance_dict, Mapping) and UNITS_KEY in instance_dict:
return U_(str(instance_dict[UNITS_KEY])) # catch TaggedString as str
elif isinstance(instance_dict, asdf.types.tagged.Tagged):
# try calling units_from_tagged for custom types
if instance_dict._tag.startswith("tag:stsci.edu:asdf/core/ndarray"):
return WELDX_UNIT_REGISTRY.dimensionless
converter = get_converter_for_tag(instance_dict._tag)
if hasattr(converter, "units_from_tagged"):
return converter.units_from_tagged(instance_dict)
return None


class _ProtectedViewDict(MutableMapping):
def __init__(self, protected_keys, data=None):
super(_ProtectedViewDict, self).__init__()
Expand Down
18 changes: 12 additions & 6 deletions weldx/asdf/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from asdf.schema import _type_to_tag

from weldx.asdf.types import WxSyntaxError
from weldx.asdf.util import _get_instance_shape, uri_match
from weldx.asdf.util import _get_instance_shape, _get_instance_units, uri_match
from weldx.constants import U_

__all__ = ["wx_unit_validator", "wx_shape_validator", "wx_property_tag_validator"]
Expand Down Expand Up @@ -95,15 +95,21 @@ def _unit_validator(
if not position:
position = instance

unit = instance["units"]
unit = str(unit) # catch TaggedString
valid = U_(unit).is_compatible_with(U_(expected_dimensionality))
if not valid:
units = _get_instance_units(instance)
if units is None:
yield ValidationError(
f"Error validating unit dimension for property '{position}'. "
f"Expected unit of dimension '{expected_dimensionality}' "
f"but got unit '{unit}'"
"but found no unit information"
)
else:
valid = units.is_compatible_with(U_(expected_dimensionality))
if not valid:
yield ValidationError(
f"Error validating unit dimension for property '{position}'. "
f"Expected unit of dimension '{expected_dimensionality}' "
f"but got unit '{units}'"
)


def _compare(_int, exp_string):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ properties:
tag: "asdf://weldx.bam.de/weldx/tags/units/quantity-0.1.*"
wx_unit: "Δ°C"

dimensionless:
description: a basic numeric type without any units or dimensionless quantity
wx_unit: " "

required: [length_prop, velocity_prop, current_prop, delta_prop]
propertyOrder: [length_prop, velocity_prop, current_prop]
flowStyle: block
Expand Down
8 changes: 7 additions & 1 deletion weldx/tags/core/data_array.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Serialization for xarray.DataArray."""
from __future__ import annotations

import pint
from asdf.tagged import TaggedDict
from xarray import DataArray

import weldx.tags.core.common_types as ct
from weldx.asdf.types import WeldxConverter
from weldx.asdf.util import _get_instance_shape
from weldx.asdf.util import _get_instance_shape, _get_instance_units


class XarrayDataArrayConverter(WeldxConverter):
Expand Down Expand Up @@ -47,3 +48,8 @@ def from_yaml_tree(self, node: dict, tag: str, ctx):
def shape_from_tagged(node: TaggedDict) -> list[int]:
"""Calculate the shape from static tagged tree instance."""
return _get_instance_shape(node["data"]["data"])

@staticmethod
def units_from_tagged(node: TaggedDict) -> pint.Unit:
"""Get the units information from static tagged tree instance."""
return _get_instance_units(node["data"])
3 changes: 3 additions & 0 deletions weldx/tags/debug/test_unit_validator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from dataclasses import dataclass, field
from typing import Union

import numpy as np
import pint
import xarray as xr

from weldx.asdf.util import dataclass_serialization_class
from weldx.constants import Q_
Expand All @@ -21,6 +23,7 @@ class UnitValidatorTestClass:
)
simple_prop: dict = field(default_factory=lambda: dict(value=float(3), units="m"))
delta_prop: dict = Q_(100, "Δ°C")
dimensionless: Union[float, int, np.ndarray, pint.Quantity, xr.DataArray] = 3.14


UnitValidatorTestClassConverter = dataclass_serialization_class(
Expand Down
10 changes: 10 additions & 0 deletions weldx/tests/asdf_tests/test_asdf_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import pandas as pd
import pytest
import xarray as xr
from asdf import ValidationError

from weldx.asdf.types import WxSyntaxError
Expand Down Expand Up @@ -179,6 +180,9 @@ def test_shape_validator_exceptions(test_input):
[
UnitValidatorTestClass(),
UnitValidatorTestClass(length_prop=Q_(1, "inch")),
UnitValidatorTestClass(dimensionless=Q_(1, " ")),
UnitValidatorTestClass(dimensionless=np.ones(3)),
UnitValidatorTestClass(dimensionless=xr.DataArray(data=Q_(np.ones(3)))),
],
)
def test_unit_validator(test):
Expand Down Expand Up @@ -211,6 +215,12 @@ def test_unit_validator(test):
UnitValidatorTestClass(
simple_prop={"value": float(3), "units": "s"}, # wrong unit
),
UnitValidatorTestClass(
dimensionless=Q_(1, "s"), # wrong unit
),
UnitValidatorTestClass(
dimensionless=xr.DataArray(data=np.ones(3)), # xarray must be quantity
),
],
)
def test_unit_validator_exception(test):
Expand Down