Skip to content

Commit

Permalink
fix: do not drop calculated column on metadata sync (#11731)
Browse files Browse the repository at this point in the history
  • Loading branch information
villebro authored Nov 18, 2020
1 parent 676e0bb commit 7ae8cd0
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 10 deletions.
6 changes: 2 additions & 4 deletions superset/connectors/base/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,8 @@ def slices(self) -> RelationshipProperty:
),
)

# placeholder for a relationship to a derivative of BaseColumn
columns: List[Any] = []
# placeholder for a relationship to a derivative of BaseMetric
metrics: List[Any] = []
columns: List["BaseColumn"] = []
metrics: List["BaseMetric"] = []

@property
def type(self) -> str:
Expand Down
2 changes: 2 additions & 0 deletions superset/connectors/druid/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,8 @@ class DruidDatasource(Model, BaseDatasource):
type = "druid"
query_language = "json"
cluster_class = DruidCluster
columns: List[DruidColumn] = []
metrics: List[DruidMetric] = []
metric_class = DruidMetric
column_class = DruidColumn
owner_class = security_manager.user_model
Expand Down
12 changes: 10 additions & 2 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,8 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
type = "table"
query_language = "sql"
is_rls_supported = True
columns: List[TableColumn] = []
metrics: List[SqlMetric] = []
metric_class = SqlMetric
column_class = TableColumn
owner_class = security_manager.user_model
Expand Down Expand Up @@ -1333,7 +1335,9 @@ def fetch_metadata(self, commit: bool = True) -> MetadataResult:
db_engine_spec = self.database.db_engine_spec
old_columns = db.session.query(TableColumn).filter(TableColumn.table == self)

old_columns_by_name = {col.column_name: col for col in old_columns}
old_columns_by_name: Dict[str, TableColumn] = {
col.column_name: col for col in old_columns
}
results = MetadataResult(
removed=[
col
Expand All @@ -1345,7 +1349,7 @@ def fetch_metadata(self, commit: bool = True) -> MetadataResult:
# clear old columns before adding modified columns back
self.columns = []
for col in new_columns:
old_column = old_columns_by_name.get(col["name"], None)
old_column = old_columns_by_name.pop(col["name"], None)
if not old_column:
results.added.append(col["name"])
new_column = TableColumn(
Expand All @@ -1358,11 +1362,15 @@ def fetch_metadata(self, commit: bool = True) -> MetadataResult:
if new_column.type != col["type"]:
results.modified.append(col["name"])
new_column.type = col["type"]
new_column.expression = ""
new_column.groupby = True
new_column.filterable = True
self.columns.append(new_column)
if not any_date_col and new_column.is_temporal:
any_date_col = col["name"]
self.columns.extend(
[col for col in old_columns_by_name.values() if col.expression]
)
metrics.append(
SqlMetric(
metric_name="count",
Expand Down
10 changes: 7 additions & 3 deletions superset/connectors/sqla/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,9 +454,13 @@ def pre_add(self, item: "TableModelView") -> None:
validate_sqlatable(item)

def post_add( # pylint: disable=arguments-differ
self, item: "TableModelView", flash_message: bool = True
self,
item: "TableModelView",
flash_message: bool = True,
fetch_metadata: bool = True,
) -> None:
item.fetch_metadata()
if fetch_metadata:
item.fetch_metadata()
create_table_permissions(item)
if flash_message:
flash(
Expand All @@ -470,7 +474,7 @@ def post_add( # pylint: disable=arguments-differ
)

def post_update(self, item: "TableModelView") -> None:
self.post_add(item, flash_message=False)
self.post_add(item, flash_message=False, fetch_metadata=False)

def _delete(self, pk: int) -> None:
DeleteMixin._delete(self, pk)
Expand Down
61 changes: 60 additions & 1 deletion tests/sqla_models_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
from typing import Any, Dict, NamedTuple, List, Tuple, Union
import re
from typing import Any, Dict, NamedTuple, List, Pattern, Tuple, Union
from unittest.mock import patch
import pytest

Expand All @@ -30,6 +31,23 @@
from .base_tests import SupersetTestCase


VIRTUAL_TABLE_INT_TYPES: Dict[str, Pattern[str]] = {
"hive": re.compile(r"^INT_TYPE$"),
"mysql": re.compile("^LONGLONG$"),
"postgresql": re.compile(r"^INT$"),
"presto": re.compile(r"^INTEGER$"),
"sqlite": re.compile(r"^INT$"),
}

VIRTUAL_TABLE_STRING_TYPES: Dict[str, Pattern[str]] = {
"hive": re.compile(r"^STRING_TYPE$"),
"mysql": re.compile(r"^VAR_STRING$"),
"postgresql": re.compile(r"^STRING$"),
"presto": re.compile(r"^VARCHAR*"),
"sqlite": re.compile(r"^STRING$"),
}


class TestDatabaseModel(SupersetTestCase):
def test_is_time_druid_time_col(self):
"""Druid has a special __time column"""
Expand Down Expand Up @@ -247,3 +265,44 @@ def test_dml_statement_raises_exception(self):
query_obj = dict(**base_query_obj, extras={})
with pytest.raises(QueryObjectValidationError):
table.get_sqla_query(**query_obj)

def test_fetch_metadata_for_updated_virtual_table(self):
table = SqlaTable(
table_name="updated_sql_table",
database=get_example_database(),
sql="select 123 as intcol, 'abc' as strcol, 'abc' as mycase",
)
TableColumn(column_name="intcol", type="FLOAT", table=table)
TableColumn(column_name="oldcol", type="INT", table=table)
TableColumn(
column_name="expr",
expression="case when 1 then 1 else 0 end",
type="INT",
table=table,
)
TableColumn(
column_name="mycase",
expression="case when 1 then 1 else 0 end",
type="INT",
table=table,
)

# make sure the columns have been mapped properly
assert len(table.columns) == 4
table.fetch_metadata()
# assert that the removed column has been dropped and
# the physical and calculated columns are present
assert {col.column_name for col in table.columns} == {
"intcol",
"strcol",
"mycase",
"expr",
}
cols: Dict[str, TableColumn] = {col.column_name: col for col in table.columns}
# assert that the type for intcol has been updated (asserting CI types)
backend = get_example_database().backend
assert VIRTUAL_TABLE_INT_TYPES[backend].match(cols["intcol"].type)
# assert that the expression has been replaced with the new physical column
assert cols["mycase"].expression == ""
assert VIRTUAL_TABLE_STRING_TYPES[backend].match(cols["mycase"].type)
assert cols["expr"].expression == "case when 1 then 1 else 0 end"

0 comments on commit 7ae8cd0

Please sign in to comment.