From cd8216abe699ba1a04850993ba017a92efd1e7bc Mon Sep 17 00:00:00 2001 From: Paul Dittamo <37558497+pvditt@users.noreply.github.com> Date: Wed, 9 Oct 2024 18:06:28 -0700 Subject: [PATCH] add support for mapping over remote launch plans (#2761) * add support for mapping over remote launch plans Signed-off-by: Paul Dittamo * add unit tests Signed-off-by: Paul Dittamo * lint Signed-off-by: Paul Dittamo * update var name Signed-off-by: Paul Dittamo * utilize TYPE_CHECKING to check for FlyteLaunchPlan Signed-off-by: Paul Dittamo * utilize adding remote node function Signed-off-by: Paul Dittamo * add unit tests Signed-off-by: Paul Dittamo * revert changes to local execute Signed-off-by: Paul Dittamo --------- Signed-off-by: Paul Dittamo --- flytekit/core/array_node.py | 79 +++++++++++++++----- flytekit/core/array_node_map_task.py | 11 ++- flytekit/core/promise.py | 25 ++++++- flytekit/models/interface.py | 23 ++++++ tests/flytekit/unit/core/test_array_node.py | 50 +++++++++++-- tests/flytekit/unit/core/test_promise.py | 8 ++ tests/flytekit/unit/models/test_interface.py | 20 +++++ 7 files changed, 183 insertions(+), 33 deletions(-) diff --git a/flytekit/core/array_node.py b/flytekit/core/array_node.py index 14c2d454c2..9a6f08689a 100644 --- a/flytekit/core/array_node.py +++ b/flytekit/core/array_node.py @@ -1,33 +1,41 @@ import math -from typing import Any, List, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union from flyteidl.core import workflow_pb2 as _core_workflow from flytekit.core import interface as flyte_interface from flytekit.core.context_manager import ExecutionState, FlyteContext -from flytekit.core.interface import transform_interface_to_list_interface, transform_interface_to_typed_interface +from flytekit.core.interface import ( + transform_interface_to_list_interface, + transform_interface_to_typed_interface, +) from flytekit.core.launch_plan import LaunchPlan from flytekit.core.node import Node from flytekit.core.promise import ( Promise, VoidPromise, create_and_link_node, + create_and_link_node_from_remote, flyte_entity_call_handler, translate_inputs_to_literals, ) from flytekit.core.task import TaskMetadata from flytekit.loggers import logger +from flytekit.models import interface as _interface_models from flytekit.models import literals as _literal_models from flytekit.models.core import workflow as _workflow_model from flytekit.models.literals import Literal, LiteralCollection, Scalar ARRAY_NODE_SUBNODE_NAME = "array_node_subnode" +if TYPE_CHECKING: + from flytekit.remote import FlyteLaunchPlan + class ArrayNode: def __init__( self, - target: LaunchPlan, + target: Union[LaunchPlan, "FlyteLaunchPlan"], execution_mode: _core_workflow.ArrayNode.ExecutionMode = _core_workflow.ArrayNode.FULL_STATE, bindings: Optional[List[_literal_models.Binding]] = None, concurrency: Optional[int] = None, @@ -47,6 +55,8 @@ def __init__( :param execution_mode: The execution mode for propeller to use when handling ArrayNode :param metadata: The metadata for the underlying entity """ + from flytekit.remote import FlyteLaunchPlan + self.target = target self._concurrency = concurrency self._execution_mode = execution_mode @@ -60,7 +70,10 @@ def __init__( self._min_success_ratio = min_success_ratio if min_success_ratio is not None else 1.0 self._min_successes = 0 - n_outputs = len(self.target.python_interface.outputs) + if self.target.python_interface: + n_outputs = len(self.target.python_interface.outputs) + else: + n_outputs = len(self.target.interface.outputs) if n_outputs > 1: raise ValueError("Only tasks with a single output are supported in map tasks.") @@ -68,13 +81,19 @@ def __init__( self._bound_inputs: Set[str] = set() output_as_list_of_optionals = min_success_ratio is not None and min_success_ratio != 1 and n_outputs == 1 - collection_interface = transform_interface_to_list_interface( - self.target.python_interface, self._bound_inputs, output_as_list_of_optionals - ) - self._collection_interface = collection_interface + + self._remote_interface = None + if self.target.python_interface: + self._python_interface = transform_interface_to_list_interface( + self.target.python_interface, self._bound_inputs, output_as_list_of_optionals + ) + elif self.target.interface: + self._remote_interface = self.target.interface.transform_interface_to_list() + else: + raise ValueError("No interface found for the target entity.") self.metadata = None - if isinstance(target, LaunchPlan): + if isinstance(target, LaunchPlan) or isinstance(target, FlyteLaunchPlan): if self._execution_mode != _core_workflow.ArrayNode.FULL_STATE: raise ValueError("Only execution version 1 is supported for LaunchPlans.") if metadata: @@ -98,7 +117,14 @@ def name(self) -> str: @property def python_interface(self) -> flyte_interface.Interface: # Part of SupportsNodeCreation interface - return self._collection_interface + return self._python_interface + + @property + def interface(self) -> _interface_models.TypedInterface: + # Required in get_serializable_node + if self._remote_interface: + return self._remote_interface + raise AttributeError("interface attribute is not available") @property def bindings(self) -> List[_literal_models.Binding]: @@ -115,6 +141,9 @@ def flyte_entity(self) -> Any: return self.target def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise]: + if self._remote_interface: + raise ValueError("Mapping over remote entities is not supported in local execution.") + outputs_expected = True if not self.python_interface.outputs: outputs_expected = False @@ -199,17 +228,27 @@ def __call__(self, *args, **kwargs): if not self._bindings: ctx = FlyteContext.current_context() # since a new entity with an updated list interface is not created, we have to work around the mismatch - # between the interface and the inputs - collection_interface = transform_interface_to_list_interface( - self.flyte_entity.python_interface, self._bound_inputs - ) - # don't link the node to the compilation state, since we don't want to add the subnode to the - # workflow as a node + # between the interface and the inputs. Also, don't link the node to the compilation state, + # since we don't want to add the subnode to the workflow as a node + if self._remote_interface: + bound_subnode = create_and_link_node_from_remote( + ctx, + entity=self.flyte_entity, + add_node_to_compilation_state=False, + overridden_interface=self._remote_interface, + **kwargs, + ) + self._bindings = bound_subnode.ref.node.bindings + return create_and_link_node_from_remote( + ctx, + entity=self, + **kwargs, + ) bound_subnode = create_and_link_node( ctx, entity=self.flyte_entity, add_node_to_compilation_state=False, - overridden_interface=collection_interface, + overridden_interface=self.python_interface, node_id=ARRAY_NODE_SUBNODE_NAME, **kwargs, ) @@ -218,7 +257,7 @@ def __call__(self, *args, **kwargs): def array_node( - target: Union[LaunchPlan], + target: Union[LaunchPlan, "FlyteLaunchPlan"], concurrency: Optional[int] = None, min_success_ratio: Optional[float] = None, min_successes: Optional[int] = None, @@ -237,7 +276,9 @@ def array_node( :return: A callable function that takes in keyword arguments and returns a Promise created by flyte_entity_call_handler """ - if not isinstance(target, LaunchPlan): + from flytekit.remote import FlyteLaunchPlan + + if not isinstance(target, LaunchPlan) and not isinstance(target, FlyteLaunchPlan): raise ValueError("Only LaunchPlans are supported for now.") node = ArrayNode( diff --git a/flytekit/core/array_node_map_task.py b/flytekit/core/array_node_map_task.py index 82bb0cbe68..970a5d9a83 100644 --- a/flytekit/core/array_node_map_task.py +++ b/flytekit/core/array_node_map_task.py @@ -5,7 +5,7 @@ import math import os # TODO: use flytekit logger from contextlib import contextmanager -from typing import Any, Dict, List, Optional, Set, Union, cast +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union, cast import typing_extensions from flyteidl.core import tasks_pb2 @@ -31,6 +31,9 @@ from flytekit.types.pickle.pickle import FlytePickleTransformer from flytekit.utils.asyn import loop_manager +if TYPE_CHECKING: + from flytekit.remote import FlyteLaunchPlan + class ArrayNodeMapTask(PythonTask): def __init__( @@ -359,7 +362,7 @@ def _raw_execute(self, **kwargs) -> Any: def map_task( - target: Union[LaunchPlan, PythonFunctionTask], + target: Union[LaunchPlan, PythonFunctionTask, "FlyteLaunchPlan"], concurrency: Optional[int] = None, min_successes: Optional[int] = None, min_success_ratio: float = 1.0, @@ -377,7 +380,9 @@ def map_task( :param min_successes: The minimum number of successful executions :param min_success_ratio: The minimum ratio of successful executions """ - if isinstance(target, LaunchPlan): + from flytekit.remote import FlyteLaunchPlan + + if isinstance(target, LaunchPlan) or isinstance(target, FlyteLaunchPlan): return array_node( target=target, concurrency=concurrency, diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index afd51cd7cc..a193c0702f 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -1070,6 +1070,9 @@ def extract_obj_name(name: str) -> str: def create_and_link_node_from_remote( ctx: FlyteContext, entity: HasFlyteInterface, + overridden_interface: Optional[_interface_models.TypedInterface] = None, + add_node_to_compilation_state: bool = True, + node_id: str = "", _inputs_not_allowed: Optional[Set[str]] = None, _ignorable_inputs: Optional[Set[str]] = None, **kwargs, @@ -1084,6 +1087,11 @@ def create_and_link_node_from_remote( :param ctx: FlyteContext :param entity: RemoteEntity + :param overridden_interface: utilize this interface instead of the one provided by the entity. This is useful for + ArrayNode as there's a mismatch between the underlying interface and inputs + :param add_node_to_compilation_state: bool that enables for nodes to be created but not linked to the workflow. This + is useful when creating nodes nested under other nodes such as ArrayNode + :param node_id: str if provided, this will be used as the node id. :param _inputs_not_allowed: Set of all variable names that should not be provided when using this entity. Useful for Launchplans with `fixed` inputs :param _ignorable_inputs: Set of all variable names that are optional, but if provided will be overridden. Useful @@ -1091,13 +1099,13 @@ def create_and_link_node_from_remote( :param kwargs: Dict[str, Any] default inputs passed from the user to this entity. Can be promises. :return: Optional[Union[Tuple[Promise], Promise, VoidPromise]] """ - if ctx.compilation_state is None: + if ctx.compilation_state is None and add_node_to_compilation_state: raise _user_exceptions.FlyteAssertion("Cannot create node when not compiling...") used_inputs = set() bindings = [] - typed_interface = entity.interface + typed_interface = overridden_interface or entity.interface if _inputs_not_allowed: inputs_not_allowed_specified = _inputs_not_allowed.intersection(kwargs.keys()) @@ -1148,14 +1156,23 @@ def create_and_link_node_from_remote( # These will be our core Nodes until we can amend the Promise to use NodeOutputs that reference our Nodes upstream_nodes = list(set([n for n in nodes if n.id != _common_constants.GLOBAL_INPUT_NODE_ID])) + # if not adding to compilation state, we don't need to generate a unique node id + node_id = node_id or ( + f"{ctx.compilation_state.prefix}n{len(ctx.compilation_state.nodes)}" + if add_node_to_compilation_state and ctx.compilation_state + else node_id + ) + flytekit_node = Node( - id=f"{ctx.compilation_state.prefix}n{len(ctx.compilation_state.nodes)}", + id=node_id, metadata=entity.construct_node_metadata(), bindings=sorted(bindings, key=lambda b: b.var), upstream_nodes=upstream_nodes, flyte_entity=entity, ) - ctx.compilation_state.add_node(flytekit_node) + + if add_node_to_compilation_state and ctx.compilation_state: + ctx.compilation_state.add_node(flytekit_node) if len(typed_interface.outputs) == 0: return VoidPromise(entity.name, NodeOutput(node=flytekit_node, var="placeholder")) diff --git a/flytekit/models/interface.py b/flytekit/models/interface.py index f80bfb9e52..005d6f4a89 100644 --- a/flytekit/models/interface.py +++ b/flytekit/models/interface.py @@ -2,6 +2,7 @@ from flyteidl.core import artifact_id_pb2 as art_id from flyteidl.core import interface_pb2 as _interface_pb2 +from flyteidl.core import types_pb2 as _types_pb2 from flytekit.models import common as _common from flytekit.models import literals as _literals @@ -64,6 +65,17 @@ def to_flyte_idl(self): artifact_tag=self.artifact_tag, ) + def to_flyte_idl_list(self): + """ + :rtype: flyteidl.core.interface_pb2.Variable + """ + return _interface_pb2.Variable( + type=_types_pb2.LiteralType(collection_type=self.type.to_flyte_idl()), + description=self.description, + artifact_partial_id=self.artifact_partial_id, + artifact_tag=self.artifact_tag, + ) + @classmethod def from_flyte_idl(cls, variable_proto) -> _interface_pb2.Variable: """ @@ -146,6 +158,17 @@ def from_flyte_idl(cls, proto: _interface_pb2.TypedInterface) -> "TypedInterface outputs={k: Variable.from_flyte_idl(v) for k, v in proto.outputs.variables.items()}, ) + def transform_interface_to_list(self) -> "TypedInterface": + """ + Takes a single task interface and interpolates it to an array interface - to allow performing distributed + python map like functions + """ + list_interface = _interface_pb2.TypedInterface( + inputs=_interface_pb2.VariableMap(variables={k: v.to_flyte_idl_list() for k, v in self.inputs.items()}), + outputs=_interface_pb2.VariableMap(variables={k: v.to_flyte_idl_list() for k, v in self.outputs.items()}), + ) + return self.from_flyte_idl(list_interface) + class Parameter(_common.FlyteIdlEntity): def __init__( diff --git a/tests/flytekit/unit/core/test_array_node.py b/tests/flytekit/unit/core/test_array_node.py index 2f4c3145ba..aced10158c 100644 --- a/tests/flytekit/unit/core/test_array_node.py +++ b/tests/flytekit/unit/core/test_array_node.py @@ -9,7 +9,9 @@ from flytekit.core.array_node import array_node from flytekit.core.array_node_map_task import map_task from flytekit.models.core import identifier as identifier_models -from flytekit.tools.translator import get_serializable +from flytekit.remote import FlyteLaunchPlan +from flytekit.remote.interface import TypedInterface +from flytekit.tools.translator import gather_dependent_entities, get_serializable @pytest.fixture @@ -40,13 +42,45 @@ def parent_wf(a: int, b: typing.Union[int, str], c: int = 2) -> int: lp = LaunchPlan.get_default_launch_plan(ctx, parent_wf) -@workflow -def grandparent_wf() -> typing.List[int]: - return array_node(lp, concurrency=10, min_success_ratio=0.9)(a=[1, 3, 5], b=["two", 4, "six"], c=[7, 8, 9]) +def get_grandparent_wf(serialization_settings): + @workflow + def grandparent_wf() -> typing.List[int]: + return array_node(lp, concurrency=10, min_success_ratio=0.9)(a=[1, 3, 5], b=["two", 4, "six"], c=[7, 8, 9]) + + return grandparent_wf + + +def get_grandparent_remote_wf(serialization_settings): + serialized = OrderedDict() + lp_model = get_serializable(serialized, serialization_settings, lp) + + task_templates, wf_specs, lp_specs = gather_dependent_entities(serialized) + for wf_id, spec in wf_specs.items(): + break + + remote_lp = FlyteLaunchPlan.promote_from_model(lp_model.id, lp_model.spec) + # To pretend that we've fetched this launch plan from Admin, also fill in the Flyte interface, which isn't + # part of the IDL object but is something FlyteRemote does + remote_lp._interface = TypedInterface.promote_from_model(spec.template.interface) + @workflow + def grandparent_remote_wf() -> typing.List[int]: + return array_node( + remote_lp, concurrency=10, min_success_ratio=0.9 + )(a=[1, 3, 5], b=["two", 4, "six"], c=[7, 8, 9]) + + return grandparent_remote_wf -def test_lp_serialization(serialization_settings): - wf_spec = get_serializable(OrderedDict(), serialization_settings, grandparent_wf) + +@pytest.mark.parametrize( + "target", + [ + get_grandparent_wf, + get_grandparent_remote_wf, + ], +) +def test_lp_serialization(target, serialization_settings): + wf_spec = get_serializable(OrderedDict(), serialization_settings, target(serialization_settings)) assert len(wf_spec.template.nodes) == 1 top_level = wf_spec.template.nodes[0] @@ -56,7 +90,9 @@ def test_lp_serialization(serialization_settings): assert binding.scalar.primitive.integer is not None assert top_level.inputs[1].var == "b" for binding in top_level.inputs[1].binding.collection.bindings: - assert binding.scalar.union is not None + assert (binding.scalar.union is not None or + binding.scalar.primitive.integer is not None or + binding.scalar.primitive.string_value is not None) assert len(top_level.inputs[1].binding.collection.bindings) == 3 assert top_level.inputs[2].var == "c" assert len(top_level.inputs[2].binding.collection.bindings) == 3 diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index 455d53a5eb..59faefdc38 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -77,9 +77,17 @@ def t1() -> None: def t2(a: int) -> int: return a + ctx = context_manager.FlyteContext.current_context().with_compilation_state(CompilationState(prefix="")) p = create_and_link_node_from_remote(ctx, t2, a=3) assert p.ref.var == "o0" assert len(p.ref.node.bindings) == 1 + assert len(ctx.compilation_state.nodes) == 1 + + ctx = context_manager.FlyteContext.current_context().with_compilation_state(CompilationState(prefix="")) + p = create_and_link_node_from_remote(ctx, t2, add_node_to_compilation_state=False, a=3) + assert p.ref.var == "o0" + assert len(p.ref.node.bindings) == 1 + assert len(ctx.compilation_state.nodes) == 0 def test_create_and_link_node_from_remote_ignore(): diff --git a/tests/flytekit/unit/models/test_interface.py b/tests/flytekit/unit/models/test_interface.py index 03f89f69e4..f3c3aa1100 100644 --- a/tests/flytekit/unit/models/test_interface.py +++ b/tests/flytekit/unit/models/test_interface.py @@ -12,6 +12,16 @@ def test_variable_type(literal_type): assert var == interface.Variable.from_flyte_idl(var.to_flyte_idl()) +@pytest.mark.parametrize("literal_type", LIST_OF_ALL_LITERAL_TYPES) +def test_variable_type_list(literal_type): + var = interface.Variable(type=literal_type, description="abc") + collection_var = interface.Variable( + type=types.LiteralType(collection_type=literal_type), + description="abc", + ) + assert collection_var == interface.Variable.from_flyte_idl(var.to_flyte_idl_list()) + + @pytest.mark.parametrize("literal_type", LIST_OF_ALL_LITERAL_TYPES) def test_typed_interface(literal_type): typed_interface = interface.TypedInterface( @@ -41,6 +51,16 @@ def test_typed_interface(literal_type): assert len(deserialized_typed_interface.inputs) == 1 assert len(deserialized_typed_interface.outputs) == 2 + deserialized_typed_interface_list = typed_interface.transform_interface_to_list() + assert deserialized_typed_interface_list.inputs["a"].type == types.LiteralType(collection_type=literal_type) + assert deserialized_typed_interface_list.outputs["b"].type == types.LiteralType(collection_type=literal_type) + assert deserialized_typed_interface_list.outputs["c"].type == types.LiteralType(collection_type=literal_type) + assert deserialized_typed_interface_list.inputs["a"].description == "description1" + assert deserialized_typed_interface_list.outputs["b"].description == "description2" + assert deserialized_typed_interface_list.outputs["c"].description == "description3" + assert len(deserialized_typed_interface_list.inputs) == 1 + assert len(deserialized_typed_interface_list.outputs) == 2 + def test_parameter(): v = interface.Variable(types.LiteralType(simple=types.SimpleType.BOOLEAN), "asdf asdf asdf")