From b64bb18e164397e5ed0d968957652b5ce0441bba Mon Sep 17 00:00:00 2001 From: Michael M Danziger Date: Wed, 13 Jul 2022 17:24:14 +0300 Subject: [PATCH] limit n_neighbors to n_samples before matching this is to enable one direction of matching even when the other direction is not well-defined, in response to https://github.com/IBM/causallib/issues/37 Signed-off-by: Michael M Danziger --- causallib/estimation/matching.py | 17 ++++++++++++++--- causallib/tests/test_matching.py | 16 +++++++++++++++- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/causallib/estimation/matching.py b/causallib/estimation/matching.py index 27ab9df3..15889d8c 100644 --- a/causallib/estimation/matching.py +++ b/causallib/estimation/matching.py @@ -566,7 +566,7 @@ def _get_metric_dict( return metric_dict - def _kneighbors(self, knn, source_df): + def _kneighbors(self, knn, source_df, n_neighbors): """Lookup neighbors in knn object. Args: @@ -575,6 +575,7 @@ def _kneighbors(self, knn, source_df): original df index. source_df (pd.DataFrame) : a DataFrame of source data points to use as "needles" for the knn "haystack." + n_neighbors Returns: match_df (pd.DataFrame) : a DataFrame of matches @@ -584,7 +585,7 @@ def _kneighbors(self, knn, source_df): source_array = self._ensure_array_columnlike(source_array) distances, neighbor_array_indices = knn.learner.kneighbors( - source_array, n_neighbors=self.n_neighbors + source_array, n_neighbors=n_neighbors ) return self._generate_match_df( @@ -673,7 +674,17 @@ def _withreplacement_match(self, X, a): matches = {} # maps treatment value to list of matches TO that value for treatment_value, knn in self.treatment_knns_.items(): - matches[treatment_value] = self._kneighbors(knn, X) + n_matchable = sum(a==treatment_value) + if n_matchable < self.n_neighbors: + n_neighbors = n_matchable + warnings.warn( + f"Not enough matchable samples in treatment group {treatment_value}. " + f"Reducing `n_neighbors` for this direction to {n_neighbors}." + ) + else: + n_neighbors = self.n_neighbors + + matches[treatment_value] = self._kneighbors(knn, X, n_neighbors) # when producing potential outcomes we may want to force the # value of the observed outcome to be the actual observed # outcome, and not an average of the k nearest samples. diff --git a/causallib/tests/test_matching.py b/causallib/tests/test_matching.py index 51c4275e..77e09b7a 100644 --- a/causallib/tests/test_matching.py +++ b/causallib/tests/test_matching.py @@ -640,4 +640,18 @@ def test_is_pickleable(self): np.testing.assert_array_equal( prepickle_estimate, postpickle_estimate) np.testing.assert_array_equal( - prepickle_covariates, postpickle_covariates) \ No newline at end of file + prepickle_covariates, postpickle_covariates) + + def test_matching_one_way_works_even_when_other_is_undefined(self): + X, a, y = self.data_serial_unbalanced_x + for n_neighbors in [5, 20, 50]: + self.check_matching_too_few_neighbors_adapts_matches(n_neighbors, X, a, y) + + def check_matching_too_few_neighbors_adapts_matches(self, n_neighbors, X, a, y): + matching = Matching(n_neighbors=n_neighbors, matching_mode="treatment_to_control") + matching.fit(X,a,y) + match_df = matching.match(X, a) + n_matches_actual = match_df.distances.apply(len).groupby(level=0).max() + self.assertEqual(n_matches_actual[0], min(n_neighbors, self.n)) + self.assertEqual(n_matches_actual[1], min(n_neighbors, self.k)) + \ No newline at end of file