From 3d8df3ac1222cb697c77c5727a118e2d88a813df Mon Sep 17 00:00:00 2001 From: Chang She Date: Fri, 15 Jun 2012 11:39:19 -0400 Subject: [PATCH 1/2] ENH: overwrite keyword in DataFrame.update --- pandas/core/frame.py | 11 +++++++++-- pandas/tests/test_frame.py | 17 +++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 7a575abc31a0e..baa48e290e4a2 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -3143,7 +3143,7 @@ def combine_first(self, other): combiner = lambda x, y: np.where(isnull(x), y, x) return self.combine(other, combiner) - def update(self, other, join='left'): + def update(self, other, join='left', overwrite=True): """ Modify DataFrame in place using non-NA values from passed DataFrame. Aligns on indices @@ -3152,6 +3152,9 @@ def update(self, other, join='left'): ---------- other : DataFrame join : {'left', 'right', 'outer', 'inner'}, default 'left' + overwrite : boolean, default True + If True then overwrite values for common keys in the calling + frame """ if join != 'left': raise NotImplementedError @@ -3160,7 +3163,11 @@ def update(self, other, join='left'): for col in self.columns: this = self[col].values that = other[col].values - self[col] = np.where(isnull(that), this, that) + if overwrite: + mask = isnull(that) + else: + mask = notnull(this) + self[col] = np.where(mask, this, that) #---------------------------------------------------------------------- # Misc methods diff --git a/pandas/tests/test_frame.py b/pandas/tests/test_frame.py index d9081eb4069cd..d61ce7df02a1d 100644 --- a/pandas/tests/test_frame.py +++ b/pandas/tests/test_frame.py @@ -5235,6 +5235,23 @@ def test_update(self): [1.5, nan, 7.]]) assert_frame_equal(df, expected) + def test_update_nooverwrite(self): + df = DataFrame([[1.5, nan, 3.], + [1.5, nan, 3.], + [1.5, nan, 3], + [1.5, nan, 3]]) + + other = DataFrame([[3.6, 2., np.nan], + [np.nan, np.nan, 7]], index=[1, 3]) + + df.update(other, overwrite=False) + + expected = DataFrame([[1.5, nan, 3], + [1.5, 2, 3], + [1.5, nan, 3], + [1.5, nan, 3.]]) + assert_frame_equal(df, expected) + def test_combineAdd(self): # trivial comb = self.frame.combineAdd(self.frame) From 8130f096f3678feb7ae43047d3ebba9b5322aaaa Mon Sep 17 00:00:00 2001 From: Chang She Date: Fri, 15 Jun 2012 11:52:26 -0400 Subject: [PATCH 2/2] ENH: filter_func keyword to DataFrame.update, related to #1477 --- pandas/core/frame.py | 17 +++++++++++------ pandas/tests/test_frame.py | 18 ++++++++++++++++++ 2 files changed, 29 insertions(+), 6 deletions(-) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index baa48e290e4a2..8b5a128a75530 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -3143,7 +3143,7 @@ def combine_first(self, other): combiner = lambda x, y: np.where(isnull(x), y, x) return self.combine(other, combiner) - def update(self, other, join='left', overwrite=True): + def update(self, other, join='left', overwrite=True, filter_func=None): """ Modify DataFrame in place using non-NA values from passed DataFrame. Aligns on indices @@ -3153,8 +3153,10 @@ def update(self, other, join='left', overwrite=True): other : DataFrame join : {'left', 'right', 'outer', 'inner'}, default 'left' overwrite : boolean, default True - If True then overwrite values for common keys in the calling - frame + If True then overwrite values for common keys in the calling frame + filter_func : callable(1d-array) -> 1d-array, default None + Can choose to replace values other than NA. Return True for values + that should be updated """ if join != 'left': raise NotImplementedError @@ -3163,10 +3165,13 @@ def update(self, other, join='left', overwrite=True): for col in self.columns: this = self[col].values that = other[col].values - if overwrite: - mask = isnull(that) + if filter_func is not None: + mask = -filter_func(this) | isnull(that) else: - mask = notnull(this) + if overwrite: + mask = isnull(that) + else: + mask = notnull(this) self[col] = np.where(mask, this, that) #---------------------------------------------------------------------- diff --git a/pandas/tests/test_frame.py b/pandas/tests/test_frame.py index d61ce7df02a1d..669ce4e91b9dd 100644 --- a/pandas/tests/test_frame.py +++ b/pandas/tests/test_frame.py @@ -5252,6 +5252,24 @@ def test_update_nooverwrite(self): [1.5, nan, 3.]]) assert_frame_equal(df, expected) + def test_update_filtered(self): + df = DataFrame([[1.5, nan, 3.], + [1.5, nan, 3.], + [1.5, nan, 3], + [1.5, nan, 3]]) + + other = DataFrame([[3.6, 2., np.nan], + [np.nan, np.nan, 7]], index=[1, 3]) + + df.update(other, filter_func=lambda x: x > 2) + + expected = DataFrame([[1.5, nan, 3], + [1.5, nan, 3], + [1.5, nan, 3], + [1.5, nan, 7.]]) + assert_frame_equal(df, expected) + + def test_combineAdd(self): # trivial comb = self.frame.combineAdd(self.frame)