Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions python/pyspark/pandas/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2666,6 +2666,93 @@ def inferred_type(self) -> str:
"""
return lib.infer_dtype([self.to_series().head(1).item()])

def putmask(
self,
mask: "Index",
value: Any
) -> "Index":
"""
Return a new Index of the values set with the mask.

Returns
-------
Index

See Also
--------
pandas.index.putmask : Changes elements of an array
based on conditional and input values.
"""
# validate mask
mask = np.asarray(mask.tolist(), dtype=bool) # type: ignore[assignment]
if mask.shape != self.values.shape:
raise ValueError("cond and data must be the same size")

noop = not mask.any()

if noop:
return self.copy()

# convert the insert value, check whether it can be inserted.
converted_other = value
if value is not np.nan:
try:
converted_other = np.dtype(self.dtype).type(value) # type: ignore[arg-type]
except ValueError:
raise ValueError("The inserted value should be in the same dtype.")

# reconstruct a new index
values = self.values.copy()
new_values = []
if isinstance(values, np.ndarray):
for val, is_need in zip(values, mask):
if is_need:
new_values.append(converted_other)
else:
new_values.append(val)
else:
new_values.append(converted_other)

return Index(new_values)

def where(
self,
cond: "Index",
other: Any = np.nan
) -> "Index":
"""
Replace values where the condition is False.

The replacement is taken from other.

Parameters
----------
cond : bool array-like with the same length as self
Condition to select the values on.
other : scalar, or array-like, default None
Replacement if the condition is False.

Returns
-------
pandas.Index
A copy of self with values replaced from other
where the condition is False.

See Also
--------
Series.where : Same method for Series.
DataFrame.where : Same method for DataFrame.

Examples
--------
>>> idx = ps.Index(['car', 'bike', 'train', 'tractor'])
>>> idx
Index(['car', 'bike', 'train', 'tractor'], dtype='object')
>>> idx.where(idx.isin(['car', 'train']), 'other')
Index(['car', 'other', 'train', 'other'], dtype='object')
"""
return self.putmask(~cond, other)

def __getattr__(self, item: str) -> Any:
if hasattr(MissingPandasLikeIndex, item):
property_or_func = getattr(MissingPandasLikeIndex, item)
Expand Down
2 changes: 0 additions & 2 deletions python/pyspark/pandas/missing/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,13 @@ class MissingPandasLikeIndex:
groupby = _unsupported_function("groupby")
is_ = _unsupported_function("is_")
join = _unsupported_function("join")
putmask = _unsupported_function("putmask")
ravel = _unsupported_function("ravel")
reindex = _unsupported_function("reindex")
searchsorted = _unsupported_function("searchsorted")
slice_indexer = _unsupported_function("slice_indexer")
slice_locs = _unsupported_function("slice_locs")
sortlevel = _unsupported_function("sortlevel")
to_flat_index = _unsupported_function("to_flat_index")
where = _unsupported_function("where")
is_mixed = _unsupported_function("is_mixed")

# Deprecated functions
Expand Down
30 changes: 30 additions & 0 deletions python/pyspark/pandas/tests/indexes/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2511,6 +2511,36 @@ def test_drop_level(self):
):
psmidx.droplevel(-3)

def test_where_putmask(self):
pidx = pd.Index([1, 2, 3, 4])
psidx = ps.from_pandas(pidx)

# where and putmask with default inserted value np.nan
self.assert_eq(
pidx.where(pidx > 2),
psidx.where(psidx > 2)
)
self.assert_eq(
pidx.putmask(pidx > 2, 99),
psidx.putmask(psidx > 2, 99)
)

# where and putmask with isin func
self.assert_eq(
pidx.where(pidx.isin([1, 2])),
psidx.where(psidx.isin([1, 2]))
)
self.assert_eq(
pidx.putmask(pidx.isin([1, 2]), 99),
psidx.putmask(psidx.isin([1, 2]), 99)
)

# negative
self.assertRaises(
ValueError,
lambda: psidx.where(psidx > 2, "True"),
)


if __name__ == "__main__":
from pyspark.pandas.tests.indexes.test_base import * # noqa: F401
Expand Down