Skip to content

Commit

Permalink
[ENH] explicit default parameter for case_when (#1165)
Browse files Browse the repository at this point in the history
* add if_else function as standalone for use in pandas.assign

* changelog

* Update janitor/functions/case_when.py

Co-authored-by: 40% <Zeroto521@gmail.com>

* Update janitor/functions/case_when.py

Co-authored-by: 40% <Zeroto521@gmail.com>

* changelog

* updates based on feedback

* if_else addition to docs

* drop if_else idea; make default parameter mandatory

* changelog

* add deprecation warning

* single function for checks and computation - separation unnecessary

* clean up tests

* update tests

Co-authored-by: sammychoco <samuel.oranyeli@slalom.com>
Co-authored-by: 40% <Zeroto521@gmail.com>
Co-authored-by: root <root@45d364731ba2>
  • Loading branch information
4 people authored Oct 11, 2022
1 parent 27b3201 commit 352977c
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 110 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
- [BUG] Avoid `change_type` mutating original `DataFrame`. PR #1162 @Zeroto521
- [ENH] The parameter `column_name` of `change_type` totally supports inputing multi-column now. #1163 @Zeroto521
- [ENH] Fix error when `sort_by_appearance=True` is combined with `dropna=True`. Issue #1168 @samukweku
- [ENH] Add explicit default parameter to `case_when` function. Issue #1159 @samukweku


## [v0.23.1] - 2022-05-03

Expand Down
149 changes: 74 additions & 75 deletions janitor/functions/case_when.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from itertools import count
from pandas.core.common import apply_if_callable
from pandas.api.types import is_list_like
from typing import Any
import pandas_flavor as pf
import pandas as pd

from pandas.api.types import is_scalar
import warnings
from janitor.utils import check

warnings.simplefilter("always", DeprecationWarning)


@pf.register_dataframe_method
def case_when(df: pd.DataFrame, *args, column_name: str) -> pd.DataFrame:
def case_when(
df: pd.DataFrame, *args, default: Any = None, column_name: str
) -> pd.DataFrame:
"""
Create a column based on a condition or multiple conditions.
Expand All @@ -33,8 +37,8 @@ def case_when(df: pd.DataFrame, *args, column_name: str) -> pd.DataFrame:
>>> df.case_when(
... ((df.a == 0) & (df.b != 0)) | (df.c == "wait"), df.a,
... (df.b == 0) & (df.a == 0), "x",
... df.c,
... column_name="value",
... default = df.c,
... column_name = "value",
... )
a b c value
0 0 0 6 x
Expand Down Expand Up @@ -90,7 +94,7 @@ def case_when(df: pd.DataFrame, *args, column_name: str) -> pd.DataFrame:
:param df: A pandas DataFrame.
:param args: Variable argument of conditions and expected values.
Takes the form
`condition0`, `value0`, `condition1`, `value1`, ..., `default`.
`condition0`, `value0`, `condition1`, `value1`, ... .
`condition` can be a 1-D boolean array, a callable, or a string.
If `condition` is a callable, it should evaluate
to a 1-D boolean array. The array should have the same length
Expand All @@ -99,84 +103,67 @@ def case_when(df: pd.DataFrame, *args, column_name: str) -> pd.DataFrame:
`result` can be a scalar, a 1-D array, or a callable.
If `result` is a callable, it should evaluate to a 1-D array.
For a 1-D array, it should have the same length as the DataFrame.
The `default` argument applies if none of `condition0`,
`condition1`, ..., evaluates to `True`.
Value can be a scalar, a callable, or a 1-D array. if `default` is a
callable, it should evaluate to a 1-D array.
:param default: scalar, 1-D array or callable.
This is the element inserted in the output
when all conditions evaluate to False.
If callable, it should evaluate to a 1-D array.
The 1-D array should be the same length as the DataFrame.
:param column_name: Name of column to assign results to. A new column
is created, if it does not already exist in the DataFrame.
:raises ValueError: If the condition fails to evaluate.
:raises ValueError: if condition/value fails to evaluate.
:returns: A pandas DataFrame.
"""
conditions, targets, default = _case_when_checks(df, args, column_name)

if len(conditions) == 1:
default = default.mask(conditions[0], targets[0])
return df.assign(**{column_name: default})

# ensures value assignment is on a first come basis
conditions = conditions[::-1]
targets = targets[::-1]
for condition, value, index in zip(conditions, targets, count()):
try:
default = default.mask(condition, value)
# error `feedoff` idea from SO
# https://stackoverflow.com/a/46091127/7175713
except Exception as e:
raise ValueError(
f"condition{index} and value{index} failed to evaluate. "
f"Original error message: {e}"
) from e

return df.assign(**{column_name: default})


def _case_when_checks(df: pd.DataFrame, args, column_name):
"""
Preliminary checks on the case_when function.
"""
if len(args) < 3:
raise ValueError(
"At least three arguments are required for the `args` parameter."
)
if len(args) % 2 != 1:
# Preliminary checks on the case_when function.
# The bare minimum checks are done; the remaining checks
# are done within `pd.Series.mask`.
check("column_name", column_name, [str])
len_args = len(args)
if len_args < 2:
raise ValueError(
"It seems the `default` argument is missing from the variable "
"`args` parameter."
"At least two arguments are required for the `args` parameter"
)

check("column_name", column_name, [str])

*args, default = args
if len_args % 2:
if default is None:
warnings.warn(
"The last argument in the variable arguments "
"has been assigned as the default. "
"Note however that this will be deprecated "
"in a future release; use an even number "
"of boolean conditions and values, "
"and pass the default argument to the `default` "
"parameter instead.",
DeprecationWarning,
stacklevel=2,
)
*args, default = args
else:
raise ValueError(
"The number of conditions and values do not match. "
f"There are {len_args - len_args//2} conditions "
f"and {len_args//2} values."
)

booleans = []
replacements = []

for index, value in enumerate(args):
if index % 2 == 0:
booleans.append(value)
else:
if index % 2:
if callable(value):
value = apply_if_callable(value, df)
replacements.append(value)

conditions = []
for condition in booleans:
if callable(condition):
condition = apply_if_callable(condition, df)
elif isinstance(condition, str):
condition = df.eval(condition)
conditions.append(condition)

targets = []
for replacement in replacements:
if callable(replacement):
replacement = apply_if_callable(replacement, df)
targets.append(replacement)
else:
if callable(value):
value = apply_if_callable(value, df)
elif isinstance(value, str):
value = df.eval(value)
booleans.append(value)

if callable(default):
default = apply_if_callable(default, df)
if not is_list_like(default):
if is_scalar(default):
default = pd.Series([default]).repeat(len(df))
default.index = df.index
if not hasattr(default, "shape"):
default = pd.Series([*default])
if isinstance(default, pd.Index):
Expand All @@ -185,14 +172,26 @@ def _case_when_checks(df: pd.DataFrame, args, column_name):
arr_ndim = default.ndim
if arr_ndim != 1:
raise ValueError(
"The `default` argument should either be a 1-D array, a scalar, "
"The argument for the `default` parameter "
"should either be a 1-D array, a scalar, "
"or a callable that can evaluate to a 1-D array."
)
if not isinstance(default, pd.Series):
default = pd.Series(default)
if default.size != len(df):
raise ValueError(
"The length of the `default` argument should be equal to the "
"length of the DataFrame."
)
return conditions, targets, default
default.index = df.index
# actual computation
# ensures value assignment is on a first come basis
booleans = booleans[::-1]
replacements = replacements[::-1]
for index, (condition, value) in enumerate(zip(booleans, replacements)):
try:
default = default.mask(condition, value)
# error `feedoff` idea from SO
# https://stackoverflow.com/a/46091127/7175713
except Exception as error:
raise ValueError(
f"condition{index} and value{index} failed to evaluate. "
f"Original error message: {error}"
) from error

return df.assign(**{column_name: default})
2 changes: 2 additions & 0 deletions janitor/functions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from janitor.utils import check_column
import functools

warnings.simplefilter("always", DeprecationWarning)


def unionize_dataframe_categories(
*dataframes, column_names: Optional[Iterable[pd.CategoricalDtype]] = None
Expand Down
Loading

0 comments on commit 352977c

Please sign in to comment.