diff --git a/docs/installation.rst b/docs/installation.rst index c08dab109aef0..10720fa96909b 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -1379,8 +1379,14 @@ Prior to SIP-15 SQLAlchemy used inclusive endpoints however these may behave lik To remedy this rather than having to define the date/time format for every non-IS0 8601 date-time column, once can define a default column mapping on a per database level via the ``extra`` parameter :: { + "main_dttm_column": "ds", + "default_dttm_column_names": ["ds", "hour_ts"], "python_date_format_by_column_name": { - "ds": "%Y-%m-%d" + "ds": "%Y-%m-%d", + "hour_ts": "epoch_s", + } + "expression_by_column_name": { + "hour_ts": "CAST(hour_ts as INTEGER)", } } diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 947551cf5bf21..91d4e4c98cbfe 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -125,6 +125,7 @@ class TableColumn(Model, BaseColumn): foreign_keys=[table_id], ) is_dttm = Column(Boolean, default=False) + # Please use get_expression as a getter for this field. expression = Column(Text) python_date_format = Column(String(255)) @@ -145,6 +146,17 @@ class TableColumn(Model, BaseColumn): update_from_object_fields = [s for s in export_fields if s not in ("table_id",)] export_parent = "table" + def get_expression(self): + if config["SIP_15_ENABLED"] and not self.expression: + db_default_expression = ( + self.table.database.get_extra() + .get("expression_by_column_name", {}) + .get(self.column_name) + ) + # TODO(bkyryliuk): consider setting self.expression to db_default_expreassion + return db_default_expression + return self.expression + @property def is_numeric(self) -> bool: db_engine_spec = self.table.database.db_engine_spec @@ -168,8 +180,9 @@ def is_temporal(self) -> bool: def get_sqla_col(self, label: Optional[str] = None) -> Column: label = label or self.column_name - if self.expression: - col = literal_column(self.expression) + expression = self.get_expression() + if expression: + col = literal_column(expression) else: db_engine_spec = self.table.database.db_engine_spec type_ = db_engine_spec.get_sqla_column_type(self.type) @@ -221,11 +234,12 @@ def get_timestamp_expression( db = self.table.database pdf = self.python_date_format is_epoch = pdf in ("epoch_s", "epoch_ms") - if not self.expression and not time_grain and not is_epoch: + expression = self.get_expression() + if not expression and not time_grain and not is_epoch: sqla_col = column(self.column_name, type_=DateTime) return self.table.make_sqla_column_compatible(sqla_col, label) - if self.expression: - col = literal_column(self.expression) + if expression: + col = literal_column(expression) else: col = column(self.column_name) time_expr = db.db_engine_spec.get_timestamp_expr( @@ -268,6 +282,8 @@ def dttm_sql_literal( # Fallback to the default format (if defined) only if the SIP-15 time range # endpoints, i.e., [start, end) are enabled. + # TODO(bkyryliuk): serialize the python_date_format_by_column_name in the column object for + # better debuggability and user experience. if not tf and time_range_endpoints == ( utils.TimeRangeEndpoint.INCLUSIVE, utils.TimeRangeEndpoint.EXCLUSIVE, @@ -1086,6 +1102,14 @@ def fetch_metadata(self, commit=True) -> None: ).format(self.table_name) ) + default_main_dttm_col = None + default_dttm_columns = [] + if config["SIP_15_ENABLED"]: + default_main_dttm_col = self.database.get_extra().get("main_dttm_column") + default_dttm_columns = ( + self.database.get_extra().get("default_dttm_column_names") or [] + ) + metrics = [] any_date_col = None db_engine_spec = self.database.db_engine_spec @@ -1113,6 +1137,15 @@ def fetch_metadata(self, commit=True) -> None: dbcol.avg = dbcol.is_numeric dbcol.is_dttm = dbcol.is_temporal db_engine_spec.alter_new_orm_column(dbcol) + # Apply default dttm setting from the database configuration. + if dbcol.column_name in default_dttm_columns: + dbcol.is_dttm = True + if ( + default_main_dttm_col + and dbcol.is_dttm + and dbcol.column_name == default_main_dttm_col + ): + any_date_col = default_main_dttm_col else: dbcol.type = datatype dbcol.groupby = True diff --git a/tests/core_tests.py b/tests/core_tests.py index 9c6b54a42cbf5..9d5bdfc3d11bd 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -1171,6 +1171,78 @@ def test_sqllab_backend_persistence_payload(self): payload = views.Superset._get_sqllab_tabs(user_id=user_id) self.assertEqual(len(payload["queries"]), 1) + def test_db_column_defaults(self): + db_extras = json.dumps( + { + "main_datetime_column": "dttm", + "default_dttm_column_names": ["id", "dttm"], + "python_date_format_by_column_name": { + "id": "epoch_ms", + "dttm": "epoch_s", + }, + "expression_by_column_name": {"dttm": "CAST(dttm as INTEGER)"}, + } + ) + + self.login(username="admin") + try: + test_db = utils.get_or_create_db( + "column_test_db", app.config["SQLALCHEMY_DATABASE_URI"], extra=db_extras + ) + + resp = self.client.post( + "/tablemodelview/add", + data=dict(database=test_db.id, table_name="logs"), + follow_redirects=True, + ) + self.assertEqual(resp.status_code, 200) + added_table = db.session.query(SqlaTable).filter_by(table_name="logs").one() + + # Make sure that dttm column is set properly + self.assertEqual(added_table.main_dttm_col, "dttm") + # Make sure that default_dttm_column_names is set + self.assertEqual(len(added_table.dttm_cols), 2) + + # validate python_date_format_by_column_name and expression_by_column_name + dttm_col = [c for c in added_table.columns if c.column_name == "dttm"][0] + self.assertEqual(dttm_col.get_expression(), "CAST(dttm as INTEGER)") + self.assertIn( + "CAST(dttm as INTEGER)", str(dttm_col.get_timestamp_expression("P1W")) + ) + self.assertIsNone( + dttm_col.python_date_format + ) # defaults are not serialized + tre = utils.TimeRangeEndpoint.INCLUSIVE, utils.TimeRangeEndpoint.EXCLUSIVE + expected_literal = ( + "STR_TO_DATE('2019-01-01 00:00:00.000000', '%Y-%m-%d %H:%i:%s.%f')" + ) + if test_db.backend == "sqlite": + expected_literal = "1546329600" + self.assertEqual( + dttm_col.dttm_sql_literal( + datetime.datetime(2019, 1, 1), time_range_endpoints=tre + ), + expected_literal, + ) + self.assertTrue(dttm_col.is_dttm) + + id_col = [c for c in added_table.columns if c.column_name == "id"][0] + self.assertIsNone(id_col.get_expression()) + self.assertIsNone(id_col.python_date_format) # defaults are not serialized + expected_literal = "1546329600000" + self.assertEqual( + id_col.dttm_sql_literal( + datetime.datetime(2019, 1, 1), time_range_endpoints=tre + ), + expected_literal, + ) + + self.assertTrue(id_col.is_dttm) + finally: + db.session.delete(added_table) + db.session.delete(test_db) + db.session.commit() + if __name__ == "__main__": unittest.main()