@@ -431,7 +431,7 @@ def _nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune, n_
431431
432432 return res
433433
434- def cate (self , basis , is_gate = False ):
434+ def cate (self , basis , is_gate = False , ** kwargs ):
435435 """
436436 Calculate conditional average treatment effects (CATE) for a given basis.
437437
@@ -440,10 +440,14 @@ def cate(self, basis, is_gate=False):
440440 basis : :class:`pandas.DataFrame`
441441 The basis for estimating the best linear predictor. Has to have the shape ``(n_obs, d)``,
442442 where ``n_obs`` is the number of observations and ``d`` is the number of predictors.
443+
443444 is_gate : bool
444445 Indicates whether the basis is constructed for GATEs (dummy-basis).
445446 Default is ``False``.
446447
448+ **kwargs: dict
449+ Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``.
450+
447451 Returns
448452 -------
449453 model : :class:`doubleML.DoubleMLBLP`
@@ -462,10 +466,10 @@ def cate(self, basis, is_gate=False):
462466 orth_signal = self .psi_elements ['psi_b' ].reshape (- 1 )
463467 # fit the best linear predictor
464468 model = DoubleMLBLP (orth_signal , basis = basis , is_gate = is_gate )
465- model .fit ()
469+ model .fit (** kwargs )
466470 return model
467471
468- def gate (self , groups ):
472+ def gate (self , groups , ** kwargs ):
469473 """
470474 Calculate group average treatment effects (GATE) for groups.
471475
@@ -476,6 +480,9 @@ def gate(self, groups):
476480 Has to be dummy coded with shape ``(n_obs, d)``, where ``n_obs`` is the number of observations
477481 and ``d`` is the number of groups or ``(n_obs, 1)`` and contain the corresponding groups (as str).
478482
483+ **kwargs: dict
484+ Additional keyword arguments to be passed to :meth:`statsmodels.regression.linear_model.OLS.fit` e.g. ``cov_type``.
485+
479486 Returns
480487 -------
481488 model : :class:`doubleML.DoubleMLBLP`
@@ -495,7 +502,7 @@ def gate(self, groups):
495502 if any (groups .sum (0 ) <= 5 ):
496503 warnings .warn ('At least one group effect is estimated with less than 6 observations.' )
497504
498- model = self .cate (groups , is_gate = True )
505+ model = self .cate (groups , is_gate = True , ** kwargs )
499506 return model
500507
501508 def policy_tree (self , features , depth = 2 , ** tree_params ):
0 commit comments