Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for mapping over remote launch plans #2761

Merged
merged 9 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 27 additions & 18 deletions flytekit/core/array_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@

from flytekit.core import interface as flyte_interface
from flytekit.core.context_manager import ExecutionState, FlyteContext
from flytekit.core.interface import transform_interface_to_list_interface, transform_interface_to_typed_interface
from flytekit.core.interface import (
transform_interface_to_list_interface,
transform_interface_to_typed_interface,
transform_typed_interface_to_interface,
)
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.node import Node
from flytekit.core.promise import (
Expand All @@ -27,7 +31,7 @@
class ArrayNode:
def __init__(
self,
target: LaunchPlan,
target: Any,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use Union to describe the types? I'd rather keep ArrayNode well-typed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup - getting an update for this. Was working around a circular imports error. Can use if TYPE_CHECKING

execution_mode: _core_workflow.ArrayNode.ExecutionMode = _core_workflow.ArrayNode.FULL_STATE,
bindings: Optional[List[_literal_models.Binding]] = None,
concurrency: Optional[int] = None,
Expand All @@ -47,6 +51,8 @@ def __init__(
:param execution_mode: The execution mode for propeller to use when handling ArrayNode
:param metadata: The metadata for the underlying entity
"""
from flytekit.remote import FlyteLaunchPlan

self.target = target
self._concurrency = concurrency
self._execution_mode = execution_mode
Expand All @@ -60,7 +66,12 @@ def __init__(
self._min_success_ratio = min_success_ratio if min_success_ratio is not None else 1.0
self._min_successes = 0

n_outputs = len(self.target.python_interface.outputs)
target_interface = self.target.python_interface or transform_typed_interface_to_interface(self.target.interface)
if target_interface is None:
raise ValueError("No interface found for the target entity.")
self._target_interface: flyte_interface.Interface = target_interface

n_outputs = len(self._target_interface.outputs)
if n_outputs > 1:
raise ValueError("Only tasks with a single output are supported in map tasks.")

Expand All @@ -69,12 +80,12 @@ def __init__(

output_as_list_of_optionals = min_success_ratio is not None and min_success_ratio != 1 and n_outputs == 1
collection_interface = transform_interface_to_list_interface(
self.target.python_interface, self._bound_inputs, output_as_list_of_optionals
self._target_interface, self._bound_inputs, output_as_list_of_optionals
)
self._collection_interface = collection_interface

self.metadata = None
if isinstance(target, LaunchPlan):
if isinstance(target, LaunchPlan) or isinstance(target, FlyteLaunchPlan):
if self._execution_mode != _core_workflow.ArrayNode.FULL_STATE:
raise ValueError("Only execution version 1 is supported for LaunchPlans.")
if metadata:
Expand Down Expand Up @@ -124,12 +135,12 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
k = binding.var
if k not in self._bound_inputs:
v = kwargs[k]
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], self.target.python_interface.inputs[k]):
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], self._target_interface.inputs[k]):
mapped_entity_count = len(v)
break
else:
raise ValueError(
f"Expected a list of {self.target.python_interface.inputs[k]} but got {type(v)} instead."
f"Expected a list of {self._target_interface.inputs[k]} but got {type(v)} instead."
)

failed_count = 0
Expand All @@ -150,12 +161,12 @@ def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Pr
single_instance_inputs[k] = kwargs[k]

# translate Python native inputs to Flyte literals
typed_interface = transform_interface_to_typed_interface(self.target.python_interface)
typed_interface = transform_interface_to_typed_interface(self._target_interface)
literal_map = translate_inputs_to_literals(
ctx,
incoming_values=single_instance_inputs,
flyte_interface_types={} if typed_interface is None else typed_interface.inputs,
native_types=self.target.python_interface.inputs,
native_types=self._target_interface.inputs,
)
kwargs_literals = {k1: Promise(var=k1, val=v1) for k1, v1 in literal_map.items()}

Expand Down Expand Up @@ -199,17 +210,13 @@ def __call__(self, *args, **kwargs):
if not self._bindings:
ctx = FlyteContext.current_context()
# since a new entity with an updated list interface is not created, we have to work around the mismatch
# between the interface and the inputs
collection_interface = transform_interface_to_list_interface(
self.flyte_entity.python_interface, self._bound_inputs
)
# don't link the node to the compilation state, since we don't want to add the subnode to the
# workflow as a node
# between the interface and the inputs. Also, don't link the node to the compilation state,
# since we don't want to add the subnode to the workflow as a node
bound_subnode = create_and_link_node(
ctx,
entity=self.flyte_entity,
add_node_to_compilation_state=False,
overridden_interface=collection_interface,
overridden_interface=self.python_interface,
node_id=ARRAY_NODE_SUBNODE_NAME,
**kwargs,
)
Expand All @@ -218,7 +225,7 @@ def __call__(self, *args, **kwargs):


def array_node(
target: Union[LaunchPlan],
target: Any,
concurrency: Optional[int] = None,
min_success_ratio: Optional[float] = None,
min_successes: Optional[int] = None,
Expand All @@ -237,7 +244,9 @@ def array_node(
:return: A callable function that takes in keyword arguments and returns a Promise created by
flyte_entity_call_handler
"""
if not isinstance(target, LaunchPlan):
from flytekit.remote import FlyteLaunchPlan

if not isinstance(target, LaunchPlan) and not isinstance(target, FlyteLaunchPlan):
raise ValueError("Only LaunchPlans are supported for now.")

node = ArrayNode(
Expand Down
6 changes: 4 additions & 2 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def _raw_execute(self, **kwargs) -> Any:


def map_task(
target: Union[LaunchPlan, PythonFunctionTask],
target: Any,
concurrency: Optional[int] = None,
min_successes: Optional[int] = None,
min_success_ratio: float = 1.0,
Expand All @@ -374,7 +374,9 @@ def map_task(
:param min_successes: The minimum number of successful executions
:param min_success_ratio: The minimum ratio of successful executions
"""
if isinstance(target, LaunchPlan):
from flytekit.remote import FlyteLaunchPlan

if isinstance(target, LaunchPlan) or isinstance(target, FlyteLaunchPlan):
return array_node(
target=target,
concurrency=concurrency,
Expand Down
24 changes: 24 additions & 0 deletions flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,30 @@ def transform_inputs_to_parameters(
return _interface_models.ParameterMap(params)


def transform_typed_interface_to_interface(
typed_interface: typing.Optional[_interface_models.TypedInterface],
) -> typing.Optional[Interface]:
"""
Transform the given FlyteIDL's typed interface to a simple python native interface
:param typed_interface: the interface object
"""
if typed_interface is None:
return None

input_map = {}
for k, v in typed_interface.inputs.items():
input_map[k] = TypeEngine.guess_python_type(v.type)

output_map = {}
for k, v in typed_interface.outputs.items():
output_map[k] = TypeEngine.guess_python_type(v.type)

return Interface(
inputs=input_map,
outputs=output_map,
)


def transform_interface_to_typed_interface(
interface: typing.Optional[Interface],
allow_partial_artifact_id_binding: bool = False,
Expand Down
46 changes: 40 additions & 6 deletions tests/flytekit/unit/core/test_array_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from flytekit.core.array_node import array_node
from flytekit.core.array_node_map_task import map_task
from flytekit.models.core import identifier as identifier_models
from flytekit.tools.translator import get_serializable
from flytekit.remote import FlyteLaunchPlan
from flytekit.remote.interface import TypedInterface
from flytekit.tools.translator import gather_dependent_entities, get_serializable


@pytest.fixture
Expand Down Expand Up @@ -38,13 +40,45 @@ def parent_wf(a: int, b: typing.Union[int, str], c: int = 2) -> int:
lp = LaunchPlan.get_default_launch_plan(current_context(), parent_wf)


@workflow
def grandparent_wf() -> typing.List[int]:
return array_node(lp, concurrency=10, min_success_ratio=0.9)(a=[1, 3, 5], b=["two", 4, "six"], c=[7, 8, 9])
def get_grandparent_wf(serialization_settings):
@workflow
def grandparent_wf() -> typing.List[int]:
return array_node(lp, concurrency=10, min_success_ratio=0.9)(a=[1, 3, 5], b=["two", 4, "six"], c=[7, 8, 9])

return grandparent_wf


def get_grandparent_remote_wf(serialization_settings):
serialized = OrderedDict()
lp_model = get_serializable(serialized, serialization_settings, lp)

task_templates, wf_specs, lp_specs = gather_dependent_entities(serialized)
for wf_id, spec in wf_specs.items():
break

remote_lp = FlyteLaunchPlan.promote_from_model(lp_model.id, lp_model.spec)
# To pretend that we've fetched this launch plan from Admin, also fill in the Flyte interface, which isn't
# part of the IDL object but is something FlyteRemote does
remote_lp._interface = TypedInterface.promote_from_model(spec.template.interface)

@workflow
def grandparent_remote_wf() -> typing.List[int]:
return array_node(
remote_lp, concurrency=10, min_success_ratio=0.9
)(a=[1, 3, 5], b=["two", 4, "six"], c=[7, 8, 9])

return grandparent_remote_wf

def test_lp_serialization(serialization_settings):
wf_spec = get_serializable(OrderedDict(), serialization_settings, grandparent_wf)

@pytest.mark.parametrize(
"target",
[
get_grandparent_wf,
get_grandparent_remote_wf,
],
)
def test_lp_serialization(target, serialization_settings):
wf_spec = get_serializable(OrderedDict(), serialization_settings, target(serialization_settings))
assert len(wf_spec.template.nodes) == 1

top_level = wf_spec.template.nodes[0]
Expand Down
28 changes: 28 additions & 0 deletions tests/flytekit/unit/core/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
transform_inputs_to_parameters,
transform_interface_to_list_interface,
transform_interface_to_typed_interface,
transform_typed_interface_to_interface,
transform_variable_map,
)
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Void
from flytekit.types.file import FlyteFile
from flytekit.types.pickle.pickle import FlytePickle


def test_extract_only():
Expand Down Expand Up @@ -409,3 +411,29 @@ def t() -> str:
mt = map_task(t, min_success_ratio=min_success_ratio)

assert mt.python_interface.outputs["o0"] == typing.List[expected_type]


@pytest.mark.parametrize(
"type_interface, expected_python_interface_inputs, expected_python_interface_outputs",
[
(str, {"a": str}, {"o0": str}),
(int, {"a": int}, {"o0": int}),
(bool, {"a": bool}, {"o0": bool}),
(typing.List[str], {"a": typing.List[str]}, {"o0": typing.List[str]}),
(typing.List[typing.List[int]], {"a": typing.List[typing.List[int]]}, {"o0": typing.List[typing.List[int]]}),
(typing.Union[str, int], {"a": typing.Union[str, int]}, {"o0": typing.Union[str, int]}),
(typing.Dict[str, int], {"a": typing.Dict[str, int]}, {"o0": typing.Dict[str, int]}),
(typing.Optional[float], {"a": typing.Optional[float]}, {"o0": typing.Optional[float]}),
(bytes, {"a": FlytePickle}, {"o0": FlytePickle}),
],
)
def test_transform_typed_interface_to_interface(type_interface, expected_python_interface_inputs, expected_python_interface_outputs):
@task
def t(a: type_interface) -> type_interface:
return a

typed_interface = transform_interface_to_typed_interface(t.python_interface)
python_interface = transform_typed_interface_to_interface(typed_interface)

assert expected_python_interface_inputs == python_interface.inputs
assert expected_python_interface_outputs == python_interface.outputs
Loading