Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add xarray support to MathematicalExpression #621

Merged
merged 26 commits into from
Nov 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ changes
- improve dimension handling of `SpatialData` `[#622]
<https://github.com/BAMWelDX/weldx/pull/622>`__

- The `MathematicalExpression` now supports `xarray.DataArray` as
parameters. Furthermore, multidimensional parameters of a
`MathematicalExpression` that is passed to a `TimeSeries` are
no longer required to have an extra dimension that represents time.
`[#621] <https://github.com/BAMWelDX/weldx/pull/621>`__

fixes
=====

Expand Down
12 changes: 6 additions & 6 deletions tutorials/experiment_design_01.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,11 @@
"metadata": {},
"outputs": [],
"source": [
"sine_y = sine(f=Q_(1, \"Hz\"), amp=Q_([[0, 1, 0]], \"mm\"))\n",
"csm.add_cs(\n",
"sine_y = sine(f=Q_(1, \"Hz\"), amp=Q_([0, 1, 0], \"mm\"))\n",
"csm.create_cs(\n",
" coordinate_system_name=\"tcp_sine_y\",\n",
" reference_system_name=\"tcp_wire\",\n",
" lcs=LCS(coordinates=sine_y),\n",
" coordinates=sine_y,\n",
")"
]
},
Expand All @@ -255,11 +255,11 @@
"metadata": {},
"outputs": [],
"source": [
"sine_z = sine(f=Q_(1, \"Hz\"), amp=Q_([[0, 0, 2]], \"mm\"), bias=Q_([0, 0, 0], \"mm\"))\n",
"csm.add_cs(\n",
"sine_z = sine(f=Q_(1, \"Hz\"), amp=Q_([0, 0, 2], \"mm\"), bias=Q_([0, 0, 0], \"mm\"))\n",
"csm.create_cs(\n",
" coordinate_system_name=\"tcp_sine_z\",\n",
" reference_system_name=\"tcp_wire\",\n",
" lcs=LCS(coordinates=sine_z),\n",
" coordinates=sine_z,\n",
")"
]
},
Expand Down
6 changes: 3 additions & 3 deletions tutorials/timeseries_01.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,8 @@
"source": [
"expr_string = \"a*sin(o*t)+b*t\"\n",
"parameters = {\n",
" \"a\": Q_(np.asarray([[1, 0, 0]]), \"m\"),\n",
" \"b\": Q_([[0, 1, 0]], \"m/s\"),\n",
" \"a\": Q_(np.asarray([1, 0, 0]), \"m\"),\n",
" \"b\": Q_([0, 1, 0], \"m/s\"),\n",
" \"o\": Q_(\"36 deg/s\"),\n",
"}\n",
"expr = MathematicalExpression(expression=expr_string, parameters=parameters)\n",
Expand Down Expand Up @@ -383,7 +383,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
"version": "3.8.11"
}
},
"nbformat": 4,
Expand Down
8 changes: 4 additions & 4 deletions tutorials/welding_example_02_weaving.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@
"metadata": {},
"outputs": [],
"source": [
"ts_sine = sine(f=Q_(0.5 * 2 * np.pi, \"Hz\"), amp=Q_([[0, 0.75, 0]], \"mm\"))"
"ts_sine = sine(f=Q_(0.5 * 2 * np.pi, \"Hz\"), amp=Q_([0, 0.75, 0], \"mm\"))"
]
},
{
Expand Down Expand Up @@ -488,7 +488,7 @@
"metadata": {},
"outputs": [],
"source": [
"ts_sine = sine(f=Q_(1 / 8 * 2 * np.pi, \"Hz\"), amp=Q_([[0, 0, 1]], \"mm\"))"
"ts_sine = sine(f=Q_(1 / 8 * 2 * np.pi, \"Hz\"), amp=Q_([0, 0, 1], \"mm\"))"
]
},
{
Expand Down Expand Up @@ -622,7 +622,7 @@
"kernelspec": {
"display_name": "weldx",
"language": "python",
"name": "weldx-dev"
"name": "weldx"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -634,7 +634,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.2"
"version": "3.8.11"
}
},
"nbformat": 4,
Expand Down
93 changes: 55 additions & 38 deletions weldx/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@
__all__ = ["MathematicalExpression", "TimeSeries"]


_me_parameter_types = Union[pint.Quantity, str, Tuple[pint.Quantity, str], xr.DataArray]


class MathematicalExpression:
"""Mathematical expression using sympy syntax."""

