diff --git a/flytekit/annotated/node.py b/flytekit/annotated/node.py index dad5ff6c5f..982234e3ed 100644 --- a/flytekit/annotated/node.py +++ b/flytekit/annotated/node.py @@ -6,6 +6,7 @@ from flytekit.annotated import interface as flyte_interface from flytekit.annotated.context_manager import FlyteContext from flytekit.annotated.promise import Promise, binding_from_python_std, create_task_output +from flytekit.common import constants as _common_constants from flytekit.common.exceptions import user as _user_exceptions from flytekit.common.nodes import SdkNode from flytekit.common.promise import NodeOutput as _NodeOutput @@ -46,9 +47,11 @@ def get_registerable_entity(self) -> SdkNode: from flytekit.annotated.workflow import Workflow if self._flyte_entity is None: - raise Exception("Node flyte entity none") + raise Exception(f"Node {self.id} has no flyte entity") - sdk_nodes = [n.get_registerable_entity() for n in self._upstream_nodes] + sdk_nodes = [ + n.get_registerable_entity() for n in self._upstream_nodes if n.id != _common_constants.GLOBAL_INPUT_NODE_ID + ] if isinstance(self._flyte_entity, PythonTask): self._sdk_node = SdkNode( @@ -153,7 +156,13 @@ def create_and_link_node( # Detect upstream nodes # These will be our annotated Nodes until we can amend the Promise to use NodeOutputs that reference our Nodes upstream_nodes = list( - set([input_val.ref.sdk_node for input_val in kwargs.values() if isinstance(input_val, Promise)]) + set( + [ + input_val.ref.sdk_node + for input_val in kwargs.values() + if isinstance(input_val, Promise) and input_val.ref.node_id != _common_constants.GLOBAL_INPUT_NODE_ID + ] + ) ) node_metadata = _workflow_model.NodeMetadata( diff --git a/flytekit/annotated/promise.py b/flytekit/annotated/promise.py index b87d4fbb7d..fa01f5980a 100644 --- a/flytekit/annotated/promise.py +++ b/flytekit/annotated/promise.py @@ -389,12 +389,8 @@ def binding_data_from_python_std( t_value: typing.Any, t_value_type: type, ) -> _literals_models.BindingData: - # This handles the case where the incoming value is a workflow-level input - if isinstance(t_value, _type_models.OutputReference): - return _literals_models.BindingData(promise=t_value) - # This handles the case where the given value is the output of another task - elif isinstance(t_value, Promise): + if isinstance(t_value, Promise): if not t_value.is_ready: return _literals_models.BindingData(promise=t_value.ref) @@ -434,8 +430,6 @@ def binding_data_from_python_std( return _literals_models.BindingData(map=m) # This is the scalar case - e.g. my_task(in1=5) - # Question: Haytham/Ketan - Is it okay for me to rely on the expected idl type, which comes from the task's - # interface, to derive the scalar value? scalar = TypeEngine.to_literal(ctx, t_value, t_value_type, expected_literal_type).scalar return _literals_models.BindingData(scalar=scalar) diff --git a/flytekit/annotated/workflow.py b/flytekit/annotated/workflow.py index 91e7aa2ce6..09483c7278 100644 --- a/flytekit/annotated/workflow.py +++ b/flytekit/annotated/workflow.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import inspect import typing from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -11,18 +13,22 @@ transform_interface_to_typed_interface, transform_signature_to_interface, ) -from flytekit.annotated.node import create_and_link_node +from flytekit.annotated.node import Node, create_and_link_node from flytekit.annotated.promise import Promise, create_task_output from flytekit.annotated.type_engine import TypeEngine from flytekit.common import constants as _common_constants +from flytekit.common.promise import NodeOutput as _NodeOutput from flytekit.common.workflow import SdkWorkflow as _SdkWorkflow from flytekit.loggers import logger from flytekit.models import interface as _interface_models from flytekit.models import literals as _literal_models -from flytekit.models import types as _type_models from flytekit.models.core import identifier as _identifier_model from flytekit.models.core import workflow as _workflow_model +GLOBAL_START_NODE = Node( + id=_common_constants.GLOBAL_INPUT_NODE_ID, metadata=None, bindings=[], upstream_nodes=[], flyte_entity=None, +) + def _workflow_fn_outputs_to_promise( ctx: FlyteContext, @@ -64,6 +70,15 @@ def _workflow_fn_outputs_to_promise( return return_vals +def construct_input_promises(workflow: Workflow, inputs: List[str]): + return { + input_name: Promise( + var=input_name, val=_NodeOutput(sdk_node=GLOBAL_START_NODE, sdk_type=None, var=input_name,), + ) + for input_name in inputs + } + + class Workflow(object): """ When you assign a name to a node. @@ -107,16 +122,6 @@ def short_name(self) -> str: def interface(self) -> _interface_models.TypedInterface: return self._interface - def _construct_input_promises(self) -> Dict[str, _type_models.OutputReference]: - """ - This constructs input promises for all the inputs of the workflow, binding them to the global - input node id which you should think about as the start node. - """ - return { - k: _type_models.OutputReference(_common_constants.GLOBAL_INPUT_NODE_ID, k) - for k in self.interface.inputs.keys() - } - def compile(self, **kwargs): """ Supply static Python native values in the kwargs if you want them to be used in the compilation. This mimics @@ -128,7 +133,7 @@ def compile(self, **kwargs): prefix = f"{ctx.compilation_state.prefix}-{self.short_name}-" if ctx.compilation_state is not None else None with ctx.new_compilation_context(prefix=prefix) as comp_ctx: # Construct the default input promise bindings, but then override with the provided inputs, if any - input_kwargs = self._construct_input_promises() + input_kwargs = construct_input_promises(self, [k for k in self.interface.inputs.keys()]) input_kwargs.update(kwargs) workflow_outputs = self._workflow_function(**input_kwargs) all_nodes.extend(comp_ctx.compilation_state.nodes) @@ -265,7 +270,7 @@ def get_registerable_entity(self) -> _SdkWorkflow: ) # Translate nodes - sdk_nodes = [n.get_registerable_entity() for n in self._nodes] + sdk_nodes = [n.get_registerable_entity() for n in self._nodes if n.id != _common_constants.GLOBAL_INPUT_NODE_ID] self._registerable_entity = _SdkWorkflow( nodes=sdk_nodes, diff --git a/flytekit/common/utils.py b/flytekit/common/utils.py index ac8216d075..a61ee42a37 100644 --- a/flytekit/common/utils.py +++ b/flytekit/common/utils.py @@ -43,7 +43,7 @@ def _dnsify(value): # type: (str) -> str res += "-" res += ch.lower() - if res[-1] == "-": + if len(res) > 0 and res[-1] == "-": res = res[: len(res) - 1] return res