Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement dttm column configuration through db extra config #9441

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1379,8 +1379,14 @@ Prior to SIP-15 SQLAlchemy used inclusive endpoints however these may behave lik
To remedy this rather than having to define the date/time format for every non-IS0 8601 date-time column, once can define a default column mapping on a per database level via the ``extra`` parameter ::

{
"main_dttm_column": "ds",
"default_dttm_column_names": ["ds", "hour_ts"],
"python_date_format_by_column_name": {
"ds": "%Y-%m-%d"
"ds": "%Y-%m-%d",
"hour_ts": "epoch_s",
}
"expression_by_column_name": {
"hour_ts": "CAST(hour_ts as INTEGER)",
}
}

Expand Down
43 changes: 38 additions & 5 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class TableColumn(Model, BaseColumn):
foreign_keys=[table_id],
)
is_dttm = Column(Boolean, default=False)
# Please use get_expression as a getter for this field.
expression = Column(Text)
python_date_format = Column(String(255))

Expand All @@ -145,6 +146,17 @@ class TableColumn(Model, BaseColumn):
update_from_object_fields = [s for s in export_fields if s not in ("table_id",)]
export_parent = "table"

def get_expression(self):
if config["SIP_15_ENABLED"] and not self.expression:
db_default_expression = (
self.table.database.get_extra()
.get("expression_by_column_name", {})
.get(self.column_name)
)
# TODO(bkyryliuk): consider setting self.expression to db_default_expreassion
return db_default_expression
return self.expression

@property
def is_numeric(self) -> bool:
db_engine_spec = self.table.database.db_engine_spec
Expand All @@ -168,8 +180,9 @@ def is_temporal(self) -> bool:

def get_sqla_col(self, label: Optional[str] = None) -> Column:
label = label or self.column_name
if self.expression:
col = literal_column(self.expression)
expression = self.get_expression()
if expression:
col = literal_column(expression)
else:
db_engine_spec = self.table.database.db_engine_spec
type_ = db_engine_spec.get_sqla_column_type(self.type)
Expand Down Expand Up @@ -221,11 +234,12 @@ def get_timestamp_expression(
db = self.table.database
pdf = self.python_date_format
is_epoch = pdf in ("epoch_s", "epoch_ms")
if not self.expression and not time_grain and not is_epoch:
expression = self.get_expression()
if not expression and not time_grain and not is_epoch:
sqla_col = column(self.column_name, type_=DateTime)
return self.table.make_sqla_column_compatible(sqla_col, label)
if self.expression:
col = literal_column(self.expression)
if expression:
col = literal_column(expression)
else:
col = column(self.column_name)
time_expr = db.db_engine_spec.get_timestamp_expr(
Expand Down Expand Up @@ -268,6 +282,8 @@ def dttm_sql_literal(

# Fallback to the default format (if defined) only if the SIP-15 time range
# endpoints, i.e., [start, end) are enabled.
# TODO(bkyryliuk): serialize the python_date_format_by_column_name in the column object for
# better debuggability and user experience.
if not tf and time_range_endpoints == (
utils.TimeRangeEndpoint.INCLUSIVE,
utils.TimeRangeEndpoint.EXCLUSIVE,
Expand Down Expand Up @@ -1086,6 +1102,14 @@ def fetch_metadata(self, commit=True) -> None:
).format(self.table_name)
)

default_main_dttm_col = None
default_dttm_columns = []
if config["SIP_15_ENABLED"]:
default_main_dttm_col = self.database.get_extra().get("main_dttm_column")
default_dttm_columns = (
self.database.get_extra().get("default_dttm_column_names") or []
)

metrics = []
any_date_col = None
db_engine_spec = self.database.db_engine_spec
Expand Down Expand Up @@ -1113,6 +1137,15 @@ def fetch_metadata(self, commit=True) -> None:
dbcol.avg = dbcol.is_numeric
dbcol.is_dttm = dbcol.is_temporal
db_engine_spec.alter_new_orm_column(dbcol)
# Apply default dttm setting from the database configuration.
if dbcol.column_name in default_dttm_columns:
dbcol.is_dttm = True
if (
default_main_dttm_col
and dbcol.is_dttm
and dbcol.column_name == default_main_dttm_col
):
any_date_col = default_main_dttm_col
else:
dbcol.type = datatype
dbcol.groupby = True
Expand Down
72 changes: 72 additions & 0 deletions tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,6 +1171,78 @@ def test_sqllab_backend_persistence_payload(self):
payload = views.Superset._get_sqllab_tabs(user_id=user_id)
self.assertEqual(len(payload["queries"]), 1)

def test_db_column_defaults(self):
db_extras = json.dumps(
{
"main_datetime_column": "dttm",
"default_dttm_column_names": ["id", "dttm"],
"python_date_format_by_column_name": {
"id": "epoch_ms",
"dttm": "epoch_s",
},
"expression_by_column_name": {"dttm": "CAST(dttm as INTEGER)"},
}
)

self.login(username="admin")
try:
test_db = utils.get_or_create_db(
"column_test_db", app.config["SQLALCHEMY_DATABASE_URI"], extra=db_extras
)

resp = self.client.post(
"/tablemodelview/add",
data=dict(database=test_db.id, table_name="logs"),
follow_redirects=True,
)
self.assertEqual(resp.status_code, 200)
added_table = db.session.query(SqlaTable).filter_by(table_name="logs").one()

# Make sure that dttm column is set properly
self.assertEqual(added_table.main_dttm_col, "dttm")
# Make sure that default_dttm_column_names is set
self.assertEqual(len(added_table.dttm_cols), 2)

# validate python_date_format_by_column_name and expression_by_column_name
dttm_col = [c for c in added_table.columns if c.column_name == "dttm"][0]
self.assertEqual(dttm_col.get_expression(), "CAST(dttm as INTEGER)")
self.assertIn(
"CAST(dttm as INTEGER)", str(dttm_col.get_timestamp_expression("P1W"))
)
self.assertIsNone(
dttm_col.python_date_format
) # defaults are not serialized
tre = utils.TimeRangeEndpoint.INCLUSIVE, utils.TimeRangeEndpoint.EXCLUSIVE
expected_literal = (
"STR_TO_DATE('2019-01-01 00:00:00.000000', '%Y-%m-%d %H:%i:%s.%f')"
)
if test_db.backend == "sqlite":
expected_literal = "1546329600"
self.assertEqual(
dttm_col.dttm_sql_literal(
datetime.datetime(2019, 1, 1), time_range_endpoints=tre
),
expected_literal,
)
self.assertTrue(dttm_col.is_dttm)

id_col = [c for c in added_table.columns if c.column_name == "id"][0]
self.assertIsNone(id_col.get_expression())
self.assertIsNone(id_col.python_date_format) # defaults are not serialized
expected_literal = "1546329600000"
self.assertEqual(
id_col.dttm_sql_literal(
datetime.datetime(2019, 1, 1), time_range_endpoints=tre
),
expected_literal,
)

self.assertTrue(id_col.is_dttm)
finally:
db.session.delete(added_table)
db.session.delete(test_db)
db.session.commit()


if __name__ == "__main__":
unittest.main()