Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Annotations separate out serializing #326

Merged
merged 25 commits into from
Jan 15, 2021
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 7 additions & 48 deletions flytekit/annotated/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@
RegistrationSettings,
)
from flytekit.annotated.interface import Interface, transform_interface_to_typed_interface
from flytekit.annotated.node import create_and_link_node
from flytekit.annotated.promise import Promise, VoidPromise, create_task_output, translate_inputs_to_literals
from flytekit.annotated.promise import (
Promise,
VoidPromise,
create_and_link_node,
create_task_output,
translate_inputs_to_literals,
)
from flytekit.annotated.type_engine import TypeEngine
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.tasks.sdk_runnable import ExecutionParameters
from flytekit.common.tasks.task import SdkTask
from flytekit.loggers import logger
from flytekit.models import dynamic_job as _dynamic_job
from flytekit.models import interface as _interface_models
Expand Down Expand Up @@ -113,10 +117,6 @@ def __init__(
self._interface = interface
self._metadata = metadata if metadata else TaskMetadata()

# This will get populated only at registration time, when we retrieve the rest of the environment variables like
# project/domain/version/image and anything else we might need from the environment in the future.
self._registerable_entity: Optional[SdkTask] = None

FlyteEntities.entities.append(self)

@property
Expand Down Expand Up @@ -239,22 +239,6 @@ def __call__(self, *args, **kwargs):
def compile(self, ctx: FlyteContext, *args, **kwargs):
raise Exception("not implemented")

def get_task_structure(self) -> SdkTask:
settings = FlyteContext.current_context().registration_settings
tk = SdkTask(
type=self.task_type,
metadata=self.metadata.to_taskmetadata_model(),
interface=self.interface,
custom=self.get_custom(settings),
container=self.get_container(settings),
)
# Reset just to make sure it's what we give it
tk.id._project = settings.project
tk.id._domain = settings.domain
tk.id._name = self.name
tk.id._version = settings.version
return tk

def get_container(self, settings: RegistrationSettings) -> _task_model.Container:
return None

Expand Down Expand Up @@ -434,31 +418,6 @@ def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any:
"""
return rval

def get_registerable_entity(self) -> SdkTask:
if self._registerable_entity is not None:
return self._registerable_entity
self._registerable_entity = self.get_task_structure()
return self._registerable_entity

def get_fast_registerable_entity(self) -> SdkTask:
entity = self.get_registerable_entity()
if entity.container is None:
# Containerless tasks are always fast registerable without modification
return entity

args = [
"pyflyte-fast-execute",
"--additional-distribution",
"{{ .remote_package_path }}",
"--dest-dir",
"{{ .dest_dir }}",
"--",
] + entity.container.args[:]

del entity._container.args[:]
entity._container.args.extend(args)
return entity

@property
def environment(self) -> Dict[str, str]:
return self._environment
44 changes: 5 additions & 39 deletions flytekit/annotated/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,17 @@
ComparisonOps,
ConjunctionExpression,
ConjunctionOps,
NodeOutput,
Promise,
VoidPromise,
create_task_output,
)
from flytekit.common.promise import NodeOutput
from flytekit.models.core import condition as _core_cond
from flytekit.models.core import workflow as _core_wf
from flytekit.models.literals import Binding, BindingData, Literal, RetryStrategy
from flytekit.models.types import Error


def to_registrable_case(c: _core_wf.IfBlock) -> _core_wf.IfBlock:
if c is None:
raise ValueError("Cannot convert none cases to registrable")
return _core_wf.IfBlock(condition=c.condition, then_node=c.then_node.get_registerable_entity())


def to_registrable_cases(cases: typing.List[_core_wf.IfBlock]) -> Optional[typing.List[_core_wf.IfBlock]]:
if cases is None:
return None
ret_cases = []
for c in cases:
ret_cases.append(to_registrable_case(c))
return ret_cases


class BranchNode(object):
def __init__(self, name: str, ifelse_block: _core_wf.IfElseBlock):
self._name = name
Expand All @@ -47,25 +32,6 @@ def __init__(self, name: str, ifelse_block: _core_wf.IfElseBlock):
def name(self):
return self._name

def get_branch_node(self) -> _core_wf.BranchNode:
# We have to iterate through the blocks to convert the nodes from their current type to SDKNode
# TODO this should be cleaned up instead of mutation, we probaby should just create a new object
first = to_registrable_case(self._ifelse_block.case)
other = to_registrable_cases(self._ifelse_block.other)
else_node = None
if self._ifelse_block.else_node:
else_node = self._ifelse_block.else_node.get_registerable_entity()

return _core_wf.BranchNode(
if_else=_core_wf.IfElseBlock(case=first, other=other, else_node=else_node, error=self._ifelse_block.error)
)

def get_registerable_entity(self) -> _core_wf.BranchNode:
if self._registerable_entity is not None:
return self._registerable_entity
self._registerable_entity = self.get_branch_node()
return self._registerable_entity


class ConditionalSection(object):
def __init__(self, name: str):
Expand Down Expand Up @@ -144,7 +110,7 @@ def end_branch(self) -> Union[Condition, Promise]:
for p in promises:
if not p.is_ready:
bindings.append(Binding(var=p.var, binding=BindingData(promise=p.ref)))
upstream_nodes.add(p.ref.sdk_node)
upstream_nodes.add(p.ref.node)

n = Node(
id=f"{ctx.compilation_state.prefix}node-{len(ctx.compilation_state.nodes)}",
Expand Down Expand Up @@ -175,7 +141,7 @@ def _compute_outputs(self, n: Node) -> Union[Promise, Tuple[Promise], VoidPromis
if len(output_var_sets) > 1:
for x in output_var_sets[1:]:
curr = curr.intersection(x)
promises = [Promise(var=x, val=NodeOutput(sdk_node=n, sdk_type=None, var=x)) for x in curr]
promises = [Promise(var=x, val=NodeOutput(node=n, sdk_type=None, var=x)) for x in curr]
wild-endeavor marked this conversation as resolved.
Show resolved Hide resolved
# TODO: Is there a way to add the Python interface here? Currently, it's an optional arg.
return create_task_output(promises)

Expand Down Expand Up @@ -341,7 +307,7 @@ def transform_to_boolexpr(

def to_case_block(c: Case) -> (Union[_core_wf.IfBlock], typing.List[Promise]):
expr, promises = transform_to_boolexpr(c.expr)
n = c.output_promise.ref.sdk_node
n = c.output_promise.ref.node
return _core_wf.IfBlock(condition=expr, then_node=n), promises


Expand All @@ -364,7 +330,7 @@ def to_ifelse_block(node_id: str, cs: ConditionalSection) -> (_core_wf.IfElseBlo
node = None
err = None
if last_case.output_promise is not None:
node = last_case.output_promise.ref.sdk_node
node = last_case.output_promise.ref.node
else:
err = Error(failed_node_id=node_id, message=last_case.err if last_case.err else "Condition failed")
return (
Expand Down
6 changes: 1 addition & 5 deletions flytekit/annotated/container_task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Type

from flytekit.annotated.base_task import PythonTask, SdkTask, TaskMetadata
from flytekit.annotated.base_task import PythonTask, TaskMetadata
from flytekit.annotated.context_manager import RegistrationSettings
from flytekit.annotated.interface import Interface
from flytekit.common.tasks.raw_container import _get_container_definition
Expand Down Expand Up @@ -76,7 +76,3 @@ def get_container(self, settings: RegistrationSettings) -> _task_model.Container
),
environment=env,
)

def get_fast_registerable_entity(self) -> SdkTask:
# Raw container tasks are always fast registerable as-is.
return self.get_registerable_entity()
3 changes: 2 additions & 1 deletion flytekit/annotated/dynamic_workflow_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from flytekit.annotated.python_function_task import PythonFunctionTask
from flytekit.annotated.task import TaskPlugins
from flytekit.annotated.workflow import Workflow, WorkflowFailurePolicy, WorkflowMetadata, WorkflowMetadataDefaults
from flytekit.common.translator import get_serializable
from flytekit.loggers import logger
from flytekit.models import dynamic_job as _dynamic_job
from flytekit.models import literals as _literal_models
Expand Down Expand Up @@ -45,7 +46,7 @@ def compile_into_workflow(
self._wf.compile(**kwargs)

wf = self._wf
sdk_workflow = wf.get_registerable_entity()
sdk_workflow = get_serializable(ctx.registration_settings, wf)

# If no nodes were produced, let's just return the strict outputs
if len(sdk_workflow.nodes) == 0:
Expand Down
73 changes: 1 addition & 72 deletions flytekit/annotated/launch_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,12 @@
from flytekit.annotated import workflow as _annotated_workflow
from flytekit.annotated.context_manager import FlyteContext, FlyteEntities
from flytekit.annotated.interface import Interface, transform_inputs_to_parameters
from flytekit.annotated.node import create_and_link_node
from flytekit.annotated.promise import translate_inputs_to_literals
from flytekit.annotated.promise import create_and_link_node, translate_inputs_to_literals
from flytekit.annotated.reference_entity import LaunchPlanReference, ReferenceEntity
from flytekit.common.launch_plan import SdkLaunchPlan
from flytekit.models import common as _common_models
from flytekit.models import interface as _interface_models
from flytekit.models import launch_plan as _launch_plan_models
from flytekit.models import literals as _literal_models
from flytekit.models import schedule as _schedule_model
from flytekit.models.common import RawOutputDataConfig
from flytekit.models.core import identifier as _identifier_model


class LaunchPlan(object):
Expand Down Expand Up @@ -130,9 +125,6 @@ def __init__(
self._raw_output_data_config = raw_output_data_config
self._auth_role = auth_role

# This will eventually hold the registerable launch plan
self._registerable_entity: Optional[SdkLaunchPlan] = None

FlyteEntities.entities.append(self)

@property
Expand Down Expand Up @@ -199,43 +191,6 @@ def __call__(self, *args, **kwargs):
inputs.update(kwargs)
return self.workflow(*args, **inputs)

def get_registerable_entity(self) -> SdkLaunchPlan:
settings = FlyteContext.current_context().registration_settings
if self._registerable_entity is not None:
return self._registerable_entity

if self._auth_role:
auth_role = self._auth_role
else:
auth_role = None

sdk_workflow = self.workflow.get_registerable_entity()
self._registerable_entity = SdkLaunchPlan(
workflow_id=sdk_workflow.id,
entity_metadata=_launch_plan_models.LaunchPlanMetadata(
schedule=self.schedule, notifications=self.notifications,
),
default_inputs=self.parameters,
fixed_inputs=self.fixed_inputs,
labels=self.labels or _common_models.Labels({}),
annotations=self.annotations or _common_models.Annotations({}),
auth_role=auth_role, # TODO: Is None here okay?
raw_output_data_config=self.raw_output_data_config,
)

# These two things are normally set to None in the SdkLaunchPlan constructor and filled in by
# SdkRunnableLaunchPlan/the registration process, so we need to set them manually. The reason is because these
# fields are not part of the underlying LaunchPlanSpec
self._registerable_entity._interface = sdk_workflow.interface
self._registerable_entity._id = _identifier_model.Identifier(
resource_type=_identifier_model.ResourceType.LAUNCH_PLAN,
project=settings.project,
domain=settings.domain,
name=self.name,
version=settings.version,
)
return self._registerable_entity


class ReferenceLaunchPlan(ReferenceEntity, LaunchPlan):
"""
Expand All @@ -248,29 +203,3 @@ def __init__(
self, project: str, domain: str, name: str, version: str, inputs: Dict[str, Type], outputs: Dict[str, Type]
):
super().__init__(LaunchPlanReference(project, domain, name, version), inputs, outputs)

def get_registerable_entity(self) -> SdkLaunchPlan:
from flytekit.common.interface import TypedInterface

wf_id = _identifier_model.Identifier(_identifier_model.ResourceType.WORKFLOW, "", "", "", "")
sdk_lp = SdkLaunchPlan(
workflow_id=wf_id,
entity_metadata=_launch_plan_models.LaunchPlanMetadata(schedule=None, notifications=[]),
default_inputs=_interface_models.ParameterMap({}),
fixed_inputs=_literal_models.LiteralMap({}),
labels=_common_models.Labels({}),
annotations=_common_models.Annotations({}),
auth_role=_common_models.AuthRole(assumable_iam_role="fake:role"),
raw_output_data_config=RawOutputDataConfig(""),
)
# Because of how SdkNodes work, it needs one of these interfaces
# Hopefully this is more trickery that can be cleaned up in the future
sdk_lp._interface = TypedInterface.promote_from_model(self.typed_interface)
sdk_lp._id = self.id

# Make sure we don't serialize this
sdk_lp._has_registered = True
sdk_lp.assign_name(self.reference.id.name)
self._registerable_entity = sdk_lp

return self._registerable_entity
Loading