Skip to content

Commit

Permalink
#76 DiD tests now pass
Browse files Browse the repository at this point in the history
  • Loading branch information
drbenvincent committed Dec 25, 2022
1 parent a0faff9 commit 0ab2fcf
Show file tree
Hide file tree
Showing 7 changed files with 294 additions and 153 deletions.
82 changes: 41 additions & 41 deletions causalpy/data/did.csv
Original file line number Diff line number Diff line change
@@ -1,41 +1,41 @@
group,t,unit,treated,y
0,0.0,0,0,1.037235444367556
0,1.0,0,0,2.1803326054240513
1,0.0,1,0,1.1815211596102946
1,1.0,1,1,2.5731948057471734
0,0.0,2,0,1.237781412485492
0,1.0,2,0,2.064583683807223
1,0.0,3,0,1.186896528144606
1,1.0,3,1,2.7215532618312173
0,0.0,4,0,1.0649519874697355
0,1.0,4,0,1.9612022680643093
1,0.0,5,0,1.2657299075634194
1,1.0,5,1,2.5508204631468674
0,0.0,6,0,0.8947560664459198
0,1.0,6,0,2.227724135358723
1,0.0,7,0,1.3074586207263057
1,1.0,7,1,2.6021177943564844
0,0.0,8,0,1.1845042721745236
0,1.0,8,0,2.1371357945762255
1,0.0,9,0,1.277659512523703
1,1.0,9,1,2.7971363729134455
0,0.0,10,0,0.948046520978673
0,1.0,10,0,1.9911586181231065
1,0.0,11,0,1.2956793345692803
1,1.0,11,1,2.714212580309264
0,0.0,12,0,1.0840699944593897
0,1.0,12,0,1.9949161598698812
1,0.0,13,0,1.279213688044527
1,1.0,13,1,2.781563007268219
0,0.0,14,0,0.9987011891791635
0,1.0,14,0,1.8914366349764102
1,0.0,15,0,1.2112578927664674
1,1.0,15,1,2.7420363802422196
0,0.0,16,0,0.993752136853551
0,1.0,16,0,2.272692180324228
1,0.0,17,0,1.1786513493076058
1,1.0,17,1,2.69965381847017
0,0.0,18,0,1.0980883419399656
0,1.0,18,0,1.9685015295514094
1,0.0,19,0,1.3616585803269048
1,1.0,19,1,2.591156615919988
group,t,unit,post_treatment,y
0,0.0,0,False,0.897122432901507
0,1.0,0,True,1.9612135788421983
1,0.0,1,False,1.2335249009813691
1,1.0,1,True,2.7527941327437286
0,0.0,2,False,1.149207391077308
0,1.0,2,True,1.9107194958946412
1,0.0,3,False,1.2096028435304764
1,1.0,3,True,2.7870530562317772
0,0.0,4,False,1.0182211686591378
0,1.0,4,True,2.1355782741951903
1,0.0,5,False,1.2566023467285772
1,1.0,5,True,2.6352164140993417
0,0.0,6,False,1.1206312917156163
0,1.0,6,True,2.0293786635661104
1,0.0,7,False,1.2253914316635341
1,1.0,7,True,2.836234979171606
0,0.0,8,False,1.0937901142584816
0,1.0,8,True,2.0046646527573992
1,0.0,9,False,1.1311676279399658
1,1.0,9,True,2.597416938762001
0,0.0,10,False,1.1338268148431594
0,1.0,10,True,2.0396150424632604
1,0.0,11,False,1.2769574784336464
1,1.0,11,True,2.7237901979669057
0,0.0,12,False,1.0548219817786735
0,1.0,12,True,2.0966644540989554
1,0.0,13,False,1.2941834769826859
1,1.0,13,True,2.828746461772019
0,0.0,14,False,1.0011352011986534
0,1.0,14,True,2.2367233120727237
1,0.0,15,False,1.2621457689408864
1,1.0,15,True,2.737756363134591
0,0.0,16,False,1.0613566957247114
0,1.0,16,True,2.105012700050028
1,0.0,17,False,1.228130146156384
1,1.0,17,True,2.6887857541638813
0,0.0,18,False,1.2259823349004162
0,1.0,18,True,2.097530059810398
1,0.0,19,False,1.263074342393256
1,1.0,19,True,2.697326984058356
13 changes: 6 additions & 7 deletions causalpy/data/simulate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,17 +154,16 @@ def generate_did():
intervention_time = 0.5

