Skip to content

Commit

Permalink
Type Mismatching while Serializing Dataclass with Union (flyteorg#2859)
Browse files Browse the repository at this point in the history
Signed-off-by: mao3267 <chenvincent610@gmail.com>
Signed-off-by: 400Ping <43886578+400Ping@users.noreply.github.com>
  • Loading branch information
mao3267 authored and 400Ping committed Nov 22, 2024
1 parent b529f8c commit 00d17a9
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 3 deletions.
27 changes: 24 additions & 3 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,12 +783,33 @@ def t1() -> DC:
"""
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile
from flytekit.types.structured import StructuredDataset

# Handle Optional
if UnionTransformer.is_optional_type(python_type):
if python_val is None:
return None
return self._make_dataclass_serializable(python_val, get_args(python_type)[0])

def get_expected_type(python_val: T, types: tuple) -> Type[T | None]:
if len(set(types) & {FlyteFile, FlyteDirectory, StructuredDataset}) > 1:
raise ValueError(
"Cannot have more than one Flyte type in the Union when attempting to use the string shortcut. Please specify the full object (e.g. FlyteFile(...)) instead of just passing a string."
)

for t in types:
try:
trans = TypeEngine.get_transformer(t) # type: ignore
if trans:
trans.assert_type(t, python_val)
return t
except Exception:
continue
return type(None)

# Get the expected type in the Union type
expected_type = type(None)
if python_val is not None:
expected_type = get_expected_type(python_val, get_args(python_type)) # type: ignore

return self._make_dataclass_serializable(python_val, expected_type)

if hasattr(python_type, "__origin__") and get_origin(python_type) is list:
if python_val is None:
Expand Down
14 changes: 14 additions & 0 deletions tests/flytekit/unit/core/test_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,3 +1118,17 @@ def empty_nested_dc_wf() -> NestedFlyteTypes:

empty_nested_flyte_types = empty_nested_dc_wf()
DataclassTransformer().assert_type(NestedFlyteTypes, empty_nested_flyte_types)

def test_dataclass_serialize_with_multiple_dataclass_union():
@dataclass
class A():
x: int

@dataclass
class B():
x: FlyteFile

b = B(x="s3://my-bucket/my-file")
res = DataclassTransformer()._make_dataclass_serializable(b, Union[None, A, B])

assert res.x.path == "s3://my-bucket/my-file"
17 changes: 17 additions & 0 deletions tests/flytekit/unit/core/test_flytetypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from dataclasses import dataclass
from flytekit.types.file import FlyteFile
from flytekit.types.structured.structured_dataset import StructuredDataset
from flytekit.core.type_engine import DataclassTransformer
from typing import Union
import pytest
import re

def test_dataclass_union_with_multiple_flytetypes_error():
@dataclass
class DC():
x: Union[None, StructuredDataset, FlyteFile]


dc = DC(x="s3://my-bucket/my-file")
with pytest.raises(ValueError, match=re.escape("Cannot have more than one Flyte type in the Union when attempting to use the string shortcut. Please specify the full object (e.g. FlyteFile(...)) instead of just passing a string.")):
DataclassTransformer()._make_dataclass_serializable(dc, DC)
4 changes: 4 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,6 +957,7 @@ class TestFileStruct(DataClassJsonMixin):
b: typing.Optional[FlyteFile]
b_prime: typing.Optional[FlyteFile]
c: typing.Union[FlyteFile, None]
c_prime: typing.Union[None, FlyteFile]
d: typing.List[FlyteFile]
e: typing.List[typing.Optional[FlyteFile]]
e_prime: typing.List[typing.Optional[FlyteFile]]
Expand All @@ -979,6 +980,7 @@ class TestFileStruct(DataClassJsonMixin):
b=f1,
b_prime=None,
c=f1,
c_prime=f1,
d=[f1],
e=[f1],
e_prime=[None],
Expand All @@ -1001,6 +1003,7 @@ class TestFileStruct(DataClassJsonMixin):
assert dict_obj["b"]["path"] == remote_path
assert dict_obj["b_prime"] is None
assert dict_obj["c"]["path"] == remote_path
assert dict_obj["c_prime"]["path"] == remote_path
assert dict_obj["d"][0]["path"] == remote_path
assert dict_obj["e"][0]["path"] == remote_path
assert dict_obj["e_prime"][0] is None
Expand All @@ -1018,6 +1021,7 @@ class TestFileStruct(DataClassJsonMixin):
assert o.b.remote_path == ot.b.remote_source
assert ot.b_prime is None
assert o.c.remote_path == ot.c.remote_source
assert o.c_prime.remote_path == ot.c_prime.remote_source
assert o.d[0].remote_path == ot.d[0].remote_source
assert o.e[0].remote_path == ot.e[0].remote_source
assert o.e_prime == [None]
Expand Down

0 comments on commit 00d17a9

Please sign in to comment.