Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

limit n_neighbors to n_samples before matching #38

Merged
merged 1 commit into from
Jul 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions causallib/estimation/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ def _get_metric_dict(

return metric_dict

def _kneighbors(self, knn, source_df):
def _kneighbors(self, knn, source_df, n_neighbors):
"""Lookup neighbors in knn object.

Args:
Expand All @@ -575,6 +575,7 @@ def _kneighbors(self, knn, source_df):
original df index.
source_df (pd.DataFrame) : a DataFrame of source data points to use
as "needles" for the knn "haystack."
n_neighbors

Returns:
match_df (pd.DataFrame) : a DataFrame of matches
Expand All @@ -584,7 +585,7 @@ def _kneighbors(self, knn, source_df):
source_array = self._ensure_array_columnlike(source_array)

distances, neighbor_array_indices = knn.learner.kneighbors(
source_array, n_neighbors=self.n_neighbors
source_array, n_neighbors=n_neighbors
)

return self._generate_match_df(
Expand Down Expand Up @@ -673,7 +674,17 @@ def _withreplacement_match(self, X, a):
matches = {} # maps treatment value to list of matches TO that value

for treatment_value, knn in self.treatment_knns_.items():
matches[treatment_value] = self._kneighbors(knn, X)
n_matchable = sum(a==treatment_value)
if n_matchable < self.n_neighbors:
n_neighbors = n_matchable
warnings.warn(
f"Not enough matchable samples in treatment group {treatment_value}. "
f"Reducing `n_neighbors` for this direction to {n_neighbors}."
)
else:
n_neighbors = self.n_neighbors

matches[treatment_value] = self._kneighbors(knn, X, n_neighbors)
# when producing potential outcomes we may want to force the
# value of the observed outcome to be the actual observed
# outcome, and not an average of the k nearest samples.
Expand Down
16 changes: 15 additions & 1 deletion causallib/tests/test_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,4 +640,18 @@ def test_is_pickleable(self):
np.testing.assert_array_equal(
prepickle_estimate, postpickle_estimate)
np.testing.assert_array_equal(
prepickle_covariates, postpickle_covariates)
prepickle_covariates, postpickle_covariates)

def test_matching_one_way_works_even_when_other_is_undefined(self):
X, a, y = self.data_serial_unbalanced_x
for n_neighbors in [5, 20, 50]:
self.check_matching_too_few_neighbors_adapts_matches(n_neighbors, X, a, y)

def check_matching_too_few_neighbors_adapts_matches(self, n_neighbors, X, a, y):
matching = Matching(n_neighbors=n_neighbors, matching_mode="treatment_to_control")
matching.fit(X,a,y)
match_df = matching.match(X, a)
n_matches_actual = match_df.distances.apply(len).groupby(level=0).max()
self.assertEqual(n_matches_actual[0], min(n_neighbors, self.n))
self.assertEqual(n_matches_actual[1], min(n_neighbors, self.k))