Skip to content

Commit 67c4c58

Browse files
committed
add apo kwargs for cate and gate
1 parent cb02bca commit 67c4c58

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

doubleml/irm/apo.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def _check_data(self, obj_dml_data):
389389

390390
return
391391

392-
def capo(self, basis, is_gate=False):
392+
def capo(self, basis, is_gate=False, **kwargs):
393393
"""
394394
Calculate conditional average potential outcomes (CAPO) for a given basis.
395395
@@ -398,10 +398,14 @@ def capo(self, basis, is_gate=False):
398398
basis : :class:`pandas.DataFrame`
399399
The basis for estimating the best linear predictor. Has to have the shape ``(n_obs, d)``,
400400
where ``n_obs`` is the number of observations and ``d`` is the number of predictors.
401+
401402
is_gate : bool
402403
Indicates whether the basis is constructed for GATE/GAPOs (dummy-basis).
403404
Default is ``False``.
404405
406+
**kwargs: dict
407+
Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``.
408+
405409
Returns
406410
-------
407411
model : :class:`doubleML.DoubleMLBLP`
@@ -420,10 +424,10 @@ def capo(self, basis, is_gate=False):
420424
orth_signal = self.psi_elements['psi_b'].reshape(-1)
421425
# fit the best linear predictor
422426
model = DoubleMLBLP(orth_signal, basis=basis, is_gate=is_gate)
423-
model.fit()
427+
model.fit(**kwargs)
424428
return model
425429

426-
def gapo(self, groups):
430+
def gapo(self, groups, **kwargs):
427431
"""
428432
Calculate group average potential outcomes (GAPO) for groups.
429433
@@ -434,6 +438,9 @@ def gapo(self, groups):
434438
Has to be dummy coded with shape ``(n_obs, d)``, where ``n_obs`` is the number of observations
435439
and ``d`` is the number of groups or ``(n_obs, 1)`` and contain the corresponding groups (as str).
436440
441+
**kwargs: dict
442+
Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``.
443+
437444
Returns
438445
-------
439446
model : :class:`doubleML.DoubleMLBLP`
@@ -453,5 +460,5 @@ def gapo(self, groups):
453460
if any(groups.sum(0) <= 5):
454461
warnings.warn('At least one group effect is estimated with less than 6 observations.')
455462

456-
model = self.capo(groups, is_gate=True)
463+
model = self.capo(groups, is_gate=True, **kwargs)
457464
return model

doubleml/irm/tests/test_apo.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,14 @@ def test_dml_apo_sensitivity(dml_apo_fixture):
200200
rtol=1e-9, atol=1e-4)
201201

202202

203+
@pytest.fixture(scope='module',
204+
params=["nonrobust", "HC0", "HC1", "HC2", "HC3"])
205+
def cov_type(request):
206+
return request.param
207+
208+
203209
@pytest.mark.ci
204-
def test_dml_apo_capo_gapo(treatment_level):
210+
def test_dml_apo_capo_gapo(treatment_level, cov_type):
205211
n = 20
206212
# collect data
207213
np.random.seed(42)
@@ -221,25 +227,28 @@ def test_dml_apo_capo_gapo(treatment_level):
221227
dml_obj.fit()
222228
# create a random basis
223229
random_basis = pd.DataFrame(np.random.normal(0, 1, size=(n, 5)))
224-
capo = dml_obj.capo(random_basis)
230+
capo = dml_obj.capo(random_basis, cov_type=cov_type)
225231
assert isinstance(capo, dml.utils.blp.DoubleMLBLP)
226232
assert isinstance(capo.confint(), pd.DataFrame)
233+
assert capo.blp_model.cov_type == cov_type
227234

228235
groups_1 = pd.DataFrame(np.column_stack([obj_dml_data.data['X1'] <= -1.0,
229236
obj_dml_data.data['X1'] > 0.2]),
230237
columns=['Group 1', 'Group 2'])
231238
msg = ('At least one group effect is estimated with less than 6 observations.')
232239
with pytest.warns(UserWarning, match=msg):
233-
gapo_1 = dml_obj.gapo(groups_1)
240+
gapo_1 = dml_obj.gapo(groups_1, cov_type=cov_type)
234241
assert isinstance(gapo_1, dml.utils.blp.DoubleMLBLP)
235242
assert isinstance(gapo_1.confint(), pd.DataFrame)
236243
assert all(gapo_1.confint().index == groups_1.columns.to_list())
244+
assert gapo_1.blp_model.cov_type == cov_type
237245

238246
np.random.seed(42)
239247
groups_2 = pd.DataFrame(np.random.choice(["1", "2"], n, p=[0.1, 0.9]))
240248
msg = ('At least one group effect is estimated with less than 6 observations.')
241249
with pytest.warns(UserWarning, match=msg):
242-
gapo_2 = dml_obj.gapo(groups_2)
250+
gapo_2 = dml_obj.gapo(groups_2, cov_type=cov_type)
243251
assert isinstance(gapo_2, dml.utils.blp.DoubleMLBLP)
244252
assert isinstance(gapo_2.confint(), pd.DataFrame)
245253
assert all(gapo_2.confint().index == ["Group_1", "Group_2"])
254+
assert gapo_2.blp_model.cov_type == cov_type

0 commit comments

Comments
 (0)