Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backport 2845 v1.13 #2851

Merged
merged 1 commit into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,10 @@ def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T):
else:
for f in dataclasses.fields(type(v)): # type: ignore
original_type = f.type
if f.name not in expected_fields_dict:
raise TypeTransformerFailedError(
f"Field '{f.name}' is not present in the expected dataclass fields {expected_type.__name__}"
)
expected_type = expected_fields_dict[f.name]

if UnionTransformer.is_optional_type(original_type):
Expand Down Expand Up @@ -796,7 +800,7 @@ def to_literal(
if type(python_val).__class__ != enum.EnumMeta:
raise TypeTransformerFailedError("Expected an enum")
if type(python_val.value) != str:
raise TypeTransformerFailedError("Only string-valued enums are supportedd")
raise TypeTransformerFailedError("Only string-valued enums are supported")

return Literal(scalar=Scalar(primitive=Primitive(string_value=python_val.value))) # type: ignore

Expand All @@ -808,6 +812,18 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[enum.Enum]:
return enum.Enum("DynamicEnum", {f"{i}": i for i in literal_type.enum_type.values}) # type: ignore
raise ValueError(f"Enum transformer cannot reverse {literal_type}")

def assert_type(self, t: Type[enum.Enum], v: T):
if sys.version_info < (3, 10):
if not isinstance(v, enum.Enum):
raise TypeTransformerFailedError(f"Value {v} needs to be an Enum in 3.9")
if not isinstance(v, t):
raise TypeTransformerFailedError(f"Value {v} is not in Enum {t}")
return

val = v.value if isinstance(v, enum.Enum) else v
if val not in t:
raise TypeTransformerFailedError(f"Value {v} is not in Enum {t}")


def generate_attribute_list_from_dataclass_json_mixin(schema: dict, schema_name: typing.Any):
attribute_list = []
Expand Down Expand Up @@ -1193,7 +1209,7 @@ def literal_map_to_kwargs(
raise ValueError("At least one of python_types or literal_types must be provided")

if literal_types:
python_interface_inputs = {
python_interface_inputs: dict[str, Type[T]] = {
name: TypeEngine.guess_python_type(lt.type) for name, lt in literal_types.items()
}
else:
Expand Down Expand Up @@ -1272,7 +1288,7 @@ def guess_python_types(
return python_types

@classmethod
def guess_python_type(cls, flyte_type: LiteralType) -> type:
def guess_python_type(cls, flyte_type: LiteralType) -> Type[T]:
"""
Transforms a flyte-specific ``LiteralType`` to a regular python value.
"""
Expand Down Expand Up @@ -1542,13 +1558,17 @@ def assert_type(self, t: Type[T], v: T):
# this is an edge case
return
try:
super().assert_type(sub_type, v)
return
sub_trans: TypeTransformer = TypeEngine.get_transformer(sub_type)
if sub_trans.type_assertions_enabled:
sub_trans.assert_type(sub_type, v)
return
else:
return
except TypeTransformerFailedError:
continue
except TypeError:
continue
raise TypeTransformerFailedError(f"Value {v} is not of type {t}")
else:
super().assert_type(t, v)

def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]:
t = get_underlying_type(t)
Expand Down Expand Up @@ -1806,7 +1826,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:

def guess_python_type(self, literal_type: LiteralType) -> Union[Type[dict], typing.Dict[Type, Type]]:
if literal_type.map_value_type:
mt = TypeEngine.guess_python_type(literal_type.map_value_type)
mt: Type = TypeEngine.guess_python_type(literal_type.map_value_type)
return typing.Dict[str, mt] # type: ignore

if literal_type.simple == SimpleType.STRUCT:
Expand Down
93 changes: 93 additions & 0 deletions tests/flytekit/unit/core/test_unions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import typing
from dataclasses import dataclass
from enum import Enum
import sys
import pytest

from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformerFailedError


def test_asserting():
@dataclass
class A:
a: str = None

@dataclass
class B:
b: str = None

@dataclass
class C:
c: str = None

ctx = FlyteContextManager.current_context()

pt = typing.Union[A, B, str]
lt = TypeEngine.to_literal_type(pt)
# mimic a register/remote fetch
guessed = TypeEngine.guess_python_type(lt)

TypeEngine.to_literal(ctx, A("a"), guessed, lt)
TypeEngine.to_literal(ctx, B(b="bb"), guessed, lt)
TypeEngine.to_literal(ctx, "hello", guessed, lt)

with pytest.raises(TypeTransformerFailedError):
TypeEngine.to_literal(ctx, C("cc"), guessed, lt)

with pytest.raises(TypeTransformerFailedError):
TypeEngine.to_literal(ctx, 3, guessed, lt)


@pytest.mark.skipif(
sys.version_info < (3, 10), reason="enum checking only works in 3.10+"
)
def test_asserting_enum():
class Color(Enum):
RED = "one"
GREEN = "two"
BLUE = "blue"

lt = TypeEngine.to_literal_type(Color)
guessed = TypeEngine.guess_python_type(lt)
tf = TypeEngine.get_transformer(guessed)
tf.assert_type(guessed, "one")
tf.assert_type(guessed, guessed("two"))
tf.assert_type(Color, "one")

guessed2 = TypeEngine.guess_python_type(lt)
tf.assert_type(guessed, guessed2("two"))


@pytest.mark.skipif(
sys.version_info >= (3, 10), reason="3.9 enum testing"
)
def test_asserting_enum_39():
class Color(Enum):
RED = "one"
GREEN = "two"
BLUE = "blue"

lt = TypeEngine.to_literal_type(Color)
guessed = TypeEngine.guess_python_type(lt)
tf = TypeEngine.get_transformer(guessed)
tf.assert_type(guessed, guessed("two"))
tf.assert_type(Color, Color.GREEN)


@pytest.mark.sandbox_test
def test_with_remote():
from flytekit.remote.remote import FlyteRemote
from typing_extensions import Annotated, get_args
from flytekit.configuration import Config, Image, ImageConfig, SerializationSettings

r = FlyteRemote(
Config.auto(config_file="/Users/ytong/.flyte/config-sandbox.yaml"),
default_project="flytesnacks",
default_domain="development",
)
lp = r.fetch_launch_plan(name="yt_dbg.scratchpad.union_enums.wf", version="oppOd5jst-LWExhTLM0F2w")
guessed_union_type = TypeEngine.guess_python_type(lp.interface.inputs["x"].type)
guessed_enum = get_args(guessed_union_type)[0]
val = guessed_enum("one")
r.execute(lp, inputs={"x": val})