Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement dttm config #4

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,18 @@ def _try_json_readsha(filepath, length): # pylint: disable=unused-argument
# }
TIME_GRAIN_ADDON_FUNCTIONS: Dict[str, Dict[str, str]] = {}

# Default python_date_format and expression for the sql tables dttm columns.
# It is useful for the use cases when there is a company wide convention.
# Example:
# DTTM_CONFIG = {
# 'ts': {'python_date_format': 'epoch_s'},
# 'hour_ts': {
# 'python_date_format': 'epoch_s',
# 'expression': 'CAST(hour_ts as INTEGER)'
# },
# }
DTTM_CONFIG: Optional[Dict[str, Dict[str, str]]] = None

# ---------------------------------------------------
# List of viz_types not allowed in your environment
# For example: Blacklist pivot table and treemap:
Expand Down
21 changes: 19 additions & 2 deletions superset/connectors/sqla/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
"""Views used by the SqlAlchemy connector"""
import logging
import re
from collections import defaultdict

from flask import flash, Markup, redirect
from flask import app, flash, Markup, redirect
from flask_appbuilder import CompactCRUDMixin, expose
from flask_appbuilder.actions import action
from flask_appbuilder.fieldwidgets import Select2Widget
Expand All @@ -29,7 +30,7 @@
from wtforms.ext.sqlalchemy.fields import QuerySelectField
from wtforms.validators import Regexp

from superset import appbuilder, db, security_manager
from superset import app as superset_app, appbuilder, db, security_manager
from superset.connectors.base.views import DatasourceModelView
from superset.utils import core as utils
from superset.views.base import (
Expand All @@ -44,6 +45,7 @@
from . import models

logger = logging.getLogger(__name__)
config = superset_app.config


class TableColumnInlineView(CompactCRUDMixin, SupersetModelView):
Expand Down Expand Up @@ -359,6 +361,21 @@ def pre_add(self, table):

def post_add(self, table, flash_message=True):
table.fetch_metadata()
dttm_config = config.get("DTTM_CONFIG", {})
for col in table.columns:
if col.column_name in dttm_config:
col.is_dttm = True
if not col.expression and "expression" in dttm_config[col.column_name]:
col.expression = dttm_config[col.column_name]["expression"]
if (
not col.python_date_format
and "python_date_format" in dttm_config[col.column_name]
):
col.python_date_format = dttm_config[col.column_name][
"python_date_format"
]
db.session.commit()

security_manager.add_permission_view_menu("datasource_access", table.get_perm())
if table.schema:
security_manager.add_permission_view_menu(
Expand Down
32 changes: 32 additions & 0 deletions tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,38 @@ def test_tablemodelview_list(self):
assert table.name in resp
assert "/superset/explore/table/{}".format(table.id) in resp

@mock.patch("superset.connectors.sqla.views.config")
def test_tablemodelview_add(self, mock_config):
mock_config.get.return_value = {
"dttm": {
"expression": "test_expression",
"python_date_format": "test_python_date_format",
},
"nonexistant": {
"expression": "test_expression_v2",
"python_date_format": "python_date_format_v2",
},
}
self.login(username="admin")
# assert that /tablemodelview/add responds with 200
example_db = utils.get_example_database()
resp = self.client.post(
"/tablemodelview/add",
data=dict(database=example_db.id, table_name="logs"),
follow_redirects=True,
)
self.assertEqual(resp.status_code, 200)
added_table = db.session.query(SqlaTable).filter_by(table_name="logs").one()

assert len(added_table.dttm_cols) == 1
dttm_col_name = added_table.dttm_cols[0]
assert dttm_col_name == "dttm"

# Make sure that dttm defaults were propagated.
dttm_col = [c for c in added_table.columns if c.column_name == dttm_col_name][0]
assert dttm_col.expression == "test_expression"
assert dttm_col.python_date_format == "test_python_date_format"

def test_add_slice(self):
self.login(username="admin")
# assert that /chart/add responds with 200
Expand Down