Skip to content

Commit

Permalink
#76 fix tests which strangely warn locally but fail remotely
Browse files Browse the repository at this point in the history
  • Loading branch information
drbenvincent committed Dec 26, 2022
1 parent bcf38f9 commit 0ed7240
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 17 deletions.
12 changes: 6 additions & 6 deletions causalpy/skl_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,30 +190,30 @@ def __init__(
self.y, self.X = np.asarray(y), np.asarray(X)
self.outcome_variable_name = y.design_info.column_names[0]

# TODO: `treated` is a deterministic function of group and time, so this should
# be a function rather than supplied data

# fit the model to all the data
self.prediction_model.fit(X=self.X, y=self.y)

# predicted outcome for control group
self.x_pred_control = pd.DataFrame(
{"group": [0, 0], "t": [0.0, 1.0], "treated": [0, 0]}
{"group": [0, 0], "t": [0.0, 1.0], "post_treatment": [0, 0]}
)
assert not self.x_pred_control.empty
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_control)
self.y_pred_control = self.prediction_model.predict(np.asarray(new_x))

# predicted outcome for treatment group
self.x_pred_treatment = pd.DataFrame(
{"group": [1, 1], "t": [0.0, 1.0], "treated": [0, 1]}
{"group": [1, 1], "t": [0.0, 1.0], "post_treatment": [0, 1]}
)
assert not self.x_pred_treatment.empty
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_treatment)
self.y_pred_treatment = self.prediction_model.predict(np.asarray(new_x))

# predicted outcome for counterfactual
self.x_pred_counterfactual = pd.DataFrame(
{"group": [1], "t": [1.0], "treated": [0]}
{"group": [1], "t": [1.0], "post_treatment": [0]}
)
assert not self.x_pred_counterfactual.empty
(new_x,) = build_design_matrices(
[self._x_design_info], self.x_pred_counterfactual
)
Expand Down
2 changes: 1 addition & 1 deletion causalpy/tests/test_integration_skl_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_did():
data = cp.load_data("did")
result = cp.skl_experiments.DifferenceInDifferences(
data,
formula="y ~ 1 + group + t + treated:group",
formula="y ~ 1 + group + t + group:post_treatment",
time_variable_name="t",
prediction_model=LinearRegression(),
)
Expand Down
2 changes: 1 addition & 1 deletion causalpy/tests/test_pymc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_idata_property():
df = cp.load_data("did")
result = cp.pymc_experiments.DifferenceInDifferences(
df,
formula="y ~ 1 + group + t + treated:group",
formula="y ~ 1 + group + t + group:post_treatment",
time_variable_name="t",
group_variable_name="group",
treated=1,
Expand Down
27 changes: 18 additions & 9 deletions docs/notebooks/did_skl.ipynb

Large diffs are not rendered by default.

0 comments on commit 0ed7240

Please sign in to comment.