# local functions
def outcome(t, control_intercept, treat_intercept_delta, trend, Δ, group, treated):
def outcome(
t, control_intercept, treat_intercept_delta, trend, Δ, group, post_treatment
):
return (
control_intercept
+ (treat_intercept_delta * group)
+ (t * trend)
+ (Δ * treated * group)
+ (Δ * post_treatment * group)
)

def _is_treated(t, intervention_time, group):
return (t > intervention_time) * group

df = pd.DataFrame(
{
"group": [0, 0, 1, 1] * 10,
Expand All @@ -173,7 +172,7 @@ def _is_treated(t, intervention_time, group):
}
)

df["treated"] = _is_treated(df["t"], intervention_time, df["group"])
df["post_treatment"] = df["t"] > intervention_time

df["y"] = outcome(
df["t"],
Expand All @@ -182,7 +181,7 @@ def _is_treated(t, intervention_time, group):
trend,
Δ,
df["group"],
df["treated"],
df["post_treatment"],
)
df["y"] += rng.normal(0, 0.1, df.shape[0])
return df
Expand Down
9 changes: 6 additions & 3 deletions causalpy/pymc_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,29 +298,31 @@ def __init__(
self.x_pred_control = (
self.data
# just the untreated group
.query(f"district == '{self.untreated}'")
.query(f"{self.group_variable_name} == @self.untreated") # 🔥
# drop the outcome variable
.drop(self.outcome_variable_name, axis=1)
)
assert not self.x_pred_control.empty
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_control)
self.y_pred_control = self.prediction_model.predict(np.asarray(new_x))

# predicted outcome for treatment group
self.x_pred_treatment = (
self.data
# just the treated group
.query(f"district == '{self.treated}'")
.query(f"{self.group_variable_name} == @self.treated") # 🔥
# drop the outcome variable
.drop(self.outcome_variable_name, axis=1)
)
assert not self.x_pred_treatment.empty
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_treatment)
self.y_pred_treatment = self.prediction_model.predict(np.asarray(new_x))

# predicted outcome for counterfactual
self.x_pred_counterfactual = (
self.data
# just the treated group
.query(f"district == '{self.treated}'")
.query(f"{self.group_variable_name} == @self.treated") # 🔥
# just the treatment period(s)
# TODO: the line below might need some work to be more robust
.query("post_treatment == True")
Expand All @@ -329,6 +331,7 @@ def __init__(
# DO AN INTERVENTION. Set the post_treatment variable to False
.assign(post_treatment=False)
)
assert not self.x_pred_counterfactual.empty
(new_x,) = build_design_matrices(
[self._x_design_info], self.x_pred_counterfactual
)
Expand Down
7 changes: 4 additions & 3 deletions causalpy/tests/test_integration_pymc_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_did():
df = cp.load_data("did")
result = cp.pymc_experiments.DifferenceInDifferences(
df,
formula="y ~ 1 + group + t + treated:group",
formula="y ~ 1 + group + t + group:post_treatment",
time_variable_name="t",
group_variable_name="group",
treated=1,
Expand All @@ -26,6 +26,7 @@ def test_did():

@pytest.mark.integration
def test_did_banks():
treatment_time = 1930.5
df = (
cp.load_data("banks")
.filter(items=["bib6", "bib8", "year"])
Expand All @@ -43,10 +44,10 @@ def test_did_banks():
).sort_values("year")
df_long["district"] = df_long["district"].astype("category")
df_long["unit"] = df_long["district"]
df_long["treated"] = (df_long.year >= 1931) & (df_long.district == "Sixth District")
df_long["post_treatment"] = df_long.year >= treatment_time
result = cp.pymc_experiments.DifferenceInDifferences(
df_long[df_long.year.isin([1930, 1931])],
formula="bib ~ 1 + district + year + district:treated",
formula="bib ~ 1 + district + year + district:post_treatment",
time_variable_name="year",
group_variable_name="district",
treated="Sixth District",
Expand Down
81 changes: 46 additions & 35 deletions docs/notebooks/did_pymc.ipynb

Large diffs are not rendered by default.

249 changes: 188 additions & 61 deletions docs/notebooks/did_pymc_banks.ipynb

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions img/interrogate_badge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 0ab2fcf

Please sign in to comment.