-
Notifications
You must be signed in to change notification settings - Fork 730
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Several minor improvements #804
Conversation
kbattocchi
commented
Aug 4, 2023
•
edited
Loading
edited
- Support direct covariance fitting for DRIV
- Ensure that groups can be passed to DMLIV and DRIV
- Dependency cleanup:
- Enable newer versions of shap, matplotlib, seaborn, and dowhy
- Drop support for sklearn<1.0 and enable support for sklearn 1.3
- CI improvements:
- Run doctests as part of build
- Don't fail fast when building packages fails on one platform
- Store test output in an artifact
c51ae9a
to
7e6ae64
Compare
ffbc253
to
063fc6a
Compare
063fc6a
to
0fff7da
Compare
4f19d3e
to
8975714
Compare
8975714
to
ad295a9
Compare
econml/iv/dml/_dml.py
Outdated
@@ -526,7 +526,7 @@ def score(self, Y, T, Z, X=None, W=None, sample_weight=None): | |||
The MSE of the final CATE model on the new data. | |||
""" | |||
# Replacing score from _OrthoLearner, to enforce Z to be required and improve the docstring | |||
return super().score(Y, T, X=X, W=W, Z=Z, sample_weight=sample_weight) | |||
return super().score(Y, T, X=X, W=W, Z=Z, sample_weight=sample_weight, groups=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should it be groups=groups here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good catch. (It doesn't affect the results since groups are never used in scoring, but I'll fix it in the next set of changes).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Upon further consideration, I've removed groups
from the DMLIV and DRIV scoring methods, because they are never used and so there's no point in including them.
The groups
argument needs to exist on the nuisance models, because the signatures for fit, predict, and score all need to be compatible for how we do cross-fitting, but there's no need for them to pollute the estimators themselves, and indeed our existing classes like LinearDML do not have groups on their scoring methods.
@@ -318,7 +318,7 @@ def predict(self, X=None): | |||
X = self._transform_X(X, fitting=False) | |||
return self._model_final.predict(X).reshape((-1,) + self.d_y + self.d_t) | |||
|
|||
def score(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None): | |||
def score(self, Y, T, X=None, W=None, Z=None, nuisances=None, sample_weight=None, groups=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
groups=groups?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the method definition, so groups=None is correct.
econml/tests/test_dmliv.py
Outdated
] | ||
|
||
for est in est_list: | ||
est.fit(y, T, Z=Z, X=X, W=W, groups=groups) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a way to make sure the groups are actually being used here? To avoid problems like when groups is accidentally left as None in the call to super().score() instead of threaded through from the args.
ad295a9
to
e42ea65
Compare
econml/iv/dml/_dml.py
Outdated
@@ -526,7 +526,7 @@ def score(self, Y, T, Z, X=None, W=None, sample_weight=None): | |||
The MSE of the final CATE model on the new data. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor since groups aren't really used for scoring but they are not included in the docstring as parameters
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned in a previous comment, removed groups from scoring on the estimator since they do nothing
econml/iv/dml/_dml.py
Outdated
@@ -837,7 +837,7 @@ def fit(self, Y, T, *, Z, X=None, W=None, sample_weight=None, freq_weight=None, | |||
sample_weight=sample_weight, freq_weight=freq_weight, sample_var=sample_var, groups=groups, | |||
cache_values=cache_values, inference=inference) | |||
|
|||
def score(self, Y, T, Z, X=None, W=None, sample_weight=None): | |||
def score(self, Y, T, Z, X=None, W=None, sample_weight=None, groups=None): | |||
""" | |||
Score the fitted CATE model on a new data set. Generates nuisance parameters | |||
for the new data set based on the fitted residual nuisance models created at fit time. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
groups missing from docstring
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned, removed groups from this method.
@@ -1151,7 +1151,7 @@ def test_groups(self): | |||
est.fit(y, t, groups=groups) | |||
|
|||
# test outer grouping | |||
est = LinearDML(model_y=LinearRegression(), model_t=LinearRegression(), cv=GroupKFold(2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth adding some check to verify that a GroupKFold splitter was used under the hood?
for est in ests_list: | ||
with self.subTest(est=est): | ||
# no heterogeneity |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor question but is there a benefit to moving this inside the for loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test passes :-). The default is for fit_cov_directly to be True, which means that the previous random seed doesn't generate identical results to what they were before, which lead to a marginal failure on this test, but just slightly reorganizing it made it pass again.
Logically, I think this makes more sense anyway: it's weird to have different loops creating two sets of identical subtests that test different things; if you run the tests locally via unittest
you'll see one result per subtest but there won't be any way to tell which was which.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The diff is hard to parse here for some reason even though the actual changes are minimal just like the econml+dowhy version of the notebook.
Not sure why. Different jupyter version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was indeed very weird; fixed.
@@ -793,9 +793,12 @@ def test_groups(self): | |||
est.fit(y, t, W=w, groups=groups) | |||
|
|||
# test outer grouping | |||
# NOTE: we should ideally use a stratified split with grouping, but sklearn doesn't have one yet | |||
# NOTE: StratifiedGroupKFold has a bug when shuffle is True where it doesn't always stratify properly |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this bug worth worrying about for our users since we use crossfit uses StratifiedGroupKFold with shuffle=True?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hopefully it will be fixed in sklearn and then it will have the right behavior, but until then it's possible that users can run into it (although the buggy behavior only occurs with certain datasets, so hopefully it works most of the time).
However I don't think there's any good fix on our end - in general we do want to shuffle, it's just that for the purposes of this one test we can ignore that, but it wouldn't be an appropriate substitute in general.
7229fef
to
e391be2
Compare
d6aa09e
to
ff63e62
Compare
ff63e62
to
86f7bcb
Compare
@@ -1151,7 +1151,7 @@ def test_groups(self): | |||
est.fit(y, t, groups=groups) | |||
|
|||
# test outer grouping | |||
est = LinearDML(model_y=LinearRegression(), model_t=LinearRegression(), cv=GroupKFold(2)) | |||
est = LinearDML(model_y=LinearRegression(), model_t=LinearRegression()) | |||
est.fit(y, t, groups=groups) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert isinstance(est.splitter, GroupKFold) |
What about adding something like this. Just to protect against the case where groups isn't actually used under the hood.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Though seems like currently we don't save the splitter to our ests as an attribute
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
86f7bcb
to
2496e32
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good
2496e32
to
f3f1d90
Compare
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
Signed-off-by: Keith Battocchi <kebatt@microsoft.com>
f3f1d90
to
0be5bf4
Compare