Skip to content

Commit 0ee11dd

Browse files
authored
Merge pull request #271 from DoubleML/s-update-blp-cov-type
Add cov_type and kwargs to BLP object
2 parents 31f7388 + 0e64d84 commit 0ee11dd

File tree

9 files changed

+122
-34
lines changed

9 files changed

+122
-34
lines changed

doubleml/irm/apo.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def _check_data(self, obj_dml_data):
389389

390390
return
391391

392-
def capo(self, basis, is_gate=False):
392+
def capo(self, basis, is_gate=False, **kwargs):
393393
"""
394394
Calculate conditional average potential outcomes (CAPO) for a given basis.
395395
@@ -398,10 +398,14 @@ def capo(self, basis, is_gate=False):
398398
basis : :class:`pandas.DataFrame`
399399
The basis for estimating the best linear predictor. Has to have the shape ``(n_obs, d)``,
400400
where ``n_obs`` is the number of observations and ``d`` is the number of predictors.
401+
401402
is_gate : bool
402403
Indicates whether the basis is constructed for GATE/GAPOs (dummy-basis).
403404
Default is ``False``.
404405
406+
**kwargs: dict
407+
Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``.
408+
405409
Returns
406410
-------
407411
model : :class:`doubleML.DoubleMLBLP`
@@ -420,10 +424,10 @@ def capo(self, basis, is_gate=False):
420424
orth_signal = self.psi_elements['psi_b'].reshape(-1)
421425
# fit the best linear predictor
422426
model = DoubleMLBLP(orth_signal, basis=basis, is_gate=is_gate)
423-
model.fit()
427+
model.fit(**kwargs)
424428
return model
425429

426-
def gapo(self, groups):
430+
def gapo(self, groups, **kwargs):
427431
"""
428432
Calculate group average potential outcomes (GAPO) for groups.
429433
@@ -434,6 +438,9 @@ def gapo(self, groups):
434438
Has to be dummy coded with shape ``(n_obs, d)``, where ``n_obs`` is the number of observations
435439
and ``d`` is the number of groups or ``(n_obs, 1)`` and contain the corresponding groups (as str).
436440
441+
**kwargs: dict
442+
Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``.
443+
437444
Returns
438445
-------
439446
model : :class:`doubleML.DoubleMLBLP`
@@ -453,5 +460,5 @@ def gapo(self, groups):
453460
if any(groups.sum(0) <= 5):
454461
warnings.warn('At least one group effect is estimated with less than 6 observations.')
455462

456-
model = self.capo(groups, is_gate=True)
463+
model = self.capo(groups, is_gate=True, **kwargs)
457464
return model

doubleml/irm/irm.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_
431431

432432
return res
433433

434-
def cate(self, basis, is_gate=False):
434+
def cate(self, basis, is_gate=False, **kwargs):
435435
"""
436436
Calculate conditional average treatment effects (CATE) for a given basis.
437437
@@ -440,10 +440,14 @@ def cate(self, basis, is_gate=False):
440440
basis : :class:`pandas.DataFrame`
441441
The basis for estimating the best linear predictor. Has to have the shape ``(n_obs, d)``,
442442
where ``n_obs`` is the number of observations and ``d`` is the number of predictors.
443+
443444
is_gate : bool
444445
Indicates whether the basis is constructed for GATEs (dummy-basis).
445446
Default is ``False``.
446447
448+
**kwargs: dict
449+
Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``.
450+
447451
Returns
448452
-------
449453
model : :class:`doubleML.DoubleMLBLP`
@@ -462,10 +466,10 @@ def cate(self, basis, is_gate=False):
462466
orth_signal = self.psi_elements['psi_b'].reshape(-1)
463467
# fit the best linear predictor
464468
model = DoubleMLBLP(orth_signal, basis=basis, is_gate=is_gate)
465-
model.fit()
469+
model.fit(**kwargs)
466470
return model
467471

468-
def gate(self, groups):
472+
def gate(self, groups, **kwargs):
469473
"""
470474
Calculate group average treatment effects (GATE) for groups.
471475
@@ -476,6 +480,9 @@ def gate(self, groups):
476480
Has to be dummy coded with shape ``(n_obs, d)``, where ``n_obs`` is the number of observations
477481
and ``d`` is the number of groups or ``(n_obs, 1)`` and contain the corresponding groups (as str).
478482
483+
**kwargs: dict
484+
Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``.
485+
479486
Returns
480487
-------
481488
model : :class:`doubleML.DoubleMLBLP`
@@ -495,7 +502,7 @@ def gate(self, groups):
495502
if any(groups.sum(0) <= 5):
496503
warnings.warn('At least one group effect is estimated with less than 6 observations.')
497504

