@@ -105,49 +105,55 @@ def test_set_muls_at_angle(input_mu, expected_muls):
105105
106106
107107@pytest .mark .parametrize (
108- "input_xtype, expected" ,
109- [
110- (
111- "tth" ,
108+ "input_diffraction_data, input_cve_params" ,
109+ [ # Test that cve diffraction object contains the expected info
110+ # Note that all cve values are interpolated to 0.5
111+ # cve do should contain the same input xarray, xtype,
112+ # wavelength, and metadata
113+ ( # C1: User did not specify method, default to fast calculation
112114 {
113115 "xarray" : np .array ([90 , 90.1 , 90.2 ]),
114- "yarray" : np .array ([0.5 , 0.5 , 0.5 ]),
115- "xtype" : "tth" ,
116+ "yarray" : np .array ([2 , 2 , 2 ]),
116117 },
118+ {"mud" : 1 , "xtype" : "tth" },
117119 ),
118- (
119- "q" ,
120+ ( # C2: User specified brute-force computation method
120121 {
121- "xarray" : np .array ([5.76998 , 5.77501 , 5.78004 ]),
122- "yarray" : np .array ([0.5 , 0.5 , 0.5 ]),
123- "xtype" : "q" ,
122+ "xarray" : np .array ([5.1 , 5.2 , 5.3 ]),
123+ "yarray" : np .array ([2 , 2 , 2 ]),
124124 },
125+ {"mud" : 1 , "method" : "brute_force" , "xtype" : "q" },
126+ ),
127+ ( # C3: User specified mu*D outside the fast calculation range,
128+ # default to brute-force computation
129+ {
130+ "xarray" : np .array ([5.1 , 5.2 , 5.3 ]),
131+ "yarray" : np .array ([2 , 2 , 2 ]),
132+ },
133+ {"mud" : 20 , "xtype" : "q" },
125134 ),
126135 ],
127136)
128- def test_compute_cve (input_xtype , expected , mocker ):
129- xarray , yarray = np . array ([ 90 , 90.1 , 90.2 ]), np . array ([ 2 , 2 , 2 ])
137+ def test_compute_cve (mocker , input_diffraction_data , input_cve_params ):
138+ expected_xarray = input_diffraction_data [ "xarray" ]
130139 expected_cve = np .array ([0.5 , 0.5 , 0.5 ])
140+ expected_xtype = input_cve_params ["xtype" ]
141+ mocker .patch ("diffpy.labpdfproc.functions.N_POINTS_ON_DIAMETER" , 4 )
131142 mocker .patch ("numpy.interp" , return_value = expected_cve )
132143 input_pattern = DiffractionObject (
133- xarray = xarray ,
134- yarray = yarray ,
135- xtype = "tth" ,
144+ xarray = input_diffraction_data [ " xarray" ] ,
145+ yarray = input_diffraction_data [ " yarray" ] ,
146+ xtype = input_cve_params [ "xtype" ] ,
136147 wavelength = 1.54 ,
137148 scat_quantity = "x-ray" ,
138149 name = "test" ,
139150 metadata = {"thing1" : 1 , "thing2" : "thing2" },
140151 )
141- actual_cve_do = compute_cve (
142- input_pattern ,
143- mud = 1 ,
144- method = "polynomial_interpolation" ,
145- xtype = input_xtype ,
146- )
152+ actual_cve_do = compute_cve (input_pattern , ** input_cve_params )
147153 expected_cve_do = DiffractionObject (
148- xarray = expected [ "xarray" ] ,
149- yarray = expected [ "yarray" ] ,
150- xtype = expected [ "xtype" ] ,
154+ xarray = expected_xarray ,
155+ yarray = expected_cve ,
156+ xtype = expected_xtype ,
151157 wavelength = 1.54 ,
152158 scat_quantity = "cve" ,
153159 name = "absorption correction, cve, for test" ,
@@ -156,32 +162,9 @@ def test_compute_cve(input_xtype, expected, mocker):
156162 assert actual_cve_do == expected_cve_do
157163
158164
159- @pytest .mark .parametrize (
160- "inputs, msg" ,
161- [
162- (
163- {"mud" : 10 , "method" : "polynomial_interpolation" },
164- f"mu*D = 10 is out of the acceptable range (0.5 to 7) "
165- f"for polynomial interpolation. "
166- f"Please rerun with a value within this range "
167- f"or specifying another method from { * CVE_METHODS , } ." ,
168- ),
169- (
170- {"mud" : 1 , "method" : "invalid_method" },
171- f"Unknown method: invalid_method. "
172- f"Allowed methods are { * CVE_METHODS , } ." ,
173- ),
174- (
175- {"mud" : 7 , "method" : "invalid_method" },
176- f"Unknown method: invalid_method. "
177- f"Allowed methods are { * CVE_METHODS , } ." ,
178- ),
179- ],
180- )
181- def test_compute_cve_bad (mocker , inputs , msg ):
165+ def test_compute_cve_bad (mocker ):
182166 xarray , yarray = np .array ([90 , 90.1 , 90.2 ]), np .array ([2 , 2 , 2 ])
183167 expected_cve = np .array ([0.5 , 0.5 , 0.5 ])
184- mocker .patch ("diffpy.labpdfproc.functions.TTH_GRID" , xarray )
185168 mocker .patch ("numpy.interp" , return_value = expected_cve )
186169 input_pattern = DiffractionObject (
187170 xarray = xarray ,
@@ -192,14 +175,21 @@ def test_compute_cve_bad(mocker, inputs, msg):
192175 name = "test" ,
193176 metadata = {"thing1" : 1 , "thing2" : "thing2" },
194177 )
195- with pytest .raises (ValueError , match = re .escape (msg )):
196- compute_cve (input_pattern , mud = inputs ["mud" ], method = inputs ["method" ])
178+ # Test that the function raises a ValueError
179+ # when an invalid method is provided
180+ with pytest .raises (
181+ ValueError ,
182+ match = re .escape (
183+ f"Unknown method: invalid_method. "
184+ f"Allowed methods are { * CVE_METHODS , } ."
185+ ),
186+ ):
187+ compute_cve (input_pattern , mud = 1 , method = "invalid_method" )
197188
198189
199190def test_apply_corr (mocker ):
200191 xarray , yarray = np .array ([90 , 90.1 , 90.2 ]), np .array ([2 , 2 , 2 ])
201192 expected_cve = np .array ([0.5 , 0.5 , 0.5 ])
202- mocker .patch ("diffpy.labpdfproc.functions.TTH_GRID" , xarray )
203193 mocker .patch ("numpy.interp" , return_value = expected_cve )
204194 input_pattern = DiffractionObject (
205195 xarray = xarray ,
0 commit comments