Skip to content

Commit

Permalink
Annotations separate out serializing (#326)
Browse files Browse the repository at this point in the history
  • Loading branch information
wild-endeavor authored Jan 15, 2021
1 parent 531c059 commit 452f682
Show file tree
Hide file tree
Showing 22 changed files with 781 additions and 646 deletions.
55 changes: 7 additions & 48 deletions flytekit/annotated/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@
RegistrationSettings,
)
from flytekit.annotated.interface import Interface, transform_interface_to_typed_interface
from flytekit.annotated.node import create_and_link_node
from flytekit.annotated.promise import Promise, VoidPromise, create_task_output, translate_inputs_to_literals
from flytekit.annotated.promise import (
Promise,
VoidPromise,
create_and_link_node,
create_task_output,
translate_inputs_to_literals,
)
from flytekit.annotated.type_engine import TypeEngine
from flytekit.common.exceptions import user as _user_exceptions
from flytekit.common.tasks.sdk_runnable import ExecutionParameters
from flytekit.common.tasks.task import SdkTask
from flytekit.loggers import logger
from flytekit.models import dynamic_job as _dynamic_job
from flytekit.models import interface as _interface_models
Expand Down Expand Up @@ -113,10 +117,6 @@ def __init__(
self._interface = interface
self._metadata = metadata if metadata else TaskMetadata()

# This will get populated only at registration time, when we retrieve the rest of the environment variables like
# project/domain/version/image and anything else we might need from the environment in the future.
self._registerable_entity: Optional[SdkTask] = None

FlyteEntities.entities.append(self)

@property
Expand Down Expand Up @@ -239,22 +239,6 @@ def __call__(self, *args, **kwargs):
def compile(self, ctx: FlyteContext, *args, **kwargs):
raise Exception("not implemented")

def get_task_structure(self) -> SdkTask:
settings = FlyteContext.current_context().registration_settings
tk = SdkTask(
type=self.task_type,
metadata=self.metadata.to_taskmetadata_model(),
interface=self.interface,
custom=self.get_custom(settings),
container=self.get_container(settings),
)
# Reset just to make sure it's what we give it
tk.id._project = settings.project
tk.id._domain = settings.domain
tk.id._name = self.name
tk.id._version = settings.version
return tk

def get_container(self, settings: RegistrationSettings) -> _task_model.Container:
return None

Expand Down Expand Up @@ -434,31 +418,6 @@ def post_execute(self, user_params: ExecutionParameters, rval: Any) -> Any:
"""
return rval

def get_registerable_entity(self) -> SdkTask:
if self._registerable_entity is not None:
return self._registerable_entity
self._registerable_entity = self.get_task_structure()
return self._registerable_entity

def get_fast_registerable_entity(self) -> SdkTask:
entity = self.get_registerable_entity()
if entity.container is None:
# Containerless tasks are always fast registerable without modification
return entity

args = [
"pyflyte-fast-execute",
"--additional-distribution",
"{{ .remote_package_path }}",
"--dest-dir",
"{{ .dest_dir }}",
"--",
] + entity.container.args[:]

del entity._container.args[:]
entity._container.args.extend(args)
return entity

@property
def environment(self) -> Dict[str, str]:
return self._environment
44 changes: 5 additions & 39 deletions flytekit/annotated/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,17 @@
ComparisonOps,
ConjunctionExpression,
ConjunctionOps,
NodeOutput,
Promise,
VoidPromise,
create_task_output,
)
from flytekit.common.promise import NodeOutput
from flytekit.models.core import condition as _core_cond
from flytekit.models.core import workflow as _core_wf
from flytekit.models.literals import Binding, BindingData, Literal, RetryStrategy
from flytekit.models.types import Error


def to_registrable_case(c: _core_wf.IfBlock) -> _core_wf.IfBlock:
if c is None:
raise ValueError("Cannot convert none cases to registrable")
return _core_wf.IfBlock(condition=c.condition, then_node=c.then_node.get_registerable_entity())


def to_registrable_cases(cases: typing.List[_core_wf.IfBlock]) -> Optional[typing.List[_core_wf.IfBlock]]:
if cases is None:
return None
ret_cases = []
for c in cases:
ret_cases.append(to_registrable_case(c))
return ret_cases


class BranchNode(object):
def __init__(self, name: str, ifelse_block: _core_wf.IfElseBlock):
self._name = name
Expand All @@ -47,25 +32,6 @@ def __init__(self, name: str, ifelse_block: _core_wf.IfElseBlock):
def name(self):
return self._name

def get_branch_node(self) -> _core_wf.BranchNode:
# We have to iterate through the blocks to convert the nodes from their current type to SDKNode
# TODO this should be cleaned up instead of mutation, we probaby should just create a new object
first = to_registrable_case(self._ifelse_block.case)
other = to_registrable_cases(self._ifelse_block.other)
else_node = None
if self._ifelse_block.else_node:
else_node = self._ifelse_block.else_node.get_registerable_entity()

return _core_wf.BranchNode(
if_else=_core_wf.IfElseBlock(case=first, other=other, else_node=else_node, error=self._ifelse_block.error)
)

def get_registerable_entity(self) -> _core_wf.BranchNode:
if self._registerable_entity is not None:
return self._registerable_entity
self._registerable_entity = self.get_branch_node()
return self._registerable_entity


class ConditionalSection(object):
def __init__(self, name: str):
Expand Down Expand Up @@ -144,7 +110,7 @@ def end_branch(self) -> Union[Condition, Promise]:
for p in promises:
if not p.is_ready:
bindings.append(Binding(var=p.var, binding=BindingData(promise=p.ref)))
upstream_nodes.add(p.ref.sdk_node)
upstream_nodes.add(p.ref.node)

n = Node(
id=f"{ctx.compilation_state.prefix}node-{len(ctx.compilation_state.nodes)}",
Expand Down Expand Up @@ -175,7 +141,7 @@ def _compute_outputs(self, n: Node) -> Union[Promise, Tuple[Promise], VoidPromis
if len(output_var_sets) > 1:
for x in output_var_sets[1:]:
curr = curr.intersection(x)
promises = [Promise(var=x, val=NodeOutput(sdk_node=n, sdk_type=None, var=x)) for x in curr]
promises = [Promise(var=x, val=NodeOutput(node=n, var=x)) for x in curr]
# TODO: Is there a way to add the Python interface here? Currently, it's an optional arg.
return create_task_output(promises)

Expand Down Expand Up @@ -341,7 +307,7 @@ def transform_to_boolexpr(

def to_case_block(c: Case) -> (Union[_core_wf.IfBlock], typing.List[Promise]):
expr, promises = transform_to_boolexpr(c.expr)
n = c.output_promise.ref.sdk_node
n = c.output_promise.ref.node
return _core_wf.IfBlock(condition=expr, then_node=n), promises


Expand All @@ -364,7 +330,7 @@ def to_ifelse_block(node_id: str, cs: ConditionalSection) -> (_core_wf.IfElseBlo
node = None
err = None
if last_case.output_promise is not None:
node = last_case.output_promise.ref.sdk_node
node = last_case.output_promise.ref.node
else:
err = Error(failed_node_id=node_id, message=last_case.err if last_case.err else "Condition failed")
return (
Expand Down
6 changes: 1 addition & 5 deletions flytekit/annotated/container_task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Type

from flytekit.annotated.base_task import PythonTask, SdkTask, TaskMetadata
from flytekit.annotated.base_task import PythonTask, TaskMetadata
from flytekit.annotated.context_manager import RegistrationSettings
from flytekit.annotated.interface import Interface
from flytekit.common.tasks.raw_container import _get_container_definition
Expand Down Expand Up @@ -76,7 +76,3 @@ def get_container(self, settings: RegistrationSettings) -> _task_model.Container
),
environment=env,
)

def get_fast_registerable_entity(self) -> SdkTask:
# Raw container tasks are always fast registerable as-is.
return self.get_registerable_entity()
73 changes: 1 addition & 72 deletions flytekit/annotated/launch_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,12 @@
from flytekit.annotated import workflow as _annotated_workflow
from flytekit.annotated.context_manager import FlyteContext, FlyteEntities
from flytekit.annotated.interface import Interface, transform_inputs_to_parameters
from flytekit.annotated.node import create_and_link_node
from flytekit.annotated.promise import translate_inputs_to_literals
from flytekit.annotated.promise import create_and_link_node, translate_inputs_to_literals
from flytekit.annotated.reference_entity import LaunchPlanReference, ReferenceEntity
from flytekit.common.launch_plan import SdkLaunchPlan
from flytekit.models import common as _common_models
from flytekit.models import interface as _interface_models
from flytekit.models import launch_plan as _launch_plan_models
from flytekit.models import literals as _literal_models
from flytekit.models import schedule as _schedule_model
from flytekit.models.common import RawOutputDataConfig
from flytekit.models.core import identifier as _identifier_model


class LaunchPlan(object):
Expand Down Expand Up @@ -130,9 +125,6 @@ def __init__(
self._raw_output_data_config = raw_output_data_config
self._auth_role = auth_role

# This will eventually hold the registerable launch plan
self._registerable_entity: Optional[SdkLaunchPlan] = None

FlyteEntities.entities.append(self)

@property
Expand Down Expand Up @@ -199,43 +191,6 @@ def __call__(self, *args, **kwargs):
inputs.update(kwargs)
return self.workflow(*args, **inputs)

def get_registerable_entity(self) -> SdkLaunchPlan:
settings = FlyteContext.current_context().registration_settings
if self._registerable_entity is not None:
return self._registerable_entity

if self._auth_role:
auth_role = self._auth_role
else:
auth_role = None

sdk_workflow = self.workflow.get_registerable_entity()
self._registerable_entity = SdkLaunchPlan(
workflow_id=sdk_workflow.id,
entity_metadata=_launch_plan_models.LaunchPlanMetadata(
schedule=self.schedule, notifications=self.notifications,
),
default_inputs=self.parameters,
fixed_inputs=self.fixed_inputs,
labels=self.labels or _common_models.Labels({}),
annotations=self.annotations or _common_models.Annotations({}),
auth_role=auth_role, # TODO: Is None here okay?
raw_output_data_config=self.raw_output_data_config,
)

# These two things are normally set to None in the SdkLaunchPlan constructor and filled in by
# SdkRunnableLaunchPlan/the registration process, so we need to set them manually. The reason is because these
# fields are not part of the underlying LaunchPlanSpec
self._registerable_entity._interface = sdk_workflow.interface
self._registerable_entity._id = _identifier_model.Identifier(
resource_type=_identifier_model.ResourceType.LAUNCH_PLAN,
project=settings.project,
domain=settings.domain,
name=self.name,
version=settings.version,
)
return self._registerable_entity


class ReferenceLaunchPlan(ReferenceEntity, LaunchPlan):
"""
Expand All @@ -248,29 +203,3 @@ def __init__(
self, project: str, domain: str, name: str, version: str, inputs: Dict[str, Type], outputs: Dict[str, Type]
):
super().__init__(LaunchPlanReference(project, domain, name, version), inputs, outputs)

def get_registerable_entity(self) -> SdkLaunchPlan:
from flytekit.common.interface import TypedInterface

wf_id = _identifier_model.Identifier(_identifier_model.ResourceType.WORKFLOW, "", "", "", "")
sdk_lp = SdkLaunchPlan(
workflow_id=wf_id,
entity_metadata=_launch_plan_models.LaunchPlanMetadata(schedule=None, notifications=[]),
default_inputs=_interface_models.ParameterMap({}),
fixed_inputs=_literal_models.LiteralMap({}),
labels=_common_models.Labels({}),
annotations=_common_models.Annotations({}),
auth_role=_common_models.AuthRole(assumable_iam_role="fake:role"),
raw_output_data_config=RawOutputDataConfig(""),
)
# Because of how SdkNodes work, it needs one of these interfaces
# Hopefully this is more trickery that can be cleaned up in the future
sdk_lp._interface = TypedInterface.promote_from_model(self.typed_interface)
sdk_lp._id = self.id

# Make sure we don't serialize this
sdk_lp._has_registered = True
sdk_lp.assign_name(self.reference.id.name)
self._registerable_entity = sdk_lp

return self._registerable_entity
Loading

0 comments on commit 452f682

Please sign in to comment.