498-
model = self.cate(groups, is_gate=True)
505+
model = self.cate(groups, is_gate=True, **kwargs)
499506
return model
500507

501508
def policy_tree(self, features, depth=2, **tree_params):

doubleml/irm/tests/test_apo.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,14 @@ def test_dml_apo_sensitivity(dml_apo_fixture):
200200
rtol=1e-9, atol=1e-4)
201201

202202

203+
@pytest.fixture(scope='module',
204+
params=["nonrobust", "HC0", "HC1", "HC2", "HC3"])
205+
def cov_type(request):
206+
return request.param
207+
208+
203209
@pytest.mark.ci
204-
def test_dml_apo_capo_gapo(treatment_level):
210+
def test_dml_apo_capo_gapo(treatment_level, cov_type):
205211
n = 20
206212
# collect data
207213
np.random.seed(42)
@@ -221,25 +227,28 @@ def test_dml_apo_capo_gapo(treatment_level):
221227
dml_obj.fit()
222228
# create a random basis
223229
random_basis = pd.DataFrame(np.random.normal(0, 1, size=(n, 5)))
224-
capo = dml_obj.capo(random_basis)
230+
capo = dml_obj.capo(random_basis, cov_type=cov_type)
225231
assert isinstance(capo, dml.utils.blp.DoubleMLBLP)
226232
assert isinstance(capo.confint(), pd.DataFrame)
233+
assert capo.blp_model.cov_type == cov_type
227234

228235
groups_1 = pd.DataFrame(np.column_stack([obj_dml_data.data['X1'] <= -1.0,
229236
obj_dml_data.data['X1'] > 0.2]),
230237
columns=['Group 1', 'Group 2'])
231238
msg = ('At least one group effect is estimated with less than 6 observations.')
232239
with pytest.warns(UserWarning, match=msg):
233-
gapo_1 = dml_obj.gapo(groups_1)
240+
gapo_1 = dml_obj.gapo(groups_1, cov_type=cov_type)
234241
assert isinstance(gapo_1, dml.utils.blp.DoubleMLBLP)
235242
assert isinstance(gapo_1.confint(), pd.DataFrame)
236243
assert all(gapo_1.confint().index == groups_1.columns.to_list())
244+
assert gapo_1.blp_model.cov_type == cov_type
237245

238246
np.random.seed(42)
239247
groups_2 = pd.DataFrame(np.random.choice(["1", "2"], n, p=[0.1, 0.9]))
240248
msg = ('At least one group effect is estimated with less than 6 observations.')
241249
with pytest.warns(UserWarning, match=msg):
242-
gapo_2 = dml_obj.gapo(groups_2)
250+
gapo_2 = dml_obj.gapo(groups_2, cov_type=cov_type)
243251
assert isinstance(gapo_2, dml.utils.blp.DoubleMLBLP)
244252
assert isinstance(gapo_2.confint(), pd.DataFrame)
245253
assert all(gapo_2.confint().index == ["Group_1", "Group_2"])
254+
assert gapo_2.blp_model.cov_type == cov_type

doubleml/irm/tests/test_irm.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,14 @@ def test_dml_irm_sensitivity_rho0(dml_irm_fixture):
187187
rtol=1e-9, atol=1e-4)
188188

189189

190+
@pytest.fixture(scope='module',
191+
params=["nonrobust", "HC0", "HC1", "HC2", "HC3"])
192+
def cov_type(request):
193+
return request.param
194+
195+
190196
@pytest.mark.ci
191-
def test_dml_irm_cate_gate():
197+
def test_dml_irm_cate_gate(cov_type):
192198
n = 9
193199
# collect data
194200
np.random.seed(42)
@@ -207,28 +213,31 @@ def test_dml_irm_cate_gate():
207213
dml_irm_obj.fit()
208214
# create a random basis
209215
random_basis = pd.DataFrame(np.random.normal(0, 1, size=(n, 5)))
210-
cate = dml_irm_obj.cate(random_basis)
216+
cate = dml_irm_obj.cate(random_basis, cov_type=cov_type)
211217
assert isinstance(cate, dml.utils.blp.DoubleMLBLP)
212218
assert isinstance(cate.confint(), pd.DataFrame)
219+
assert cate.blp_model.cov_type == cov_type
213220

