From fbc70cbe0d35ff0916bb26fae98c488a7c256d2d Mon Sep 17 00:00:00 2001 From: michellethomas Date: Tue, 22 May 2018 09:58:38 -0700 Subject: [PATCH] Allow MetricsControl to aggregate on a column with an expression (#5021) * Allow MetricsControl to aggregate on a column with an expression * Adding test case for metrics based on columns --- superset/connectors/sqla/models.py | 20 +++++++++++++++++--- superset/data/__init__.py | 23 +++++++++++++++++++++++ 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index cef3bd61a5d1e..875707f55ccbe 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -450,10 +450,24 @@ def get_from_clause(self, template_processor=None, db_engine_spec=None): return TextAsFrom(sa.text(from_sql), []).alias('expr_qry') return self.get_sqla_table() - def adhoc_metric_to_sa(self, metric): + def adhoc_metric_to_sa(self, metric, cols): + """ + Turn an adhoc metric into a sqlalchemy column. + + :param dict metric: Adhoc metric definition + :param dict cols: Columns for the current table + :returns: The metric defined as a sqlalchemy column + :rtype: sqlalchemy.sql.column + """ expressionType = metric.get('expressionType') if expressionType == utils.ADHOC_METRIC_EXPRESSION_TYPES['SIMPLE']: - sa_column = column(metric.get('column').get('column_name')) + column_name = metric.get('column').get('column_name') + sa_column = column(column_name) + table_column = cols.get(column_name) + + if table_column: + sa_column = table_column.sqla_col + sa_metric = self.sqla_aggregations[metric.get('aggregate')](sa_column) sa_metric = sa_metric.label(metric.get('label')) return sa_metric @@ -518,7 +532,7 @@ def get_sqla_query( # sqla metrics_exprs = [] for m in metrics: if utils.is_adhoc_metric(m): - metrics_exprs.append(self.adhoc_metric_to_sa(m)) + metrics_exprs.append(self.adhoc_metric_to_sa(m, cols)) elif m in metrics_dict: metrics_exprs.append(metrics_dict.get(m).sqla_col) else: diff --git a/superset/data/__init__.py b/superset/data/__init__.py index 8451b95631c15..4f79be842a90b 100644 --- a/superset/data/__init__.py +++ b/superset/data/__init__.py @@ -19,6 +19,7 @@ from superset import app, db, security_manager, utils from superset.connectors.connector_registry import ConnectorRegistry +from superset.connectors.sqla.models import TableColumn from superset.models import core as models # Shortcuts @@ -585,6 +586,10 @@ def load_birth_names(): obj.main_dttm_col = 'ds' obj.database = utils.get_or_create_main_db() obj.filter_select_enabled = True + obj.columns.append(TableColumn( + column_name='num_california', + expression="CASE WHEN state = 'CA' THEN num ELSE 0 END" + )) db.session.merge(obj) db.session.commit() obj.fetch_metadata() @@ -737,6 +742,24 @@ def load_birth_names(): 'val': ['girl'], }], subheader='total female participants')), + Slice( + slice_name="Number of California Births", + viz_type='big_number_total', + datasource_type='table', + datasource_id=tbl.id, + params=get_slice_json( + defaults, + metric={ + "expressionType": "SIMPLE", + "column": { + "column_name": "num_california", + "expression": "CASE WHEN state = 'CA' THEN num ELSE 0 END", + }, + "aggregate": "SUM", + "label": "SUM(num_california)", + }, + viz_type="big_number_total", + granularity_sqla="ds")), ] for slc in slices: merge_slice(slc)