From 352977c663fb070733df4914ecfe4afaff9e8a6d Mon Sep 17 00:00:00 2001 From: Samuel Oranyeli Date: Wed, 12 Oct 2022 02:21:22 +1100 Subject: [PATCH] [ENH] explicit default parameter for case_when (#1165) * add if_else function as standalone for use in pandas.assign * changelog * Update janitor/functions/case_when.py Co-authored-by: 40% * Update janitor/functions/case_when.py Co-authored-by: 40% * changelog * updates based on feedback * if_else addition to docs * drop if_else idea; make default parameter mandatory * changelog * add deprecation warning * single function for checks and computation - separation unnecessary * clean up tests * update tests Co-authored-by: sammychoco Co-authored-by: 40% Co-authored-by: root --- CHANGELOG.md | 2 + janitor/functions/case_when.py | 149 +++++++++++++++--------------- janitor/functions/utils.py | 2 + tests/functions/test_case_when.py | 119 +++++++++++++++++------- 4 files changed, 162 insertions(+), 110 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1fa9ed78b..c92fee966 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,8 @@ - [BUG] Avoid `change_type` mutating original `DataFrame`. PR #1162 @Zeroto521 - [ENH] The parameter `column_name` of `change_type` totally supports inputing multi-column now. #1163 @Zeroto521 - [ENH] Fix error when `sort_by_appearance=True` is combined with `dropna=True`. Issue #1168 @samukweku +- [ENH] Add explicit default parameter to `case_when` function. Issue #1159 @samukweku + ## [v0.23.1] - 2022-05-03 diff --git a/janitor/functions/case_when.py b/janitor/functions/case_when.py index 37a372fad..7812fc9e8 100644 --- a/janitor/functions/case_when.py +++ b/janitor/functions/case_when.py @@ -1,14 +1,18 @@ -from itertools import count from pandas.core.common import apply_if_callable -from pandas.api.types import is_list_like +from typing import Any import pandas_flavor as pf import pandas as pd - +from pandas.api.types import is_scalar +import warnings from janitor.utils import check +warnings.simplefilter("always", DeprecationWarning) + @pf.register_dataframe_method -def case_when(df: pd.DataFrame, *args, column_name: str) -> pd.DataFrame: +def case_when( + df: pd.DataFrame, *args, default: Any = None, column_name: str +) -> pd.DataFrame: """ Create a column based on a condition or multiple conditions. @@ -33,8 +37,8 @@ def case_when(df: pd.DataFrame, *args, column_name: str) -> pd.DataFrame: >>> df.case_when( ... ((df.a == 0) & (df.b != 0)) | (df.c == "wait"), df.a, ... (df.b == 0) & (df.a == 0), "x", - ... df.c, - ... column_name="value", + ... default = df.c, + ... column_name = "value", ... ) a b c value 0 0 0 6 x @@ -90,7 +94,7 @@ def case_when(df: pd.DataFrame, *args, column_name: str) -> pd.DataFrame: :param df: A pandas DataFrame. :param args: Variable argument of conditions and expected values. Takes the form - `condition0`, `value0`, `condition1`, `value1`, ..., `default`. + `condition0`, `value0`, `condition1`, `value1`, ... . `condition` can be a 1-D boolean array, a callable, or a string. If `condition` is a callable, it should evaluate to a 1-D boolean array. The array should have the same length @@ -99,84 +103,67 @@ def case_when(df: pd.DataFrame, *args, column_name: str) -> pd.DataFrame: `result` can be a scalar, a 1-D array, or a callable. If `result` is a callable, it should evaluate to a 1-D array. For a 1-D array, it should have the same length as the DataFrame. - The `default` argument applies if none of `condition0`, - `condition1`, ..., evaluates to `True`. - Value can be a scalar, a callable, or a 1-D array. if `default` is a - callable, it should evaluate to a 1-D array. + :param default: scalar, 1-D array or callable. + This is the element inserted in the output + when all conditions evaluate to False. + If callable, it should evaluate to a 1-D array. The 1-D array should be the same length as the DataFrame. + :param column_name: Name of column to assign results to. A new column is created, if it does not already exist in the DataFrame. - :raises ValueError: If the condition fails to evaluate. + :raises ValueError: if condition/value fails to evaluate. :returns: A pandas DataFrame. """ - conditions, targets, default = _case_when_checks(df, args, column_name) - - if len(conditions) == 1: - default = default.mask(conditions[0], targets[0]) - return df.assign(**{column_name: default}) - - # ensures value assignment is on a first come basis - conditions = conditions[::-1] - targets = targets[::-1] - for condition, value, index in zip(conditions, targets, count()): - try: - default = default.mask(condition, value) - # error `feedoff` idea from SO - # https://stackoverflow.com/a/46091127/7175713 - except Exception as e: - raise ValueError( - f"condition{index} and value{index} failed to evaluate. " - f"Original error message: {e}" - ) from e - - return df.assign(**{column_name: default}) - - -def _case_when_checks(df: pd.DataFrame, args, column_name): - """ - Preliminary checks on the case_when function. - """ - if len(args) < 3: - raise ValueError( - "At least three arguments are required for the `args` parameter." - ) - if len(args) % 2 != 1: + # Preliminary checks on the case_when function. + # The bare minimum checks are done; the remaining checks + # are done within `pd.Series.mask`. + check("column_name", column_name, [str]) + len_args = len(args) + if len_args < 2: raise ValueError( - "It seems the `default` argument is missing from the variable " - "`args` parameter." + "At least two arguments are required for the `args` parameter" ) - check("column_name", column_name, [str]) - - *args, default = args + if len_args % 2: + if default is None: + warnings.warn( + "The last argument in the variable arguments " + "has been assigned as the default. " + "Note however that this will be deprecated " + "in a future release; use an even number " + "of boolean conditions and values, " + "and pass the default argument to the `default` " + "parameter instead.", + DeprecationWarning, + stacklevel=2, + ) + *args, default = args + else: + raise ValueError( + "The number of conditions and values do not match. " + f"There are {len_args - len_args//2} conditions " + f"and {len_args//2} values." + ) booleans = [] replacements = [] + for index, value in enumerate(args): - if index % 2 == 0: - booleans.append(value) - else: + if index % 2: + if callable(value): + value = apply_if_callable(value, df) replacements.append(value) - - conditions = [] - for condition in booleans: - if callable(condition): - condition = apply_if_callable(condition, df) - elif isinstance(condition, str): - condition = df.eval(condition) - conditions.append(condition) - - targets = [] - for replacement in replacements: - if callable(replacement): - replacement = apply_if_callable(replacement, df) - targets.append(replacement) + else: + if callable(value): + value = apply_if_callable(value, df) + elif isinstance(value, str): + value = df.eval(value) + booleans.append(value) if callable(default): default = apply_if_callable(default, df) - if not is_list_like(default): + if is_scalar(default): default = pd.Series([default]).repeat(len(df)) - default.index = df.index if not hasattr(default, "shape"): default = pd.Series([*default]) if isinstance(default, pd.Index): @@ -185,14 +172,26 @@ def _case_when_checks(df: pd.DataFrame, args, column_name): arr_ndim = default.ndim if arr_ndim != 1: raise ValueError( - "The `default` argument should either be a 1-D array, a scalar, " + "The argument for the `default` parameter " + "should either be a 1-D array, a scalar, " "or a callable that can evaluate to a 1-D array." ) if not isinstance(default, pd.Series): default = pd.Series(default) - if default.size != len(df): - raise ValueError( - "The length of the `default` argument should be equal to the " - "length of the DataFrame." - ) - return conditions, targets, default + default.index = df.index + # actual computation + # ensures value assignment is on a first come basis + booleans = booleans[::-1] + replacements = replacements[::-1] + for index, (condition, value) in enumerate(zip(booleans, replacements)): + try: + default = default.mask(condition, value) + # error `feedoff` idea from SO + # https://stackoverflow.com/a/46091127/7175713 + except Exception as error: + raise ValueError( + f"condition{index} and value{index} failed to evaluate. " + f"Original error message: {error}" + ) from error + + return df.assign(**{column_name: default}) diff --git a/janitor/functions/utils.py b/janitor/functions/utils.py index b4e051496..597ef3885 100644 --- a/janitor/functions/utils.py +++ b/janitor/functions/utils.py @@ -23,6 +23,8 @@ from janitor.utils import check_column import functools +warnings.simplefilter("always", DeprecationWarning) + def unionize_dataframe_categories( *dataframes, column_names: Optional[Iterable[pd.CategoricalDtype]] = None diff --git a/tests/functions/test_case_when.py b/tests/functions/test_case_when.py index 7d97fd296..846072be9 100644 --- a/tests/functions/test_case_when.py +++ b/tests/functions/test_case_when.py @@ -1,7 +1,7 @@ import numpy as np import pandas as pd import pytest -from hypothesis import assume, given, settings +from hypothesis import given from pandas.testing import assert_frame_equal from janitor.testing_utils.strategies import ( @@ -32,7 +32,7 @@ def test_case_when_1(): df.a, (df.b == 0) & (df.a == 0), "x", - df.c, + default=df.c, column_name="value", ) @@ -40,19 +40,40 @@ def test_case_when_1(): def test_len_args(dataframe): - """Raise ValueError if `args` length is less than 3.""" - with pytest.raises(ValueError, match="three arguments are required"): - dataframe.case_when(dataframe.a < 10, "less_than_10", column_name="a") + """Raise ValueError if `args` length is less than 2.""" + with pytest.raises( + ValueError, + match="At least two arguments are required for the `args` parameter", + ): + dataframe.case_when( + dataframe.a < 10, default="less_than_10", column_name="a" + ) def test_args_even(dataframe): - """Raise ValueError if `args` length is even.""" - with pytest.raises(ValueError, match="`default` argument is missing"): + """Raise ValueError if `args` length is odd.""" + with pytest.raises( + ValueError, match="The number of conditions and values do not match.+" + ): + dataframe.case_when( + dataframe.a < 10, + "less_than_10", + dataframe.a == 5, + default="five", + column_name="a", + ) + + +def test_args_even_warning(dataframe): + """ + Raise Warning if `args` length + is odd and `default` is None. + """ + with pytest.warns(DeprecationWarning): dataframe.case_when( dataframe.a < 10, "less_than_10", dataframe.a == 5, - "five", column_name="a", ) @@ -63,51 +84,54 @@ def test_column_name(dataframe): dataframe.case_when( dataframe.a < 10, "less_than_10", - dataframe.a, + default=dataframe.a, column_name=("a",), ) -@given(df=df_strategy()) -def test_default_ndim(df): +def test_default_ndim(): """Raise ValueError if `default` ndim > 1.""" - with pytest.raises(ValueError): - df.case_when(df.a < 10, "less_than_10", df, column_name="a") + df = pd.DataFrame({"a": range(20)}) + with pytest.raises( + ValueError, + match="The argument for the `default` parameter " + "should either be a 1-D array.+", + ): + df.case_when( + df.a < 10, "less_than_10", default=df.to_numpy(), column_name="a" + ) -@pytest.mark.turtle -@given(df=df_strategy()) -@settings(deadline=None) -def test_default_length(df): +@pytest.mark.xfail(reason="Error handled by pd.Series.mask") +def test_default_length(): """Raise ValueError if `default` length != len(df).""" - assume(len(df) > 10) + df = pd.DataFrame({"a": range(20)}) with pytest.raises( ValueError, - match=( - "length of the `default` argument should be equal to the length of" - " the DataFrame" - ), + match=("The length of the argument for the `default` parameter is.+"), ): df.case_when( df.a < 10, "less_than_10", - df.loc[:5, "a"], + default=df.loc[:5, "a"], column_name="a", ) -@given(df=df_strategy()) -def test_error_multiple_conditions(df): +def test_error_multiple_conditions(): """Raise ValueError for multiple conditions.""" + df = pd.DataFrame({"a": range(20)}) with pytest.raises(ValueError): - df.case_when(df.a < 10, "baby", df.a + 5, "kid", df.a, column_name="a") + df.case_when( + df.a < 10, "baby", df.a + 5, "kid", default=df.a, column_name="a" + ) @given(df=df_strategy()) def test_case_when_condition_callable(df): """Test case_when for callable.""" result = df.case_when( - lambda df: df.a < 10, "baby", "bleh", column_name="bleh" + lambda df: df.a < 10, "baby", default="bleh", column_name="bleh" ) expected = np.where(df.a < 10, "baby", "bleh") expected = df.assign(bleh=expected) @@ -117,7 +141,7 @@ def test_case_when_condition_callable(df): @given(df=df_strategy()) def test_case_when_condition_eval(df): """Test case_when for callable.""" - result = df.case_when("a < 10", "baby", "bleh", column_name="bleh") + result = df.case_when("a < 10", "baby", default="bleh", column_name="bleh") expected = np.where(df.a < 10, "baby", "bleh") expected = df.assign(bleh=expected) assert_frame_equal(result, expected) @@ -127,7 +151,10 @@ def test_case_when_condition_eval(df): def test_case_when_replacement_callable(df): """Test case_when for callable.""" result = df.case_when( - "a > 10", lambda df: df.a + 10, lambda df: df.a * 2, column_name="bleh" + "a > 10", + lambda df: df.a + 10, + default=lambda df: df.a * 2, + column_name="bleh", ) expected = np.where(df.a > 10, df.a + 10, df.a * 2) expected = df.assign(bleh=expected) @@ -135,14 +162,34 @@ def test_case_when_replacement_callable(df): @given(df=categoricaldf_strategy()) -def test_case_when_default_list(df): +def test_case_when_default_array(df): + """ + Test case_when for scenarios where `default` is array-like + """ + default = np.arange(len(df)) + result = df.case_when( + "numbers > 1", + lambda df: df.numbers + 10, + default=default, + column_name="bleh", + ) + expected = np.where(df.numbers > 1, df.numbers + 10, default) + expected = df.assign(bleh=expected) + assert_frame_equal(result, expected) + + +@given(df=categoricaldf_strategy()) +def test_case_when_default_list_like(df): """ Test case_when for scenarios where `default` is list-like, - but not a Pandas or numpy object. + but has no shape attribute. """ default = range(len(df)) result = df.case_when( - "numbers > 1", lambda df: df.numbers + 10, default, column_name="bleh" + "numbers > 1", + lambda df: df.numbers + 10, + default=default, + column_name="bleh", ) expected = np.where(df.numbers > 1, df.numbers + 10, default) expected = df.assign(bleh=expected) @@ -151,12 +198,14 @@ def test_case_when_default_list(df): @given(df=categoricaldf_strategy()) def test_case_when_default_index(df): - """Test case_when for scenarios where `default` is an index.""" + """ + Test case_when for scenarios where `default` is an index. + """ default = range(len(df)) result = df.case_when( "numbers > 1", lambda df: df.numbers + 10, - pd.Index(default), + default=pd.Index(default), column_name="bleh", ) expected = np.where(df.numbers > 1, df.numbers + 10, default) @@ -176,7 +225,7 @@ def test_case_when_multiple_args(df): "young", "30 <= a < 50", "mature", - "grandpa", + default="grandpa", column_name="elderly", ) conditions = [