Skip to content

Commit

Permalink
Fix mypy errors caught in 1.11.2 (flyteorg#2808)
Browse files Browse the repository at this point in the history
Signed-off-by: Eduardo Apolinario <eapolinario@users.noreply.github.com>
  • Loading branch information
eapolinario authored and otarabai committed Oct 15, 2024
1 parent 391e3da commit 8e82d6c
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 21 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.4.7
rev: v0.6.9
hooks:
# Run the linter.
- id: ruff
args: [--fix, --show-fixes, --output-format=full]
# Run the formatter.
- id: ruff-format
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.2.0
rev: v5.0.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
Expand Down
4 changes: 2 additions & 2 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from flytekit.core.interface import transform_interface_to_list_interface
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.python_function_task import PythonFunctionTask, PythonInstanceTask
from flytekit.core.type_engine import TypeEngine, is_annotated
from flytekit.core.type_engine import TypeEngine, TypeTransformer, is_annotated
from flytekit.core.utils import timeit
from flytekit.loggers import logger
from flytekit.models import literals as _literal_models
Expand Down Expand Up @@ -80,7 +80,7 @@ def __init__(
for k, v in actual_task.python_interface.inputs.items():
if bound_inputs and k in bound_inputs:
continue
transformer = TypeEngine.get_transformer(v)
transformer: TypeTransformer = TypeEngine.get_transformer(v)
if isinstance(transformer, FlytePickleTransformer):
if is_annotated(v):
for annotation in typing_extensions.get_args(v)[1:]:
Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def get_container(self, settings: SerializationSettings) -> _task_model.Containe
return self._get_container(settings)

def _get_container(self, settings: SerializationSettings) -> _task_model.Container:
env = {}
env: Dict[str, str] = {}
for elem in (settings.env, self.environment):
if elem:
env.update(elem)
Expand Down
38 changes: 25 additions & 13 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from abc import ABC, abstractmethod
from collections import OrderedDict
from functools import lru_cache
from types import GenericAlias
from typing import Any, Dict, List, NamedTuple, Optional, Type, cast

import msgpack
Expand Down Expand Up @@ -434,7 +435,7 @@ class Test(DataClassJsonMixin):
"""

def __init__(self):
def __init__(self) -> None:
super().__init__("Object-Dataclass-Transformer", object)
self._decoder: Dict[Type, JSONDecoder] = dict()

Expand Down Expand Up @@ -922,7 +923,7 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[enum.Enum]:


def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name: typing.Any):
attribute_list = []
attribute_list: typing.List[tuple[Any, GenericAlias]] = []
for property_key, property_val in schema["properties"].items():
if property_val.get("anyOf"):
property_type = property_val["anyOf"][0]["type"]
Expand All @@ -939,7 +940,12 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name:
sub_schemea = property_val["anyOf"][0]
sub_schemea_name = sub_schemea["title"]
attribute_list.append(
(property_key, convert_mashumaro_json_schema_to_python_class(sub_schemea, sub_schemea_name))
(
property_key,
typing.cast(
GenericAlias, convert_mashumaro_json_schema_to_python_class(sub_schemea, sub_schemea_name)
),
)
)
elif property_val.get("additionalProperties"):
attribute_list.append(
Expand All @@ -948,7 +954,12 @@ def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name:
else:
sub_schemea_name = property_val["title"]
attribute_list.append(
(property_key, convert_mashumaro_json_schema_to_python_class(property_val, sub_schemea_name))
(
property_key,
typing.cast(
GenericAlias, convert_mashumaro_json_schema_to_python_class(property_val, sub_schemea_name)
),
)
)
elif property_type == "enum":
attribute_list.append([property_key, str]) # type: ignore
Expand Down Expand Up @@ -2153,7 +2164,7 @@ def to_python_value(


def generate_attribute_list_from_dataclass_json(schema: dict, schema_name: typing.Any):
attribute_list = []
attribute_list: typing.List[tuple[Any, GenericAlias]] = []
for property_key, property_val in schema[schema_name]["properties"].items():
property_type = property_val["type"]
# Handle list
Expand All @@ -2163,10 +2174,15 @@ def generate_attribute_list_from_dataclass_json(schema: dict, schema_name: typin
elif property_type == "object":
if property_val.get("$ref"):
name = property_val["$ref"].split("/")[-1]
attribute_list.append((property_key, convert_marshmallow_json_schema_to_python_class(schema, name)))
attribute_list.append(
(
property_key,
typing.cast(GenericAlias, convert_marshmallow_json_schema_to_python_class(schema, name)),
)
)
elif property_val.get("additionalProperties"):
attribute_list.append(
(property_key, Dict[str, _get_element_type(property_val["additionalProperties"])]) # type: ignore[misc,index]
(property_key, typing.cast(GenericAlias, _get_element_type(property_val["additionalProperties"]))),
)
else:
attribute_list.append((property_key, Dict[str, _get_element_type(property_val)])) # type: ignore[misc,index]
Expand All @@ -2176,9 +2192,7 @@ def generate_attribute_list_from_dataclass_json(schema: dict, schema_name: typin
return attribute_list


def convert_marshmallow_json_schema_to_python_class(
schema: dict, schema_name: typing.Any
) -> Type[dataclasses.dataclass()]: # type: ignore
def convert_marshmallow_json_schema_to_python_class(schema: dict, schema_name: typing.Any) -> type:
"""
Generate a model class based on the provided JSON Schema
:param schema: dict representing valid JSON schema
Expand All @@ -2189,9 +2203,7 @@ def convert_marshmallow_json_schema_to_python_class(
return dataclass_json(dataclasses.make_dataclass(schema_name, attribute_list))


def convert_mashumaro_json_schema_to_python_class(
schema: dict, schema_name: typing.Any
) -> Type[dataclasses.dataclass()]: # type: ignore
def convert_mashumaro_json_schema_to_python_class(schema: dict, schema_name: typing.Any) -> type:
"""
Generate a model class based on the provided JSON Schema
:param schema: dict representing valid JSON schema
Expand Down
6 changes: 3 additions & 3 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import typing
from abc import ABC, abstractmethod
from dataclasses import dataclass, field, is_dataclass
from typing import Dict, Generator, List, Optional, Type, Union
from typing import Dict, Generator, Generic, List, Optional, Type, Union

import msgpack
from dataclasses_json import config
Expand Down Expand Up @@ -223,7 +223,7 @@ def extract_cols_and_format(
return t, ordered_dict_cols, fmt, pa_schema


class StructuredDatasetEncoder(ABC):
class StructuredDatasetEncoder(ABC, Generic[T]):
def __init__(
self,
python_type: Type[T],
Expand Down Expand Up @@ -290,7 +290,7 @@ def encode(
raise NotImplementedError


class StructuredDatasetDecoder(ABC):
class StructuredDatasetDecoder(ABC, Generic[DF]):
def __init__(
self,
python_type: Type[DF],
Expand Down

0 comments on commit 8e82d6c

Please sign in to comment.