Skip to content

Commit

Permalink
Merge pull request #200 from pymc-labs/its-class
Browse files Browse the repository at this point in the history
Add `InterruptedTimeSeries` class
  • Loading branch information
drbenvincent authored Jun 2, 2023
2 parents b61295e + b7f023a commit ca90775
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 66 deletions.
14 changes: 10 additions & 4 deletions causalpy/pymc_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def _input_validation(self, data, treatment_time):
"If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501
)

def plot(self):
def plot(self, counterfactual_label="Counterfactual", **kwargs):
"""Plot the results"""
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))

Expand All @@ -161,7 +161,7 @@ def plot(self):
plot_hdi_kwargs={"color": "C1"},
)
handles.append((h_line, h_patch))
labels.append("Synthetic control")
labels.append(counterfactual_label)

ax[0].plot(self.datapost.index, self.post_y, "k.")
# Shaded causal effect
Expand Down Expand Up @@ -243,14 +243,20 @@ def summary(self):
self.print_coefficients()


class InterruptedTimeSeries(PrePostFit):
"""Interrupted time series analysis"""

expt_type = "Interrupted Time Series"


class SyntheticControl(PrePostFit):
"""A wrapper around the PrePostFit class"""

expt_type = "Synthetic Control"

def plot(self, plot_predictors=False):
def plot(self, plot_predictors=False, **kwargs):
"""Plot the results"""
fig, ax = super().plot()
fig, ax = super().plot(counterfactual_label="Synthetic control", **kwargs)
if plot_predictors:
# plot control units as well
ax[0].plot(self.datapre.index, self.pre_X, "-", c=[0.8, 0.8, 0.8], zorder=1)
Expand Down
16 changes: 11 additions & 5 deletions causalpy/skl_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(
# cumulative impact post
self.post_impact_cumulative = np.cumsum(self.post_impact)

def plot(self):
def plot(self, counterfactual_label="Counterfactual", **kwargs):
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))

ax[0].plot(self.datapre.index, self.pre_y, "k.")
Expand All @@ -84,7 +84,7 @@ def plot(self):
ax[0].plot(
self.datapost.index,
self.post_pred,
label="counterfactual",
label=counterfactual_label,
ls=":",
c="k",
)
Expand All @@ -95,7 +95,7 @@ def plot(self):
self.datapost.index,
self.post_impact,
"k.",
label="counterfactual",
label=counterfactual_label,
)
ax[1].axhline(y=0, c="k")
ax[1].set(title="Causal Impact")
Expand Down Expand Up @@ -151,12 +151,18 @@ def plot_coeffs(self):
)


class InterruptedTimeSeries(PrePostFit):
"""Interrupted time series analysis"""

expt_type = "Interrupted Time Series"


class SyntheticControl(PrePostFit):
"""A wrapper around the PrePostFit class"""

def plot(self, plot_predictors=False):
def plot(self, plot_predictors=False, **kwargs):
"""Plot the results"""
fig, ax = super().plot()
fig, ax = super().plot(counterfactual_label="Synthetic control", **kwargs)
if plot_predictors:
# plot control units as well
ax[0].plot(self.datapre.index, self.pre_X, "-", c=[0.8, 0.8, 0.8], zorder=1)
Expand Down
4 changes: 2 additions & 2 deletions causalpy/tests/test_integration_pymc_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,14 @@ def test_its_covid():
.set_index("date")
)
treatment_time = pd.to_datetime("2020-01-01")
result = cp.pymc_experiments.SyntheticControl(
result = cp.pymc_experiments.InterruptedTimeSeries(
df,
treatment_time,
formula="standardize(deaths) ~ 0 + standardize(t) + C(month) + standardize(temp)", # noqa E501
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
)
assert isinstance(df, pd.DataFrame)
assert isinstance(result, cp.pymc_experiments.SyntheticControl)
assert isinstance(result, cp.pymc_experiments.InterruptedTimeSeries)
assert len(result.idata.posterior.coords["chain"]) == sample_kwargs["chains"]
assert len(result.idata.posterior.coords["draw"]) == sample_kwargs["draws"]

Expand Down
Binary file modified docs/source/_static/classes.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 3 additions & 3 deletions docs/source/_static/interrogate_badge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
92 changes: 43 additions & 49 deletions docs/source/notebooks/its_covid.ipynb

Large diffs are not rendered by default.

11 changes: 8 additions & 3 deletions docs/source/notebooks/sc_skl.ipynb

Large diffs are not rendered by default.

0 comments on commit ca90775

Please sign in to comment.