-
Notifications
You must be signed in to change notification settings - Fork 21
Refactor __add__
operation in DiffractionObject
and add tests
#285
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7db3a4f
3d4841b
5d0ebcc
9741a8e
3f577d7
d561583
ac5a2f3
12848e8
11c4166
846c72a
da70bd6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
**Added:** | ||
|
||
* unit tests for __add__ operation for DiffractionObject | ||
|
||
**Changed:** | ||
|
||
* <news item> | ||
|
||
**Deprecated:** | ||
|
||
* <news item> | ||
|
||
**Removed:** | ||
|
||
* <news item> | ||
|
||
**Fixed:** | ||
|
||
* <news item> | ||
|
||
**Security:** | ||
|
||
* <news item> |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,9 +14,15 @@ | |
XQUANTITIES = ANGLEQUANTITIES + DQUANTITIES + QQUANTITIES | ||
XUNITS = ["degrees", "radians", "rad", "deg", "inv_angs", "inv_nm", "nm-1", "A-1"] | ||
|
||
x_grid_emsg = ( | ||
"objects are not on the same x-grid. You may add them using the self.add method " | ||
"and specifying how to handle the mismatch." | ||
y_grid_length_mismatch_emsg = ( | ||
"The two objects have different y-array lengths. " | ||
"Please ensure the length of the y-value during initialization is identical." | ||
) | ||
|
||
invalid_add_type_emsg = ( | ||
"You may only add a DiffractionObject with another DiffractionObject or a scalar value. " | ||
"Please rerun by adding another DiffractionObject instance or a scalar value. " | ||
"e.g., my_do_1 + my_do_2 or my_do + 10 or 10 + my_do" | ||
) | ||
|
||
|
||
|
@@ -169,32 +175,56 @@ def __eq__(self, other): | |
return True | ||
|
||
def __add__(self, other): | ||
summed = deepcopy(self) | ||
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray): | ||
summed.on_tth[1] = self.on_tth[1] + other | ||
summed.on_q[1] = self.on_q[1] + other | ||
elif not isinstance(other, DiffractionObject): | ||
raise TypeError("I only know how to sum two DiffractionObject objects") | ||
elif self.on_tth[0].all() != other.on_tth[0].all(): | ||
raise RuntimeError(x_grid_emsg) | ||
else: | ||
summed.on_tth[1] = self.on_tth[1] + other.on_tth[1] | ||
summed.on_q[1] = self.on_q[1] + other.on_q[1] | ||
return summed | ||
"""Add a scalar value or another DiffractionObject to the yarray of the | ||
DiffractionObject. | ||
|
||
def __radd__(self, other): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. radd i think we don't need? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we may. Are you sure? Anyway, we can test and see. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you are correct - added please see a new test below for |
||
summed = deepcopy(self) | ||
if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray): | ||
summed.on_tth[1] = self.on_tth[1] + other | ||
summed.on_q[1] = self.on_q[1] + other | ||
elif not isinstance(other, DiffractionObject): | ||
raise TypeError("I only know how to sum two Scattering_object objects") | ||
elif self.on_tth[0].all() != other.on_tth[0].all(): | ||
raise RuntimeError(x_grid_emsg) | ||
else: | ||
summed.on_tth[1] = self.on_tth[1] + other.on_tth[1] | ||
summed.on_q[1] = self.on_q[1] + other.on_q[1] | ||
return summed | ||
Parameters | ||
---------- | ||
other : DiffractionObject or int or float | ||
The object to add to the current DiffractionObject. If `other` is a scalar value, | ||
it will be added to all yarray. The length of the yarray must match if `other` is | ||
an instance of DiffractionObject. | ||
|
||
Returns | ||
------- | ||
DiffractionObject | ||
The new and deep-copied DiffractionObject instance after adding values to the yarray. | ||
|
||
Raises | ||
------ | ||
ValueError | ||
Raised when the length of the yarray of the two DiffractionObject instances do not match. | ||
TypeError | ||
Raised when the type of `other` is not an instance of DiffractionObject, int, or float. | ||
|
||
Examples | ||
-------- | ||
Add a scalar value to the yarray of the DiffractionObject instance: | ||
>>> new_do = my_do + 10.1 | ||
>>> new_do = 10.1 + my_do | ||
sbillinge marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Add the yarray of two DiffractionObject instances: | ||
>>> new_do = my_do_1 + my_do_2 | ||
""" | ||
|
||
self._check_operation_compatibility(other) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Created a private func that checks the validity of other def _check_operation_compatibility(self, other):
if not isinstance(other, (DiffractionObject, int, float)):
raise TypeError(invalid_add_type_emsg)
if isinstance(other, DiffractionObject):
self_yarray = self.all_arrays[:, 0]
other_yarray = other.all_arrays[:, 0]
if len(self_yarray) != len(other_yarray):
raise ValueError(y_grid_length_mismatch_emsg) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! it may be more readable if we use
to accomplish the same thing? |
||
summed_do = deepcopy(self) | ||
if isinstance(other, (int, float)): | ||
summed_do._all_arrays[:, 0] += other | ||
if isinstance(other, DiffractionObject): | ||
summed_do._all_arrays[:, 0] += other.all_arrays[:, 0] | ||
return summed_do | ||
|
||
__radd__ = __add__ | ||
sbillinge marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def _check_operation_compatibility(self, other): | ||
if not isinstance(other, (DiffractionObject, int, float)): | ||
raise TypeError(invalid_add_type_emsg) | ||
if isinstance(other, DiffractionObject): | ||
self_yarray = self.all_arrays[:, 0] | ||
other_yarray = other.all_arrays[:, 0] | ||
if len(self_yarray) != len(other_yarray): | ||
raise ValueError(y_grid_length_mismatch_emsg) | ||
|
||
def __sub__(self, other): | ||
subtracted = deepcopy(self) | ||
|
@@ -204,7 +234,7 @@ def __sub__(self, other): | |
elif not isinstance(other, DiffractionObject): | ||
raise TypeError("I only know how to subtract two Scattering_object objects") | ||
elif self.on_tth[0].all() != other.on_tth[0].all(): | ||
raise RuntimeError(x_grid_emsg) | ||
raise RuntimeError(y_grid_length_mismatch_emsg) | ||
else: | ||
subtracted.on_tth[1] = self.on_tth[1] - other.on_tth[1] | ||
subtracted.on_q[1] = self.on_q[1] - other.on_q[1] | ||
|
@@ -218,7 +248,7 @@ def __rsub__(self, other): | |
elif not isinstance(other, DiffractionObject): | ||
raise TypeError("I only know how to subtract two Scattering_object objects") | ||
elif self.on_tth[0].all() != other.on_tth[0].all(): | ||
raise RuntimeError(x_grid_emsg) | ||
raise RuntimeError(y_grid_length_mismatch_emsg) | ||
else: | ||
subtracted.on_tth[1] = other.on_tth[1] - self.on_tth[1] | ||
subtracted.on_q[1] = other.on_q[1] - self.on_q[1] | ||
|
@@ -232,7 +262,7 @@ def __mul__(self, other): | |
elif not isinstance(other, DiffractionObject): | ||
raise TypeError("I only know how to multiply two Scattering_object objects") | ||
elif self.on_tth[0].all() != other.on_tth[0].all(): | ||
raise RuntimeError(x_grid_emsg) | ||
raise RuntimeError(y_grid_length_mismatch_emsg) | ||
else: | ||
multiplied.on_tth[1] = self.on_tth[1] * other.on_tth[1] | ||
multiplied.on_q[1] = self.on_q[1] * other.on_q[1] | ||
|
@@ -244,7 +274,7 @@ def __rmul__(self, other): | |
multiplied.on_tth[1] = other * self.on_tth[1] | ||
multiplied.on_q[1] = other * self.on_q[1] | ||
elif self.on_tth[0].all() != other.on_tth[0].all(): | ||
raise RuntimeError(x_grid_emsg) | ||
raise RuntimeError(y_grid_length_mismatch_emsg) | ||
else: | ||
multiplied.on_tth[1] = self.on_tth[1] * other.on_tth[1] | ||
multiplied.on_q[1] = self.on_q[1] * other.on_q[1] | ||
|
@@ -258,7 +288,7 @@ def __truediv__(self, other): | |
elif not isinstance(other, DiffractionObject): | ||
raise TypeError("I only know how to multiply two Scattering_object objects") | ||
elif self.on_tth[0].all() != other.on_tth[0].all(): | ||
raise RuntimeError(x_grid_emsg) | ||
raise RuntimeError(y_grid_length_mismatch_emsg) | ||
else: | ||
divided.on_tth[1] = self.on_tth[1] / other.on_tth[1] | ||
divided.on_q[1] = self.on_q[1] / other.on_q[1] | ||
|
@@ -270,7 +300,7 @@ def __rtruediv__(self, other): | |
divided.on_tth[1] = other / self.on_tth[1] | ||
divided.on_q[1] = other / self.on_q[1] | ||
elif self.on_tth[0].all() != other.on_tth[0].all(): | ||
raise RuntimeError(x_grid_emsg) | ||
raise RuntimeError(y_grid_length_mismatch_emsg) | ||
else: | ||
divided.on_tth[1] = other.on_tth[1] / self.on_tth[1] | ||
divided.on_q[1] = other.on_q[1] / self.on_q[1] | ||
|
sbillinge marked this conversation as resolved.
Show resolved
Hide resolved
|
Uh oh!
There was an error while loading. Please reload this page.