@@ -216,6 +216,7 @@ def test_dml_irm_cate_gate(cov_type):
216216 cate = dml_irm_obj .cate (random_basis , cov_type = cov_type )
217217 assert isinstance (cate , dml .utils .blp .DoubleMLBLP )
218218 assert isinstance (cate .confint (), pd .DataFrame )
219+ assert cate .blp_model .cov_type == cov_type
219220
220221 groups_1 = pd .DataFrame (np .column_stack ([obj_dml_data .data ['X1' ] <= 0 ,
221222 obj_dml_data .data ['X1' ] > 0.2 ]),
@@ -226,6 +227,7 @@ def test_dml_irm_cate_gate(cov_type):
226227 assert isinstance (gate_1 , dml .utils .blp .DoubleMLBLP )
227228 assert isinstance (gate_1 .confint (), pd .DataFrame )
228229 assert all (gate_1 .confint ().index == groups_1 .columns .to_list ())
230+ assert gate_1 .blp_model .cov_type == cov_type
229231
230232 np .random .seed (42 )
231233 groups_2 = pd .DataFrame (np .random .choice (["1" , "2" ], n ))
@@ -235,6 +237,7 @@ def test_dml_irm_cate_gate(cov_type):
235237 assert isinstance (gate_2 , dml .utils .blp .DoubleMLBLP )
236238 assert isinstance (gate_2 .confint (), pd .DataFrame )
237239 assert all (gate_2 .confint ().index == ["Group_1" , "Group_2" ])
240+ assert gate_2 .blp_model .cov_type == cov_type
238241
239242
240243@pytest .fixture (scope = 'module' ,
0 commit comments