diff --git a/superset/datasets/dao.py b/superset/datasets/dao.py index 44ab8efa0ce54..961bf1fc12390 100644 --- a/superset/datasets/dao.py +++ b/superset/datasets/dao.py @@ -163,14 +163,10 @@ def update( """ if "columns" in properties: - properties["columns"] = cls.update_columns( - model, properties.get("columns", []), commit=commit - ) + properties["columns"] = cls.update_columns(model, properties["columns"]) if "metrics" in properties: - properties["metrics"] = cls.update_metrics( - model, properties.get("metrics", []), commit=commit - ) + properties["metrics"] = cls.update_metrics(model, properties["metrics"]) return super().update(model, properties, commit=commit) @@ -179,7 +175,6 @@ def update_columns( cls, model: SqlaTable, property_columns: List[Dict[str, Any]], - commit: bool = True, ) -> List[TableColumn]: """ Creates/updates and/or deletes a list of columns, based on a @@ -190,28 +185,34 @@ def update_columns( - If there are extra columns on the metadata db that are not defined on the List then we delete. """ - new_columns = [] - for column in property_columns: - column_id = column.get("id") - if column_id: - column_obj = db.session.query(TableColumn).get(column_id) - column_obj = DatasetDAO.update_column(column_obj, column, commit=commit) + column_by_id = {column.id: column for column in model.columns} + columns = [] + + for properties in property_columns: + if "id" in properties: + columns.append( + DatasetDAO.update_column( + column_by_id[properties["id"]], + properties, + commit=False, + ) + ) else: - column_obj = DatasetDAO.create_column(column, commit=commit) - new_columns.append(column_obj) - # Checks if an exiting column is missing from properties and delete it - for existing_column in model.columns: - if existing_column.id not in [column.id for column in new_columns]: - DatasetDAO.delete_column(existing_column) - return new_columns + + # Note for new columns the primary key is undefined sans a commit/flush. + columns.append(DatasetDAO.create_column(properties, commit=False)) + + for id_ in {obj.id for obj in model.columns} - {obj.id for obj in columns}: + DatasetDAO.delete_column(column_by_id[id_], commit=False) + + return columns @classmethod def update_metrics( cls, model: SqlaTable, property_metrics: List[Dict[str, Any]], - commit: bool = True, ) -> List[SqlMetric]: """ Creates/updates and/or deletes a list of metrics, based on a @@ -222,21 +223,28 @@ def update_metrics( - If there are extra metrics on the metadata db that are not defined on the List then we delete. """ - new_metrics = [] - for metric in property_metrics: - metric_id = metric.get("id") - if metric.get("id"): - metric_obj = db.session.query(SqlMetric).get(metric_id) - metric_obj = DatasetDAO.update_metric(metric_obj, metric, commit=commit) + + metric_by_id = {metric.id: metric for metric in model.metrics} + metrics = [] + + for properties in property_metrics: + if "id" in properties: + metrics.append( + DatasetDAO.update_metric( + metric_by_id[properties["id"]], + properties, + commit=False, + ) + ) else: - metric_obj = DatasetDAO.create_metric(metric, commit=commit) - new_metrics.append(metric_obj) - # Checks if an exiting column is missing from properties and delete it - for existing_metric in model.metrics: - if existing_metric.id not in [metric.id for metric in new_metrics]: - DatasetDAO.delete_metric(existing_metric) - return new_metrics + # Note for new metrics the primary key is undefined sans a commit/flush. + metrics.append(DatasetDAO.create_metric(properties, commit=False)) + + for id_ in {obj.id for obj in model.metrics} - {obj.id for obj in metrics}: + DatasetDAO.delete_column(metric_by_id[id_], commit=False) + + return metrics @classmethod def find_dataset_column( @@ -254,23 +262,24 @@ def find_dataset_column( @classmethod def update_column( - cls, model: TableColumn, properties: Dict[str, Any], commit: bool = True - ) -> Optional[TableColumn]: + cls, + model: TableColumn, + properties: Dict[str, Any], + commit: bool = True, + ) -> TableColumn: return DatasetColumnDAO.update(model, properties, commit=commit) @classmethod def create_column( cls, properties: Dict[str, Any], commit: bool = True - ) -> Optional[TableColumn]: + ) -> TableColumn: """ Creates a Dataset model on the metadata DB """ return DatasetColumnDAO.create(properties, commit=commit) @classmethod - def delete_column( - cls, model: TableColumn, commit: bool = True - ) -> Optional[TableColumn]: + def delete_column(cls, model: TableColumn, commit: bool = True) -> TableColumn: """ Deletes a Dataset column """ @@ -287,9 +296,7 @@ def find_dataset_metric( return db.session.query(SqlMetric).get(metric_id) @classmethod - def delete_metric( - cls, model: SqlMetric, commit: bool = True - ) -> Optional[TableColumn]: + def delete_metric(cls, model: SqlMetric, commit: bool = True) -> SqlMetric: """ Deletes a Dataset metric """ @@ -297,14 +304,19 @@ def delete_metric( @classmethod def update_metric( - cls, model: SqlMetric, properties: Dict[str, Any], commit: bool = True - ) -> Optional[SqlMetric]: + cls, + model: SqlMetric, + properties: Dict[str, Any], + commit: bool = True, + ) -> SqlMetric: return DatasetMetricDAO.update(model, properties, commit=commit) @classmethod def create_metric( - cls, properties: Dict[str, Any], commit: bool = True - ) -> Optional[SqlMetric]: + cls, + properties: Dict[str, Any], + commit: bool = True, + ) -> SqlMetric: """ Creates a Dataset model on the metadata DB """