Skip to content

Commit

Permalink
handle HashMethod case
Browse files Browse the repository at this point in the history
Signed-off-by: Yicheng-Lu-llll <luyc58576@gmail.com>
  • Loading branch information
Yicheng-Lu-llll committed Mar 13, 2023
1 parent 66f36e8 commit 01aba67
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 25 deletions.
7 changes: 5 additions & 2 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,8 +988,11 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
if ListTransformer.is_batchable(python_type):
batchSize = len(python_val) # default batch size
# parse annotated to get the number of items saved in a pickle file.
if get_origin(python_type) is Annotated and type(get_args(python_type)[1]) == int:
batchSize = get_args(python_type)[1]
if get_origin(python_type) is Annotated:
for annotation in get_args(python_type)[1:]:
if isinstance(annotation, int):
batchSize = annotation
break
lit_list = [TypeEngine.to_literal(ctx, python_val[i : i + batchSize], FlytePickle, expected.collection_type) for i in range(0, len(python_val), batchSize)] # type: ignore
else:
t = self.get_sub_type(python_type)
Expand Down
41 changes: 18 additions & 23 deletions tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from marshmallow_enum import LoadDumpOptions
from marshmallow_jsonschema import JSONSchema
from pandas._testing import assert_frame_equal
from typing_extensions import Annotated, get_args, get_origin
from typing_extensions import Annotated

from flytekit import kwtypes
from flytekit.core.annotation import FlyteAnnotation
Expand Down Expand Up @@ -1576,27 +1576,22 @@ def test_file_ext_with_flyte_file_wrong_type():
assert str(e.value) == "Underlying type of File Extension must be of type <str>"


def test_batch_pickle_list():
@pytest.mark.parametrize(
"python_val, python_type, batch_size",
[
([{"a": {0: "foo"}}] * 5, typing.List[typing.Dict[str, FlytePickle]], 5),
([{"a": {0: "foo"}}] * 5, Annotated[typing.List[typing.Dict[str, FlytePickle]], 2], 2),
([{"a": {0: "foo"}}] * 6, Annotated[typing.List[typing.Dict[str, FlytePickle]], HashMethod(function=str), 2], 2),
],
)
def test_batch_pickle_list(python_val, python_type, batch_size):
from math import ceil

python_val = [{"a": {0: "foo"}}] * 5
python_type_list = [
typing.List[typing.Dict[str, FlytePickle]],
Annotated[typing.List[typing.Dict[str, FlytePickle]], 2],
]

for python_type in python_type_list:
batch_size = len(python_val)
if get_origin(python_type) is Annotated:
batch_size = get_args(python_type)[1]

ctx = FlyteContext.current_context()
expected = TypeEngine.to_literal_type(python_type)

lv = TypeEngine.to_literal(ctx, python_val, python_type, expected)
# For example, if the batch size is 2 and the length of the list is 5, the list should be split into ceil(5/3) = 3 chunks.
# By default, the batch_size is set to the length of the whole list.
assert len(lv.collection.literals) == ceil(len(python_val) / batch_size)

pv = TypeEngine.to_python_value(ctx, lv, python_type)
assert pv == python_val
ctx = FlyteContext.current_context()
expected = TypeEngine.to_literal_type(python_type)
lv = TypeEngine.to_literal(ctx, python_val, python_type, expected)
# For example, if the batch size is 2 and the length of the list is 5, the list should be split into ceil(5/3) = 3 chunks.
# By default, the batch_size is set to the length of the whole list.
assert len(lv.collection.literals) == ceil(len(python_val) / batch_size)
pv = TypeEngine.to_python_value(ctx, lv, python_type)
assert pv == python_val

0 comments on commit 01aba67

Please sign in to comment.