Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Annotations wf inputs #245

Merged
merged 10 commits into from
Nov 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions flytekit/annotated/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from flytekit.annotated import interface as flyte_interface
from flytekit.annotated.context_manager import FlyteContext
from flytekit.annotated.promise import Promise, binding_from_python_std, create_task_output
from flytekit.common import constants as _common_constants
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.nodes import SdkNode
from flytekit.common.promise import NodeOutput as _NodeOutput
Expand Down Expand Up @@ -46,9 +47,11 @@ def get_registerable_entity(self) -> SdkNode:
from flytekit.annotated.workflow import Workflow

if self._flyte_entity is None:
raise Exception("Node flyte entity none")
raise Exception(f"Node {self.id} has no flyte entity")

sdk_nodes = [n.get_registerable_entity() for n in self._upstream_nodes]
sdk_nodes = [
n.get_registerable_entity() for n in self._upstream_nodes if n.id != _common_constants.GLOBAL_INPUT_NODE_ID
]

if isinstance(self._flyte_entity, PythonTask):
self._sdk_node = SdkNode(
Expand Down Expand Up @@ -153,7 +156,13 @@ def create_and_link_node(
# Detect upstream nodes
# These will be our annotated Nodes until we can amend the Promise to use NodeOutputs that reference our Nodes
upstream_nodes = list(
set([input_val.ref.sdk_node for input_val in kwargs.values() if isinstance(input_val, Promise)])
set(
[
input_val.ref.sdk_node
for input_val in kwargs.values()
if isinstance(input_val, Promise) and input_val.ref.node_id != _common_constants.GLOBAL_INPUT_NODE_ID
]
)
)

node_metadata = _workflow_model.NodeMetadata(
Expand Down
8 changes: 1 addition & 7 deletions flytekit/annotated/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,12 +389,8 @@ def binding_data_from_python_std(
t_value: typing.Any,
t_value_type: type,
) -> _literals_models.BindingData:
# This handles the case where the incoming value is a workflow-level input
if isinstance(t_value, _type_models.OutputReference):
return _literals_models.BindingData(promise=t_value)

# This handles the case where the given value is the output of another task
elif isinstance(t_value, Promise):
if isinstance(t_value, Promise):
if not t_value.is_ready:
return _literals_models.BindingData(promise=t_value.ref)

Expand Down Expand Up @@ -434,8 +430,6 @@ def binding_data_from_python_std(
return _literals_models.BindingData(map=m)

# This is the scalar case - e.g. my_task(in1=5)
# Question: Haytham/Ketan - Is it okay for me to rely on the expected idl type, which comes from the task's
# interface, to derive the scalar value?
scalar = TypeEngine.to_literal(ctx, t_value, t_value_type, expected_literal_type).scalar
return _literals_models.BindingData(scalar=scalar)

Expand Down
33 changes: 19 additions & 14 deletions flytekit/annotated/workflow.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import inspect
import typing
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
Expand All @@ -11,18 +13,22 @@
transform_interface_to_typed_interface,
transform_signature_to_interface,
)
from flytekit.annotated.node import create_and_link_node
from flytekit.annotated.node import Node, create_and_link_node
from flytekit.annotated.promise import Promise, create_task_output
from flytekit.annotated.type_engine import TypeEngine
from flytekit.common import constants as _common_constants
from flytekit.common.promise import NodeOutput as _NodeOutput
from flytekit.common.workflow import SdkWorkflow as _SdkWorkflow
from flytekit.loggers import logger
from flytekit.models import interface as _interface_models
from flytekit.models import literals as _literal_models
from flytekit.models import types as _type_models
from flytekit.models.core import identifier as _identifier_model
from flytekit.models.core import workflow as _workflow_model

GLOBAL_START_NODE = Node(
id=_common_constants.GLOBAL_INPUT_NODE_ID, metadata=None, bindings=[], upstream_nodes=[], flyte_entity=None,
)


def _workflow_fn_outputs_to_promise(
ctx: FlyteContext,
Expand Down Expand Up @@ -64,6 +70,15 @@ def _workflow_fn_outputs_to_promise(
return return_vals


def construct_input_promises(workflow: Workflow, inputs: List[str]):
return {
input_name: Promise(
var=input_name, val=_NodeOutput(sdk_node=GLOBAL_START_NODE, sdk_type=None, var=input_name,),
)
for input_name in inputs
}


class Workflow(object):
"""
When you assign a name to a node.
Expand Down Expand Up @@ -107,16 +122,6 @@ def short_name(self) -> str:
def interface(self) -> _interface_models.TypedInterface:
return self._interface

def _construct_input_promises(self) -> Dict[str, _type_models.OutputReference]:
"""
This constructs input promises for all the inputs of the workflow, binding them to the global
input node id which you should think about as the start node.
"""
return {
k: _type_models.OutputReference(_common_constants.GLOBAL_INPUT_NODE_ID, k)
for k in self.interface.inputs.keys()
}

def compile(self, **kwargs):
"""
Supply static Python native values in the kwargs if you want them to be used in the compilation. This mimics
Expand All @@ -128,7 +133,7 @@ def compile(self, **kwargs):
prefix = f"{ctx.compilation_state.prefix}-{self.short_name}-" if ctx.compilation_state is not None else None
with ctx.new_compilation_context(prefix=prefix) as comp_ctx:
# Construct the default input promise bindings, but then override with the provided inputs, if any
input_kwargs = self._construct_input_promises()
input_kwargs = construct_input_promises(self, [k for k in self.interface.inputs.keys()])
input_kwargs.update(kwargs)
workflow_outputs = self._workflow_function(**input_kwargs)
all_nodes.extend(comp_ctx.compilation_state.nodes)
Expand Down Expand Up @@ -265,7 +270,7 @@ def get_registerable_entity(self) -> _SdkWorkflow:
)

# Translate nodes
sdk_nodes = [n.get_registerable_entity() for n in self._nodes]
sdk_nodes = [n.get_registerable_entity() for n in self._nodes if n.id != _common_constants.GLOBAL_INPUT_NODE_ID]

self._registerable_entity = _SdkWorkflow(
nodes=sdk_nodes,
Expand Down
2 changes: 1 addition & 1 deletion flytekit/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _dnsify(value): # type: (str) -> str
res += "-"
res += ch.lower()

if res[-1] == "-":
if len(res) > 0 and res[-1] == "-":
res = res[: len(res) - 1]

return res
Expand Down