diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index d76a9c3da1..157b1a6a1d 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -988,8 +988,11 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp if ListTransformer.is_batchable(python_type): batchSize = len(python_val) # default batch size # parse annotated to get the number of items saved in a pickle file. - if get_origin(python_type) is Annotated and type(get_args(python_type)[1]) == int: - batchSize = get_args(python_type)[1] + if get_origin(python_type) is Annotated: + for annotation in get_args(python_type)[1:]: + if isinstance(annotation, int): + batchSize = annotation + break lit_list = [TypeEngine.to_literal(ctx, python_val[i : i + batchSize], FlytePickle, expected.collection_type) for i in range(0, len(python_val), batchSize)] # type: ignore else: t = self.get_sub_type(python_type) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index d87493eb05..f04b61651d 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -18,7 +18,7 @@ from marshmallow_enum import LoadDumpOptions from marshmallow_jsonschema import JSONSchema from pandas._testing import assert_frame_equal -from typing_extensions import Annotated, get_args, get_origin +from typing_extensions import Annotated from flytekit import kwtypes from flytekit.core.annotation import FlyteAnnotation @@ -1576,27 +1576,22 @@ def test_file_ext_with_flyte_file_wrong_type(): assert str(e.value) == "Underlying type of File Extension must be of type " -def test_batch_pickle_list(): +@pytest.mark.parametrize( + "python_val, python_type, batch_size", + [ + ([{"a": {0: "foo"}}] * 5, typing.List[typing.Dict[str, FlytePickle]], 5), + ([{"a": {0: "foo"}}] * 5, Annotated[typing.List[typing.Dict[str, FlytePickle]], 2], 2), + ([{"a": {0: "foo"}}] * 6, Annotated[typing.List[typing.Dict[str, FlytePickle]], HashMethod(function=str), 2], 2), + ], +) +def test_batch_pickle_list(python_val, python_type, batch_size): from math import ceil - python_val = [{"a": {0: "foo"}}] * 5 - python_type_list = [ - typing.List[typing.Dict[str, FlytePickle]], - Annotated[typing.List[typing.Dict[str, FlytePickle]], 2], - ] - - for python_type in python_type_list: - batch_size = len(python_val) - if get_origin(python_type) is Annotated: - batch_size = get_args(python_type)[1] - - ctx = FlyteContext.current_context() - expected = TypeEngine.to_literal_type(python_type) - - lv = TypeEngine.to_literal(ctx, python_val, python_type, expected) - # For example, if the batch size is 2 and the length of the list is 5, the list should be split into ceil(5/3) = 3 chunks. - # By default, the batch_size is set to the length of the whole list. - assert len(lv.collection.literals) == ceil(len(python_val) / batch_size) - - pv = TypeEngine.to_python_value(ctx, lv, python_type) - assert pv == python_val + ctx = FlyteContext.current_context() + expected = TypeEngine.to_literal_type(python_type) + lv = TypeEngine.to_literal(ctx, python_val, python_type, expected) + # For example, if the batch size is 2 and the length of the list is 5, the list should be split into ceil(5/3) = 3 chunks. + # By default, the batch_size is set to the length of the whole list. + assert len(lv.collection.literals) == ceil(len(python_val) / batch_size) + pv = TypeEngine.to_python_value(ctx, lv, python_type) + assert pv == python_val