Skip to content

Commit 0e64d84

Browse files
committed
test cov_type gate and cate irm and plr
1 parent 67c4c58 commit 0e64d84

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

doubleml/irm/tests/test_irm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def test_dml_irm_cate_gate(cov_type):
216216
cate = dml_irm_obj.cate(random_basis, cov_type=cov_type)
217217
assert isinstance(cate, dml.utils.blp.DoubleMLBLP)
218218
assert isinstance(cate.confint(), pd.DataFrame)
219+
assert cate.blp_model.cov_type == cov_type
219220

220221
groups_1 = pd.DataFrame(np.column_stack([obj_dml_data.data['X1'] <= 0,
221222
obj_dml_data.data['X1'] > 0.2]),
@@ -226,6 +227,7 @@ def test_dml_irm_cate_gate(cov_type):
226227
assert isinstance(gate_1, dml.utils.blp.DoubleMLBLP)
227228
assert isinstance(gate_1.confint(), pd.DataFrame)
228229
assert all(gate_1.confint().index == groups_1.columns.to_list())
230+
assert gate_1.blp_model.cov_type == cov_type
229231

230232
np.random.seed(42)
231233
groups_2 = pd.DataFrame(np.random.choice(["1", "2"], n))
@@ -235,6 +237,7 @@ def test_dml_irm_cate_gate(cov_type):
235237
assert isinstance(gate_2, dml.utils.blp.DoubleMLBLP)
236238
assert isinstance(gate_2.confint(), pd.DataFrame)
237239
assert all(gate_2.confint().index == ["Group_1", "Group_2"])
240+
assert gate_2.blp_model.cov_type == cov_type
238241

239242

240243
@pytest.fixture(scope='module',

doubleml/plm/tests/test_plr.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ def test_dml_plr_cate_gate(score, cov_type):
327327
cate = dml_plr_obj.cate(random_basis, cov_type=cov_type)
328328
assert isinstance(cate, dml.DoubleMLBLP)
329329
assert isinstance(cate.confint(), pd.DataFrame)
330+
assert cate.blp_model.cov_type == cov_type
330331

331332
groups_1 = pd.DataFrame(
332333
np.column_stack([obj_dml_data.data['X1'] <= 0,
@@ -338,6 +339,7 @@ def test_dml_plr_cate_gate(score, cov_type):
338339
assert isinstance(gate_1, dml.utils.blp.DoubleMLBLP)
339340
assert isinstance(gate_1.confint(), pd.DataFrame)
340341
assert all(gate_1.confint().index == groups_1.columns.tolist())
342+
assert gate_1.blp_model.cov_type == cov_type
341343

342344
np.random.seed(42)
343345
groups_2 = pd.DataFrame(np.random.choice(["1", "2"], n))
@@ -347,3 +349,4 @@ def test_dml_plr_cate_gate(score, cov_type):
347349
assert isinstance(gate_2, dml.utils.blp.DoubleMLBLP)
348350
assert isinstance(gate_2.confint(), pd.DataFrame)
349351
assert all(gate_2.confint().index == ["Group_1", "Group_2"])
352+
assert gate_2.blp_model.cov_type == cov_type

0 commit comments

Comments
 (0)