Skip to content

Commit

Permalink
FlyteRemote fetch of conditional nodes (#772)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <wild-endeavor@users.noreply.github.com>
  • Loading branch information
wild-endeavor authored Mar 3, 2022
1 parent 0da523c commit 8a5d848
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 23 deletions.
1 change: 0 additions & 1 deletion flytekit/clis/flyte_cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,6 @@ def _render_schedule_expr(lp):
)
_insecure_option = _click.option(*_INSECURE_FLAGS, is_flag=True, help="Do not use SSL")
_urn_option = _click.option("-u", "--urn", required=True, help="The unique identifier for an entity.")

_optional_urn_option = _click.option("-u", "--urn", required=False, help="The unique identifier for an entity.")

_host_option = _click.option(
Expand Down
2 changes: 1 addition & 1 deletion flytekit/models/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def else_node(self):
def error(self):
"""
An error to throw in case none of the branches were taken.
:rtype: flytekit.models.core.errors.ContainerError
:rtype: flytekit.models.types.Error
"""

return self._error
Expand Down
35 changes: 34 additions & 1 deletion flytekit/remote/component_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def promote_from_model(

if base_model.reference_id in tasks:
task = tasks[base_model.reference_id]
_logging.info(f"Found existing task template for {task.id}, will not retrieve from Admin")
_logging.debug(f"Found existing task template for {task.id}, will not retrieve from Admin")
flyte_task = FlyteTask.promote_from_model(task)
return cls(flyte_task)

Expand Down Expand Up @@ -124,3 +124,36 @@ def promote_from_model(
raise _system_exceptions.FlyteSystemException(
"Bad workflow node model, neither subworkflow nor launchplan specified."
)


class FlyteBranchNode(_workflow_model.BranchNode):
def __init__(self, if_else: _workflow_model.IfElseBlock):
super().__init__(if_else)

@classmethod
def promote_from_model(
cls,
base_model: _workflow_model.BranchNode,
sub_workflows: Dict[id_models.Identifier, _workflow_model.WorkflowTemplate],
node_launch_plans: Dict[id_models.Identifier, _launch_plan_model.LaunchPlanSpec],
tasks: Dict[id_models.Identifier, _task_model.TaskTemplate],
) -> "FlyteBranchNode":

from flytekit.remote.nodes import FlyteNode

block = base_model.if_else

else_node = None
if block.else_node:
else_node = FlyteNode.promote_from_model(block.else_node, sub_workflows, node_launch_plans, tasks)

block.case._then_node = FlyteNode.promote_from_model(
block.case.then_node, sub_workflows, node_launch_plans, tasks
)

for o in block.other:
o._then_node = FlyteNode.promote_from_model(o.then_node, sub_workflows, node_launch_plans, tasks)

new_if_else_block = _workflow_model.IfElseBlock(block.case, block.other, else_node, block.error)

return cls(new_if_else_block)
7 changes: 7 additions & 0 deletions flytekit/remote/launch_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ def __init__(self, id, *args, **kwargs):
def name(self) -> str:
return self._name

# If fetched when creating this object, can store it here.
self._flyte_workflow = None

@property
def flyte_workflow(self) -> Optional["FlyteWorkflow"]:
return self._flyte_workflow

@classmethod
def promote_from_model(
cls, id: id_models.Identifier, model: _launch_plan_models.LaunchPlanSpec
Expand Down
28 changes: 21 additions & 7 deletions flytekit/remote/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,19 @@ def __init__(
flyte_task: Optional["FlyteTask"] = None,
flyte_workflow: Optional["FlyteWorkflow"] = None,
flyte_launch_plan: Optional["FlyteLaunchPlan"] = None,
flyte_branch=None,
flyte_branch_node: Optional["FlyteBranchNode"] = None,
):
non_none_entities = list(filter(None, [flyte_task, flyte_workflow, flyte_launch_plan, flyte_branch]))
# todo: flyte_branch_node is the only non-entity here, feels wrong, it should probably be a Condition
# or the other ones changed.
non_none_entities = list(filter(None, [flyte_task, flyte_workflow, flyte_launch_plan, flyte_branch_node]))
if len(non_none_entities) != 1:
raise _user_exceptions.FlyteAssertion(
"An Flyte node must have one underlying entity specified at once. Received the following "
"entities: {}".format(non_none_entities)
)
self._flyte_entity = flyte_task or flyte_workflow or flyte_launch_plan or flyte_branch
# todo: wip - flyte_branch_node is a hack, it should be a Condition, but backing out a Condition object from
# the compiled IfElseBlock is cumbersome, shouldn't do it if we can get away with it.
self._flyte_entity = flyte_task or flyte_workflow or flyte_launch_plan or flyte_branch_node

workflow_node = None
if flyte_workflow is not None:
Expand All @@ -46,7 +50,6 @@ def __init__(
task_node = None
if flyte_task:
task_node = _component_nodes.FlyteTaskNode(flyte_task)
branch_node = None

super(FlyteNode, self).__init__(
id=id,
Expand All @@ -56,7 +59,7 @@ def __init__(
output_aliases=[],
task_node=task_node,
workflow_node=workflow_node,
branch_node=branch_node,
branch_node=flyte_branch_node,
)
self._upstream = upstream_nodes

Expand All @@ -78,7 +81,7 @@ def promote_from_model(
_logging.warning(f"Should not call promote from model on a start node or end node {model}")
return None

flyte_task_node, flyte_workflow_node = None, None
flyte_task_node, flyte_workflow_node, flyte_branch_node = None, None, None
if model.task_node is not None:
flyte_task_node = _component_nodes.FlyteTaskNode.promote_from_model(model.task_node, tasks)
elif model.workflow_node is not None:
Expand All @@ -88,7 +91,10 @@ def promote_from_model(
node_launch_plans,
tasks,
)
# TODO: Implement branch node https://github.com/flyteorg/flyte/issues/1116
elif model.branch_node is not None:
flyte_branch_node = _component_nodes.FlyteBranchNode.promote_from_model(
model.branch_node, sub_workflows, node_launch_plans, tasks
)
else:
raise _system_exceptions.FlyteSystemException(
f"Bad Node model, neither task nor workflow detected, node: {model}"
Expand Down Expand Up @@ -132,6 +138,14 @@ def promote_from_model(
raise _system_exceptions.FlyteSystemException(
"Bad FlyteWorkflowNode model, both launch plan and workflow are None"
)
elif flyte_branch_node is not None:
return cls(
id=node_model_id,
upstream_nodes=[], # set downstream, model doesn't contain this information
bindings=model.inputs,
metadata=model.metadata,
flyte_branch_node=flyte_branch_node,
)
raise _system_exceptions.FlyteSystemException("Bad FlyteNode model, both task and workflow nodes are empty")

@property
Expand Down
34 changes: 21 additions & 13 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,7 @@ def fetch_launch_plan(
wf_id = flyte_launch_plan.workflow_id
workflow = self.fetch_workflow(wf_id.project, wf_id.domain, wf_id.name, wf_id.version)
flyte_launch_plan._interface = workflow.interface
flyte_launch_plan._flyte_workflow = workflow
flyte_launch_plan.guessed_python_interface = Interface(
inputs=TypeEngine.guess_python_types(flyte_launch_plan.interface.inputs),
outputs=TypeEngine.guess_python_types(flyte_launch_plan.interface.outputs),
Expand Down Expand Up @@ -1053,6 +1054,7 @@ def sync_workflow_execution(
if execution.spec.launch_plan.resource_type == ResourceType.TASK:
# This condition is only true for single-task executions
flyte_entity = self.fetch_task(lp_id.project, lp_id.domain, lp_id.name, lp_id.version)
node_interface = flyte_entity.interface
if sync_nodes:
# Need to construct the mapping. There should've been returned exactly three nodes, a start,
# an end, and a task node.
Expand Down Expand Up @@ -1080,21 +1082,23 @@ def sync_workflow_execution(
)
else:
# This is the default case, an execution of a normal workflow through a launch plan
wf_id = self.fetch_launch_plan(lp_id.project, lp_id.domain, lp_id.name, lp_id.version).workflow_id
flyte_entity = self.fetch_workflow(wf_id.project, wf_id.domain, wf_id.name, wf_id.version)
execution._flyte_workflow = flyte_entity
node_mapping = flyte_entity._node_map
fetched_lp = self.fetch_launch_plan(lp_id.project, lp_id.domain, lp_id.name, lp_id.version)
node_interface = fetched_lp.flyte_workflow.interface
execution._flyte_workflow = fetched_lp.flyte_workflow
node_mapping = fetched_lp.flyte_workflow._node_map

# update node executions (if requested), and inputs/outputs
if sync_nodes:
node_execs = {}
for n in underlying_node_executions:
node_execs[n.id.node_id] = self.sync_node_execution(n, node_mapping)
execution._node_executions = node_execs
return self._assign_inputs_and_outputs(execution, execution_data, flyte_entity.interface)
return self._assign_inputs_and_outputs(execution, execution_data, node_interface)

def sync_node_execution(
self, execution: FlyteNodeExecution, node_mapping: typing.Dict[str, FlyteNode]
self,
execution: FlyteNodeExecution,
node_mapping: typing.Dict[str, FlyteNode],
) -> FlyteNodeExecution:
"""
Get data backing a node execution. These FlyteNodeExecution objects should've come from Admin with the model
Expand Down Expand Up @@ -1201,13 +1205,9 @@ def sync_node_execution(
for t in iterate_task_executions(self.client, execution.id)
]
execution._interface = dynamic_flyte_wf.interface
else:
# If it does not, then it should be a static subworkflow
if not isinstance(execution._node.flyte_entity, FlyteWorkflow):
remote_logger.error(
f"NE {execution} entity should be a workflow, {type(execution._node)}, {execution._node}"
)
raise Exception(f"Node entity has type {type(execution._node)}")

# Handle the case where it's a static subworkflow
elif isinstance(execution._node.flyte_entity, FlyteWorkflow):
sub_flyte_workflow = execution._node.flyte_entity
sub_node_mapping = {n.id: n for n in sub_flyte_workflow.flyte_nodes}
execution._underlying_node_executions = [
Expand All @@ -1216,6 +1216,14 @@ def sync_node_execution(
]
execution._interface = sub_flyte_workflow.interface

# Handle the case where it's a branch node
elif execution._node.branch_node is not None:
remote_logger.debug("Skipping remote node execution for now")
return execution
else:
remote_logger.error(f"NE {execution} undeterminable, {type(execution._node)}, {execution._node}")
raise Exception(f"Node execution undeterminable, entity has type {type(execution._node)}")

# This is the plain ol' task execution case
else:
execution._task_executions = [
Expand Down

0 comments on commit 8a5d848

Please sign in to comment.