diff --git a/python/pyspark/pandas/indexes/base.py b/python/pyspark/pandas/indexes/base.py index facedb1dc91c0..4ae39420ad38d 100644 --- a/python/pyspark/pandas/indexes/base.py +++ b/python/pyspark/pandas/indexes/base.py @@ -2666,6 +2666,93 @@ def inferred_type(self) -> str: """ return lib.infer_dtype([self.to_series().head(1).item()]) + def putmask( + self, + mask: "Index", + value: Any + ) -> "Index": + """ + Return a new Index of the values set with the mask. + + Returns + ------- + Index + + See Also + -------- + pandas.index.putmask : Changes elements of an array + based on conditional and input values. + """ + # validate mask + mask = np.asarray(mask.tolist(), dtype=bool) # type: ignore[assignment] + if mask.shape != self.values.shape: + raise ValueError("cond and data must be the same size") + + noop = not mask.any() + + if noop: + return self.copy() + + # convert the insert value, check whether it can be inserted. + converted_other = value + if value is not np.nan: + try: + converted_other = np.dtype(self.dtype).type(value) # type: ignore[arg-type] + except ValueError: + raise ValueError("The inserted value should be in the same dtype.") + + # reconstruct a new index + values = self.values.copy() + new_values = [] + if isinstance(values, np.ndarray): + for val, is_need in zip(values, mask): + if is_need: + new_values.append(converted_other) + else: + new_values.append(val) + else: + new_values.append(converted_other) + + return Index(new_values) + + def where( + self, + cond: "Index", + other: Any = np.nan + ) -> "Index": + """ + Replace values where the condition is False. + + The replacement is taken from other. + + Parameters + ---------- + cond : bool array-like with the same length as self + Condition to select the values on. + other : scalar, or array-like, default None + Replacement if the condition is False. + + Returns + ------- + pandas.Index + A copy of self with values replaced from other + where the condition is False. + + See Also + -------- + Series.where : Same method for Series. + DataFrame.where : Same method for DataFrame. + + Examples + -------- + >>> idx = ps.Index(['car', 'bike', 'train', 'tractor']) + >>> idx + Index(['car', 'bike', 'train', 'tractor'], dtype='object') + >>> idx.where(idx.isin(['car', 'train']), 'other') + Index(['car', 'other', 'train', 'other'], dtype='object') + """ + return self.putmask(~cond, other) + def __getattr__(self, item: str) -> Any: if hasattr(MissingPandasLikeIndex, item): property_or_func = getattr(MissingPandasLikeIndex, item) diff --git a/python/pyspark/pandas/missing/indexes.py b/python/pyspark/pandas/missing/indexes.py index 55050a3e8c967..e5fa407c51a07 100644 --- a/python/pyspark/pandas/missing/indexes.py +++ b/python/pyspark/pandas/missing/indexes.py @@ -55,7 +55,6 @@ class MissingPandasLikeIndex: groupby = _unsupported_function("groupby") is_ = _unsupported_function("is_") join = _unsupported_function("join") - putmask = _unsupported_function("putmask") ravel = _unsupported_function("ravel") reindex = _unsupported_function("reindex") searchsorted = _unsupported_function("searchsorted") @@ -63,7 +62,6 @@ class MissingPandasLikeIndex: slice_locs = _unsupported_function("slice_locs") sortlevel = _unsupported_function("sortlevel") to_flat_index = _unsupported_function("to_flat_index") - where = _unsupported_function("where") is_mixed = _unsupported_function("is_mixed") # Deprecated functions diff --git a/python/pyspark/pandas/tests/indexes/test_base.py b/python/pyspark/pandas/tests/indexes/test_base.py index 5379e512825b0..2d605e55d03a1 100644 --- a/python/pyspark/pandas/tests/indexes/test_base.py +++ b/python/pyspark/pandas/tests/indexes/test_base.py @@ -2511,6 +2511,36 @@ def test_drop_level(self): ): psmidx.droplevel(-3) + def test_where_putmask(self): + pidx = pd.Index([1, 2, 3, 4]) + psidx = ps.from_pandas(pidx) + + # where and putmask with default inserted value np.nan + self.assert_eq( + pidx.where(pidx > 2), + psidx.where(psidx > 2) + ) + self.assert_eq( + pidx.putmask(pidx > 2, 99), + psidx.putmask(psidx > 2, 99) + ) + + # where and putmask with isin func + self.assert_eq( + pidx.where(pidx.isin([1, 2])), + psidx.where(psidx.isin([1, 2])) + ) + self.assert_eq( + pidx.putmask(pidx.isin([1, 2]), 99), + psidx.putmask(psidx.isin([1, 2]), 99) + ) + + # negative + self.assertRaises( + ValueError, + lambda: psidx.where(psidx > 2, "True"), + ) + if __name__ == "__main__": from pyspark.pandas.tests.indexes.test_base import * # noqa: F401