Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add xarray support to MathematicalExpression #621

Merged
merged 26 commits into from
Nov 3, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions tutorials/experiment_design_01.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,11 @@
"metadata": {},
"outputs": [],
"source": [
"sine_y = sine(f=Q_(1, \"Hz\"), amp=Q_([[0, 1, 0]], \"mm\"))\n",
"csm.add_cs(\n",
"sine_y = sine(f=Q_(1, \"Hz\"), amp=Q_([0, 1, 0], \"mm\"))\n",
"csm.create_cs(\n",
" coordinate_system_name=\"tcp_sine_y\",\n",
" reference_system_name=\"tcp_wire\",\n",
" lcs=LCS(coordinates=sine_y),\n",
" coordinates=sine_y,\n",
")"
]
},
Expand All @@ -255,11 +255,11 @@
"metadata": {},
"outputs": [],
"source": [
"sine_z = sine(f=Q_(1, \"Hz\"), amp=Q_([[0, 0, 2]], \"mm\"), bias=Q_([0, 0, 0], \"mm\"))\n",
"csm.add_cs(\n",
"sine_z = sine(f=Q_(1, \"Hz\"), amp=Q_([0, 0, 2], \"mm\"), bias=Q_([0, 0, 0], \"mm\"))\n",
"csm.create_cs(\n",
" coordinate_system_name=\"tcp_sine_z\",\n",
" reference_system_name=\"tcp_wire\",\n",
" lcs=LCS(coordinates=sine_z),\n",
" coordinates=sine_z,\n",
")"
]
},
Expand Down
6 changes: 3 additions & 3 deletions tutorials/timeseries_01.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,8 @@
"source": [
"expr_string = \"a*sin(o*t)+b*t\"\n",
"parameters = {\n",
" \"a\": Q_(np.asarray([[1, 0, 0]]), \"m\"),\n",
" \"b\": Q_([[0, 1, 0]], \"m/s\"),\n",
" \"a\": Q_(np.asarray([1, 0, 0]), \"m\"),\n",
" \"b\": Q_([0, 1, 0], \"m/s\"),\n",
" \"o\": Q_(\"36 deg/s\"),\n",
"}\n",
"expr = MathematicalExpression(expression=expr_string, parameters=parameters)\n",
Expand Down Expand Up @@ -383,7 +383,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
"version": "3.8.11"
}
},
"nbformat": 4,
Expand Down
8 changes: 4 additions & 4 deletions tutorials/welding_example_02_weaving.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@
"metadata": {},
"outputs": [],
"source": [
"ts_sine = sine(f=Q_(0.5 * 2 * np.pi, \"Hz\"), amp=Q_([[0, 0.75, 0]], \"mm\"))"
"ts_sine = sine(f=Q_(0.5 * 2 * np.pi, \"Hz\"), amp=Q_([0, 0.75, 0], \"mm\"))"
]
},
{
Expand Down Expand Up @@ -488,7 +488,7 @@
"metadata": {},
"outputs": [],
"source": [
"ts_sine = sine(f=Q_(1 / 8 * 2 * np.pi, \"Hz\"), amp=Q_([[0, 0, 1]], \"mm\"))"
"ts_sine = sine(f=Q_(1 / 8 * 2 * np.pi, \"Hz\"), amp=Q_([0, 0, 1], \"mm\"))"
]
},
{
Expand Down Expand Up @@ -622,7 +622,7 @@
"kernelspec": {
"display_name": "weldx",
"language": "python",
"name": "weldx-dev"
"name": "weldx"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -634,7 +634,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.2"
"version": "3.8.11"
}
},
"nbformat": 4,
Expand Down
81 changes: 46 additions & 35 deletions weldx/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,29 @@ def __init__(
self.function = sympy.lambdify(
tuple(self._expression.free_symbols), self._expression, "numpy"
)
self._parameters = {}

self._parameters: Dict[str, Union[pint.Quantity, xr.DataArray]] = {}
if parameters is not None:
if not isinstance(parameters, dict):
raise ValueError(
f'"parameters" must be dictionary, got {type(parameters)}'
)
parameters = {k: Q_(v) for k, v in parameters.items()}
variable_names = self.get_variable_names()
for key in parameters:
if key not in variable_names:
raise ValueError(
f'The expression does not have a parameter "{key}"'
)
self._parameters = parameters
self.set_parameters(parameters)

def set_parameters(self, params: Dict[str, Union[pint.Quantity, xr.DataArray]]):
"""Set the expressions parameters.

Parameters
----------
params:
Dictionary that contains the values for the specified parameters.

"""
if not isinstance(params, dict):
raise ValueError(f'"parameters" must be dictionary, got {type(params)}')
variable_names = [str(v) for v in self._expression.free_symbols]
for k, v in params.items():
if k not in variable_names:
raise ValueError(f'The expression does not have a parameter "{k}"')
if not isinstance(v, xr.DataArray):
v = Q_(v)
self._parameters[k] = v

def __repr__(self):
"""Give __repr__ output."""
Expand Down Expand Up @@ -155,14 +164,7 @@ def set_parameter(self, name, value):
Parameter value. This can be number, array or pint.Quantity

"""
if not isinstance(name, str):
raise TypeError(f'Parameter "name" must be a string, got {type(name)}')
if name not in str(self._expression.free_symbols):
raise ValueError(
f'The expression "{self._expression}" does not have a '
f'parameter with name "{name}".'
)
self._parameters[name] = value
self.set_parameters({name: value})

@property
def num_parameters(self):
Expand Down Expand Up @@ -246,7 +248,19 @@ def evaluate(self, **kwargs) -> Any:
raise ValueError(
f"The variables {intersection} are already defined as parameters."
)
inputs = {**kwargs, **self._parameters}
variables = {}
for k, v in kwargs.items():
if not isinstance(v, xr.DataArray):
v = xr.DataArray(Q_(v))
variables[k] = v

parameters = {}
for k, v in self._parameters.items():
if not isinstance(v, xr.DataArray):
v = xr.DataArray(v)
parameters[k] = v

inputs = {**variables, **parameters}
return self.function(**inputs)


Expand Down Expand Up @@ -368,6 +382,8 @@ def _check_data_array(data_array: xr.DataArray):
def _create_data_array(
data: Union[pint.Quantity, xr.DataArray], time: Time
) -> xr.DataArray:
if isinstance(data, xr.DataArray):
return data
return (
xr.DataArray(data=data)
.rename({"dim_0": "time"})
Expand Down Expand Up @@ -416,7 +432,7 @@ def _init_expression(self, data):
# check that the expression can be evaluated with a time quantity
time_var_name = data.get_variable_names()[0]
try:
eval_data = data.evaluate(**{time_var_name: Q_(1, "second")})
eval_data = data.evaluate(**{time_var_name: Q_(1, "second")}).data
self._units = eval_data.units
if np.iterable(eval_data):
self._shape = eval_data.shape
Expand Down Expand Up @@ -459,20 +475,14 @@ def _interp_time_discrete(self, time: Time) -> xr.DataArray:
def _interp_time_expression(self, time: Time, time_unit: str) -> xr.DataArray:
"""Interpolate the time series if its data is a mathematical expression."""
time_q = time.as_quantity(unit=time_unit)
if len(time_q.shape) == 0:
time_q = np.expand_dims(time_q, 0)

if len(self.shape) > 1 and np.iterable(time_q):
while len(time_q.shape) < len(self.shape):
time_q = time_q[:, np.newaxis]
time_xr = xr.DataArray(time_q, dims=["time"])

# evaluate expression
data = self._data.evaluate(**{self._time_var_name: time_q})
data = data.astype(float).to_reduced_units() # float conversion before reduce!

# create data array
if not np.iterable(data): # make sure quantity is not scalar value
data = np.expand_dims(data, 0)

return self._create_data_array(data, time)
data = self._data.evaluate(**{self._time_var_name: time_xr})
return data.assign_coords({"time": time.as_timedelta_index()})
CagtayFabry marked this conversation as resolved.
Show resolved Hide resolved

@property
def data(self) -> Union[pint.Quantity, MathematicalExpression]:
Expand Down Expand Up @@ -602,10 +612,11 @@ def interp_time(

if isinstance(self._data, xr.DataArray):
dax = self._interp_time_discrete(time_interp)
ts = TimeSeries(data=dax.data, time=time, interpolation=self.interpolation)
else:
dax = self._interp_time_expression(time_interp, time_unit)
ts = TimeSeries(data=dax, interpolation=self.interpolation)

ts = TimeSeries(data=dax.data, time=time, interpolation=self.interpolation)
ts._interp_counter = self._interp_counter + 1
return ts

Expand Down
2 changes: 1 addition & 1 deletion weldx/measurement.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def _determine_output_signal_unit(
"The provided function is incompatible with the input signals unit."
f" \nThe test raised the following exception:\n{e}"
)
return test_output.units
return test_output.data.units

return input_unit

Expand Down
6 changes: 2 additions & 4 deletions weldx/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_set_parameter(self):
"name, value, exception_type, test_name",
[
("k", 1, ValueError, "# parameter not in expression"),
(33, 1, TypeError, "# wrong type as name #1"),
(33, 1, ValueError, "# wrong type as name #1"),
({"a": 1}, 1, TypeError, "# wrong type as name #2"),
],
ids=get_test_name,
Expand Down Expand Up @@ -220,7 +220,7 @@ class TestTimeSeries:
me_expr_str = "a*t + b"
me_params = {"a": Q_(2, "m/s"), "b": Q_(-2, "m")}

me_params_vec = {"a": Q_([[2, 0, 1]], "m/s"), "b": Q_([[-2, 3, 0]], "m")}
me_params_vec = {"a": Q_([2, 0, 1], "m/s"), "b": Q_([-2, 3, 0], "m")}
CagtayFabry marked this conversation as resolved.
Show resolved Hide resolved

ts_constant = TimeSeries(value_constant)
ts_disc_step = TimeSeries(values_discrete, time_discrete, "step")
Expand Down Expand Up @@ -322,7 +322,6 @@ def test_init_data_array(data, dims, coords, exception_type):
time_def = Q_([0, 1, 2, 3, 4], "s")
me_too_many_vars = ME("a*t + b", {})
me_param_units = ME("a*t + b", {"a": Q_(2, "1/s"), "b": Q_(-2, "m")})
me_time_vec = ME("a*t + b", {"a": Q_([2, 3, 4], "1/s"), "b": Q_([-2, 3, 1], "")})

@staticmethod
@pytest.mark.parametrize(
Expand All @@ -332,7 +331,6 @@ def test_init_data_array(data, dims, coords, exception_type):
(values_def, time_def.magnitude, "step", TypeError, "# invalid time type"),
(me_too_many_vars, None, None, Exception, "# too many free variables"),
(me_param_units, None, None, Exception, "# incompatible parameter units"),
(me_time_vec, None, None, Exception, "# not compatible with time vectors"),
("a string", None, None, TypeError, "# wrong data type"),
],
ids=get_test_name,
Expand Down
4 changes: 2 additions & 2 deletions weldx/tests/transformations/test_cs_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,7 +1438,7 @@ def test_get_local_coordinate_system_timeseries(
The expected rotation angles around the z-axis

"""
me = MathematicalExpression("a*t", {"a": Q_([[0, 1, 0]], "mm/s")})
me = MathematicalExpression("a*t", {"a": Q_([0, 1, 0], "mm/s")})
ts = TimeSeries(me)
rotation = WXRotation.from_euler("z", [0, 90], degrees=True).as_matrix()
translation = [[1, 0, 0], [2, 0, 0]]
Expand Down Expand Up @@ -1535,7 +1535,7 @@ def test_get_cs_exception_timeseries(lcs, in_lcs, exp_exception):
Set to `True` if the transformation should raise

"""
me = MathematicalExpression("a*t", {"a": Q_([[0, 1, 0]], "mm/s")})
me = MathematicalExpression("a*t", {"a": Q_([0, 1, 0], "mm/s")})
ts = TimeSeries(me)
translation = [[1, 0, 0], [2, 0, 0]]

Expand Down
4 changes: 2 additions & 2 deletions weldx/tests/transformations/test_local_cs.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def test_init_expr_time_series_as_coord(time, time_ref, angles):

"""
coordinates = MathematicalExpression(
expression="a*t+b", parameters=dict(a=Q_([[1, 0, 0]], "1/s"), b=[1, 2, 3])
expression="a*t+b", parameters=dict(a=Q_([1, 0, 0], "1/s"), b=[1, 2, 3])
)

ts_coord = TimeSeries(data=coordinates)
Expand Down Expand Up @@ -607,7 +607,7 @@ def test_interp_time_timeseries_as_coords(

# create expression
expr = "a*t+b"
param = dict(a=Q_([[1, 0, 0]], "mm/s"), b=Q_([1, 1, 1], "mm"))
param = dict(a=Q_([1, 0, 0], "mm/s"), b=Q_([1, 1, 1], "mm"))
me = MathematicalExpression(expression=expr, parameters=param)

# create orientation and time of LCS
Expand Down