Skip to content

Commit 2287e65

Browse files
authored
Merge pull request #250 from DoubleML/s-add-irm-apo
Add Average potential outcome model
2 parents 20d9864 + 56ce8cc commit 2287e65

39 files changed

+4112
-332
lines changed

doubleml/__init__.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from .plm.plr import DoubleMLPLR
66
from .plm.pliv import DoubleMLPLIV
77
from .irm.irm import DoubleMLIRM
8+
from .irm.apo import DoubleMLAPO
9+
from .irm.apos import DoubleMLAPOS
810
from .irm.iivm import DoubleMLIIVM
911
from .double_ml_data import DoubleMLData, DoubleMLClusterData
1012
from .did.did import DoubleMLDID
@@ -18,22 +20,26 @@
1820
from .utils.blp import DoubleMLBLP
1921
from .utils.policytree import DoubleMLPolicyTree
2022

21-
__all__ = ['concat',
22-
'DoubleMLFramework',
23-
'DoubleMLPLR',
24-
'DoubleMLPLIV',
25-
'DoubleMLIRM',
26-
'DoubleMLIIVM',
27-
'DoubleMLData',
28-
'DoubleMLClusterData',
29-
'DoubleMLDID',
30-
'DoubleMLDIDCS',
31-
'DoubleMLPQ',
32-
'DoubleMLQTE',
33-
'DoubleMLLPQ',
34-
'DoubleMLCVAR',
35-
'DoubleMLBLP',
36-
'DoubleMLPolicyTree',
37-
'DoubleMLSSM']
23+
__all__ = [
24+
'concat',
25+
'DoubleMLFramework',
26+
'DoubleMLPLR',
27+
'DoubleMLPLIV',
28+
'DoubleMLIRM',
29+
'DoubleMLAPO',
30+
'DoubleMLAPOS',
31+
'DoubleMLIIVM',
32+
'DoubleMLData',
33+
'DoubleMLClusterData',
34+
'DoubleMLDID',
35+
'DoubleMLDIDCS',
36+
'DoubleMLPQ',
37+
'DoubleMLQTE',
38+
'DoubleMLLPQ',
39+
'DoubleMLCVAR',
40+
'DoubleMLBLP',
41+
'DoubleMLPolicyTree',
42+
'DoubleMLSSM'
43+
]
3844

3945
__version__ = importlib.metadata.version('doubleml')

