diff --git a/flytekit/core/map_task.py b/flytekit/core/map_task.py index f760be5d3c..4ec9f64a2d 100644 --- a/flytekit/core/map_task.py +++ b/flytekit/core/map_task.py @@ -209,7 +209,7 @@ def _raw_execute(self, **kwargs) -> Any: return outputs -def map_task(task_function: PythonFunctionTask, concurrency: int = None, min_success_ratio: float = None, **kwargs): +def map_task(task_function: PythonFunctionTask, concurrency: int = 0, min_success_ratio: float = 1.0, **kwargs): """ Use a map task for parallelizable tasks that run across a list of an input type. A map task can be composed of any individual :py:class:`flytekit.PythonFunctionTask`. @@ -231,7 +231,7 @@ def map_task(task_function: PythonFunctionTask, concurrency: int = None, min_suc :param task_function: This argument is implicitly passed and represents the repeatable function :param concurrency: If specified, this limits the number of mapped tasks than can run in parallel to the given batch size. If the size of the input exceeds the concurrency value, then multiple batches will be run serially until - all inputs are processed. + all inputs are processed. If left unspecified, this means unbounded concurrency. :param min_success_ratio: If specified, this determines the minimum fraction of total jobs which can complete successfully before terminating this task and marking it successful. diff --git a/flytekit/models/array_job.py b/flytekit/models/array_job.py index 4e4bf99cc7..2c86acdd7e 100644 --- a/flytekit/models/array_job.py +++ b/flytekit/models/array_job.py @@ -70,13 +70,21 @@ def to_dict(self): """ :rtype: dict[T, Text] """ - return _json_format.MessageToDict( - _array_job.ArrayJob( + array_job = None + if self.min_successes is not None: + array_job = _array_job.ArrayJob( parallelism=self.parallelism, size=self.size, min_successes=self.min_successes, ) - ) + elif self.min_success_ratio is not None: + array_job = _array_job.ArrayJob( + parallelism=self.parallelism, + size=self.size, + min_success_ratio=self.min_success_ratio, + ) + + return _json_format.MessageToDict(array_job) @classmethod def from_dict(cls, idl_dict): @@ -86,8 +94,15 @@ def from_dict(cls, idl_dict): """ pb2_object = _json_format.Parse(_json.dumps(idl_dict), _array_job.ArrayJob()) - return cls( - parallelism=pb2_object.parallelism, - size=pb2_object.size, - min_successes=pb2_object.min_successes, - ) + if pb2_object.HasField("min_successes"): + return cls( + parallelism=pb2_object.parallelism, + size=pb2_object.size, + min_successes=pb2_object.min_successes, + ) + else: + return cls( + parallelism=pb2_object.parallelism, + size=pb2_object.size, + min_success_ratio=pb2_object.min_success_ratio, + ) diff --git a/tests/flytekit/unit/core/test_map_task.py b/tests/flytekit/unit/core/test_map_task.py index d1f95852c1..4eb44d6e76 100644 --- a/tests/flytekit/unit/core/test_map_task.py +++ b/tests/flytekit/unit/core/test_map_task.py @@ -12,6 +12,18 @@ from flytekit.tools.translator import get_serializable +@pytest.fixture +def serialization_settings(): + default_img = Image(name="default", fqn="test", tag="tag") + return context_manager.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + + @task def t1(a: int) -> str: b = a + 2 @@ -54,18 +66,12 @@ def test_map_task_types(): _ = map_task(t1, metadata=TaskMetadata(retries=1))(a=["invalid", "args"]) -def test_serialization(): +def test_serialization(serialization_settings): maptask = map_task(t1, metadata=TaskMetadata(retries=1)) - default_img = Image(name="default", fqn="test", tag="tag") - serialization_settings = context_manager.SerializationSettings( - project="project", - domain="domain", - version="version", - env=None, - image_config=ImageConfig(default_image=default_img, images=[default_img]), - ) task_spec = get_serializable(OrderedDict(), serialization_settings, maptask) + # By default all map_task tasks will have their custom fields set. + assert task_spec.template.custom["minSuccessRatio"] == 1.0 assert task_spec.template.type == "container_array" assert task_spec.template.task_type_version == 1 assert task_spec.template.container.args == [ @@ -90,7 +96,23 @@ def test_serialization(): ] -def test_serialization_workflow_def(): +@pytest.mark.parametrize( + "custom_fields_dict, expected_custom_fields", + [ + ({}, {"minSuccessRatio": 1.0}), + ({"concurrency": 99}, {"parallelism": "99", "minSuccessRatio": 1.0}), + ({"min_success_ratio": 0.271828}, {"minSuccessRatio": 0.271828}), + ({"concurrency": 42, "min_success_ratio": 0.31415}, {"parallelism": "42", "minSuccessRatio": 0.31415}), + ], +) +def test_serialization_of_custom_fields(custom_fields_dict, expected_custom_fields, serialization_settings): + maptask = map_task(t1, **custom_fields_dict) + task_spec = get_serializable(OrderedDict(), serialization_settings, maptask) + + assert task_spec.template.custom == expected_custom_fields + + +def test_serialization_workflow_def(serialization_settings): @task def complex_task(a: int) -> str: b = a + 2 @@ -106,14 +128,6 @@ def w1(a: typing.List[int]) -> typing.List[str]: def w2(a: typing.List[int]) -> typing.List[str]: return map_task(complex_task, metadata=TaskMetadata(retries=2))(a=a) - default_img = Image(name="default", fqn="test", tag="tag") - serialization_settings = context_manager.SerializationSettings( - project="project", - domain="domain", - version="version", - env=None, - image_config=ImageConfig(default_image=default_img, images=[default_img]), - ) serialized_control_plane_entities = OrderedDict() wf1_spec = get_serializable(serialized_control_plane_entities, serialization_settings, w1) assert wf1_spec.template is not None