Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: rolling and cum operator on multiple series #16945

Merged
merged 3 commits into from
Oct 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 51 additions & 9 deletions superset/utils/pandas_postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ def _flatten_column_after_pivot(
def validate_column_args(*argnames: str) -> Callable[..., Any]:
def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapped(df: DataFrame, **options: Any) -> Any:
if options.get("is_pivot_df"):
# skip validation when pivot Dataframe
return func(df, **options)
columns = df.columns.tolist()
for name in argnames:
if name in options and not all(
Expand Down Expand Up @@ -223,6 +226,7 @@ def pivot( # pylint: disable=too-many-arguments,too-many-locals
marginal_distributions: Optional[bool] = None,
marginal_distribution_name: Optional[str] = None,
flatten_columns: bool = True,
reset_index: bool = True,
) -> DataFrame:
"""
Perform a pivot operation on a DataFrame.
Expand All @@ -243,6 +247,7 @@ def pivot( # pylint: disable=too-many-arguments,too-many-locals
:param marginal_distribution_name: Name of row/column with marginal distribution.
Default to 'All'.
:param flatten_columns: Convert column names to strings
:param reset_index: Convert index to column
:return: A pivot table
:raises QueryObjectValidationError: If the request in incorrect
"""
Expand Down Expand Up @@ -300,7 +305,8 @@ def pivot( # pylint: disable=too-many-arguments,too-many-locals
_flatten_column_after_pivot(col, aggregates) for col in df.columns
]
# return index as regular column
df.reset_index(level=0, inplace=True)
if reset_index:
df.reset_index(level=0, inplace=True)
return df


Expand Down Expand Up @@ -343,13 +349,14 @@ def sort(df: DataFrame, columns: Dict[str, bool]) -> DataFrame:
@validate_column_args("columns")
def rolling( # pylint: disable=too-many-arguments
df: DataFrame,
columns: Dict[str, str],
rolling_type: str,
columns: Optional[Dict[str, str]] = None,
window: Optional[int] = None,
rolling_type_options: Optional[Dict[str, Any]] = None,
center: bool = False,
win_type: Optional[str] = None,
min_periods: Optional[int] = None,
is_pivot_df: bool = False,
) -> DataFrame:
"""
Apply a rolling window on the dataset. See the Pandas docs for further details:
Expand All @@ -369,11 +376,16 @@ def rolling( # pylint: disable=too-many-arguments
:param win_type: Type of window function.
:param min_periods: The minimum amount of periods required for a row to be included
in the result set.
:param is_pivot_df: Dataframe is pivoted or not
:return: DataFrame with the rolling columns
:raises QueryObjectValidationError: If the request in incorrect
"""
rolling_type_options = rolling_type_options or {}
df_rolling = df[columns.keys()]
columns = columns or {}
if is_pivot_df:
df_rolling = df
else:
df_rolling = df[columns.keys()]
kwargs: Dict[str, Union[str, int]] = {}
if window is None:
raise QueryObjectValidationError(_("Undefined window for rolling operation"))
Expand Down Expand Up @@ -405,10 +417,20 @@ def rolling( # pylint: disable=too-many-arguments
options=rolling_type_options,
)
) from ex
df = _append_columns(df, df_rolling, columns)

if is_pivot_df:
agg_in_pivot_df = df.columns.get_level_values(0).drop_duplicates().to_list()
agg: Dict[str, Dict[str, Any]] = {col: {} for col in agg_in_pivot_df}
df_rolling.columns = [
_flatten_column_after_pivot(col, agg) for col in df_rolling.columns
]
Comment on lines +422 to +426
Copy link
Member Author

@zhaoyongjie zhaoyongjie Oct 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will create a separate PR for refactoring _flatten_column_after_pivot

df_rolling.reset_index(level=0, inplace=True)
else:
df_rolling = _append_columns(df, df_rolling, columns)

if min_periods:
df = df[min_periods:]
return df
df_rolling = df_rolling[min_periods:]
return df_rolling


@validate_column_args("columns", "drop", "rename")
Expand Down Expand Up @@ -524,7 +546,12 @@ def compare( # pylint: disable=too-many-arguments


@validate_column_args("columns")
def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame:
def cum(
df: DataFrame,
operator: str,
columns: Optional[Dict[str, str]] = None,
is_pivot_df: bool = False,
) -> DataFrame:
"""
Calculate cumulative sum/product/min/max for select columns.

Expand All @@ -535,17 +562,32 @@ def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame:
`y2` based on cumulative values calculated from `y`, leaving the original
column `y` unchanged.
:param operator: cumulative operator, e.g. `sum`, `prod`, `min`, `max`
:param is_pivot_df: Dataframe is pivoted or not
:return: DataFrame with cumulated columns
"""
df_cum = df[columns.keys()]
columns = columns or {}
if is_pivot_df:
df_cum = df
else:
df_cum = df[columns.keys()]
operation = "cum" + operator
if operation not in ALLOWLIST_CUMULATIVE_FUNCTIONS or not hasattr(
df_cum, operation
):
raise QueryObjectValidationError(
_("Invalid cumulative operator: %(operator)s", operator=operator)
)
return _append_columns(df, getattr(df_cum, operation)(), columns)
if is_pivot_df:
df_cum = getattr(df_cum, operation)()
agg_in_pivot_df = df.columns.get_level_values(0).drop_duplicates().to_list()
agg: Dict[str, Dict[str, Any]] = {col: {} for col in agg_in_pivot_df}
df_cum.columns = [
_flatten_column_after_pivot(col, agg) for col in df_cum.columns
]
Comment on lines +581 to +586
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will create a separate PR for refactoring _flatten_column_after_pivot

df_cum.reset_index(level=0, inplace=True)
else:
df_cum = _append_columns(df, getattr(df_cum, operation)(), columns)
return df_cum


def geohash_decode(
Expand Down
16 changes: 16 additions & 0 deletions tests/integration_tests/fixtures/dataframes.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,19 @@
"b": [4, 3, 4.1, 3.95],
}
)

single_metric_df = DataFrame(
{
"dttm": to_datetime(["2019-01-01", "2019-01-01", "2019-01-02", "2019-01-02",]),
"country": ["UK", "US", "UK", "US"],
"sum_metric": [5, 6, 7, 8],
}
)
multiple_metrics_df = DataFrame(
{
"dttm": to_datetime(["2019-01-01", "2019-01-01", "2019-01-02", "2019-01-02",]),
"country": ["UK", "US", "UK", "US"],
"sum_metric": [5, 6, 7, 8],
"count_metric": [1, 2, 3, 4],
}
)
118 changes: 118 additions & 0 deletions tests/integration_tests/pandas_postprocessing_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
from .base_tests import SupersetTestCase
from .fixtures.dataframes import (
categories_df,
single_metric_df,
multiple_metrics_df,
lonlat_df,
names_df,
timeseries_df,
Expand Down Expand Up @@ -305,6 +307,23 @@ def test_pivot_eliminate_cartesian_product_columns(self):
)
self.assertTrue(np.isnan(df["metric, 1, 1"][0]))

def test_pivot_without_flatten_columns_and_reset_index(self):
df = proc.pivot(
df=single_metric_df,
index=["dttm"],
columns=["country"],
aggregates={"sum_metric": {"operator": "sum"}},
flatten_columns=False,
reset_index=False,
)
# metric
# country UK US
# dttm
# 2019-01-01 5 6
# 2019-01-02 7 8
assert df.columns.to_list() == [("sum_metric", "UK"), ("sum_metric", "US")]
assert df.index.to_list() == to_datetime(["2019-01-01", "2019-01-02"]).to_list()

def test_aggregate(self):
aggregates = {
"asc sum": {"column": "asc_idx", "operator": "sum"},
Expand Down Expand Up @@ -405,6 +424,60 @@ def test_rolling(self):
window=2,
)

def test_rolling_with_pivot_df_and_single_metric(self):
pivot_df = proc.pivot(
df=single_metric_df,
index=["dttm"],
columns=["country"],
aggregates={"sum_metric": {"operator": "sum"}},
flatten_columns=False,
reset_index=False,
)
rolling_df = proc.rolling(
df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True,
)
# dttm UK US
# 0 2019-01-01 5 6
# 1 2019-01-02 12 14
assert rolling_df["UK"].to_list() == [5.0, 12.0]
assert rolling_df["US"].to_list() == [6.0, 14.0]
assert (
rolling_df["dttm"].to_list()
== to_datetime(["2019-01-01", "2019-01-02",]).to_list()
)

rolling_df = proc.rolling(
df=pivot_df, rolling_type="sum", window=2, min_periods=2, is_pivot_df=True,
)
assert rolling_df.empty is True

def test_rolling_with_pivot_df_and_multiple_metrics(self):
pivot_df = proc.pivot(
df=multiple_metrics_df,
index=["dttm"],
columns=["country"],
aggregates={
"sum_metric": {"operator": "sum"},
"count_metric": {"operator": "sum"},
},
flatten_columns=False,
reset_index=False,
)
rolling_df = proc.rolling(
df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True,
)
# dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
# 0 2019-01-01 1.0 2.0 5.0 6.0
# 1 2019-01-02 4.0 6.0 12.0 14.0
assert rolling_df["count_metric, UK"].to_list() == [1.0, 4.0]
assert rolling_df["count_metric, US"].to_list() == [2.0, 6.0]
assert rolling_df["sum_metric, UK"].to_list() == [5.0, 12.0]
assert rolling_df["sum_metric, US"].to_list() == [6.0, 14.0]
assert (
rolling_df["dttm"].to_list()
== to_datetime(["2019-01-01", "2019-01-02",]).to_list()
)

def test_select(self):
# reorder columns
post_df = proc.select(df=timeseries_df, columns=["y", "label"])
Expand Down Expand Up @@ -557,6 +630,51 @@ def test_cum(self):
operator="abc",
)

