Skip to content

Commit

Permalink
add support for mapping over remote launch plans (#2761)
Browse files Browse the repository at this point in the history
* add support for mapping over remote launch plans

Signed-off-by: Paul Dittamo <pvdittamo@gmail.com>

* add unit tests

Signed-off-by: Paul Dittamo <pvdittamo@gmail.com>

* lint

Signed-off-by: Paul Dittamo <pvdittamo@gmail.com>

* update var name

Signed-off-by: Paul Dittamo <pvdittamo@gmail.com>

* utilize TYPE_CHECKING to check for FlyteLaunchPlan

Signed-off-by: Paul Dittamo <pvdittamo@gmail.com>

* utilize adding remote node function

Signed-off-by: Paul Dittamo <pvdittamo@gmail.com>

* add unit tests

Signed-off-by: Paul Dittamo <pvdittamo@gmail.com>

* revert changes to local execute

Signed-off-by: Paul Dittamo <pvdittamo@gmail.com>

---------

Signed-off-by: Paul Dittamo <pvdittamo@gmail.com>
  • Loading branch information
pvditt authored Oct 10, 2024
1 parent f759a3c commit cd8216a
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 33 deletions.
79 changes: 60 additions & 19 deletions flytekit/core/array_node.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,41 @@
import math
from typing import Any, List, Optional, Set, Tuple, Union
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union

from flyteidl.core import workflow_pb2 as _core_workflow

from flytekit.core import interface as flyte_interface
from flytekit.core.context_manager import ExecutionState, FlyteContext
from flytekit.core.interface import transform_interface_to_list_interface, transform_interface_to_typed_interface
from flytekit.core.interface import (
transform_interface_to_list_interface,
transform_interface_to_typed_interface,
)
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.node import Node
from flytekit.core.promise import (
Promise,
VoidPromise,
create_and_link_node,
create_and_link_node_from_remote,
flyte_entity_call_handler,
translate_inputs_to_literals,
)
from flytekit.core.task import TaskMetadata
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
from flytekit.models import literals as _literal_models
from flytekit.models.core import workflow as _workflow_model
from flytekit.models.literals import Literal, LiteralCollection, Scalar

ARRAY_NODE_SUBNODE_NAME = "array_node_subnode"

if TYPE_CHECKING:
from flytekit.remote import FlyteLaunchPlan


class ArrayNode:
def __init__(
self,
target: LaunchPlan,
target: Union[LaunchPlan, "FlyteLaunchPlan"],
execution_mode: _core_workflow.ArrayNode.ExecutionMode = _core_workflow.ArrayNode.FULL_STATE,
bindings: Optional[List[_literal_models.Binding]] = None,
concurrency: Optional[int] = None,
Expand All @@ -47,6 +55,8 @@ def __init__(
:param execution_mode: The execution mode for propeller to use when handling ArrayNode
:param metadata: The metadata for the underlying entity
"""
from flytekit.remote import FlyteLaunchPlan

self.target = target
self._concurrency = concurrency
self._execution_mode = execution_mode
Expand All @@ -60,21 +70,30 @@ def __init__(
self._min_success_ratio = min_success_ratio if min_success_ratio is not None else 1.0
self._min_successes = 0

n_outputs = len(self.target.python_interface.outputs)
if self.target.python_interface:
n_outputs = len(self.target.python_interface.outputs)
else:
n_outputs = len(self.target.interface.outputs)
if n_outputs > 1:
raise ValueError("Only tasks with a single output are supported in map tasks.")

# TODO - bound inputs are not supported at the moment
self._bound_inputs: Set[str] = set()

output_as_list_of_optionals = min_success_ratio is not None and min_success_ratio != 1 and n_outputs == 1
collection_interface = transform_interface_to_list_interface(
self.target.python_interface, self._bound_inputs, output_as_list_of_optionals
)
self._collection_interface = collection_interface

self._remote_interface = None
if self.target.python_interface:
self._python_interface = transform_interface_to_list_interface(
self.target.python_interface, self._bound_inputs, output_as_list_of_optionals
)
elif self.target.interface:
self._remote_interface = self.target.interface.transform_interface_to_list()
else:
raise ValueError("No interface found for the target entity.")

self.metadata = None
if isinstance(target, LaunchPlan):
if isinstance(target, LaunchPlan) or isinstance(target, FlyteLaunchPlan):
if self._execution_mode != _core_workflow.ArrayNode.FULL_STATE:
raise ValueError("Only execution version 1 is supported for LaunchPlans.")
if metadata:
Expand All @@ -98,7 +117,14 @@ def name(self) -> str:
@property
def python_interface(self) -> flyte_interface.Interface:
# Part of SupportsNodeCreation interface
return self._collection_interface
return self._python_interface

@property
def interface(self) -> _interface_models.TypedInterface:
# Required in get_serializable_node
if self._remote_interface:
return self._remote_interface
raise AttributeError("interface attribute is not available")

@property
def bindings(self) -> List[_literal_models.Binding]:
Expand All @@ -115,6 +141,9 @@ def flyte_entity(self) -> Any:
return self.target

def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]:
if self._remote_interface:
raise ValueError("Mapping over remote entities is not supported in local execution.")

outputs_expected = True
if not self.python_interface.outputs:
outputs_expected = False
Expand Down Expand Up @@ -199,17 +228,27 @@ def __call__(self, *args, **kwargs):
if not self._bindings:
ctx = FlyteContext.current_context()
# since a new entity with an updated list interface is not created, we have to work around the mismatch
# between the interface and the inputs
collection_interface = transform_interface_to_list_interface(
self.flyte_entity.python_interface, self._bound_inputs
)
# don't link the node to the compilation state, since we don't want to add the subnode to the
# workflow as a node
# between the interface and the inputs. Also, don't link the node to the compilation state,
# since we don't want to add the subnode to the workflow as a node
if self._remote_interface:
bound_subnode = create_and_link_node_from_remote(
ctx,
entity=self.flyte_entity,
add_node_to_compilation_state=False,
overridden_interface=self._remote_interface,
**kwargs,
)
self._bindings = bound_subnode.ref.node.bindings
return create_and_link_node_from_remote(
ctx,
entity=self,
**kwargs,
)
bound_subnode = create_and_link_node(
ctx,
entity=self.flyte_entity,
add_node_to_compilation_state=False,
overridden_interface=collection_interface,
overridden_interface=self.python_interface,
node_id=ARRAY_NODE_SUBNODE_NAME,
**kwargs,
)
Expand All @@ -218,7 +257,7 @@ def __call__(self, *args, **kwargs):


def array_node(
target: Union[LaunchPlan],
target: Union[LaunchPlan, "FlyteLaunchPlan"],
concurrency: Optional[int] = None,
min_success_ratio: Optional[float] = None,
min_successes: Optional[int] = None,
Expand All @@ -237,7 +276,9 @@ def array_node(
:return: A callable function that takes in keyword arguments and returns a Promise created by
flyte_entity_call_handler
"""
if not isinstance(target, LaunchPlan):
from flytekit.remote import FlyteLaunchPlan

if not isinstance(target, LaunchPlan) and not isinstance(target, FlyteLaunchPlan):
raise ValueError("Only LaunchPlans are supported for now.")

node = ArrayNode(
Expand Down
11 changes: 8 additions & 3 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import math
import os # TODO: use flytekit logger
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Set, Union, cast
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union, cast

import typing_extensions
from flyteidl.core import tasks_pb2
Expand All @@ -31,6 +31,9 @@
from flytekit.types.pickle.pickle import FlytePickleTransformer
from flytekit.utils.asyn import loop_manager

if TYPE_CHECKING:
from flytekit.remote import FlyteLaunchPlan


class ArrayNodeMapTask(PythonTask):
def __init__(
Expand Down Expand Up @@ -359,7 +362,7 @@ def _raw_execute(self, **kwargs) -> Any:


def map_task(
target: Union[LaunchPlan, PythonFunctionTask],
target: Union[LaunchPlan, PythonFunctionTask, "FlyteLaunchPlan"],
concurrency: Optional[int] = None,
min_successes: Optional[int] = None,
min_success_ratio: float = 1.0,
Expand All @@ -377,7 +380,9 @@ def map_task(
:param min_successes: The minimum number of successful executions
:param min_success_ratio: The minimum ratio of successful executions
"""
if isinstance(target, LaunchPlan):
from flytekit.remote import FlyteLaunchPlan

if isinstance(target, LaunchPlan) or isinstance(target, FlyteLaunchPlan):
return array_node(
target=target,
concurrency=concurrency,
Expand Down
25 changes: 21 additions & 4 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,9 @@ def extract_obj_name(name: str) -> str:
def create_and_link_node_from_remote(
ctx: FlyteContext,
entity: HasFlyteInterface,
overridden_interface: Optional[_interface_models.TypedInterface] = None,
add_node_to_compilation_state: bool = True,
node_id: str = "",
_inputs_not_allowed: Optional[Set[str]] = None,
_ignorable_inputs: Optional[Set[str]] = None,
**kwargs,
Expand All @@ -1084,20 +1087,25 @@ def create_and_link_node_from_remote(
:param ctx: FlyteContext
:param entity: RemoteEntity
:param overridden_interface: utilize this interface instead of the one provided by the entity. This is useful for
ArrayNode as there's a mismatch between the underlying interface and inputs
:param add_node_to_compilation_state: bool that enables for nodes to be created but not linked to the workflow. This
is useful when creating nodes nested under other nodes such as ArrayNode
:param node_id: str if provided, this will be used as the node id.
:param _inputs_not_allowed: Set of all variable names that should not be provided when using this entity.
Useful for Launchplans with `fixed` inputs
:param _ignorable_inputs: Set of all variable names that are optional, but if provided will be overridden. Useful
for launchplans with `default` inputs
:param kwargs: Dict[str, Any] default inputs passed from the user to this entity. Can be promises.
:return: Optional[Union[Tuple[Promise], Promise, VoidPromise]]
"""
if ctx.compilation_state is None:
if ctx.compilation_state is None and add_node_to_compilation_state:
raise _user_exceptions.FlyteAssertion("Cannot create node when not compiling...")

used_inputs = set()
bindings = []

typed_interface = entity.interface
typed_interface = overridden_interface or entity.interface

if _inputs_not_allowed:
inputs_not_allowed_specified = _inputs_not_allowed.intersection(kwargs.keys())
Expand Down Expand Up @@ -1148,14 +1156,23 @@ def create_and_link_node_from_remote(
# These will be our core Nodes until we can amend the Promise to use NodeOutputs that reference our Nodes
upstream_nodes = list(set([n for n in nodes if n.id != _common_constants.GLOBAL_INPUT_NODE_ID]))

# if not adding to compilation state, we don't need to generate a unique node id
node_id = node_id or (
f"{ctx.compilation_state.prefix}n{len(ctx.compilation_state.nodes)}"
if add_node_to_compilation_state and ctx.compilation_state
else node_id
)

flytekit_node = Node(
id=f"{ctx.compilation_state.prefix}n{len(ctx.compilation_state.nodes)}",
id=node_id,
metadata=entity.construct_node_metadata(),
bindings=sorted(bindings, key=lambda b: b.var),
upstream_nodes=upstream_nodes,
flyte_entity=entity,
)
ctx.compilation_state.add_node(flytekit_node)

if add_node_to_compilation_state and ctx.compilation_state:
ctx.compilation_state.add_node(flytekit_node)

if len(typed_interface.outputs) == 0:
return VoidPromise(entity.name, NodeOutput(node=flytekit_node, var="placeholder"))
Expand Down
23 changes: 23 additions & 0 deletions flytekit/models/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from flyteidl.core import artifact_id_pb2 as art_id
from flyteidl.core import interface_pb2 as _interface_pb2
from flyteidl.core import types_pb2 as _types_pb2

from flytekit.models import common as _common
from flytekit.models import literals as _literals
Expand Down Expand Up @@ -64,6 +65,17 @@ def to_flyte_idl(self):
artifact_tag=self.artifact_tag,
)

def to_flyte_idl_list(self):
"""
:rtype: flyteidl.core.interface_pb2.Variable
"""
return _interface_pb2.Variable(
type=_types_pb2.LiteralType(collection_type=self.type.to_flyte_idl()),
description=self.description,
artifact_partial_id=self.artifact_partial_id,
artifact_tag=self.artifact_tag,
)

@classmethod
def from_flyte_idl(cls, variable_proto) -> _interface_pb2.Variable:
"""
Expand Down Expand Up @@ -146,6 +158,17 @@ def from_flyte_idl(cls, proto: _interface_pb2.TypedInterface) -> "TypedInterface
outputs={k: Variable.from_flyte_idl(v) for k, v in proto.outputs.variables.items()},
)

def transform_interface_to_list(self) -> "TypedInterface":
"""
Takes a single task interface and interpolates it to an array interface - to allow performing distributed
python map like functions
"""
list_interface = _interface_pb2.TypedInterface(
inputs=_interface_pb2.VariableMap(variables={k: v.to_flyte_idl_list() for k, v in self.inputs.items()}),
outputs=_interface_pb2.VariableMap(variables={k: v.to_flyte_idl_list() for k, v in self.outputs.items()}),
)
return self.from_flyte_idl(list_interface)


class Parameter(_common.FlyteIdlEntity):
def __init__(
Expand Down
50 changes: 43 additions & 7 deletions tests/flytekit/unit/core/test_array_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from flytekit.core.array_node import array_node
from flytekit.core.array_node_map_task import map_task
from flytekit.models.core import identifier as identifier_models
from flytekit.tools.translator import get_serializable
from flytekit.remote import FlyteLaunchPlan
from flytekit.remote.interface import TypedInterface
from flytekit.tools.translator import gather_dependent_entities, get_serializable


@pytest.fixture
Expand Down Expand Up @@ -40,13 +42,45 @@ def parent_wf(a: int, b: typing.Union[int, str], c: int = 2) -> int:
lp = LaunchPlan.get_default_launch_plan(ctx, parent_wf)


@workflow
def grandparent_wf() -> typing.List[int]:
return array_node(lp, concurrency=10, min_success_ratio=0.9)(a=[1, 3, 5], b=["two", 4, "six"], c=[7, 8, 9])
def get_grandparent_wf(serialization_settings):
@workflow
def grandparent_wf() -> typing.List[int]:
return array_node(lp, concurrency=10, min_success_ratio=0.9)(a=[1, 3, 5], b=["two", 4, "six"], c=[7, 8, 9])

return grandparent_wf


def get_grandparent_remote_wf(serialization_settings):
serialized = OrderedDict()
lp_model = get_serializable(serialized, serialization_settings, lp)

task_templates, wf_specs, lp_specs = gather_dependent_entities(serialized)
for wf_id, spec in wf_specs.items():
break

remote_lp = FlyteLaunchPlan.promote_from_model(lp_model.id, lp_model.spec)
# To pretend that we've fetched this launch plan from Admin, also fill in the Flyte interface, which isn't
# part of the IDL object but is something FlyteRemote does
remote_lp._interface = TypedInterface.promote_from_model(spec.template.interface)

@workflow
def grandparent_remote_wf() -> typing.List[int]:
return array_node(
remote_lp, concurrency=10, min_success_ratio=0.9
)(a=[1, 3, 5], b=["two", 4, "six"], c=[7, 8, 9])

return grandparent_remote_wf

def test_lp_serialization(serialization_settings):
wf_spec = get_serializable(OrderedDict(), serialization_settings, grandparent_wf)

@pytest.mark.parametrize(
"target",
[
get_grandparent_wf,
get_grandparent_remote_wf,
],
)
def test_lp_serialization(target, serialization_settings):
wf_spec = get_serializable(OrderedDict(), serialization_settings, target(serialization_settings))
assert len(wf_spec.template.nodes) == 1

top_level = wf_spec.template.nodes[0]
Expand All @@ -56,7 +90,9 @@ def test_lp_serialization(serialization_settings):
assert binding.scalar.primitive.integer is not None
assert top_level.inputs[1].var == "b"
for binding in top_level.inputs[1].binding.collection.bindings:
assert binding.scalar.union is not None
assert (binding.scalar.union is not None or
binding.scalar.primitive.integer is not None or
binding.scalar.primitive.string_value is not None)
assert len(top_level.inputs[1].binding.collection.bindings) == 3
assert top_level.inputs[2].var == "c"
assert len(top_level.inputs[2].binding.collection.bindings) == 3
Expand Down
Loading

0 comments on commit cd8216a

Please sign in to comment.