);
}
@@ -70,8 +76,10 @@ Filters.defaultProps = defaultProps;
function mapStateToProps(state) {
return {
+ datasource_type: state.datasource_type,
filterColumnOpts: state.filterColumnOpts,
filters: state.viz.form_data.filters,
+ renderFilterSelect: state.filter_select,
};
}
diff --git a/superset/assets/javascripts/explorev2/components/SelectField.jsx b/superset/assets/javascripts/explorev2/components/SelectField.jsx
index 999d397e74057..b3514449ee563 100644
--- a/superset/assets/javascripts/explorev2/components/SelectField.jsx
+++ b/superset/assets/javascripts/explorev2/components/SelectField.jsx
@@ -58,8 +58,18 @@ export default class SelectField extends React.Component {
if (this.props.freeForm) {
// For FreeFormSelect, insert value into options if not exist
const values = choices.map((c) => c[0]);
- if (values.indexOf(this.props.value) === -1) {
- options.push({ value: this.props.value, label: this.props.value });
+ if (this.props.value) {
+ if (typeof this.props.value === 'object') {
+ this.props.value.forEach((v) => {
+ if (values.indexOf(v) === -1) {
+ options.push({ value: v, label: v });
+ }
+ });
+ } else {
+ if (values.indexOf(this.props.value) === -1) {
+ options.push({ value: this.props.value, label: this.props.value });
+ }
+ }
}
}
@@ -77,13 +87,19 @@ export default class SelectField extends React.Component {
// Tab, comma or Enter will trigger a new option created for FreeFormSelect
const selectWrap = this.props.freeForm ?
(
-
+
{selectWrap}
);
diff --git a/superset/assets/javascripts/explorev2/index.jsx b/superset/assets/javascripts/explorev2/index.jsx
index 9cf8e34cd3634..ec45b654bd118 100644
--- a/superset/assets/javascripts/explorev2/index.jsx
+++ b/superset/assets/javascripts/explorev2/index.jsx
@@ -25,6 +25,7 @@ const bootstrappedState = Object.assign(
initialState(bootstrapData.viz.form_data.viz_type, bootstrapData.datasource_type), {
can_edit: bootstrapData.can_edit,
can_download: bootstrapData.can_download,
+ filter_select: bootstrapData.filter_select,
datasources: bootstrapData.datasources,
datasource_type: bootstrapData.datasource_type,
viz: bootstrapData.viz,
diff --git a/superset/assets/javascripts/explorev2/stores/store.js b/superset/assets/javascripts/explorev2/stores/store.js
index 240ec1df880c2..3f533f9506624 100644
--- a/superset/assets/javascripts/explorev2/stores/store.js
+++ b/superset/assets/javascripts/explorev2/stores/store.js
@@ -43,6 +43,7 @@ export function initialState(vizType = 'table', datasourceType = 'table') {
datasources: null,
datasource_type: null,
filterColumnOpts: [],
+ filter_select: false,
fields,
viz: defaultViz(vizType, datasourceType),
isStarred: false,
diff --git a/superset/assets/spec/javascripts/explorev2/components/Filter_spec.js b/superset/assets/spec/javascripts/explorev2/components/Filter_spec.js
index 9a55b59cdc51d..ef8be8c145d85 100644
--- a/superset/assets/spec/javascripts/explorev2/components/Filter_spec.js
+++ b/superset/assets/spec/javascripts/explorev2/components/Filter_spec.js
@@ -8,7 +8,9 @@ import { shallow } from 'enzyme';
import Filter from '../../../../javascripts/explorev2/components/Filter';
const defaultProps = {
- actions: {},
+ actions: {
+ fetchFilterValues: () => ({}),
+ },
filterColumnOpts: ['country_name'],
filter: {
id: 1,
diff --git a/superset/migrations/versions/f1f2d4af5b90_.py b/superset/migrations/versions/f1f2d4af5b90_.py
new file mode 100644
index 0000000000000..36bae518ce063
--- /dev/null
+++ b/superset/migrations/versions/f1f2d4af5b90_.py
@@ -0,0 +1,25 @@
+"""Enable Filter Select
+
+Revision ID: f1f2d4af5b90
+Revises: e46f2d27a08e
+Create Date: 2016-11-23 10:27:18.517919
+
+"""
+
+# revision identifiers, used by Alembic.
+revision = 'f1f2d4af5b90'
+down_revision = 'e46f2d27a08e'
+
+from alembic import op
+import sqlalchemy as sa
+
+
+def upgrade():
+ op.add_column('datasources', sa.Column('filter_select_enabled',
+ sa.Boolean(), default=False))
+ op.add_column('tables', sa.Column('filter_select_enabled',
+ sa.Boolean(), default=False))
+
+def downgrade():
+ op.drop_column('tables', 'filter_select_enabled')
+ op.drop_column('datasources', 'filter_select_enabled')
diff --git a/superset/models.py b/superset/models.py
index 10928ff1e6e49..847700563cb41 100644
--- a/superset/models.py
+++ b/superset/models.py
@@ -32,6 +32,7 @@
from flask_babel import lazy_gettext as _
from pydruid.client import PyDruid
+from pydruid.utils.aggregators import count
from pydruid.utils.filters import Dimension, Filter
from pydruid.utils.postaggregator import Postaggregator
from pydruid.utils.having import Aggregation
@@ -866,6 +867,7 @@ class SqlaTable(Model, Queryable, AuditMixinNullable, ImportMixin):
default_endpoint = Column(Text)
database_id = Column(Integer, ForeignKey('dbs.id'), nullable=False)
is_featured = Column(Boolean, default=False)
+ filter_select_enabled = Column(Boolean, default=False)
user_id = Column(Integer, ForeignKey('ab_user.id'))
owner = relationship('User', backref='tables', foreign_keys=[user_id])
database = relationship(
@@ -977,6 +979,45 @@ def get_col(self, col_name):
if col_name == col.column_name:
return col
+ def values_for_column(self,
+ column_name,
+ from_dttm,
+ to_dttm,
+ limit=500):
+ """Runs query against sqla to retrieve some
+ sample values for the given column.
+ """
+ granularity = self.main_dttm_col
+
+ cols = {col.column_name: col for col in self.columns}
+ target_col = cols[column_name]
+
+ tbl = table(self.table_name)
+ qry = select([target_col.sqla_col])
+ qry = qry.select_from(tbl)
+ qry = qry.distinct(column_name)
+ qry = qry.limit(limit)
+
+ if granularity:
+ dttm_col = cols[granularity]
+ timestamp = dttm_col.sqla_col.label('timestamp')
+ time_filter = [
+ timestamp >= text(dttm_col.dttm_sql_literal(from_dttm)),
+ timestamp <= text(dttm_col.dttm_sql_literal(to_dttm)),
+ ]
+ qry = qry.where(and_(*time_filter))
+
+ engine = self.database.get_sqla_engine()
+ sql = "{}".format(
+ qry.compile(
+ engine, compile_kwargs={"literal_binds": True}, ),
+ )
+
+ return pd.read_sql_query(
+ sql=sql,
+ con=engine
+ )
+
def query( # sqla
self, groupby, metrics,
granularity,
@@ -1594,6 +1635,7 @@ class DruidDatasource(Model, AuditMixinNullable, Queryable):
datasource_name = Column(String(255), unique=True)
is_featured = Column(Boolean, default=False)
is_hidden = Column(Boolean, default=False)
+ filter_select_enabled = Column(Boolean, default=False)
description = Column(Text)
default_endpoint = Column(Text)
user_id = Column(Integer, ForeignKey('ab_user.id'))
@@ -1930,6 +1972,35 @@ def granularity(period_name, timezone=None, origin=None):
period_name).total_seconds() * 1000
return granularity
+ def values_for_column(self,
+ column_name,
+ from_dttm,
+ to_dttm,
+ limit=500):
+ """Retrieve some values for the given column"""
+ # TODO: Use Lexicographic TopNMeticSpec onces supported by PyDruid
+ from_dttm = from_dttm.replace(tzinfo=config.get("DRUID_TZ"))
+ to_dttm = to_dttm.replace(tzinfo=config.get("DRUID_TZ"))
+
+ qry = dict(
+ datasource=self.datasource_name,
+ granularity="all",
+ intervals=from_dttm.isoformat() + '/' + to_dttm.isoformat(),
+ aggregations=dict(count=count("count")),
+ dimension=column_name,
+ metric="count",
+ threshold=limit,
+ )
+
+ client = self.cluster.get_pydruid_client()
+ client.topn(**qry)
+ df = client.export_pandas()
+
+ if df is None or df.size == 0:
+ raise Exception(_("No data was returned."))
+
+ return df
+
def query( # druid
self, groupby, metrics,
granularity,
diff --git a/superset/views.py b/superset/views.py
index 85fe9a11d286b..34709a1720dc8 100755
--- a/superset/views.py
+++ b/superset/views.py
@@ -646,7 +646,8 @@ class TableModelView(SupersetModelView, DeleteMixin): # noqa
'link', 'database', 'is_featured', 'changed_on_']
add_columns = ['table_name', 'database', 'schema']
edit_columns = [
- 'table_name', 'sql', 'is_featured', 'database', 'schema',
+ 'table_name', 'sql', 'is_featured', 'filter_select_enabled',
+ 'database', 'schema',
'description', 'owner',
'main_dttm_col', 'default_endpoint', 'offset', 'cache_timeout']
show_columns = edit_columns + ['perm']
@@ -674,6 +675,7 @@ class TableModelView(SupersetModelView, DeleteMixin): # noqa
'database': _("Database"),
'changed_on_': _("Last Changed"),
'is_featured': _("Is Featured"),
+ 'filter_select_enabled': _("Enable Filter Select"),
'schema': _("Schema"),
'default_endpoint': _("Default Endpoint"),
'offset': _("Offset"),
@@ -1031,8 +1033,8 @@ class DruidDatasourceModelView(SupersetModelView, DeleteMixin): # noqa
related_views = [DruidColumnInlineView, DruidMetricInlineView]
edit_columns = [
'datasource_name', 'cluster', 'description', 'owner',
- 'is_featured', 'is_hidden', 'default_endpoint', 'offset',
- 'cache_timeout']
+ 'is_featured', 'is_hidden', 'filter_select_enabled',
+ 'default_endpoint', 'offset', 'cache_timeout']
add_columns = edit_columns
show_columns = add_columns + ['perm']
page_size = 500
@@ -1051,6 +1053,7 @@ class DruidDatasourceModelView(SupersetModelView, DeleteMixin): # noqa
'owner': _("Owner"),
'is_featured': _("Is Featured"),
'is_hidden': _("Is Hidden"),
+ 'filter_select_enabled': _("Enable Filter Select"),
'default_endpoint': _("Default Endpoint"),
'offset': _("Time Offset"),
'cache_timeout': _("Cache Timeout"),
@@ -1494,7 +1497,8 @@ def explore(self, datasource_type, datasource_id):
"datasource_name": viz_obj.datasource.name,
"datasource_type": datasource_type,
"user_id": user_id,
- "viz": json.loads(viz_obj.json_data)
+ "viz": json.loads(viz_obj.json_data),
+ "filter_select": viz_obj.datasource.filter_select_enabled
}
table_name = viz_obj.datasource.table_name \
if datasource_type == 'table' \
@@ -1513,6 +1517,53 @@ def explore(self, datasource_type, datasource_id):
userid=g.user.get_id() if g.user else ''
)
+ @api
+ @has_access_api
+ @expose("/filter/
///")
+ def filter(self, datasource_type, datasource_id, column):
+ """
+ Endpoint to retrieve values for specified column.
+
+ :param datasource_type: Type of datasource e.g. table
+ :param datasource_id: Datasource id
+ :param column: Column name to retrieve values for
+ :return:
+ """
+ # TODO: Cache endpoint by user, datasource and column
+ error_redirect = '/slicemodelview/list/'
+ datasource_class = models.SqlaTable \
+ if datasource_type == "table" else models.DruidDatasource
+
+ datasource = db.session.query(
+ datasource_class).filter_by(id=datasource_id).first()
+
+ if not datasource:
+ flash(DATASOURCE_MISSING_ERR, "alert")
+ return json_error_response(DATASOURCE_MISSING_ERR)
+ if not self.datasource_access(datasource):
+ flash(get_datasource_access_error_msg(datasource.name), "danger")
+ return json_error_response(DATASOURCE_ACCESS_ERR)
+
+ viz_type = request.args.get("viz_type")
+ if not viz_type and datasource.default_endpoint:
+ return redirect(datasource.default_endpoint)
+ if not viz_type:
+ viz_type = "table"
+ try:
+ obj = viz.viz_types[viz_type](
+ datasource,
+ form_data=request.args,
+ slice_=None)
+ except Exception as e:
+ flash(str(e), "danger")
+ return redirect(error_redirect)
+ status = 200
+ payload = obj.get_values_for_column(column)
+ return Response(
+ payload,
+ status=status,
+ mimetype="application/json")
+
def save_or_overwrite_slice(
self, args, slc, slice_add_perm, slice_edit_perm):
"""Save or overwrite a slice"""
diff --git a/superset/viz.py b/superset/viz.py
index 2c142c443df12..3f19dcd7ad632 100755
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -151,6 +151,34 @@ def get_url(self, for_cache_key=False, json_endpoint=False, **kwargs):
del od['force']
return href(od)
+ def get_filter_url(self):
+ """Returns the URL to retrieve column values used in the filter"""
+ data = self.orig_form_data.copy()
+ # Remove unchecked checkboxes because HTML is weird like that
+ ordered_data = MultiDict()
+ for key in sorted(data.keys()):
+ # if MultiDict is initialized with MD({key:[emptyarray]}),
+ # key is included in d.keys() but accessing it throws
+ try:
+ if data[key] is False:
+ del data[key]
+ continue
+ except IndexError:
+ pass
+
+ if isinstance(data, (MultiDict, ImmutableMultiDict)):
+ v = data.getlist(key)
+ else:
+ v = data.get(key)
+ if not isinstance(v, list):
+ v = [v]
+ for item in v:
+ ordered_data.add(key, item)
+ href = Href(
+ '/caravel/filter/{self.datasource.type}/'
+ '{self.datasource.id}/'.format(**locals()))
+ return href(ordered_data)
+
def get_df(self, query_obj=None):
"""Returns a pandas dataframe based on the query object"""
if not query_obj:
@@ -325,6 +353,7 @@ def get_json(self, force=False):
'form_data': self.form_data,
'json_endpoint': self.json_endpoint,
'query': self.query,
+ 'filter_endpoint': self.filter_endpoint,
'standalone_endpoint': self.standalone_endpoint,
'column_formats': self.data['column_formats'],
}
@@ -359,9 +388,11 @@ def data(self):
'csv_endpoint': self.csv_endpoint,
'form_data': self.form_data,
'json_endpoint': self.json_endpoint,
+ 'filter_endpoint': self.filter_endpoint,
'standalone_endpoint': self.standalone_endpoint,
'token': self.token,
'viz_name': self.viz_type,
+ 'filter_select_enabled': self.datasource.filter_select_enabled,
'column_formats': {
m.metric_name: m.d3format
for m in self.datasource.metrics
@@ -375,6 +406,34 @@ def get_csv(self):
include_index = not isinstance(df.index, pd.RangeIndex)
return df.to_csv(index=include_index, encoding="utf-8")
+ def get_values_for_column(self, column):
+ """
+ Retrieves values for a column to be used by the filter dropdown.
+
+ :param column: column name
+ :return: JSON containing the some values for a column
+ """
+ form_data = self.form_data
+
+ since = form_data.get("since", "1 year ago")
+ from_dttm = utils.parse_human_datetime(since)
+ now = datetime.now()
+ if from_dttm > now:
+ from_dttm = now - (from_dttm - now)
+ until = form_data.get("until", "now")
+ to_dttm = utils.parse_human_datetime(until)
+ if from_dttm > to_dttm:
+ flasher("The date range doesn't seem right.", "danger")
+ from_dttm = to_dttm # Making them identical to not raise
+
+ kwargs = dict(
+ column_name=column,
+ from_dttm=from_dttm,
+ to_dttm=to_dttm,
+ )
+ df = self.datasource.values_for_column(**kwargs)
+ return df[column].to_json()
+
def get_data(self):
return []
@@ -382,6 +441,10 @@ def get_data(self):
def json_endpoint(self):
return self.get_url(json_endpoint=True)
+ @property
+ def filter_endpoint(self):
+ return self.get_filter_url()
+
@property
def cache_key(self):
url = self.get_url(for_cache_key=True, json="true", force="false")
diff --git a/tests/core_tests.py b/tests/core_tests.py
index faa8d78baf6f3..4964464e91726 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -122,6 +122,25 @@ def test_save_slice(self):
assert 'Energy' in self.get_resp(
url.format(tbl_id, slice_id, copy_name, 'overwrite'))
+ def test_filter_endpoint(self):
+ self.login(username='admin')
+ slice_name = "Energy Sankey"
+ slice_id = self.get_slice(slice_name, db.session).id
+ db.session.commit()
+ tbl_id = self.table_ids.get('energy_usage')
+ table = db.session.query(models.SqlaTable).filter(models.SqlaTable.id == tbl_id)
+ table.filter_select_enabled = True
+ url = (
+ "/superset/filter/table/{}/target/?viz_type=sankey&groupby=source"
+ "&metric=sum__value&flt_col_0=source&flt_op_0=in&flt_eq_0=&"
+ "slice_id={}&datasource_name=energy_usage&"
+ "datasource_id=1&datasource_type=table")
+
+ # Changing name
+ resp = self.get_resp(url.format(tbl_id, slice_id))
+ assert len(resp) > 0
+ assert 'Carbon Dioxide' in resp
+
def test_slices(self):
# Testing by hitting the two supported end points for all slices
self.login(username='admin')