214221
groups_1 = pd.DataFrame(np.column_stack([obj_dml_data.data['X1'] <= 0,
215222
obj_dml_data.data['X1'] > 0.2]),
216223
columns=['Group 1', 'Group 2'])
217224
msg = ('At least one group effect is estimated with less than 6 observations.')
218225
with pytest.warns(UserWarning, match=msg):
219-
gate_1 = dml_irm_obj.gate(groups_1)
226+
gate_1 = dml_irm_obj.gate(groups_1, cov_type=cov_type)
220227
assert isinstance(gate_1, dml.utils.blp.DoubleMLBLP)
221228
assert isinstance(gate_1.confint(), pd.DataFrame)
222229
assert all(gate_1.confint().index == groups_1.columns.to_list())
230+
assert gate_1.blp_model.cov_type == cov_type
223231

224232
np.random.seed(42)
225233
groups_2 = pd.DataFrame(np.random.choice(["1", "2"], n))
226234
msg = ('At least one group effect is estimated with less than 6 observations.')
227235
with pytest.warns(UserWarning, match=msg):
228-
gate_2 = dml_irm_obj.gate(groups_2)
236+
gate_2 = dml_irm_obj.gate(groups_2, cov_type=cov_type)
229237
assert isinstance(gate_2, dml.utils.blp.DoubleMLBLP)
230238
assert isinstance(gate_2.confint(), pd.DataFrame)
231239
assert all(gate_2.confint().index == ["Group_1", "Group_2"])
240+
assert gate_2.blp_model.cov_type == cov_type
232241

233242

234243
@pytest.fixture(scope='module',

doubleml/plm/plr.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_
341341

342342
return res
343343

344-
def cate(self, basis, is_gate=False):
344+
def cate(self, basis, is_gate=False, **kwargs):
345345
"""
346346
Calculate conditional average treatment effects (CATE) for a given basis.
347347
@@ -350,10 +350,14 @@ def cate(self, basis, is_gate=False):
350350
basis : :class:`pandas.DataFrame`
351351
The basis for estimating the best linear predictor. Has to have the shape ``(n_obs, d)``,
352352
where ``n_obs`` is the number of observations and ``d`` is the number of predictors.
353+
353354
is_gate : bool
354355
Indicates whether the basis is constructed for GATEs (dummy-basis).
355356
Default is ``False``.
356357
358+
**kwargs: dict
359+
Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``.
360+
357361
Returns
358362
-------
359363
model : :class:`doubleML.DoubleMLBLP`
@@ -374,10 +378,10 @@ def cate(self, basis, is_gate=False):
374378
basis=D_basis,
375379
is_gate=is_gate,
376380
)
377-
model.fit()
381+
model.fit(**kwargs)
378382
return model
379383

380-
def gate(self, groups):
384+
def gate(self, groups, **kwargs):
381385
"""
382386
Calculate group average treatment effects (GATE) for groups.
383387
@@ -388,6 +392,9 @@ def gate(self, groups):
388392
Has to be dummy coded with shape ``(n_obs, d)``, where ``n_obs`` is the number of observations
389393
and ``d`` is the number of groups or ``(n_obs, 1)`` and contain the corresponding groups (as str).
390394
395+
**kwargs: dict
396+
Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``.
397+
391398
Returns
392399
-------
393400
model : :class:`doubleML.DoubleMLBLP`
@@ -407,7 +414,7 @@ def gate(self, groups):
407414
if any(groups.sum(0) <= 5):
408415
warnings.warn('At least one group effect is estimated with less than 6 observations.')
409416

410-
model = self.cate(groups, is_gate=True)
417+
model = self.cate(groups, is_gate=True, **kwargs)
411418
return model
412419

413420
def _partial_out(self):

doubleml/plm/tests/test_plr.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,14 @@ def test_dml_plr_ols_manual_boot(dml_plr_ols_manual_fixture):
301301
rtol=1e-9, atol=1e-4)
302302

303303

304+
@pytest.fixture(scope='module',
305+
params=["nonrobust", "HC0", "HC1", "HC2", "HC3"])
306+
def cov_type(request):
307+
return request.param
308+
309+
304310
@pytest.mark.ci
305-
def test_dml_plr_cate_gate(score):
311+
def test_dml_plr_cate_gate(score, cov_type):
306312
n = 9
307313

