Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type Mismatching while Serializing Dataclass with Union #2859

Merged
merged 8 commits into from
Nov 6, 2024
Merged
27 changes: 24 additions & 3 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,12 +676,33 @@ def _make_dataclass_serializable(self, python_val: T, python_type: Type[T]) -> t
"""
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile
from flytekit.types.structured import StructuredDataset

# Handle Optional
if UnionTransformer.is_optional_type(python_type):
if python_val is None:
return None
return self._make_dataclass_serializable(python_val, get_args(python_type)[0])

def get_expected_type(python_val: T, types: tuple) -> Type[T | None]:
if len(set(types) & {FlyteFile, FlyteDirectory, StructuredDataset}) > 1:
raise ValueError(
"Cannot have more than one Flyte type in the Union when attempting to use the string shortcut. Please specify the full object (e.g. FlyteFile(...)) instead of just passing a string."
)

for t in types:
try:
trans = TypeEngine.get_transformer(t) # type: ignore
if trans:
trans.assert_type(t, python_val)
return t
except Exception:
continue
return type(None)

# Get the expected type in the Union type
expected_type = type(None)
if python_val is not None:
expected_type = get_expected_type(python_val, get_args(python_type)) # type: ignore

return self._make_dataclass_serializable(python_val, expected_type)

if hasattr(python_type, "__origin__") and get_origin(python_type) is list:
if python_val is None:
Expand Down
14 changes: 14 additions & 0 deletions tests/flytekit/unit/core/test_dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,3 +1118,17 @@ def empty_nested_dc_wf() -> NestedFlyteTypes:

empty_nested_flyte_types = empty_nested_dc_wf()
DataclassTransformer().assert_type(NestedFlyteTypes, empty_nested_flyte_types)

def test_dataclass_serialize_with_multiple_dataclass_union():
@dataclass
class A():
x: int

@dataclass
class B():
x: FlyteFile

b = B(x="s3://my-bucket/my-file")
res = DataclassTransformer()._make_dataclass_serializable(b, Union[None, A, B])

assert res.x.path == "s3://my-bucket/my-file"
17 changes: 17 additions & 0 deletions tests/flytekit/unit/core/test_flytetypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from dataclasses import dataclass
from flytekit.types.file import FlyteFile
from flytekit.types.structured.structured_dataset import StructuredDataset
from flytekit.core.type_engine import DataclassTransformer
from typing import Union
import pytest
import re

def test_dataclass_union_with_multiple_flytetypes_error():
@dataclass
class DC():
x: Union[None, StructuredDataset, FlyteFile]


dc = DC(x="s3://my-bucket/my-file")
with pytest.raises(ValueError, match=re.escape("Cannot have more than one Flyte type in the Union when attempting to use the string shortcut. Please specify the full object (e.g. FlyteFile(...)) instead of just passing a string.")):
DataclassTransformer()._make_dataclass_serializable(dc, DC)
4 changes: 4 additions & 0 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,6 +967,7 @@ class TestFileStruct(DataClassJsonMixin):
b: typing.Optional[FlyteFile]
b_prime: typing.Optional[FlyteFile]
c: typing.Union[FlyteFile, None]
c_prime: typing.Union[None, FlyteFile]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make this

Suggested change
c_prime: typing.Union[None, FlyteFile]
c_prime: typing.Union[None, StructuredDataset, int, FlyteFile]

Copy link
Contributor

@wild-endeavor wild-endeavor Oct 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually can you write one more unit test for me please? (and add it under test_dataclass.py this file is getting too big).

@dataclass
class A():
  x: int

@dataclass
class B():
   x: FlyteFile

then call _make_dataclass_serializable on Union[None, A, B] where b = B(x="s3://tmp) or something.

Copy link
Contributor Author

@mao3267 mao3267 Oct 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although we didn’t expect typing.Union[None, StructuredDataset, int, FlyteFile] to work in our tests due to multiple FlyteTypes ambiguity, we anticipated something like typing.Union[None, int, FlyteFile] might function correctly. However, it seems that while messagepack decoder decodes from a binary to a python value, it fails to identify the target type for Union types with more than two types, resulting a dictionary rather than a FlyteFile. cc @Future-Outlier

d: typing.List[FlyteFile]
e: typing.List[typing.Optional[FlyteFile]]
e_prime: typing.List[typing.Optional[FlyteFile]]
Expand All @@ -989,6 +990,7 @@ class TestFileStruct(DataClassJsonMixin):
b=f1,
b_prime=None,
c=f1,
c_prime=f1,
d=[f1],
e=[f1],
e_prime=[None],
Expand All @@ -1011,6 +1013,7 @@ class TestFileStruct(DataClassJsonMixin):
assert dict_obj["b"]["path"] == remote_path
assert dict_obj["b_prime"] is None
assert dict_obj["c"]["path"] == remote_path
assert dict_obj["c_prime"]["path"] == remote_path
assert dict_obj["d"][0]["path"] == remote_path
assert dict_obj["e"][0]["path"] == remote_path
assert dict_obj["e_prime"][0] is None
Expand All @@ -1028,6 +1031,7 @@ class TestFileStruct(DataClassJsonMixin):
assert o.b.remote_path == ot.b.remote_source
assert ot.b_prime is None
assert o.c.remote_path == ot.c.remote_source
assert o.c_prime.remote_path == ot.c_prime.remote_source
assert o.d[0].remote_path == ot.d[0].remote_source
assert o.e[0].remote_path == ot.e[0].remote_source
assert o.e_prime == [None]
Expand Down
Loading