diff --git a/news/add-operations-tests.rst b/news/add-operations-tests.rst new file mode 100644 index 00000000..e59edbd5 --- /dev/null +++ b/news/add-operations-tests.rst @@ -0,0 +1,23 @@ +**Added:** + +* unit tests for __add__ operation for DiffractionObject + +**Changed:** + +* + +**Deprecated:** + +* + +**Removed:** + +* + +**Fixed:** + +* + +**Security:** + +* diff --git a/src/diffpy/utils/diffraction_objects.py b/src/diffpy/utils/diffraction_objects.py index 25e3e28c..25e09381 100644 --- a/src/diffpy/utils/diffraction_objects.py +++ b/src/diffpy/utils/diffraction_objects.py @@ -14,9 +14,15 @@ XQUANTITIES = ANGLEQUANTITIES + DQUANTITIES + QQUANTITIES XUNITS = ["degrees", "radians", "rad", "deg", "inv_angs", "inv_nm", "nm-1", "A-1"] -x_grid_emsg = ( - "objects are not on the same x-grid. You may add them using the self.add method " - "and specifying how to handle the mismatch." +y_grid_length_mismatch_emsg = ( + "The two objects have different y-array lengths. " + "Please ensure the length of the y-value during initialization is identical." +) + +invalid_add_type_emsg = ( + "You may only add a DiffractionObject with another DiffractionObject or a scalar value. " + "Please rerun by adding another DiffractionObject instance or a scalar value. " + "e.g., my_do_1 + my_do_2 or my_do + 10 or 10 + my_do" ) @@ -169,32 +175,56 @@ def __eq__(self, other): return True def __add__(self, other): - summed = deepcopy(self) - if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray): - summed.on_tth[1] = self.on_tth[1] + other - summed.on_q[1] = self.on_q[1] + other - elif not isinstance(other, DiffractionObject): - raise TypeError("I only know how to sum two DiffractionObject objects") - elif self.on_tth[0].all() != other.on_tth[0].all(): - raise RuntimeError(x_grid_emsg) - else: - summed.on_tth[1] = self.on_tth[1] + other.on_tth[1] - summed.on_q[1] = self.on_q[1] + other.on_q[1] - return summed + """Add a scalar value or another DiffractionObject to the yarray of the + DiffractionObject. - def __radd__(self, other): - summed = deepcopy(self) - if isinstance(other, int) or isinstance(other, float) or isinstance(other, np.ndarray): - summed.on_tth[1] = self.on_tth[1] + other - summed.on_q[1] = self.on_q[1] + other - elif not isinstance(other, DiffractionObject): - raise TypeError("I only know how to sum two Scattering_object objects") - elif self.on_tth[0].all() != other.on_tth[0].all(): - raise RuntimeError(x_grid_emsg) - else: - summed.on_tth[1] = self.on_tth[1] + other.on_tth[1] - summed.on_q[1] = self.on_q[1] + other.on_q[1] - return summed + Parameters + ---------- + other : DiffractionObject or int or float + The object to add to the current DiffractionObject. If `other` is a scalar value, + it will be added to all yarray. The length of the yarray must match if `other` is + an instance of DiffractionObject. + + Returns + ------- + DiffractionObject + The new and deep-copied DiffractionObject instance after adding values to the yarray. + + Raises + ------ + ValueError + Raised when the length of the yarray of the two DiffractionObject instances do not match. + TypeError + Raised when the type of `other` is not an instance of DiffractionObject, int, or float. + + Examples + -------- + Add a scalar value to the yarray of the DiffractionObject instance: + >>> new_do = my_do + 10.1 + >>> new_do = 10.1 + my_do + + Add the yarray of two DiffractionObject instances: + >>> new_do = my_do_1 + my_do_2 + """ + + self._check_operation_compatibility(other) + summed_do = deepcopy(self) + if isinstance(other, (int, float)): + summed_do._all_arrays[:, 0] += other + if isinstance(other, DiffractionObject): + summed_do._all_arrays[:, 0] += other.all_arrays[:, 0] + return summed_do + + __radd__ = __add__ + + def _check_operation_compatibility(self, other): + if not isinstance(other, (DiffractionObject, int, float)): + raise TypeError(invalid_add_type_emsg) + if isinstance(other, DiffractionObject): + self_yarray = self.all_arrays[:, 0] + other_yarray = other.all_arrays[:, 0] + if len(self_yarray) != len(other_yarray): + raise ValueError(y_grid_length_mismatch_emsg) def __sub__(self, other): subtracted = deepcopy(self) @@ -204,7 +234,7 @@ def __sub__(self, other): elif not isinstance(other, DiffractionObject): raise TypeError("I only know how to subtract two Scattering_object objects") elif self.on_tth[0].all() != other.on_tth[0].all(): - raise RuntimeError(x_grid_emsg) + raise RuntimeError(y_grid_length_mismatch_emsg) else: subtracted.on_tth[1] = self.on_tth[1] - other.on_tth[1] subtracted.on_q[1] = self.on_q[1] - other.on_q[1] @@ -218,7 +248,7 @@ def __rsub__(self, other): elif not isinstance(other, DiffractionObject): raise TypeError("I only know how to subtract two Scattering_object objects") elif self.on_tth[0].all() != other.on_tth[0].all(): - raise RuntimeError(x_grid_emsg) + raise RuntimeError(y_grid_length_mismatch_emsg) else: subtracted.on_tth[1] = other.on_tth[1] - self.on_tth[1] subtracted.on_q[1] = other.on_q[1] - self.on_q[1] @@ -232,7 +262,7 @@ def __mul__(self, other): elif not isinstance(other, DiffractionObject): raise TypeError("I only know how to multiply two Scattering_object objects") elif self.on_tth[0].all() != other.on_tth[0].all(): - raise RuntimeError(x_grid_emsg) + raise RuntimeError(y_grid_length_mismatch_emsg) else: multiplied.on_tth[1] = self.on_tth[1] * other.on_tth[1] multiplied.on_q[1] = self.on_q[1] * other.on_q[1] @@ -244,7 +274,7 @@ def __rmul__(self, other): multiplied.on_tth[1] = other * self.on_tth[1] multiplied.on_q[1] = other * self.on_q[1] elif self.on_tth[0].all() != other.on_tth[0].all(): - raise RuntimeError(x_grid_emsg) + raise RuntimeError(y_grid_length_mismatch_emsg) else: multiplied.on_tth[1] = self.on_tth[1] * other.on_tth[1] multiplied.on_q[1] = self.on_q[1] * other.on_q[1] @@ -258,7 +288,7 @@ def __truediv__(self, other): elif not isinstance(other, DiffractionObject): raise TypeError("I only know how to multiply two Scattering_object objects") elif self.on_tth[0].all() != other.on_tth[0].all(): - raise RuntimeError(x_grid_emsg) + raise RuntimeError(y_grid_length_mismatch_emsg) else: divided.on_tth[1] = self.on_tth[1] / other.on_tth[1] divided.on_q[1] = self.on_q[1] / other.on_q[1] @@ -270,7 +300,7 @@ def __rtruediv__(self, other): divided.on_tth[1] = other / self.on_tth[1] divided.on_q[1] = other / self.on_q[1] elif self.on_tth[0].all() != other.on_tth[0].all(): - raise RuntimeError(x_grid_emsg) + raise RuntimeError(y_grid_length_mismatch_emsg) else: divided.on_tth[1] = other.on_tth[1] / self.on_tth[1] divided.on_q[1] = other.on_q[1] / self.on_q[1] diff --git a/tests/conftest.py b/tests/conftest.py index 7f8de460..9e5f1e60 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -47,6 +47,12 @@ def do_minimal_tth(): return DiffractionObject(wavelength=2 * np.pi, xarray=np.array([30, 60]), yarray=np.array([1, 2]), xtype="tth") +@pytest.fixture +def do_minimal_d(): + # Create an instance of DiffractionObject with non-empty xarray, yarray, and wavelength values + return DiffractionObject(wavelength=1.54, xarray=np.array([1, 2]), yarray=np.array([1, 2]), xtype="d") + + @pytest.fixture def wavelength_warning_msg(): return ( @@ -63,3 +69,20 @@ def invalid_q_or_d_or_wavelength_error_msg(): "The supplied input array and wavelength will result in an impossible two-theta. " "Please check these values and re-instantiate the DiffractionObject with correct values." ) + + +@pytest.fixture +def invalid_add_type_error_msg(): + return ( + "You may only add a DiffractionObject with another DiffractionObject or a scalar value. " + "Please rerun by adding another DiffractionObject instance or a scalar value. " + "e.g., my_do_1 + my_do_2 or my_do + 10 or 10 + my_do" + ) + + +@pytest.fixture +def y_grid_size_mismatch_error_msg(): + return ( + "The two objects have different y-array lengths. " + "Please ensure the length of the y-value during initialization is identical." + ) diff --git a/tests/test_diffraction_objects.py b/tests/test_diffraction_objects.py index c6355882..ccab71fe 100644 --- a/tests/test_diffraction_objects.py +++ b/tests/test_diffraction_objects.py @@ -710,3 +710,78 @@ def test_copy_object(do_minimal): do_copy = do.copy() assert do == do_copy assert id(do) != id(do_copy) + + +@pytest.mark.parametrize( + "starting_all_arrays, scalar_to_add, expected_all_arrays", + [ + # Test scalar addition to yarray values (intensity) and expect no change to xarrays (q, tth, d) + ( # C1: Add integer of 5, expect yarray to increase by by 5 + np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]), + 5, + np.array([[6.0, 0.51763809, 30.0, 12.13818192], [7.0, 1.0, 60.0, 6.28318531]]), + ), + ( # C2: Add float of 5.1, expect yarray to be added by 5.1 + np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]), + 5.1, + np.array([[6.1, 0.51763809, 30.0, 12.13818192], [7.1, 1.0, 60.0, 6.28318531]]), + ), + ], +) +def test_addition_operator_by_scalar(starting_all_arrays, scalar_to_add, expected_all_arrays, do_minimal_tth): + do = do_minimal_tth + assert np.allclose(do.all_arrays, starting_all_arrays) + do_scalar_right_sum = do + scalar_to_add + assert np.allclose(do_scalar_right_sum.all_arrays, expected_all_arrays) + do_scalar_left_sum = scalar_to_add + do + assert np.allclose(do_scalar_left_sum.all_arrays, expected_all_arrays) + + +@pytest.mark.parametrize( + "do_1_all_arrays, " + "do_2_all_arrays, " + "expected_do_1_all_arrays_with_y_summed, " + "expected_do_2_all_arrays_with_y_summed", + [ + # Test addition of two DO objects, expect combined yarray values and no change to xarrays ((q, tth, d) + ( # C1: Add two DO objects, expect sum of yarray values + (np.array([[1.0, 0.51763809, 30.0, 12.13818192], [2.0, 1.0, 60.0, 6.28318531]]),), + (np.array([[1.0, 6.28318531, 100.70777771, 1], [2.0, 3.14159265, 45.28748053, 2.0]]),), + (np.array([[2.0, 0.51763809, 30.0, 12.13818192], [4.0, 1.0, 60.0, 6.28318531]]),), + (np.array([[2.0, 6.28318531, 100.70777771, 1], [4.0, 3.14159265, 45.28748053, 2.0]]),), + ), + ], +) +def test_addition_operator_by_another_do( + do_1_all_arrays, + do_2_all_arrays, + expected_do_1_all_arrays_with_y_summed, + expected_do_2_all_arrays_with_y_summed, + do_minimal_tth, + do_minimal_d, +): + do_1 = do_minimal_tth + assert np.allclose(do_1.all_arrays, do_1_all_arrays) + do_2 = do_minimal_d + assert np.allclose(do_2.all_arrays, do_2_all_arrays) + assert np.allclose((do_1 + do_2).all_arrays, expected_do_1_all_arrays_with_y_summed) + assert np.allclose((do_2 + do_1).all_arrays, expected_do_2_all_arrays_with_y_summed) + + +def test_addition_operator_invalid_type(do_minimal_tth, invalid_add_type_error_msg): + # Add a string to a DO object, expect TypeError, only scalar (int, float) allowed for addition + do = do_minimal_tth + with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)): + do + "string_value" + with pytest.raises(TypeError, match=re.escape(invalid_add_type_error_msg)): + "string_value" + do + + +def test_addition_operator_invalid_yarray_length(do_minimal, do_minimal_tth, y_grid_size_mismatch_error_msg): + # Combine two DO objects, one with empty xarrays (do_minimal) and the other with non-empty xarrays + do_1 = do_minimal + do_2 = do_minimal_tth + assert len(do_1.all_arrays[:, 0]) == 0 + assert len(do_2.all_arrays[:, 0]) == 2 + with pytest.raises(ValueError, match=re.escape(y_grid_size_mismatch_error_msg)): + do_1 + do_2