diff --git a/superset/utils/pandas_postprocessing.py b/superset/utils/pandas_postprocessing.py index f352e5fad45de..e22a9c744e785 100644 --- a/superset/utils/pandas_postprocessing.py +++ b/superset/utils/pandas_postprocessing.py @@ -25,6 +25,7 @@ from geopy.point import Point from pandas import DataFrame, NamedAgg, Series, Timestamp +from superset.constants import NULL_STRING from superset.exceptions import QueryObjectValidationError from superset.utils.core import ( DTTM_ALIAS, @@ -214,7 +215,7 @@ def pivot( # pylint: disable=too-many-arguments aggregates: Dict[str, Dict[str, Any]], columns: Optional[List[str]] = None, metric_fill_value: Optional[Any] = None, - column_fill_value: Optional[str] = None, + column_fill_value: Optional[str] = NULL_STRING, drop_missing_columns: Optional[bool] = True, combine_value_with_metric: bool = False, marginal_distributions: Optional[bool] = None, @@ -228,7 +229,9 @@ def pivot( # pylint: disable=too-many-arguments :param index: Columns to group by on the table index (=rows) :param columns: Columns to group by on the table columns :param metric_fill_value: Value to replace missing values with - :param column_fill_value: Value to replace missing pivot columns with + :param column_fill_value: Value to replace missing pivot columns with. By default + replaces missing values with "". Set to `None` to remove columns + with missing values. :param drop_missing_columns: Do not include columns whose entries are all missing :param combine_value_with_metric: Display metrics side by side within each column, as opposed to each column being displayed side by side for each metric. @@ -250,7 +253,7 @@ def pivot( # pylint: disable=too-many-arguments _("Pivot operation must include at least one aggregate") ) - if column_fill_value: + if columns and column_fill_value: df[columns] = df[columns].fillna(value=column_fill_value) aggregate_funcs = _get_aggregate_funcs(df, aggregates) diff --git a/tests/pandas_postprocessing_tests.py b/tests/pandas_postprocessing_tests.py index 73d3d7c94b95f..3f54f7e79db73 100644 --- a/tests/pandas_postprocessing_tests.py +++ b/tests/pandas_postprocessing_tests.py @@ -197,6 +197,21 @@ def test_pivot_fill_values(self): ) self.assertEqual(df.sum()[1], 382) + def test_pivot_fill_column_values(self): + """ + Make sure pivot witn null column names returns correct DataFrame + """ + df_copy = categories_df.copy() + df_copy["category"] = None + df = proc.pivot( + df=df_copy, + index=["name"], + columns=["category"], + aggregates={"idx_nulls": {"operator": "sum"}}, + ) + assert len(df) == 101 + assert df.columns.tolist() == ["name", ""] + def test_pivot_exceptions(self): """ Make sure pivot raises correct Exceptions