Skip to content

Commit

Permalink
Hydrate subworkflow template identifier (#315)
Browse files Browse the repository at this point in the history
  • Loading branch information
katrogan authored Jan 6, 2021
1 parent 7349158 commit 6658422
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 6 deletions.
7 changes: 4 additions & 3 deletions flytekit/clis/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def _hydrate_node(project: str, domain: str, version: str, node: _workflow_pb2.N
return node


def _hydrate_workflow_template(
def _hydrate_workflow_template_nodes(
project: str, domain: str, version: str, template: _workflow_pb2.WorkflowTemplate
) -> _workflow_pb2.WorkflowTemplate:
refreshed_nodes = []
Expand Down Expand Up @@ -169,10 +169,11 @@ def hydrate_registration_parameters(
# Workflow nodes that are defined inline with the workflows will be missing project/domain/version so we fill those
# in now.
# (entity is of type flyteidl.admin.workflow_pb2.WorkflowSpec)
entity.template.CopyFrom(_hydrate_workflow_template(project, domain, version, entity.template))
entity.template.CopyFrom(_hydrate_workflow_template_nodes(project, domain, version, entity.template))
refreshed_sub_workflows = []
for sub_workflow in entity.sub_workflows:
refreshed_sub_workflow = _hydrate_workflow_template(project, domain, version, sub_workflow)
refreshed_sub_workflow = _hydrate_workflow_template_nodes(project, domain, version, sub_workflow)
refreshed_sub_workflow.id.CopyFrom(_hydrate_identifier(project, domain, version, refreshed_sub_workflow.id))
refreshed_sub_workflows.append(refreshed_sub_workflow)
# Reassign subworkflows with the newly hydrated ones.
del entity.sub_workflows[:]
Expand Down
45 changes: 42 additions & 3 deletions tests/flytekit/unit/cli/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from flyteidl.core import workflow_pb2 as _core_workflow_pb2

from flytekit.clis import helpers
from flytekit.clis.helpers import _hydrate_identifier, _hydrate_workflow_template, hydrate_registration_parameters
from flytekit.clis.helpers import _hydrate_identifier, _hydrate_workflow_template_nodes, hydrate_registration_parameters
from flytekit.models import literals, types
from flytekit.models.interface import Parameter, ParameterMap, Variable

Expand Down Expand Up @@ -121,7 +121,7 @@ def test_hydrate_workflow_template():
),
)
)
hydrated_workflow_template = _hydrate_workflow_template("project", "domain", "12345", workflow_template)
hydrated_workflow_template = _hydrate_workflow_template_nodes("project", "domain", "12345", workflow_template)
assert len(hydrated_workflow_template.nodes) == 4
task_node_identifier = hydrated_workflow_template.nodes[0].task_node.reference_id
assert task_node_identifier.project == "project"
Expand Down Expand Up @@ -185,7 +185,7 @@ def test_hydrate_workflow_template__branch_node():
]
)
workflow_template.nodes.append(branch_node)
hydrated_workflow_template = _hydrate_workflow_template("project", "domain", "12345", workflow_template)
hydrated_workflow_template = _hydrate_workflow_template_nodes("project", "domain", "12345", workflow_template)
if_case_id = hydrated_workflow_template.nodes[0].branch_node.if_else.case.then_node.task_node.reference_id
assert if_case_id.project == "project"
assert if_case_id.domain == "domain"
Expand Down Expand Up @@ -353,3 +353,42 @@ def test_hydrate_registration_parameters__workflow_nothing_set():
assert workflow.template.nodes[0].task_node.reference_id == _identifier_pb2.Identifier(
resource_type=_identifier_pb2.TASK, project="project", domain="domain", name="task1", version="12345",
)


def test_hydrate_registration_parameters__subworkflows():
workflow_template = _core_workflow_pb2.WorkflowTemplate()
workflow_template.id.CopyFrom(_identifier_pb2.Identifier(resource_type=_identifier_pb2.WORKFLOW, name="workflow"))

sub_workflow_template = _core_workflow_pb2.WorkflowTemplate()
sub_workflow_template.id.CopyFrom(
_identifier_pb2.Identifier(resource_type=_identifier_pb2.WORKFLOW, name="subworkflow")
)
sub_workflow_template.nodes.append(
_core_workflow_pb2.Node(
id="task_node",
task_node=_core_workflow_pb2.TaskNode(
reference_id=_identifier_pb2.Identifier(resource_type=_identifier_pb2.TASK)
),
)
)
workflow_spec = _workflow_pb2.WorkflowSpec(template=workflow_template)
workflow_spec.sub_workflows.append(sub_workflow_template)

identifier, entity = hydrate_registration_parameters(
workflow_template.id, "project", "domain", "12345", workflow_spec
)
assert (
identifier
== _identifier_pb2.Identifier(
resource_type=_identifier_pb2.WORKFLOW,
project="project",
domain="domain",
name="workflow",
version="12345",
)
== entity.template.id
)

assert entity.sub_workflows[0].id == _identifier_pb2.Identifier(
resource_type=_identifier_pb2.WORKFLOW, project="project", domain="domain", name="subworkflow", version="12345",
)

0 comments on commit 6658422

Please sign in to comment.