def __init__(
self,
expression: Union[sympy.Expr, str],
parameters: Dict[str, Union[pint.Quantity, str]] = None,
parameters: _me_parameter_types = None,
):
"""Construct a MathematicalExpression.

Expand All @@ -44,23 +47,14 @@ def __init__(
if not isinstance(expression, sympy.Expr):
expression = sympy.sympify(expression)
self._expression = expression

self.function = sympy.lambdify(
tuple(self._expression.free_symbols), self._expression, "numpy"
)
self._parameters = {}

self._parameters: Union[pint.Quantity, xr.DataArray] = {}
if parameters is not None:
if not isinstance(parameters, dict):
raise ValueError(
f'"parameters" must be dictionary, got {type(parameters)}'
)
parameters = {k: Q_(v) for k, v in parameters.items()}
variable_names = self.get_variable_names()
for key in parameters:
if key not in variable_names:
raise ValueError(
f'The expression does not have a parameter "{key}"'
)
self._parameters = parameters
self.set_parameters(parameters)

def __repr__(self):
"""Give __repr__ output."""
Expand Down Expand Up @@ -155,14 +149,30 @@ def set_parameter(self, name, value):
Parameter value. This can be number, array or pint.Quantity

"""
if not isinstance(name, str):
raise TypeError(f'Parameter "name" must be a string, got {type(name)}')
if name not in str(self._expression.free_symbols):
raise ValueError(
f'The expression "{self._expression}" does not have a '
f'parameter with name "{name}".'
)
self._parameters[name] = value
self.set_parameters({name: value})

def set_parameters(self, params: _me_parameter_types):
"""Set the expressions parameters.

Parameters
----------
params:
Dictionary that contains the values for the specified parameters.

"""
if not isinstance(params, dict):
raise ValueError(f'"parameters" must be dictionary, got {type(params)}')

variable_names = [str(v) for v in self._expression.free_symbols]

for k, v in params.items():
if k not in variable_names:
raise ValueError(f'The expression does not have a parameter "{k}"')
if isinstance(v, tuple):
v = xr.DataArray(v[0], dims=v[1])
if not isinstance(v, xr.DataArray):
v = Q_(v)
self._parameters[k] = v

@property
def num_parameters(self):
Expand Down Expand Up @@ -246,8 +256,18 @@ def evaluate(self, **kwargs) -> Any:
raise ValueError(
f"The variables {intersection} are already defined as parameters."
)
inputs = {**kwargs, **self._parameters}
return self.function(**inputs)

variables = {
k: v if isinstance(v, xr.DataArray) else xr.DataArray(Q_(v))
for k, v in kwargs.items()
}

parameters = {
k: v if isinstance(v, xr.DataArray) else xr.DataArray(v)
for k, v in self._parameters.items()
}

return self.function(**variables, **parameters)