308314
# collect data
@@ -318,26 +324,29 @@ def test_dml_plr_cate_gate(score):
318324
score=score)
319325
dml_plr_obj.fit()
320326
random_basis = pd.DataFrame(np.random.normal(0, 1, size=(n, 5)))
321-
cate = dml_plr_obj.cate(random_basis)
327+
cate = dml_plr_obj.cate(random_basis, cov_type=cov_type)
322328
assert isinstance(cate, dml.DoubleMLBLP)
323329
assert isinstance(cate.confint(), pd.DataFrame)
330+
assert cate.blp_model.cov_type == cov_type
324331

325332
groups_1 = pd.DataFrame(
326333
np.column_stack([obj_dml_data.data['X1'] <= 0,
327334
obj_dml_data.data['X1'] > 0.2]),
328335
columns=['Group 1', 'Group 2'])
329336
msg = ('At least one group effect is estimated with less than 6 observations.')
330337
with pytest.warns(UserWarning, match=msg):
331-
gate_1 = dml_plr_obj.gate(groups_1)
338+
gate_1 = dml_plr_obj.gate(groups_1, cov_type=cov_type)
332339
assert isinstance(gate_1, dml.utils.blp.DoubleMLBLP)
333340
assert isinstance(gate_1.confint(), pd.DataFrame)
334341
assert all(gate_1.confint().index == groups_1.columns.tolist())
342+
assert gate_1.blp_model.cov_type == cov_type
335343

336344
np.random.seed(42)
337345
groups_2 = pd.DataFrame(np.random.choice(["1", "2"], n))
338346
msg = ('At least one group effect is estimated with less than 6 observations.')
339347
with pytest.warns(UserWarning, match=msg):
340-
gate_2 = dml_plr_obj.gate(groups_2)
348+
gate_2 = dml_plr_obj.gate(groups_2, cov_type=cov_type)
341349
assert isinstance(gate_2, dml.utils.blp.DoubleMLBLP)
342350
assert isinstance(gate_2.confint(), pd.DataFrame)
343351
assert all(gate_2.confint().index == ["Group_1", "Group_2"])
352+
assert gate_2.blp_model.cov_type == cov_type

doubleml/utils/blp.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,18 +110,27 @@ def summary(self):
110110
columns=col_names)
111111
return df_summary
112112

113-
def fit(self):
113+
def fit(self, cov_type='HC0', **kwargs):
114114
"""
115115
Estimate DoubleMLBLP models.
116116
117+
Parameters
118+
----------
119+
cov_type : str
120+
The covariance type to be used in the estimation. Default is ``'HC0'``.
121+
See :meth:`statsmodels.regression.linear_model.OLS.fit` for more information.
122+
123+
**kwargs: dict
124+
Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit`.
125+
117126
Returns
118127
-------
119128
self : object
120129
"""
121130

122131
# fit the best-linear-predictor of the orthogonal signal with respect to the grid
123-
self._blp_model = sm.OLS(self._orth_signal, self._basis).fit()
124-
self._blp_omega = self._blp_model.cov_HC0
132+
self._blp_model = sm.OLS(self._orth_signal, self._basis).fit(cov_type=cov_type, **kwargs)
133+
self._blp_omega = self._blp_model.cov_params().to_numpy()
125134

126135
return self
127136

doubleml/utils/tests/_utils_blp_manual.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import pandas as pd
66

77

8-
def fit_blp(orth_signal, basis):
9-
blp_model = sm.OLS(orth_signal, basis).fit()
8+
def fit_blp(orth_signal, basis, cov_type, **kwargs):
9+
blp_model = sm.OLS(orth_signal, basis).fit(cov_type=cov_type, **kwargs)
1010

1111
return blp_model
1212

@@ -15,7 +15,7 @@ def blp_confint(blp_model, basis, joint=False, level=0.95, n_rep_boot=500):
1515
alpha = 1 - level
1616
g_hat = blp_model.predict(basis)
1717

18-
blp_omega = blp_model.cov_HC0
18+
blp_omega = blp_model.cov_params().to_numpy()
1919

2020
blp_se = np.sqrt((basis.dot(blp_omega) * basis).sum(axis=1))
2121

0 commit comments

Comments
 (0)