Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Embrace the walrus operator #24127

Merged
merged 1 commit into from
May 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ repos:
rev: 5.12.0
hooks:
- id: isort
- repo: https://github.com/MarcoGorelli/auto-walrus
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably should reformat the code before running Mypy, Black, etc.

rev: v0.2.2
hooks:
- id: auto-walrus
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.3.0
hooks:
Expand Down
6 changes: 2 additions & 4 deletions superset/charts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,8 +649,7 @@ def screenshot(self, pk: int, digest: str) -> WerkzeugResponse:
return self.response_404()

# fetch the chart screenshot using the current user and cache if set
img = ChartScreenshot.get_from_cache_key(thumbnail_cache, digest)
if img:
if img := ChartScreenshot.get_from_cache_key(thumbnail_cache, digest):
return Response(
FileWrapper(img), mimetype="image/png", direct_passthrough=True
)
Expand Down Expand Up @@ -783,7 +782,6 @@ def export(self, **kwargs: Any) -> Response:
500:
$ref: '#/components/responses/500'
"""
token = request.args.get("token")
requested_ids = kwargs["rison"]
timestamp = datetime.now().strftime("%Y%m%dT%H%M%S")
root = f"chart_export_{timestamp}"
Expand All @@ -805,7 +803,7 @@ def export(self, **kwargs: Any) -> Response:
as_attachment=True,
download_name=filename,
)
if token:
if token := request.args.get("token"):
response.set_cookie(token, "done", max_age=600)
return response

Expand Down
3 changes: 1 addition & 2 deletions superset/charts/commands/bulk_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ def validate(self) -> None:
if not self._models or len(self._models) != len(self._model_ids):
raise ChartNotFoundError()
# Check there are no associated ReportSchedules
reports = ReportScheduleDAO.find_by_chart_ids(self._model_ids)
if reports:
if reports := ReportScheduleDAO.find_by_chart_ids(self._model_ids):
report_names = [report.name for report in reports]
raise ChartBulkDeleteFailedReportsExistError(
_("There are associated alerts or reports: %s" % ",".join(report_names))
Expand Down
3 changes: 1 addition & 2 deletions superset/charts/commands/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ def validate(self) -> None:
if not self._model:
raise ChartNotFoundError()
# Check there are no associated ReportSchedules
reports = ReportScheduleDAO.find_by_chart_id(self._model_id)
if reports:
if reports := ReportScheduleDAO.find_by_chart_id(self._model_id):
report_names = [report.name for report in reports]
raise ChartDeleteFailedReportsExistError(
_("There are associated alerts or reports: %s" % ",".join(report_names))
Expand Down
3 changes: 1 addition & 2 deletions superset/common/query_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,7 @@ def get_query_results(
:raises QueryObjectValidationError: if an unsupported result type is requested
:return: JSON serializable result payload
"""
result_func = _result_type_functions.get(result_type)
if result_func:
if result_func := _result_type_functions.get(result_type):
return result_func(query_context, query_obj, force_cached)
raise QueryObjectValidationError(
_("Invalid result type: %(result_type)s", result_type=result_type)
Expand Down
3 changes: 1 addition & 2 deletions superset/common/query_context_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,9 @@ def _apply_granularity(
for column in datasource.columns
if (column["is_dttm"] if isinstance(column, dict) else column.is_dttm)
}
granularity = query_object.granularity
x_axis = form_data and form_data.get("x_axis")

if granularity:
if granularity := query_object.granularity:
filter_to_remove = None
if x_axis and x_axis in temporal_columns:
filter_to_remove = x_axis
Expand Down
3 changes: 1 addition & 2 deletions superset/common/query_context_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,8 +500,7 @@ def get_payload(
return return_value

def get_cache_timeout(self) -> int:
cache_timeout_rv = self._query_context.get_cache_timeout()
if cache_timeout_rv:
if cache_timeout_rv := self._query_context.get_cache_timeout():
return cache_timeout_rv
if (
data_cache_timeout := config["DATA_CACHE_CONFIG"].get(
Expand Down
3 changes: 1 addition & 2 deletions superset/common/utils/query_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,7 @@ def get(
if not key or not _cache[region] or force_query:
return query_cache

cache_value = _cache[region].get(key)
if cache_value:
if cache_value := _cache[region].get(key):
logger.debug("Cache key: %s", key)
stats_logger.incr("loading_from_cache")
try:
Expand Down
3 changes: 1 addition & 2 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,11 +993,10 @@ def adhoc_column_to_sqla( # pylint: disable=too-many-locals
schema=self.schema,
template_processor=template_processor,
)
col_in_metadata = self.get_column(expression)
time_grain = col.get("timeGrain")
has_timegrain = col.get("columnType") == "BASE_AXIS" and time_grain
is_dttm = False
if col_in_metadata:
if col_in_metadata := self.get_column(expression):
sqla_column = col_in_metadata.get_sqla_col(
template_processor=template_processor
)
Expand Down
3 changes: 1 addition & 2 deletions superset/dashboards/commands/bulk_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ def validate(self) -> None:
if not self._models or len(self._models) != len(self._model_ids):
raise DashboardNotFoundError()
# Check there are no associated ReportSchedules
reports = ReportScheduleDAO.find_by_dashboard_ids(self._model_ids)
if reports:
if reports := ReportScheduleDAO.find_by_dashboard_ids(self._model_ids):
report_names = [report.name for report in reports]
raise DashboardBulkDeleteFailedReportsExistError(
_("There are associated alerts or reports: %s" % ",".join(report_names))
Expand Down
3 changes: 1 addition & 2 deletions superset/dashboards/commands/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def validate(self) -> None:
if not self._model:
raise DashboardNotFoundError()
# Check there are no associated ReportSchedules
reports = ReportScheduleDAO.find_by_dashboard_id(self._model_id)
if reports:
if reports := ReportScheduleDAO.find_by_dashboard_id(self._model_id):
report_names = [report.name for report in reports]
raise DashboardDeleteFailedReportsExistError(
_("There are associated alerts or reports: %s" % ",".join(report_names))
Expand Down
3 changes: 1 addition & 2 deletions superset/dashboards/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,10 @@ def set_dash_metadata( # pylint: disable=too-many-locals
old_to_new_slice_ids: Optional[Dict[int, int]] = None,
commit: bool = False,
) -> Dashboard:
positions = data.get("positions")
new_filter_scopes = {}
md = dashboard.params_dict

if positions is not None:
if (positions := data.get("positions")) is not None:
# find slices in the position data
slice_ids = [
value.get("meta", {}).get("chartId")
Expand Down
3 changes: 1 addition & 2 deletions superset/databases/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1036,7 +1036,6 @@ def export(self, **kwargs: Any) -> Response:
500:
$ref: '#/components/responses/500'
"""
token = request.args.get("token")
requested_ids = kwargs["rison"]
timestamp = datetime.now().strftime("%Y%m%dT%H%M%S")
root = f"database_export_{timestamp}"
Expand All @@ -1060,7 +1059,7 @@ def export(self, **kwargs: Any) -> Response:
as_attachment=True,
download_name=filename,
)
if token:
if token := request.args.get("token"):
response.set_cookie(token, "done", max_age=600)
return response

Expand Down
3 changes: 1 addition & 2 deletions superset/databases/commands/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,8 @@ def validate(self) -> None:
if not self._model:
raise DatabaseNotFoundError()
# Check there are no associated ReportSchedules
reports = ReportScheduleDAO.find_by_database_id(self._model_id)

if reports:
if reports := ReportScheduleDAO.find_by_database_id(self._model_id):
report_names = [report.name for report in reports]
raise DatabaseDeleteFailedReportsExistError(
_("There are associated alerts or reports: %s" % ",".join(report_names))
Expand Down
3 changes: 1 addition & 2 deletions superset/databases/commands/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,5 @@ def ping(engine: Engine) -> bool:
raise DatabaseTestConnectionUnexpectedError(errors) from ex

def validate(self) -> None:
database_name = self._properties.get("database_name")
if database_name is not None:
if (database_name := self._properties.get("database_name")) is not None:
self._model = DatabaseDAO.get_database_by_name(database_name)
3 changes: 1 addition & 2 deletions superset/databases/commands/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,5 @@ def run(self) -> None:
)

def validate(self) -> None:
database_id = self._properties.get("id")
if database_id is not None:
if (database_id := self._properties.get("id")) is not None:
self._model = DatabaseDAO.find_by_id(database_id)
3 changes: 1 addition & 2 deletions superset/datasets/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,8 +977,7 @@ def get_or_create_dataset(self) -> Response:
return self.response(400, message=ex.messages)
table_name = body["table_name"]
database_id = body["database_id"]
table = DatasetDAO.get_table_by_name(database_id, table_name)
if table:
if table := DatasetDAO.get_table_by_name(database_id, table_name):
return self.response(200, result={"table_id": table.id})

body["database"] = database_id
Expand Down
3 changes: 1 addition & 2 deletions superset/datasets/commands/importers/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,7 @@ def get_sqla_type(native_type: str) -> VisitableType:
if native_type.upper() in type_map:
return type_map[native_type.upper()]

match = VARCHAR.match(native_type)
if match:
if match := VARCHAR.match(native_type):
size = int(match.group(1))
return String(size)

Expand Down
6 changes: 2 additions & 4 deletions superset/datasets/commands/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,11 @@ def validate(self) -> None:
exceptions.append(DatasetEndpointUnsafeValidationError())

# Validate columns
columns = self._properties.get("columns")
if columns:
if columns := self._properties.get("columns"):
self._validate_columns(columns, exceptions)

# Validate metrics
metrics = self._properties.get("metrics")
if metrics:
if metrics := self._properties.get("metrics"):
self._validate_metrics(metrics, exceptions)

if exceptions:
Expand Down
6 changes: 2 additions & 4 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1704,8 +1704,7 @@ def get_column_spec( # pylint: disable=unused-argument
:param source: Type coming from the database table or cursor description
:return: ColumnSpec object
"""
col_types = cls.get_column_types(native_type)
if col_types:
if col_types := cls.get_column_types(native_type):
column_type, generic_type = col_types
is_dttm = generic_type == GenericDataType.TEMPORAL
return ColumnSpec(
Expand Down Expand Up @@ -1996,9 +1995,8 @@ def validate_parameters(
required = {"host", "port", "username", "database"}
parameters = properties.get("parameters", {})
present = {key for key in parameters if parameters.get(key, ())}
missing = sorted(required - present)

if missing:
if missing := sorted(required - present):
errors.append(
SupersetError(
message=f'One or more parameters are missing: {", ".join(missing)}',
Expand Down
3 changes: 1 addition & 2 deletions superset/db_engine_specs/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,9 +384,8 @@ def df_to_sql(
}

# Add credentials if they are set on the SQLAlchemy dialect.
creds = engine.dialect.credentials_info

if creds:
if creds := engine.dialect.credentials_info:
to_gbq_kwargs[
"credentials"
] = service_account.Credentials.from_service_account_info(creds)
Expand Down
3 changes: 1 addition & 2 deletions superset/db_engine_specs/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,8 @@ def validate_parameters( # type: ignore
parameters["http_path"] = connect_args.get("http_path")

present = {key for key in parameters if parameters.get(key, ())}
missing = sorted(required - present)

if missing:
if missing := sorted(required - present):
errors.append(
SupersetError(
message=f'One or more parameters are missing: {", ".join(missing)}',
Expand Down
6 changes: 2 additions & 4 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1213,8 +1213,7 @@ def extra_table_metadata(
) -> Dict[str, Any]:
metadata = {}

indexes = database.get_indexes(table_name, schema_name)
if indexes:
if indexes := database.get_indexes(table_name, schema_name):
col_names, latest_parts = cls.latest_partition(
table_name, schema_name, database, show_first=True
)
Expand Down Expand Up @@ -1278,8 +1277,7 @@ def get_tracking_url(cls, cursor: "Cursor") -> Optional[str]:
@classmethod
def handle_cursor(cls, cursor: "Cursor", query: Query, session: Session) -> None:
"""Updates progress information"""
tracking_url = cls.get_tracking_url(cursor)
if tracking_url:
if tracking_url := cls.get_tracking_url(cursor):
query.tracking_url = tracking_url
session.commit()

Expand Down
3 changes: 1 addition & 2 deletions superset/db_engine_specs/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,9 +312,8 @@ def validate_parameters(
}
parameters = properties.get("parameters", {})
present = {key for key in parameters if parameters.get(key, ())}
missing = sorted(required - present)

if missing:
if missing := sorted(required - present):
errors.append(
SupersetError(
message=f'One or more parameters are missing: {", ".join(missing)}',
Expand Down
6 changes: 2 additions & 4 deletions superset/db_engine_specs/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def extra_table_metadata(
) -> Dict[str, Any]:
metadata = {}

indexes = database.get_indexes(table_name, schema_name)
if indexes:
if indexes := database.get_indexes(table_name, schema_name):
col_names, latest_parts = cls.latest_partition(
table_name, schema_name, database, show_first=True
)
Expand Down Expand Up @@ -150,8 +149,7 @@ def get_tracking_url(cls, cursor: Cursor) -> Optional[str]:

@classmethod
def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None:
tracking_url = cls.get_tracking_url(cursor)
if tracking_url:
if tracking_url := cls.get_tracking_url(cursor):
query.tracking_url = tracking_url

# Adds the executed query id to the extra payload so the query can be cancelled
Expand Down
3 changes: 1 addition & 2 deletions superset/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,7 @@ def __post_init__(self) -> None:
Mutates the extra params with user facing error codes that map to backend
errors.
"""
issue_codes = ERROR_TYPES_TO_ISSUE_CODES_MAPPING.get(self.error_type)
if issue_codes:
if issue_codes := ERROR_TYPES_TO_ISSUE_CODES_MAPPING.get(self.error_type):
self.extra = self.extra or {}
self.extra.update(
{
Expand Down
3 changes: 1 addition & 2 deletions superset/initialization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,8 +453,7 @@ def init_app_in_ctx(self) -> None:

# Hook that provides administrators a handle on the Flask APP
# after initialization
flask_app_mutator = self.config["FLASK_APP_MUTATOR"]
if flask_app_mutator:
if flask_app_mutator := self.config["FLASK_APP_MUTATOR"]:
flask_app_mutator(self.superset_app)

if feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM"):
Expand Down
5 changes: 2 additions & 3 deletions superset/migrations/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,16 +103,15 @@ def process_revision_directives( # pylint: disable=redefined-outer-name, unused
kwargs = {}
if engine.name in ("sqlite", "mysql"):
kwargs = {"transaction_per_migration": True, "transactional_ddl": True}
configure_args = current_app.extensions["migrate"].configure_args
if configure_args:
if configure_args := current_app.extensions["migrate"].configure_args:
kwargs.update(configure_args)

context.configure(
connection=connection,
target_metadata=target_metadata,
# compare_type=True,
process_revision_directives=process_revision_directives,
**kwargs
**kwargs,
)

try:
Expand Down
6 changes: 2 additions & 4 deletions superset/migrations/shared/migrate_viz/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,17 +130,15 @@ def upgrade_slice(cls, slc: Slice) -> Slice:
# only backup params
slc.params = json.dumps({**clz.data, FORM_DATA_BAK_FIELD_NAME: form_data_bak})

query_context = try_load_json(slc.query_context)
if "form_data" in query_context:
if "form_data" in (query_context := try_load_json(slc.query_context)):
query_context["form_data"] = clz.data
slc.query_context = json.dumps(query_context)
return slc

@classmethod
def downgrade_slice(cls, slc: Slice) -> Slice:
form_data = try_load_json(slc.params)
form_data_bak = form_data.get(FORM_DATA_BAK_FIELD_NAME, {})
if "viz_type" in form_data_bak:
if "viz_type" in (form_data_bak := form_data.get(FORM_DATA_BAK_FIELD_NAME, {})):
slc.params = json.dumps(form_data_bak)
slc.viz_type = form_data_bak.get("viz_type")
query_context = try_load_json(slc.query_context)
Expand Down
6 changes: 2 additions & 4 deletions superset/migrations/shared/migrate_viz/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,14 @@ def _pre_action(self) -> None:
if self.data.get("contribution"):
self.data["contributionMode"] = "row"

stacked = self.data.get("stacked_style")
if stacked:
if stacked := self.data.get("stacked_style"):
stacked_map = {
"expand": "Expand",
"stack": "Stack",
}
self.data["show_extra_controls"] = True
self.data["stack"] = stacked_map.get(stacked)

x_axis_label = self.data.get("x_axis_label")
if x_axis_label:
if x_axis_label := self.data.get("x_axis_label"):
self.data["x_axis_title"] = x_axis_label
self.data["x_axis_title_margin"] = 30
Loading