Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: rename get_iterable #24994

Merged
merged 7 commits into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions superset/connectors/base/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def data_for_slices( # pylint: disable=too-many-locals
form_data = slc.form_data
# pull out all required metrics from the form_data
for metric_param in METRIC_FORM_DATA_PARAMS:
for metric in utils.get_iterable(form_data.get(metric_param) or []):
for metric in utils.as_list(form_data.get(metric_param) or []):
metric_names.add(utils.get_metric_name(metric))
if utils.is_adhoc_metric(metric):
column = metric.get("column") or {}
Expand Down Expand Up @@ -377,7 +377,7 @@ def data_for_slices( # pylint: disable=too-many-locals
if utils.is_adhoc_column(column)
else column
for column_param in COLUMN_FORM_DATA_PARAMS
for column in utils.get_iterable(form_data.get(column_param) or [])
for column in utils.as_list(form_data.get(column_param) or [])
]
column_names.update(_columns)

Expand Down
10 changes: 5 additions & 5 deletions superset/daos/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations

from typing import Any, Generic, get_args, TypeVar
from typing import Any, cast, Generic, get_args, TypeVar

from flask_appbuilder.models.filters import BaseFilter
from flask_appbuilder.models.sqla import Model
Expand All @@ -30,7 +30,7 @@
DAOUpdateFailedError,
)
from superset.extensions import db
from superset.utils.core import get_iterable
from superset.utils.core import as_list

T = TypeVar("T", bound=Model)

Expand Down Expand Up @@ -197,7 +197,7 @@ def update(
return item # type: ignore

@classmethod
def delete(cls, items: T | list[T], commit: bool = True) -> None:
def delete(cls, item_or_items: T | list[T], commit: bool = True) -> None:
"""
Delete the specified item(s) including their associated relationships.

Expand All @@ -214,9 +214,9 @@ def delete(cls, items: T | list[T], commit: bool = True) -> None:
:raises DAODeleteFailedError: If the deletion failed
:see: https://docs.sqlalchemy.org/en/latest/orm/queryguide/dml.html
"""

items = cast(list[T], as_list(item_or_items))
try:
for item in get_iterable(items):
for item in items:
db.session.delete(item)

if commit:
Expand Down
9 changes: 6 additions & 3 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1578,12 +1578,15 @@ def split(
yield string[i:]


def get_iterable(x: Any) -> list[Any]:
T = TypeVar("T")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!



def as_list(x: T | list[T]) -> list[T]:
"""
Get an iterable (list) representation of the object.
Wrap an object in a list if it's not a list.

:param x: The object
:returns: An iterable representation
:returns: A list wrapping the object if it's not already a list
"""
return x if isinstance(x, list) else [x]

Expand Down
10 changes: 5 additions & 5 deletions tests/integration_tests/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
format_timedelta,
GenericDataType,
get_form_data_token,
get_iterable,
as_list,
get_email_address_list,
get_stacktrace,
json_int_dttm_ser,
Expand Down Expand Up @@ -749,10 +749,10 @@ def test_get_or_create_db_existing_invalid_uri(self):
database = get_or_create_db("test_db", "sqlite:///superset.db")
assert database.sqlalchemy_uri == "sqlite:///superset.db"

def test_get_iterable(self):
self.assertListEqual(get_iterable(123), [123])
self.assertListEqual(get_iterable([123]), [123])
self.assertListEqual(get_iterable("foo"), ["foo"])
def test_as_list(self):
self.assertListEqual(as_list(123), [123])
self.assertListEqual(as_list([123]), [123])
self.assertListEqual(as_list("foo"), ["foo"])

@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")
def test_build_extra_filters(self):
Expand Down