44from collections .abc import Callable
55
66from scipy .stats import norm
7- from rdrobust import rdrobust , rdbwselect
87
98from sklearn .base import clone
109from sklearn .utils .multiclass import type_of_target
1312from doubleml .double_ml import DoubleML
1413from doubleml .utils .resampling import DoubleMLResampling
1514from doubleml .utils ._checks import _check_resampling_specification , _check_supports_sample_weights
15+ from doubleml .rdd ._utils import _is_rdrobust_available
16+
17+ # validate optional rdrobust import
18+ rdrobust = _is_rdrobust_available ()
1619
1720
1821class RDFlex ():
@@ -30,7 +33,7 @@ class RDFlex():
3033 defined as :math:`\\ eta_0(X) = (g_0^{+}(X) + g_0^{-}(X))/2`.
3134
3235 ml_m : classifier implementing ``fit()`` and ``predict_proba()`` or None
33- A machine learner implementing ``fit()`` and ``predict_proba()`` methods and support ``sample_weights``(e.g.
36+ A machine learner implementing ``fit()`` and ``predict_proba()`` methods and support ``sample_weights`` (e.g.
3437 :py:class:`sklearn.ensemble.RandomForestClassifier`) for the nuisance functions
3538 :math:`m_0^{\\ pm}(X) = E[D|\\ text{score}=\\ text{cutoff}^{\\ pm}, X]`. The adjustment function is then
3639 defined as :math:`\\ eta_0(X) = (m_0^{+}(X) + m_0^{-}(X))/2`.
@@ -66,17 +69,29 @@ class RDFlex():
6669 Default is ``cutoff``.
6770
6871 fs_kernel : str
69- Kernel for the first stage estimation. ``uniform``, ``triangular`` and ``epanechnikov``are supported.
72+ Kernel for the first stage estimation. ``uniform``, ``triangular`` and ``epanechnikov`` are supported.
7073 Default is ``triangular``.
7174
7275 **kwargs : kwargs
7376 Key-worded arguments that are not used within RDFlex but directly handed to rdrobust.
7477
7578 Examples
7679 --------
77-
78- Notes
79- -----
80+ >>> import numpy as np
81+ >>> import doubleml as dml
82+ >>> from doubleml.rdd.datasets import make_simple_rdd_data
83+ >>> from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
84+ >>> np.random.seed(123)
85+ >>> data_dict = make_simple_rdd_data(fuzzy=True)
86+ >>> obj_dml_data = dml.DoubleMLData.from_arrays(x=data_dict["X"], y=data_dict["Y"], d=data_dict["D"], s=data_dict["score"])
87+ >>> ml_g = RandomForestRegressor()
88+ >>> ml_m = RandomForestClassifier()
89+ >>> rdflex_obj = dml.rdd.RDFlex(obj_dml_data, ml_g, ml_m, fuzzy=True)
90+ >>> print(rdflex_obj.fit())
91+ Method Coef. S.E. t-stat P>|t| 95% CI
92+ -------------------------------------------------------------------------
93+ Conventional 0.935 0.220 4.244 2.196e-05 [0.503, 1.367]
94+ Robust - - 3.635 2.785e-04 [0.418, 1.396]
8095
8196 """
8297
@@ -112,9 +127,10 @@ def __init__(self,
112127
113128 if h_fs is None :
114129 fuzzy = self ._dml_data .d if self ._fuzzy else None
115- self ._h_fs = rdbwselect (y = obj_dml_data .y ,
116- x = self ._score ,
117- fuzzy = fuzzy ).bws .values .flatten ().max ()
130+ self ._h_fs = rdrobust .rdbwselect (
131+ y = obj_dml_data .y ,
132+ x = self ._score ,
133+ fuzzy = fuzzy ).bws .values .flatten ().max ()
118134 else :
119135 if not isinstance (h_fs , (float )):
120136 raise TypeError ("Initial bandwidth 'h_fs' has to be a float. "
@@ -437,11 +453,13 @@ def _update_weights(self):
437453
438454 def _fit_rdd (self , h = None , b = None ):
439455 if self .fuzzy :
440- rdd_res = rdrobust (y = self ._M_Y [:, self ._i_rep ], x = self ._score ,
441- fuzzy = self ._M_D [:, self ._i_rep ], h = h , b = b , ** self .kwargs )
456+ rdd_res = rdrobust .rdrobust (
457+ y = self ._M_Y [:, self ._i_rep ], x = self ._score ,
458+ fuzzy = self ._M_D [:, self ._i_rep ], h = h , b = b , ** self .kwargs )
442459 else :
443- rdd_res = rdrobust (y = self ._M_Y [:, self ._i_rep ], x = self ._score ,
444- h = h , b = b , ** self .kwargs )
460+ rdd_res = rdrobust .rdrobust (
461+ y = self ._M_Y [:, self ._i_rep ], x = self ._score ,
462+ h = h , b = b , ** self .kwargs )
445463 return rdd_res
446464
447465 def _set_coefs (self , rdd_res , h ):
0 commit comments