def test_cum_with_pivot_df_and_single_metric(self):
pivot_df = proc.pivot(
df=single_metric_df,
index=["dttm"],
columns=["country"],
aggregates={"sum_metric": {"operator": "sum"}},
flatten_columns=False,
reset_index=False,
)
cum_df = proc.cum(df=pivot_df, operator="sum", is_pivot_df=True,)
# dttm UK US
# 0 2019-01-01 5 6
# 1 2019-01-02 12 14
assert cum_df["UK"].to_list() == [5.0, 12.0]
assert cum_df["US"].to_list() == [6.0, 14.0]
assert (
cum_df["dttm"].to_list()
== to_datetime(["2019-01-01", "2019-01-02",]).to_list()
)

def test_cum_with_pivot_df_and_multiple_metrics(self):
pivot_df = proc.pivot(
df=multiple_metrics_df,
index=["dttm"],
columns=["country"],
aggregates={
"sum_metric": {"operator": "sum"},
"count_metric": {"operator": "sum"},
},
flatten_columns=False,
reset_index=False,
)
cum_df = proc.cum(df=pivot_df, operator="sum", is_pivot_df=True,)
# dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
# 0 2019-01-01 1 2 5 6
# 1 2019-01-02 4 6 12 14
assert cum_df["count_metric, UK"].to_list() == [1.0, 4.0]
assert cum_df["count_metric, US"].to_list() == [2.0, 6.0]
assert cum_df["sum_metric, UK"].to_list() == [5.0, 12.0]
assert cum_df["sum_metric, US"].to_list() == [6.0, 14.0]
assert (
cum_df["dttm"].to_list()
== to_datetime(["2019-01-01", "2019-01-02",]).to_list()
)

def test_geohash_decode(self):
# decode lon/lat from geohash
post_df = proc.geohash_decode(
Expand Down