# TimeSeries ---------------------------------------------------------------------------
Expand Down Expand Up @@ -368,6 +388,8 @@ def _check_data_array(data_array: xr.DataArray):
def _create_data_array(
data: Union[pint.Quantity, xr.DataArray], time: Time
) -> xr.DataArray:
if isinstance(data, xr.DataArray):
return data
return (
xr.DataArray(data=data)
.rename({"dim_0": "time"})
Expand Down Expand Up @@ -416,7 +438,7 @@ def _init_expression(self, data):
# check that the expression can be evaluated with a time quantity
time_var_name = data.get_variable_names()[0]
try:
eval_data = data.evaluate(**{time_var_name: Q_(1, "second")})
eval_data = data.evaluate(**{time_var_name: Q_(1, "second")}).data
self._units = eval_data.units
if np.iterable(eval_data):
self._shape = eval_data.shape
Expand Down Expand Up @@ -450,7 +472,7 @@ def _interp_time_discrete(self, time: Time) -> xr.DataArray:
"""Interpolate the time series if its data is composed of discrete values."""
return ut.xr_interp_like(
self._data,
{"time": time.as_timedelta()},
{"time": time.as_data_array()},
method=self.interpolation,
assume_sorted=False,
broadcast_missing=False,
Expand All @@ -459,20 +481,14 @@ def _interp_time_discrete(self, time: Time) -> xr.DataArray:
def _interp_time_expression(self, time: Time, time_unit: str) -> xr.DataArray:
"""Interpolate the time series if its data is a mathematical expression."""
time_q = time.as_quantity(unit=time_unit)
if len(time_q.shape) == 0:
time_q = np.expand_dims(time_q, 0)

if len(self.shape) > 1 and np.iterable(time_q):
while len(time_q.shape) < len(self.shape):
time_q = time_q[:, np.newaxis]
time_xr = xr.DataArray(time_q, dims=["time"])

# evaluate expression
data = self._data.evaluate(**{self._time_var_name: time_q})
data = data.astype(float).to_reduced_units() # float conversion before reduce!

# create data array
if not np.iterable(data): # make sure quantity is not scalar value
data = np.expand_dims(data, 0)

return self._create_data_array(data, time)
data = self._data.evaluate(**{self._time_var_name: time_xr})
return data.assign_coords({"time": time.as_data_array()})

@property
def data(self) -> Union[pint.Quantity, MathematicalExpression]:
Expand Down Expand Up @@ -602,10 +618,11 @@ def interp_time(

if isinstance(self._data, xr.DataArray):
dax = self._interp_time_discrete(time_interp)
ts = TimeSeries(data=dax.data, time=time, interpolation=self.interpolation)
else:
dax = self._interp_time_expression(time_interp, time_unit)
ts = TimeSeries(data=dax, interpolation=self.interpolation)

ts = TimeSeries(data=dax.data, time=time, interpolation=self.interpolation)
ts._interp_counter = self._interp_counter + 1
return ts

Expand Down
2 changes: 1 addition & 1 deletion weldx/measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def _determine_output_signal_unit(
"The provided function is incompatible with the input signals unit."
f" \nThe test raised the following exception:\n{e}"
)
return test_output.units
return test_output.data.units

return input_unit

Expand Down
30 changes: 25 additions & 5 deletions weldx/tags/core/mathematical_expression.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import warnings

import sympy
from xarray import DataArray

from weldx.asdf.types import WeldxConverter
from weldx.core import MathematicalExpression
Expand All @@ -15,12 +18,29 @@ class MathematicalExpressionConverter(WeldxConverter):

def to_yaml_tree(self, obj: MathematicalExpression, tag: str, ctx) -> dict:
"""Convert to python dict."""
tree = {"expression": obj.expression.__str__(), "parameters": obj.parameters}
return tree
parameters = {}
for k, v in obj.parameters.items():
if isinstance(v, DataArray):
if len(v.coords) > 0:
warnings.warn("Coordinates are dropped during serialization.")
dims = v.dims
v = v.data
v.wx_metadata = dict(dims=dims)
parameters[k] = v

return {"expression": obj.expression.__str__(), "parameters": parameters}

def from_yaml_tree(self, node: dict, tag: str, ctx):
"""Construct from tree."""
obj = MathematicalExpression(
sympy.sympify(node["expression"]), parameters=node["parameters"]

parameters = {}
for k, v in node["parameters"].items():
if hasattr(v, "wx_metadata"):
dims = v.wx_metadata["dims"]
delattr(v, "wx_metadata")
v = (v, dims)
parameters[k] = v

return MathematicalExpression(
sympy.sympify(node["expression"]), parameters=parameters
)
return obj
31 changes: 31 additions & 0 deletions weldx/tests/asdf_tests/test_asdf_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,3 +753,34 @@ def test_graph_serialization():

assert all(e in g.edges for e in g2.edges)
assert all(n in g.nodes for n in g2.nodes)


# --------------------------------------------------------------------------------------
# MathematicalExpression
# --------------------------------------------------------------------------------------


class TestMathematicalExpression:
@staticmethod
@pytest.mark.parametrize(
"a, b",
[
(Q_([1, 2, 3], "m"), Q_([4, 5, 6], "m")),
(
xr.DataArray(Q_([1, 2], "m"), dims=["a"]),
xr.DataArray(Q_([3, 4], "m"), dims=["b"]),
),
(
Q_([1, 2], "m"),
xr.DataArray(Q_([3, 4], "m"), dims=["b"]),
),
],
)
def test_parameters(a, b):
expression = "a*x + b"
parameters = dict(a=a, b=b)

me = ME(expression, parameters)
me_2 = write_read_buffer({"me": me})["me"]

assert me == me_2
6 changes: 2 additions & 4 deletions weldx/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_set_parameter(self):
"name, value, exception_type, test_name",
[
("k", 1, ValueError, "# parameter not in expression"),
(33, 1, TypeError, "# wrong type as name #1"),
(33, 1, ValueError, "# wrong type as name #1"),
({"a": 1}, 1, TypeError, "# wrong type as name #2"),
],
ids=get_test_name,
Expand Down Expand Up @@ -220,7 +220,7 @@ class TestTimeSeries:
me_expr_str = "a*t + b"
me_params = {"a": Q_(2, "m/s"), "b": Q_(-2, "m")}

me_params_vec = {"a": Q_([[2, 0, 1]], "m/s"), "b": Q_([[-2, 3, 0]], "m")}
me_params_vec = {"a": Q_([2, 0, 1], "m/s"), "b": Q_([-2, 3, 0], "m")}
CagtayFabry marked this conversation as resolved.
Show resolved Hide resolved

ts_constant = TimeSeries(value_constant)
ts_disc_step = TimeSeries(values_discrete, time_discrete, "step")
Expand Down Expand Up @@ -322,7 +322,6 @@ def test_init_data_array(data, dims, coords, exception_type):
time_def = Q_([0, 1, 2, 3, 4], "s")
me_too_many_vars = ME("a*t + b", {})
me_param_units = ME("a*t + b", {"a": Q_(2, "1/s"), "b": Q_(-2, "m")})
me_time_vec = ME("a*t + b", {"a": Q_([2, 3, 4], "1/s"), "b": Q_([-2, 3, 1], "")})

@staticmethod
@pytest.mark.parametrize(
Expand All @@ -332,7 +331,6 @@ def test_init_data_array(data, dims, coords, exception_type):
(values_def, time_def.magnitude, "step", TypeError, "# invalid time type"),
(me_too_many_vars, None, None, Exception, "# too many free variables"),
(me_param_units, None, None, Exception, "# incompatible parameter units"),
(me_time_vec, None, None, Exception, "# not compatible with time vectors"),
("a string", None, None, TypeError, "# wrong data type"),
],
ids=get_test_name,
Expand Down
Loading