Skip to content

Commit

Permalink
[mypy] Enforcing typing for charts (#9411)
Browse files Browse the repository at this point in the history
Co-authored-by: John Bodley <john.bodley@airbnb.com>
  • Loading branch information
john-bodley and John Bodley authored Mar 29, 2020
1 parent 2e81e27 commit ec795a4
Show file tree
Hide file tree
Showing 14 changed files with 41 additions and 29 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ order_by_type = false
ignore_missing_imports = true
no_implicit_optional = true

[mypy-superset.db_engine_specs.*]
[mypy-superset.charts.*,superset.db_engine_specs.*]
check_untyped_defs = true
disallow_untyped_calls = true
disallow_untyped_defs = true
5 changes: 4 additions & 1 deletion superset/charts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any

from flask import g, request, Response
from flask_appbuilder.api import expose, protect, rison, safe
Expand Down Expand Up @@ -287,7 +288,9 @@ def delete(self, pk: int) -> Response: # pylint: disable=arguments-differ
@protect()
@safe
@rison(get_delete_ids_schema)
def bulk_delete(self, **kwargs) -> Response: # pylint: disable=arguments-differ
def bulk_delete(
self, **kwargs: Any
) -> Response: # pylint: disable=arguments-differ
"""Delete bulk Charts
---
delete:
Expand Down
2 changes: 1 addition & 1 deletion superset/charts/commands/bulk_delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, user: User, model_ids: List[int]):
self._model_ids = model_ids
self._models: Optional[List[Slice]] = None

def run(self):
def run(self) -> None:
self.validate()
try:
ChartDAO.bulk_delete(self._models)
Expand Down
3 changes: 2 additions & 1 deletion superset/charts/commands/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import logging
from typing import Dict, List, Optional

from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from marshmallow import ValidationError

Expand All @@ -39,7 +40,7 @@ def __init__(self, user: User, data: Dict):
self._actor = user
self._properties = data.copy()

def run(self):
def run(self) -> Model:
self.validate()
try:
chart = ChartDAO.create(self._properties)
Expand Down
3 changes: 2 additions & 1 deletion superset/charts/commands/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import logging
from typing import Optional

from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User

from superset.charts.commands.exceptions import (
Expand All @@ -40,7 +41,7 @@ def __init__(self, user: User, model_id: int):
self._model_id = model_id
self._model: Optional[SqlaTable] = None

def run(self):
def run(self) -> Model:
self.validate()
try:
chart = ChartDAO.delete(self._model)
Expand Down
6 changes: 3 additions & 3 deletions superset/charts/commands/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class DatabaseNotFoundValidationError(ValidationError):
Marshmallow validation error for database does not exist
"""

def __init__(self):
def __init__(self) -> None:
super().__init__(_("Database does not exist"), field_names=["database"])


Expand All @@ -41,7 +41,7 @@ class DashboardsNotFoundValidationError(ValidationError):
Marshmallow validation error for dashboards don't exist
"""

def __init__(self):
def __init__(self) -> None:
super().__init__(_("Dashboards do not exist"), field_names=["dashboards"])


Expand All @@ -50,7 +50,7 @@ class DatasourceTypeUpdateRequiredValidationError(ValidationError):
Marshmallow validation error for dashboards don't exist
"""

def __init__(self):
def __init__(self) -> None:
super().__init__(
_("Datasource type is required when datasource_id is given"),
field_names=["datasource_type"],
Expand Down
3 changes: 2 additions & 1 deletion superset/charts/commands/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import logging
from typing import Dict, List, Optional

from flask_appbuilder.models.sqla import Model
from flask_appbuilder.security.sqla.models import User
from marshmallow import ValidationError

Expand Down Expand Up @@ -47,7 +48,7 @@ def __init__(self, user: User, model_id: int, data: Dict):
self._properties = data.copy()
self._model: Optional[SqlaTable] = None

def run(self):
def run(self) -> Model:
self.validate()
try:
chart = ChartDAO.update(self._model, self._properties)
Expand Down
15 changes: 8 additions & 7 deletions superset/charts/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import List
from typing import List, Optional

from sqlalchemy.exc import SQLAlchemyError

Expand All @@ -32,13 +32,14 @@ class ChartDAO(BaseDAO):
base_filter = ChartFilter

@staticmethod
def bulk_delete(models: List[Slice], commit=True):
item_ids = [model.id for model in models]
def bulk_delete(models: Optional[List[Slice]], commit: bool = True) -> None:
item_ids = [model.id for model in models] if models else []
# bulk delete, first delete related data
for model in models:
model.owners = []
model.dashboards = []
db.session.merge(model)
if models:
for model in models:
model.owners = []
model.dashboards = []
db.session.merge(model)
# bulk delete itself
try:
db.session.query(Slice).filter(Slice.id.in_(item_ids)).delete(
Expand Down
5 changes: 4 additions & 1 deletion superset/charts/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any

from sqlalchemy import or_
from sqlalchemy.orm.query import Query

from superset import security_manager
from superset.views.base import BaseFilter


class ChartFilter(BaseFilter): # pylint: disable=too-few-public-methods
def apply(self, query, value):
def apply(self, query: Query, value: Any) -> Query:
if security_manager.all_datasource_access():
return query
perms = security_manager.user_view_menu_names("datasource_access")
Expand Down
3 changes: 2 additions & 1 deletion superset/charts/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Union

from marshmallow import fields, Schema, ValidationError
from marshmallow.validate import Length
Expand All @@ -24,7 +25,7 @@
get_delete_ids_schema = {"type": "array", "items": {"type": "integer"}}


def validate_json(value):
def validate_json(value: Union[bytes, bytearray, str]) -> None:
try:
utils.validate_json(value)
except SupersetException:
Expand Down
10 changes: 5 additions & 5 deletions superset/commands/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import List
from typing import Any, Dict, List

from flask_babel import lazy_gettext as _
from marshmallow import ValidationError
Expand All @@ -36,8 +36,8 @@ class CommandInvalidError(CommandException):

status = 422

def __init__(self, message=""):
self._invalid_exceptions = list()
def __init__(self, message="") -> None:
self._invalid_exceptions: List[ValidationError] = []
super().__init__(self.message)

def add(self, exception: ValidationError):
Expand All @@ -46,8 +46,8 @@ def add(self, exception: ValidationError):
def add_list(self, exceptions: List[ValidationError]):
self._invalid_exceptions.extend(exceptions)

def normalized_messages(self):
errors = {}
def normalized_messages(self) -> Dict[Any, Any]:
errors: Dict[Any, Any] = {}
for exception in self._invalid_exceptions:
errors.update(exception.normalized_messages())
return errors
Expand Down
4 changes: 2 additions & 2 deletions superset/dao/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def find_by_ids(cls, model_ids: List[int]) -> List[Model]:
return query.all()

@classmethod
def create(cls, properties: Dict, commit=True) -> Optional[Model]:
def create(cls, properties: Dict, commit: bool = True) -> Model:
"""
Generic for creating models
:raises: DAOCreateFailedError
Expand All @@ -95,7 +95,7 @@ def create(cls, properties: Dict, commit=True) -> Optional[Model]:
return model

@classmethod
def update(cls, model: Model, properties: Dict, commit=True) -> Optional[Model]:
def update(cls, model: Model, properties: Dict, commit: bool = True) -> Model:
"""
Generic update a model
:raises: DAOCreateFailedError
Expand Down
2 changes: 1 addition & 1 deletion superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def get_datasource_full_name(database_name, datasource_name, schema=None):
return "[{}].[{}].[{}]".format(database_name, schema, datasource_name)


def validate_json(obj):
def validate_json(obj: Union[bytes, bytearray, str]) -> None:
if obj:
try:
json.loads(obj)
Expand Down
7 changes: 4 additions & 3 deletions superset/views/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import logging
import traceback
from datetime import datetime
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

import simplejson as json
import yaml
Expand All @@ -27,6 +27,7 @@
from flask_appbuilder.actions import action
from flask_appbuilder.forms import DynamicForm
from flask_appbuilder.models.sqla.filters import BaseFilter
from flask_appbuilder.security.sqla.models import User
from flask_appbuilder.widgets import ListWidget
from flask_babel import get_locale, gettext as __, lazy_gettext as _
from flask_wtf.form import FlaskForm
Expand Down Expand Up @@ -365,7 +366,7 @@ class CsvResponse(Response): # pylint: disable=too-many-ancestors
charset = conf["CSV_EXPORT"].get("encoding", "utf-8")


def check_ownership(obj, raise_if_false=True):
def check_ownership(obj: Any, raise_if_false: bool = True) -> bool:
"""Meant to be used in `pre_update` hooks on models to enforce ownership
Admin have all access, and other users need to be referenced on either
Expand All @@ -392,7 +393,7 @@ def check_ownership(obj, raise_if_false=True):
orig_obj = scoped_session.query(obj.__class__).filter_by(id=obj.id).first()

# Making a list of owners that works across ORM models
owners = []
owners: List[User] = []
if hasattr(orig_obj, "owners"):
owners += orig_obj.owners
if hasattr(orig_obj, "owner"):
Expand Down

0 comments on commit ec795a4

Please sign in to comment.