@@ -200,8 +200,14 @@ def test_dml_apo_sensitivity(dml_apo_fixture):
200200 rtol = 1e-9 , atol = 1e-4 )
201201
202202
203+ @pytest .fixture (scope = 'module' ,
204+ params = ["nonrobust" , "HC0" , "HC1" , "HC2" , "HC3" ])
205+ def cov_type (request ):
206+ return request .param
207+
208+
203209@pytest .mark .ci
204- def test_dml_apo_capo_gapo (treatment_level ):
210+ def test_dml_apo_capo_gapo (treatment_level , cov_type ):
205211 n = 20
206212 # collect data
207213 np .random .seed (42 )
@@ -221,25 +227,28 @@ def test_dml_apo_capo_gapo(treatment_level):
221227 dml_obj .fit ()
222228 # create a random basis
223229 random_basis = pd .DataFrame (np .random .normal (0 , 1 , size = (n , 5 )))
224- capo = dml_obj .capo (random_basis )
230+ capo = dml_obj .capo (random_basis , cov_type = cov_type )
225231 assert isinstance (capo , dml .utils .blp .DoubleMLBLP )
226232 assert isinstance (capo .confint (), pd .DataFrame )
233+ assert capo .blp_model .cov_type == cov_type
227234
228235 groups_1 = pd .DataFrame (np .column_stack ([obj_dml_data .data ['X1' ] <= - 1.0 ,
229236 obj_dml_data .data ['X1' ] > 0.2 ]),
230237 columns = ['Group 1' , 'Group 2' ])
231238 msg = ('At least one group effect is estimated with less than 6 observations.' )
232239 with pytest .warns (UserWarning , match = msg ):
233- gapo_1 = dml_obj .gapo (groups_1 )
240+ gapo_1 = dml_obj .gapo (groups_1 , cov_type = cov_type )
234241 assert isinstance (gapo_1 , dml .utils .blp .DoubleMLBLP )
235242 assert isinstance (gapo_1 .confint (), pd .DataFrame )
236243 assert all (gapo_1 .confint ().index == groups_1 .columns .to_list ())
244+ assert gapo_1 .blp_model .cov_type == cov_type
237245
238246 np .random .seed (42 )
239247 groups_2 = pd .DataFrame (np .random .choice (["1" , "2" ], n , p = [0.1 , 0.9 ]))
240248 msg = ('At least one group effect is estimated with less than 6 observations.' )
241249 with pytest .warns (UserWarning , match = msg ):
242- gapo_2 = dml_obj .gapo (groups_2 )
250+ gapo_2 = dml_obj .gapo (groups_2 , cov_type = cov_type )
243251 assert isinstance (gapo_2 , dml .utils .blp .DoubleMLBLP )
244252 assert isinstance (gapo_2 .confint (), pd .DataFrame )
245253 assert all (gapo_2 .confint ().index == ["Group_1" , "Group_2" ])
254+ assert gapo_2 .blp_model .cov_type == cov_type
0 commit comments