doubleml/datasets.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1485,3 +1485,162 @@ def make_ssm_data(n_obs=8000, dim_x=100, theta=1, mar=True, return_type='DoubleM
14851485
return DoubleMLData(data, 'y', 'd', x_cols, 'z', None, 's')
14861486
else:
14871487
raise ValueError('Invalid return_type.')
1488+
1489+
1490+
def make_irm_data_discrete_treatments(n_obs=200, n_levels=3, linear=False, random_state=None, **kwargs):
1491+
"""
1492+
Generates data from a interactive regression (IRM) model with multiple treatment levels (based on an
1493+
underlying continous treatment).
1494+
1495+
The data generating process is defined as follows (similar to the Monte Carlo simulation used
1496+
in Sant'Anna and Zhao (2020)).
1497+
1498+
Let :math:`X= (X_1, X_2, X_3, X_4, X_5)^T \\sim \\mathcal{N}(0, \\Sigma)`, where :math:`\\Sigma` corresponds
1499+
to the identity matrix.
1500+
Further, define :math:`Z_j = (\\tilde{Z_j} - \\mathbb{E}[\\tilde{Z}_j]) / \\sqrt{\\text{Var}(\\tilde{Z}_j)}`,
1501+
where
1502+
1503+
.. math::
1504+
1505+
\\tilde{Z}_1 &= \\exp(0.5 \\cdot X_1)
1506+
1507+
\\tilde{Z}_2 &= 10 + X_2/(1 + \\exp(X_1))
1508+
1509+
\\tilde{Z}_3 &= (0.6 + X_1 \\cdot X_3 / 25)^3
1510+
1511+
\\tilde{Z}_4 &= (20 + X_2 + X_4)^2
1512+
1513+
\\tilde{Z}_5 &= X_5.
1514+
1515+
A continuous treatment :math:`D_{\\text{cont}}` is generated as
1516+
1517+
.. math::
1518+
1519+
D_{\\text{cont}} = \\xi (-Z_1 + 0.5 Z_2 - 0.25 Z_3 - 0.1 Z_4) + \\varepsilon_D,
1520+
1521+
where :math:`\\varepsilon_D \\sim \\mathcal{N}(0,1)` and :math:`\\xi=0.3`. The corresponding treatment
1522+
effect is defined as
1523+
1524+
.. math::
1525+
1526+
\\text{\\theta}(d) = 0.1 \\exp(d) + 10 \\sin(0.7 d) + 2 d - 0.2 d^2.
1527+
1528+
Based on the continous treatment, a discrete treatment :math:`D` is generated as with a baseline level of
1529+
:math:`D=0` and additional levels based on the quantiles of :math:`D_{\\text{cont}}`. The number of levels
1530+
is defined by :math:`n_{\\text{levels}}`. Each level is chosen to have the same probability of being selected.
1531+
1532+
The potential outcomes are defined as
1533+
1534+
.. math::
1535+
1536+
Y(0) &= 210 + 27.4 Z_1 + 13.7 (Z_2 + Z_3 + Z_4) + \\varepsilon_Y
1537+
1538+
Y(1) &= \\text{\\theta}(D_{\\text{cont}}) 1\\{D_{\\text{cont}} > 0\\} + Y(0),
1539+
1540+
where :math:`\\varepsilon_Y \\sim \\mathcal{N}(0,5)`. Further, the observed outcome is defined as
1541+
1542+
.. math::
1543+
1544+
Y = Y(1) 1\\{D > 0\\} + Y(0) 1\\{D = 0\\}.
1545+
1546+
The data is returned as a dictionary with the entries ``x``, ``y``, ``d`` and ``oracle_values``.
1547+
1548+
Parameters
1549+
----------
1550+
n_obs : int
1551+
The number of observations to simulate.
1552+
Default is ``200``.
1553+
1554+
n_levels : int
1555+
The number of treatment levels.
1556+
Default is ``3``.
1557+
1558+
linear : bool
1559+
Indicates whether the true underlying regression is linear.
1560+
Default is ``False``.
1561+
1562+
random_state : int
1563+
Random seed for reproducibility.
1564+
Default is ``42``.
1565+
1566+
Returns
1567+
-------
1568+
res_dict : dictionary
1569+
Dictionary with entries ``x``, ``y``, ``d`` and ``oracle_values``.
1570+
1571+
"""
1572+
if random_state is not None:
1573+
np.random.seed(random_state)
1574+
xi = kwargs.get('xi', 0.3)
1575+
c = kwargs.get('c', 0.0)
1576+
dim_x = kwargs.get('dim_x', 5)
1577+
1578+
if not isinstance(n_levels, int):
1579+
raise ValueError('n_levels must be an integer.')
1580+
if n_levels < 2:
1581+
raise ValueError('n_levels must be at least 2.')
1582+
1583+
# observed covariates
1584+
cov_mat = toeplitz([np.power(c, k) for k in range(dim_x)])
1585+
x = np.random.multivariate_normal(np.zeros(dim_x), cov_mat, size=[n_obs, ])
1586+
1587+
def f_reg(w):
1588+
res = 210 + 27.4*w[:, 0] + 13.7*(w[:, 1] + w[:, 2] + w[:, 3])
1589+
return res
1590+
1591+
def f_treatment(w, xi):
1592+
res = xi * (-w[:, 0] + 0.5*w[:, 1] - 0.25*w[:, 2] - 0.1*w[:, 3])
1593+
return res
1594+
1595+
def treatment_effect(d, scale=15):
1596+
return scale * (1 / (1 + np.exp(-d - 1.2 * np.cos(d)))) - 2
1597+
1598+
z_tilde_1 = np.exp(0.5 * x[:, 0])
1599+
z_tilde_2 = 10 + x[:, 1] / (1 + np.exp(x[:, 0]))
1600+
z_tilde_3 = (0.6 + x[:, 0] * x[:, 2]/25)**3
1601+
z_tilde_4 = (20 + x[:, 1] + x[:, 3])**2
1602+
1603+
z_tilde = np.column_stack((z_tilde_1, z_tilde_2, z_tilde_3, z_tilde_4, x[:, 4:]))
1604+
z = (z_tilde - np.mean(z_tilde, axis=0)) / np.std(z_tilde, axis=0)
1605+
1606+
# error terms
1607+
var_eps_y = 5
1608+
eps_y = np.random.normal(loc=0, scale=np.sqrt(var_eps_y), size=n_obs)
1609+
var_eps_d = 1
1610+
eps_d = np.random.normal(loc=0, scale=np.sqrt(var_eps_d), size=n_obs)
1611+
1612+
if linear:
1613+
g = f_reg(x)
1614+
m = f_treatment(x, xi)
1615+
else:
1616+
assert not linear
1617+
g = f_reg(z)
1618+
m = f_treatment(z, xi)
1619+
1620+
cont_d = m + eps_d
1621+
level_bounds = np.quantile(cont_d, q=np.linspace(0, 1, n_levels + 1))
1622+
potential_level = sum([1.0 * (cont_d >= bound) for bound in level_bounds[1:-1]]) + 1
1623+
eta = np.random.uniform(0, 1, size=n_obs)
1624+
d = 1.0 * (eta >= 1/n_levels) * potential_level
1625+
1626+
ite = treatment_effect(cont_d)
1627+
y0 = g + eps_y
1628+
# only treated for d > 0 compared to the baseline
1629+
y = ite * (d > 0) + y0
1630+
1631+
oracle_values = {
1632+
'cont_d': cont_d,
1633+
'level_bounds': level_bounds,
1634+
'potential_level': potential_level,
1635+
'ite': ite,
1636+
'y0': y0,
1637+
}
1638+
1639+
resul_dict = {
1640+
'x': x,
1641+
'y': y,
1642+
'd': d,
1643+
'oracle_values': oracle_values
1644+
}
1645+
1646+
return resul_dict

doubleml/double_ml.py

Lines changed: 6 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,6 +1219,8 @@ def set_sample_splitting(self, all_smpls, all_smpls_cluster=None):
12191219
>>> ml_m = learner
12201220
>>> obj_dml_data = make_plr_CCDDHNR2018(n_obs=10, alpha=0.5)
12211221
>>> dml_plr_obj = dml.DoubleMLPLR(obj_dml_data, ml_g, ml_m)
1222+
>>> # simple sample splitting with two folds and without cross-fitting
1223+
>>> smpls = ([0, 1, 2, 3, 4], [5, 6, 7, 8, 9])
12221224
>>> dml_plr_obj.set_sample_splitting(smpls)
12231225
>>> # sample splitting with two folds and cross-fitting
12241226
>>> smpls = [([0, 1, 2, 3, 4], [5, 6, 7, 8, 9]),
@@ -1434,44 +1436,11 @@ def sensitivity_summary(self):
14341436
res : str
14351437
Summary for the sensitivity analysis.
14361438
"""
1437-
header = '================== Sensitivity Analysis ==================\n'
1438-
if self.sensitivity_params is None:
1439-
res = header + 'Apply sensitivity_analysis() to generate sensitivity_summary.'
1439+
if self._framework is None:
1440+
raise ValueError('Apply sensitivity_analysis() before sensitivity_summary.')
14401441
else:
1441-
sig_level = f'Significance Level: level={self.sensitivity_params["input"]["level"]}\n'
1442-
scenario_params = f'Sensitivity parameters: cf_y={self.sensitivity_params["input"]["cf_y"]}; ' \
1443-
f'cf_d={self.sensitivity_params["input"]["cf_d"]}, ' \
1444-
f'rho={self.sensitivity_params["input"]["rho"]}'
1445-
1446-
theta_and_ci_col_names = ['CI lower', 'theta lower', ' theta', 'theta upper', 'CI upper']
1447-
theta_and_ci = np.transpose(np.vstack((self.sensitivity_params['ci']['lower'],
1448-
self.sensitivity_params['theta']['lower'],
1449-
self.coef,
1450-
self.sensitivity_params['theta']['upper'],
1451-
self.sensitivity_params['ci']['upper'])))
1452-
df_theta_and_ci = pd.DataFrame(theta_and_ci,
1453-
columns=theta_and_ci_col_names,
1454-
index=self._dml_data.d_cols)
1455-
theta_and_ci_summary = str(df_theta_and_ci)
1456-
1457-
rvs_col_names = ['H_0', 'RV (%)', 'RVa (%)']
1458-
rvs = np.transpose(np.vstack((self.sensitivity_params['rv'],
1459-
self.sensitivity_params['rva']))) * 100
1460-
1461-
df_rvs = pd.DataFrame(np.column_stack((self.sensitivity_params["input"]["null_hypothesis"], rvs)),
1462-
columns=rvs_col_names,
1463-
index=self._dml_data.d_cols)
1464-
rvs_summary = str(df_rvs)
1465-
1466-
res = header + \
1467-
'\n------------------ Scenario ------------------\n' + \
1468-
sig_level + scenario_params + '\n' + \
1469-
'\n------------------ Bounds with CI ------------------\n' + \
1470-
theta_and_ci_summary + '\n' + \
1471-
'\n------------------ Robustness Values ------------------\n' + \
1472-
rvs_summary
1473-
1474-
return res
1442+
sensitivity_summary = self._framework.sensitivity_summary
1443+
return sensitivity_summary
14751444

14761445
def sensitivity_plot(self, idx_treatment=0, value='theta', rho=1.0, level=0.95, null_hypothesis=0.0,
14771446
include_scenario=True, benchmarks=None, fill=True, grid_bounds=(0.15, 0.15), grid_size=100):

0 commit comments

Comments
 (0)