From 14022567dd64b58120ac55ac4a62a9531e60e06c Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Thu, 3 Sep 2020 13:09:30 +0300 Subject: [PATCH 1/2] fix: pivot table timestamp grouping --- superset/viz.py | 35 +++++++++++++++++++++++++++-------- tests/viz_tests.py | 34 +++++++++++++++++++++++++++++++++- 2 files changed, 60 insertions(+), 9 deletions(-) diff --git a/superset/viz.py b/superset/viz.py index e6d9856f6ce87..17b776bdf1c22 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -27,7 +27,7 @@ import math import re from collections import defaultdict, OrderedDict -from datetime import datetime, timedelta +from datetime import date, datetime, timedelta from itertools import product from typing import ( Any, @@ -856,6 +856,30 @@ def get_aggfunc( # only min and max work properly for non-numerics return aggfunc if aggfunc in ("min", "max") else "max" + @staticmethod + def _format_datetime(value: Any) -> Any: + """ + Format a timestamp in such a way that the viz will be able to apply + the correct formatting in the frontend. + + :param value: the value of a temporal column + :return: formatted timestamp if it is a valid timestamp, otherwise + the original value + """ + tstamp: Optional[pd.Timestamp] = None + if isinstance(value, pd.Timestamp): + tstamp = value + if isinstance(value, datetime) or isinstance(value, date): + tstamp = pd.Timestamp(value) + if isinstance(value, str): + try: + tstamp = pd.Timestamp(value) + except ValueError: + pass + if tstamp: + return f"__timestamp:{datetime_to_epoch(tstamp)}" + return value + def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: return None @@ -871,15 +895,10 @@ def get_data(self, df: pd.DataFrame) -> VizData: groupby = self.form_data.get("groupby") or [] columns = self.form_data.get("columns") or [] - def _format_datetime(value: Any) -> Optional[str]: - if isinstance(value, str): - return f"__timestamp:{datetime_to_epoch(pd.Timestamp(value))}" - return None - for column_name in groupby + columns: column = self.datasource.get_column(column_name) - if column and column.type in ("DATE", "DATETIME", "TIMESTAMP"): - ts = df[column_name].apply(_format_datetime) + if column and column.is_temporal: + ts = df[column_name].apply(self._format_datetime) df[column_name] = ts if self.form_data.get("transpose_pivot"): diff --git a/tests/viz_tests.py b/tests/viz_tests.py index 637cd33cc1d1c..6b399e3f84edf 100644 --- a/tests/viz_tests.py +++ b/tests/viz_tests.py @@ -16,7 +16,7 @@ # under the License. # isort:skip_file import uuid -from datetime import datetime +from datetime import date, datetime, timezone import logging from math import nan from unittest.mock import Mock, patch @@ -1353,6 +1353,38 @@ def test_get_aggfunc_non_numeric(self): == "min" ) + def test_format_datetime_from_pd_timestamp(self): + tstamp = pd.Timestamp(datetime(2020, 9, 3, tzinfo=timezone.utc)) + assert ( + viz.PivotTableViz._format_datetime(tstamp) == "__timestamp:1599091200000.0" + ) + + def test_format_datetime_from_datetime(self): + tstamp = datetime(2020, 9, 3, tzinfo=timezone.utc) + assert ( + viz.PivotTableViz._format_datetime(tstamp) == "__timestamp:1599091200000.0" + ) + + def test_format_datetime_from_date(self): + tstamp = date(2020, 9, 3) + assert ( + viz.PivotTableViz._format_datetime(tstamp) == "__timestamp:1599091200000.0" + ) + + def test_format_datetime_from_string(self): + tstamp = "2020-09-03T00:00:00" + assert ( + viz.PivotTableViz._format_datetime(tstamp) == "__timestamp:1599091200000.0" + ) + + def test_format_datetime_from_invalid_string(self): + tstamp = "abracadabra" + assert viz.PivotTableViz._format_datetime(tstamp) == tstamp + + def test_format_datetime_from_int(self): + assert viz.PivotTableViz._format_datetime(123) == 123 + assert viz.PivotTableViz._format_datetime(123.0) == 123.0 + class TestDistributionPieViz(SupersetTestCase): base_df = pd.DataFrame( From 38f620cc9babbfac373daa448e196153d927edc7 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt Date: Thu, 3 Sep 2020 16:52:13 +0300 Subject: [PATCH 2/2] address comments --- superset/viz.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/superset/viz.py b/superset/viz.py index 17b776bdf1c22..57531ea4cb32f 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -647,11 +647,11 @@ class TableViz(BaseViz): def process_metrics(self) -> None: """Process form data and store parsed column configs. - 1. Determine query mode based on form_data params. - - Use `query_mode` if it has a valid value - - Set as RAW mode if `all_columns` is set - - Otherwise defaults to AGG mode - 2. Determine output columns based on query mode. + 1. Determine query mode based on form_data params. + - Use `query_mode` if it has a valid value + - Set as RAW mode if `all_columns` is set + - Otherwise defaults to AGG mode + 2. Determine output columns based on query mode. """ # Verify form data first: if not specifying query mode, then cannot have both # GROUP BY and RAW COLUMNS. @@ -857,7 +857,7 @@ def get_aggfunc( return aggfunc if aggfunc in ("min", "max") else "max" @staticmethod - def _format_datetime(value: Any) -> Any: + def _format_datetime(value: Union[pd.Timestamp, datetime, date, str]) -> str: """ Format a timestamp in such a way that the viz will be able to apply the correct formatting in the frontend. @@ -878,7 +878,8 @@ def _format_datetime(value: Any) -> Any: pass if tstamp: return f"__timestamp:{datetime_to_epoch(tstamp)}" - return value + # fallback in case something incompatible is returned + return cast(str, value) def get_data(self, df: pd.DataFrame) -> VizData: if df.empty: