Skip to content

Commit cb02bca

Browse files
committed
add kwargs to cate and gate irm
1 parent 88c8ec8 commit cb02bca

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

doubleml/irm/irm.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_
431431

432432
return res
433433

434-
def cate(self, basis, is_gate=False):
434+
def cate(self, basis, is_gate=False, **kwargs):
435435
"""
436436
Calculate conditional average treatment effects (CATE) for a given basis.
437437
@@ -440,10 +440,14 @@ def cate(self, basis, is_gate=False):
440440
basis : :class:`pandas.DataFrame`
441441
The basis for estimating the best linear predictor. Has to have the shape ``(n_obs, d)``,
442442
where ``n_obs`` is the number of observations and ``d`` is the number of predictors.
443+
443444
is_gate : bool
444445
Indicates whether the basis is constructed for GATEs (dummy-basis).
445446
Default is ``False``.
446447
448+
**kwargs: dict
449+
Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``.
450+
447451
Returns
448452
-------
449453
model : :class:`doubleML.DoubleMLBLP`
@@ -462,10 +466,10 @@ def cate(self, basis, is_gate=False):
462466
orth_signal = self.psi_elements['psi_b'].reshape(-1)
463467
# fit the best linear predictor
464468
model = DoubleMLBLP(orth_signal, basis=basis, is_gate=is_gate)
465-
model.fit()
469+
model.fit(**kwargs)
466470
return model
467471

468-
def gate(self, groups):
472+
def gate(self, groups, **kwargs):
469473
"""
470474
Calculate group average treatment effects (GATE) for groups.
471475
@@ -476,6 +480,9 @@ def gate(self, groups):
476480
Has to be dummy coded with shape ``(n_obs, d)``, where ``n_obs`` is the number of observations
477481
and ``d`` is the number of groups or ``(n_obs, 1)`` and contain the corresponding groups (as str).
478482
483+
**kwargs: dict
484+
Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``.
485+
479486
Returns
480487
-------
481488
model : :class:`doubleML.DoubleMLBLP`
@@ -495,7 +502,7 @@ def gate(self, groups):
495502
if any(groups.sum(0) <= 5):
496503
warnings.warn('At least one group effect is estimated with less than 6 observations.')
497504

498-
model = self.cate(groups, is_gate=True)
505+
model = self.cate(groups, is_gate=True, **kwargs)
499506
return model
500507

501508
def policy_tree(self, features, depth=2, **tree_params):

doubleml/irm/tests/test_irm.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,14 @@ def test_dml_irm_sensitivity_rho0(dml_irm_fixture):
187187
rtol=1e-9, atol=1e-4)
188188

189189

190+
@pytest.fixture(scope='module',
191+
params=["nonrobust", "HC0", "HC1", "HC2", "HC3"])
192+
def cov_type(request):
193+
return request.param
194+
195+
190196
@pytest.mark.ci
191-
def test_dml_irm_cate_gate():
197+
def test_dml_irm_cate_gate(cov_type):
192198
n = 9
193199
# collect data
194200
np.random.seed(42)
@@ -207,7 +213,7 @@ def test_dml_irm_cate_gate():
207213
dml_irm_obj.fit()
208214
# create a random basis
209215
random_basis = pd.DataFrame(np.random.normal(0, 1, size=(n, 5)))
210-
cate = dml_irm_obj.cate(random_basis)
216+
cate = dml_irm_obj.cate(random_basis, cov_type=cov_type)
211217
assert isinstance(cate, dml.utils.blp.DoubleMLBLP)
212218
assert isinstance(cate.confint(), pd.DataFrame)
213219

@@ -216,7 +222,7 @@ def test_dml_irm_cate_gate():
216222
columns=['Group 1', 'Group 2'])
217223
msg = ('At least one group effect is estimated with less than 6 observations.')
218224
with pytest.warns(UserWarning, match=msg):
219-
gate_1 = dml_irm_obj.gate(groups_1)
225+
gate_1 = dml_irm_obj.gate(groups_1, cov_type=cov_type)
220226
assert isinstance(gate_1, dml.utils.blp.DoubleMLBLP)
221227
assert isinstance(gate_1.confint(), pd.DataFrame)
222228
assert all(gate_1.confint().index == groups_1.columns.to_list())
@@ -225,7 +231,7 @@ def test_dml_irm_cate_gate():
225231
groups_2 = pd.DataFrame(np.random.choice(["1", "2"], n))
226232
msg = ('At least one group effect is estimated with less than 6 observations.')
227233
with pytest.warns(UserWarning, match=msg):
228-
gate_2 = dml_irm_obj.gate(groups_2)
234+
gate_2 = dml_irm_obj.gate(groups_2, cov_type=cov_type)
229235
assert isinstance(gate_2, dml.utils.blp.DoubleMLBLP)
230236
assert isinstance(gate_2.confint(), pd.DataFrame)
231237
assert all(gate_2.confint().index == ["Group_1", "Group_2"])

0 commit comments

Comments
 (0)