diff --git a/CHANGELOG.rst b/CHANGELOG.rst index b4da0f3a3..099d5a294 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -46,6 +46,12 @@ changes - improve dimension handling of `SpatialData` `[#622] `__ +- The `MathematicalExpression` now supports `xarray.DataArray` as + parameters. Furthermore, multidimensional parameters of a + `MathematicalExpression` that is passed to a `TimeSeries` are + no longer required to have an extra dimension that represents time. + `[#621] `__ + fixes ===== diff --git a/tutorials/experiment_design_01.ipynb b/tutorials/experiment_design_01.ipynb index 274208498..7b10ef291 100644 --- a/tutorials/experiment_design_01.ipynb +++ b/tutorials/experiment_design_01.ipynb @@ -234,11 +234,11 @@ "metadata": {}, "outputs": [], "source": [ - "sine_y = sine(f=Q_(1, \"Hz\"), amp=Q_([[0, 1, 0]], \"mm\"))\n", - "csm.add_cs(\n", + "sine_y = sine(f=Q_(1, \"Hz\"), amp=Q_([0, 1, 0], \"mm\"))\n", + "csm.create_cs(\n", " coordinate_system_name=\"tcp_sine_y\",\n", " reference_system_name=\"tcp_wire\",\n", - " lcs=LCS(coordinates=sine_y),\n", + " coordinates=sine_y,\n", ")" ] }, @@ -255,11 +255,11 @@ "metadata": {}, "outputs": [], "source": [ - "sine_z = sine(f=Q_(1, \"Hz\"), amp=Q_([[0, 0, 2]], \"mm\"), bias=Q_([0, 0, 0], \"mm\"))\n", - "csm.add_cs(\n", + "sine_z = sine(f=Q_(1, \"Hz\"), amp=Q_([0, 0, 2], \"mm\"), bias=Q_([0, 0, 0], \"mm\"))\n", + "csm.create_cs(\n", " coordinate_system_name=\"tcp_sine_z\",\n", " reference_system_name=\"tcp_wire\",\n", - " lcs=LCS(coordinates=sine_z),\n", + " coordinates=sine_z,\n", ")" ] }, diff --git a/tutorials/timeseries_01.ipynb b/tutorials/timeseries_01.ipynb index 1556e605b..fbe52eb88 100644 --- a/tutorials/timeseries_01.ipynb +++ b/tutorials/timeseries_01.ipynb @@ -338,8 +338,8 @@ "source": [ "expr_string = \"a*sin(o*t)+b*t\"\n", "parameters = {\n", - " \"a\": Q_(np.asarray([[1, 0, 0]]), \"m\"),\n", - " \"b\": Q_([[0, 1, 0]], \"m/s\"),\n", + " \"a\": Q_(np.asarray([1, 0, 0]), \"m\"),\n", + " \"b\": Q_([0, 1, 0], \"m/s\"),\n", " \"o\": Q_(\"36 deg/s\"),\n", "}\n", "expr = MathematicalExpression(expression=expr_string, parameters=parameters)\n", @@ -383,7 +383,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.8" + "version": "3.8.11" } }, "nbformat": 4, diff --git a/tutorials/welding_example_02_weaving.ipynb b/tutorials/welding_example_02_weaving.ipynb index d7b0b4626..224d08e48 100644 --- a/tutorials/welding_example_02_weaving.ipynb +++ b/tutorials/welding_example_02_weaving.ipynb @@ -252,7 +252,7 @@ "metadata": {}, "outputs": [], "source": [ - "ts_sine = sine(f=Q_(0.5 * 2 * np.pi, \"Hz\"), amp=Q_([[0, 0.75, 0]], \"mm\"))" + "ts_sine = sine(f=Q_(0.5 * 2 * np.pi, \"Hz\"), amp=Q_([0, 0.75, 0], \"mm\"))" ] }, { @@ -488,7 +488,7 @@ "metadata": {}, "outputs": [], "source": [ - "ts_sine = sine(f=Q_(1 / 8 * 2 * np.pi, \"Hz\"), amp=Q_([[0, 0, 1]], \"mm\"))" + "ts_sine = sine(f=Q_(1 / 8 * 2 * np.pi, \"Hz\"), amp=Q_([0, 0, 1], \"mm\"))" ] }, { @@ -622,7 +622,7 @@ "kernelspec": { "display_name": "weldx", "language": "python", - "name": "weldx-dev" + "name": "weldx" }, "language_info": { "codemirror_mode": { @@ -634,7 +634,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.2" + "version": "3.8.11" } }, "nbformat": 4, diff --git a/weldx/core.py b/weldx/core.py index c05a7158d..c15033499 100644 --- a/weldx/core.py +++ b/weldx/core.py @@ -20,13 +20,16 @@ __all__ = ["MathematicalExpression", "TimeSeries"] +_me_parameter_types = Union[pint.Quantity, str, Tuple[pint.Quantity, str], xr.DataArray] + + class MathematicalExpression: """Mathematical expression using sympy syntax.""" def __init__( self, expression: Union[sympy.Expr, str], - parameters: Dict[str, Union[pint.Quantity, str]] = None, + parameters: _me_parameter_types = None, ): """Construct a MathematicalExpression. @@ -44,23 +47,14 @@ def __init__( if not isinstance(expression, sympy.Expr): expression = sympy.sympify(expression) self._expression = expression + self.function = sympy.lambdify( tuple(self._expression.free_symbols), self._expression, "numpy" ) - self._parameters = {} + + self._parameters: Union[pint.Quantity, xr.DataArray] = {} if parameters is not None: - if not isinstance(parameters, dict): - raise ValueError( - f'"parameters" must be dictionary, got {type(parameters)}' - ) - parameters = {k: Q_(v) for k, v in parameters.items()} - variable_names = self.get_variable_names() - for key in parameters: - if key not in variable_names: - raise ValueError( - f'The expression does not have a parameter "{key}"' - ) - self._parameters = parameters + self.set_parameters(parameters) def __repr__(self): """Give __repr__ output.""" @@ -155,14 +149,30 @@ def set_parameter(self, name, value): Parameter value. This can be number, array or pint.Quantity """ - if not isinstance(name, str): - raise TypeError(f'Parameter "name" must be a string, got {type(name)}') - if name not in str(self._expression.free_symbols): - raise ValueError( - f'The expression "{self._expression}" does not have a ' - f'parameter with name "{name}".' - ) - self._parameters[name] = value + self.set_parameters({name: value}) + + def set_parameters(self, params: _me_parameter_types): + """Set the expressions parameters. + + Parameters + ---------- + params: + Dictionary that contains the values for the specified parameters. + + """ + if not isinstance(params, dict): + raise ValueError(f'"parameters" must be dictionary, got {type(params)}') + + variable_names = [str(v) for v in self._expression.free_symbols] + + for k, v in params.items(): + if k not in variable_names: + raise ValueError(f'The expression does not have a parameter "{k}"') + if isinstance(v, tuple): + v = xr.DataArray(v[0], dims=v[1]) + if not isinstance(v, xr.DataArray): + v = Q_(v) + self._parameters[k] = v @property def num_parameters(self): @@ -246,8 +256,18 @@ def evaluate(self, **kwargs) -> Any: raise ValueError( f"The variables {intersection} are already defined as parameters." ) - inputs = {**kwargs, **self._parameters} - return self.function(**inputs) + + variables = { + k: v if isinstance(v, xr.DataArray) else xr.DataArray(Q_(v)) + for k, v in kwargs.items() + } + + parameters = { + k: v if isinstance(v, xr.DataArray) else xr.DataArray(v) + for k, v in self._parameters.items() + } + + return self.function(**variables, **parameters) # TimeSeries --------------------------------------------------------------------------- @@ -368,6 +388,8 @@ def _check_data_array(data_array: xr.DataArray): def _create_data_array( data: Union[pint.Quantity, xr.DataArray], time: Time ) -> xr.DataArray: + if isinstance(data, xr.DataArray): + return data return ( xr.DataArray(data=data) .rename({"dim_0": "time"}) @@ -416,7 +438,7 @@ def _init_expression(self, data): # check that the expression can be evaluated with a time quantity time_var_name = data.get_variable_names()[0] try: - eval_data = data.evaluate(**{time_var_name: Q_(1, "second")}) + eval_data = data.evaluate(**{time_var_name: Q_(1, "second")}).data self._units = eval_data.units if np.iterable(eval_data): self._shape = eval_data.shape @@ -450,7 +472,7 @@ def _interp_time_discrete(self, time: Time) -> xr.DataArray: """Interpolate the time series if its data is composed of discrete values.""" return ut.xr_interp_like( self._data, - {"time": time.as_timedelta()}, + {"time": time.as_data_array()}, method=self.interpolation, assume_sorted=False, broadcast_missing=False, @@ -459,20 +481,14 @@ def _interp_time_discrete(self, time: Time) -> xr.DataArray: def _interp_time_expression(self, time: Time, time_unit: str) -> xr.DataArray: """Interpolate the time series if its data is a mathematical expression.""" time_q = time.as_quantity(unit=time_unit) + if len(time_q.shape) == 0: + time_q = np.expand_dims(time_q, 0) - if len(self.shape) > 1 and np.iterable(time_q): - while len(time_q.shape) < len(self.shape): - time_q = time_q[:, np.newaxis] + time_xr = xr.DataArray(time_q, dims=["time"]) # evaluate expression - data = self._data.evaluate(**{self._time_var_name: time_q}) - data = data.astype(float).to_reduced_units() # float conversion before reduce! - - # create data array - if not np.iterable(data): # make sure quantity is not scalar value - data = np.expand_dims(data, 0) - - return self._create_data_array(data, time) + data = self._data.evaluate(**{self._time_var_name: time_xr}) + return data.assign_coords({"time": time.as_data_array()}) @property def data(self) -> Union[pint.Quantity, MathematicalExpression]: @@ -602,10 +618,11 @@ def interp_time( if isinstance(self._data, xr.DataArray): dax = self._interp_time_discrete(time_interp) + ts = TimeSeries(data=dax.data, time=time, interpolation=self.interpolation) else: dax = self._interp_time_expression(time_interp, time_unit) + ts = TimeSeries(data=dax, interpolation=self.interpolation) - ts = TimeSeries(data=dax.data, time=time, interpolation=self.interpolation) ts._interp_counter = self._interp_counter + 1 return ts diff --git a/weldx/measurement.py b/weldx/measurement.py index 7d8c3058e..0b5092689 100644 --- a/weldx/measurement.py +++ b/weldx/measurement.py @@ -546,7 +546,7 @@ def _determine_output_signal_unit( "The provided function is incompatible with the input signals unit." f" \nThe test raised the following exception:\n{e}" ) - return test_output.units + return test_output.data.units return input_unit diff --git a/weldx/tags/core/mathematical_expression.py b/weldx/tags/core/mathematical_expression.py index 7da76535a..83f390267 100644 --- a/weldx/tags/core/mathematical_expression.py +++ b/weldx/tags/core/mathematical_expression.py @@ -1,4 +1,7 @@ +import warnings + import sympy +from xarray import DataArray from weldx.asdf.types import WeldxConverter from weldx.core import MathematicalExpression @@ -15,12 +18,29 @@ class MathematicalExpressionConverter(WeldxConverter): def to_yaml_tree(self, obj: MathematicalExpression, tag: str, ctx) -> dict: """Convert to python dict.""" - tree = {"expression": obj.expression.__str__(), "parameters": obj.parameters} - return tree + parameters = {} + for k, v in obj.parameters.items(): + if isinstance(v, DataArray): + if len(v.coords) > 0: + warnings.warn("Coordinates are dropped during serialization.") + dims = v.dims + v = v.data + v.wx_metadata = dict(dims=dims) + parameters[k] = v + + return {"expression": obj.expression.__str__(), "parameters": parameters} def from_yaml_tree(self, node: dict, tag: str, ctx): """Construct from tree.""" - obj = MathematicalExpression( - sympy.sympify(node["expression"]), parameters=node["parameters"] + + parameters = {} + for k, v in node["parameters"].items(): + if hasattr(v, "wx_metadata"): + dims = v.wx_metadata["dims"] + delattr(v, "wx_metadata") + v = (v, dims) + parameters[k] = v + + return MathematicalExpression( + sympy.sympify(node["expression"]), parameters=parameters ) - return obj diff --git a/weldx/tests/asdf_tests/test_asdf_core.py b/weldx/tests/asdf_tests/test_asdf_core.py index 88ae9c8a8..1af2bd523 100644 --- a/weldx/tests/asdf_tests/test_asdf_core.py +++ b/weldx/tests/asdf_tests/test_asdf_core.py @@ -753,3 +753,34 @@ def test_graph_serialization(): assert all(e in g.edges for e in g2.edges) assert all(n in g.nodes for n in g2.nodes) + + +# -------------------------------------------------------------------------------------- +# MathematicalExpression +# -------------------------------------------------------------------------------------- + + +class TestMathematicalExpression: + @staticmethod + @pytest.mark.parametrize( + "a, b", + [ + (Q_([1, 2, 3], "m"), Q_([4, 5, 6], "m")), + ( + xr.DataArray(Q_([1, 2], "m"), dims=["a"]), + xr.DataArray(Q_([3, 4], "m"), dims=["b"]), + ), + ( + Q_([1, 2], "m"), + xr.DataArray(Q_([3, 4], "m"), dims=["b"]), + ), + ], + ) + def test_parameters(a, b): + expression = "a*x + b" + parameters = dict(a=a, b=b) + + me = ME(expression, parameters) + me_2 = write_read_buffer({"me": me})["me"] + + assert me == me_2 diff --git a/weldx/tests/test_core.py b/weldx/tests/test_core.py index 93f393146..1d5584225 100644 --- a/weldx/tests/test_core.py +++ b/weldx/tests/test_core.py @@ -110,7 +110,7 @@ def test_set_parameter(self): "name, value, exception_type, test_name", [ ("k", 1, ValueError, "# parameter not in expression"), - (33, 1, TypeError, "# wrong type as name #1"), + (33, 1, ValueError, "# wrong type as name #1"), ({"a": 1}, 1, TypeError, "# wrong type as name #2"), ], ids=get_test_name, @@ -220,7 +220,7 @@ class TestTimeSeries: me_expr_str = "a*t + b" me_params = {"a": Q_(2, "m/s"), "b": Q_(-2, "m")} - me_params_vec = {"a": Q_([[2, 0, 1]], "m/s"), "b": Q_([[-2, 3, 0]], "m")} + me_params_vec = {"a": Q_([2, 0, 1], "m/s"), "b": Q_([-2, 3, 0], "m")} ts_constant = TimeSeries(value_constant) ts_disc_step = TimeSeries(values_discrete, time_discrete, "step") @@ -322,7 +322,6 @@ def test_init_data_array(data, dims, coords, exception_type): time_def = Q_([0, 1, 2, 3, 4], "s") me_too_many_vars = ME("a*t + b", {}) me_param_units = ME("a*t + b", {"a": Q_(2, "1/s"), "b": Q_(-2, "m")}) - me_time_vec = ME("a*t + b", {"a": Q_([2, 3, 4], "1/s"), "b": Q_([-2, 3, 1], "")}) @staticmethod @pytest.mark.parametrize( @@ -332,7 +331,6 @@ def test_init_data_array(data, dims, coords, exception_type): (values_def, time_def.magnitude, "step", TypeError, "# invalid time type"), (me_too_many_vars, None, None, Exception, "# too many free variables"), (me_param_units, None, None, Exception, "# incompatible parameter units"), - (me_time_vec, None, None, Exception, "# not compatible with time vectors"), ("a string", None, None, TypeError, "# wrong data type"), ], ids=get_test_name, diff --git a/weldx/tests/test_time.py b/weldx/tests/test_time.py index 8f34f83cd..03aafc191 100644 --- a/weldx/tests/test_time.py +++ b/weldx/tests/test_time.py @@ -680,7 +680,7 @@ def test_convert_util(): assert np.all(time_q == Q_(range(10), "s")) assert time_q.time_ref == ts - arr2 = time.as_data_array().weldx.time_ref_restore() + arr2 = time.as_data_array() assert arr.time.identical(arr2.time) # test_duration -------------------------------------------------------------------- diff --git a/weldx/tests/transformations/test_cs_manager.py b/weldx/tests/transformations/test_cs_manager.py index 40dce1285..e6af3c693 100644 --- a/weldx/tests/transformations/test_cs_manager.py +++ b/weldx/tests/transformations/test_cs_manager.py @@ -1438,7 +1438,7 @@ def test_get_local_coordinate_system_timeseries( The expected rotation angles around the z-axis """ - me = MathematicalExpression("a*t", {"a": Q_([[0, 1, 0]], "mm/s")}) + me = MathematicalExpression("a*t", {"a": Q_([0, 1, 0], "mm/s")}) ts = TimeSeries(me) rotation = WXRotation.from_euler("z", [0, 90], degrees=True).as_matrix() translation = [[1, 0, 0], [2, 0, 0]] @@ -1535,7 +1535,7 @@ def test_get_cs_exception_timeseries(lcs, in_lcs, exp_exception): Set to `True` if the transformation should raise """ - me = MathematicalExpression("a*t", {"a": Q_([[0, 1, 0]], "mm/s")}) + me = MathematicalExpression("a*t", {"a": Q_([0, 1, 0], "mm/s")}) ts = TimeSeries(me) translation = [[1, 0, 0], [2, 0, 0]] diff --git a/weldx/tests/transformations/test_local_cs.py b/weldx/tests/transformations/test_local_cs.py index 3dd43eb02..60026b5da 100644 --- a/weldx/tests/transformations/test_local_cs.py +++ b/weldx/tests/transformations/test_local_cs.py @@ -175,7 +175,7 @@ def test_init_expr_time_series_as_coord(time, time_ref, angles): """ coordinates = MathematicalExpression( - expression="a*t+b", parameters=dict(a=Q_([[1, 0, 0]], "1/s"), b=[1, 2, 3]) + expression="a*t+b", parameters=dict(a=Q_([1, 0, 0], "1/s"), b=[1, 2, 3]) ) ts_coord = TimeSeries(data=coordinates) @@ -607,7 +607,7 @@ def test_interp_time_timeseries_as_coords( # create expression expr = "a*t+b" - param = dict(a=Q_([[1, 0, 0]], "mm/s"), b=Q_([1, 1, 1], "mm")) + param = dict(a=Q_([1, 0, 0], "mm/s"), b=Q_([1, 1, 1], "mm")) me = MathematicalExpression(expression=expr, parameters=param) # create orientation and time of LCS diff --git a/weldx/time.py b/weldx/time.py index bfc0e0e16..1f6f209b7 100644 --- a/weldx/time.py +++ b/weldx/time.py @@ -490,9 +490,22 @@ def as_pandas_index(self) -> Union[pd.TimedeltaIndex, pd.DatetimeIndex]: return pd.TimedeltaIndex([self._time]) return self._time - def as_data_array(self) -> DataArray: - """Return the data as `xarray.DataArray`.""" - da = xr.DataArray(self._time, coords={"time": self._time}, dims=["time"]) + def as_data_array(self, timedelta_base: bool = True) -> DataArray: + """Return the time data as a `xarray.DataArray` coordinate. + + By default the format is timedelta values with reference time as attribute. + + Parameters + ---------- + timedelta_base + If true (the default) the values of the xarray will always be timedeltas. + + """ + if timedelta_base: + t = self.as_timedelta_index() + else: + t = self.index + da = xr.DataArray(t, coords={"time": t}, dims=["time"]) da.time.attrs["time_ref"] = self.reference_time return da