Skip to content

Commit

Permalink
Task metadata (#300)
Browse files Browse the repository at this point in the history
* Switching from using model Metadata -> TaskMetadata (#298)

* Switching from using model Metadata -> TaskMetadata

  TaskMetadata will be maintained as a shadow and allows decoupling of
protocol buffer types from contributor code and user code. This allows
more flexiblity

* addressed comments

* unit test fix

* Formatting fixed
  • Loading branch information
kumare3 authored Dec 30, 2020
1 parent 4468f08 commit f507049
Show file tree
Hide file tree
Showing 24 changed files with 194 additions and 136 deletions.
4 changes: 2 additions & 2 deletions flytekit/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import flytekit.plugins # noqa: F401
from flytekit.annotated.base_sql_task import SQLTask
from flytekit.annotated.base_task import kwtypes
from flytekit.annotated.base_task import TaskMetadata, kwtypes
from flytekit.annotated.container_task import ContainerTask
from flytekit.annotated.context_manager import FlyteContext
from flytekit.annotated.dynamic_workflow_task import dynamic
Expand All @@ -9,7 +9,7 @@
from flytekit.annotated.reference import get_reference_entity
from flytekit.annotated.reference_entity import TaskReference, WorkflowReference
from flytekit.annotated.reference_task import reference_task
from flytekit.annotated.task import metadata, task
from flytekit.annotated.task import task
from flytekit.annotated.workflow import workflow
from flytekit.loggers import logger

Expand Down
13 changes: 6 additions & 7 deletions flytekit/annotated/base_sql_task.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import re
from typing import Any, Dict, Type, TypeVar
from typing import Any, Dict, Optional, Type, TypeVar

from flytekit.annotated.base_task import PythonTask
from flytekit.annotated.base_task import PythonTask, TaskMetadata
from flytekit.annotated.interface import Interface
from flytekit.models import task as _task_model

T = TypeVar("T")

Expand All @@ -18,19 +17,19 @@ class SQLTask(PythonTask[T]):
def __init__(
self,
name: str,
metadata: _task_model.TaskMetadata,
query_template: str,
inputs: Dict[str, Type],
task_type="sql_task",
inputs: Optional[Dict[str, Type]] = None,
metadata: Optional[TaskMetadata] = None,
task_config: T = None,
outputs: Dict[str, Type] = None,
task_type="sql_task",
*args,
**kwargs,
):
super().__init__(
task_type=task_type,
name=name,
interface=Interface(inputs=inputs, outputs=outputs or {}),
interface=Interface(inputs=inputs or {}, outputs=outputs or {}),
metadata=metadata,
task_config=task_config,
*args,
Expand Down
76 changes: 67 additions & 9 deletions flytekit/annotated/base_task.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import collections
import datetime
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generic, Optional, Tuple, Type, TypeVar, Union

from flytekit.annotated.context_manager import (
Expand Down Expand Up @@ -33,6 +35,60 @@ def kwtypes(**kwargs) -> Dict[str, Type]:
return d


@dataclass
class TaskMetadata(object):
"""
Create Metadata to be associated with this Task
Args:
cache: Boolean that indicates if caching should be enabled
cache_version: Version string to be used for the cached value
interruptable: Boolean that indicates that this task can be interrupted and/or scheduled on nodes
with lower QoS guarantees. This will directly reduce the `$`/`execution cost` associated,
at the cost of performance penalties due to potential interruptions
deprecated: A string that can be used to provide a warning message for deprecated task. Absence / empty str
indicates that the task is active and not deprecated
retries: for retries=n; n > 0, on failures of this task, the task will be retried at-least n number of times.
timeout: the max amount of time for which one execution of this task should be executed for. If the execution
will be terminated if the runtime exceeds the given timeout (approximately)
"""

cache: bool = False
cache_version: str = ""
interruptable: bool = False
deprecated: str = ""
retries: int = 0
timeout: Optional[Union[datetime.timedelta, int]] = None

def __post_init__(self):
if self.timeout:
if isinstance(self.timeout, int):
self.timeout = datetime.timedelta(seconds=self.timeout)
elif not isinstance(self.timeout, datetime.timedelta):
raise ValueError("timeout should be duration represented as either a datetime.timedelta or int seconds")
if self.cache and not self.cache_version:
raise ValueError("Caching is enabled ``cache=True`` but ``cache_version`` is not set.")

@property
def retry_strategy(self) -> _literal_models.RetryStrategy:
return _literal_models.RetryStrategy(self.retries)

def to_taskmetadata_model(self) -> _task_model.TaskMetadata:
"""
Converts to _task_model.TaskMetadata
"""
return _task_model.TaskMetadata(
discoverable=self.cache,
# TODO Fix the version circular dependency before beta
runtime=_task_model.RuntimeMetadata(_task_model.RuntimeMetadata.RuntimeType.FLYTE_SDK, "0.16.0", "python"),
timeout=self.timeout,
retries=self.retry_strategy,
interruptible=self.interruptable,
discovery_version=self.cache_version,
deprecated_error_message=self.deprecated,
)


# This is the least abstract task. It will have access to the loaded Python function
# itself if run locally, so it will always be a Python task.
# This is analogous to the current SdkRunnableTask. Need to analyze the benefits of duplicating the class versus
Expand All @@ -48,15 +104,15 @@ def __init__(
self,
task_type: str,
name: str,
interface: _interface_models.TypedInterface,
metadata: _task_model.TaskMetadata,
interface: Optional[_interface_models.TypedInterface] = None,
metadata: Optional[TaskMetadata] = None,
*args,
**kwargs,
):
self._task_type = task_type
self._name = name
self._interface = interface
self._metadata = metadata
self._metadata = metadata if metadata else TaskMetadata()

# This will get populated only at registration time, when we retrieve the rest of the environment variables like
# project/domain/version/image and anything else we might need from the environment in the future.
Expand All @@ -69,7 +125,7 @@ def interface(self) -> _interface_models.TypedInterface:
return self._interface

@property
def metadata(self) -> _task_model.TaskMetadata:
def metadata(self) -> TaskMetadata:
return self._metadata

@property
Expand Down Expand Up @@ -175,7 +231,7 @@ def get_task_structure(self) -> SdkTask:
settings = FlyteContext.current_context().registration_settings
tk = SdkTask(
type=self.task_type,
metadata=self.metadata,
metadata=self.metadata.to_taskmetadata_model(),
interface=self.interface,
custom=self.get_custom(settings),
container=self.get_container(settings),
Expand Down Expand Up @@ -233,9 +289,9 @@ def __init__(
self,
task_type: str,
name: str,
interface: Interface,
metadata: _task_model.TaskMetadata,
task_config: T,
interface: Optional[Interface] = None,
metadata: Optional[TaskMetadata] = None,
*args,
**kwargs,
):
Expand Down Expand Up @@ -268,7 +324,7 @@ def compile(self, ctx: FlyteContext, *args, **kwargs):
entity=self,
interface=self.python_interface,
timeout=self.metadata.timeout,
retry_strategy=self.metadata.retries,
retry_strategy=self.metadata.retry_strategy,
**kwargs,
)

Expand All @@ -289,7 +345,9 @@ def dispatch_execute(

# Create another execution context with the new user params, but let's keep the same working dir
with ctx.new_execution_context(
mode=ctx.execution_state.mode, execution_params=new_user_params, working_dir=ctx.execution_state.working_dir
mode=ctx.execution_state.mode,
execution_params=new_user_params,
working_dir=ctx.execution_state.working_dir,
) as exec_ctx:
# TODO We could support default values here too - but not part of the plan right now
# Translate the input literals to Python native
Expand Down
8 changes: 4 additions & 4 deletions flytekit/annotated/container_task.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
from typing import Any, Dict, List, Type
from typing import Any, Dict, List, Optional, Type

from flytekit.annotated.base_task import PythonTask
from flytekit.annotated.base_task import PythonTask, TaskMetadata
from flytekit.annotated.context_manager import RegistrationSettings
from flytekit.annotated.interface import Interface
from flytekit.common.tasks.raw_container import _get_container_definition
Expand All @@ -26,9 +26,9 @@ def __init__(
self,
name: str,
image: str,
metadata: _task_model.TaskMetadata,
inputs: Dict[str, Type],
command: List[str],
inputs: Optional[Dict[str, Type]] = None,
metadata: Optional[TaskMetadata] = None,
arguments: List[str] = None,
outputs: Dict[str, Type] = None,
input_data_dir: str = None,
Expand Down
6 changes: 3 additions & 3 deletions flytekit/annotated/dynamic_workflow_task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
from typing import Any, Callable, Union
from typing import Any, Callable, Optional, Union

from flytekit import TaskMetadata
from flytekit.annotated import task
from flytekit.annotated.context_manager import ExecutionState, FlyteContext
from flytekit.annotated.python_function_task import PythonFunctionTask
Expand All @@ -9,7 +10,6 @@
from flytekit.loggers import logger
from flytekit.models import dynamic_job as _dynamic_job
from flytekit.models import literals as _literal_models
from flytekit.models import task as _task_model


class _Dynamic(object):
Expand All @@ -21,7 +21,7 @@ def __init__(
self,
task_config: _Dynamic,
dynamic_workflow_function: Callable,
metadata: _task_model.TaskMetadata,
metadata: Optional[TaskMetadata] = None,
*args,
**kwargs,
):
Expand Down
4 changes: 3 additions & 1 deletion flytekit/annotated/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ def transform_inputs_to_parameters(
return _interface_models.ParameterMap(params)


def transform_interface_to_typed_interface(interface: Interface) -> _interface_models.TypedInterface:
def transform_interface_to_typed_interface(
interface: typing.Optional[Interface],
) -> typing.Optional[_interface_models.TypedInterface]:
"""
Transform the given simple python native interface to FlyteIDL's interface
"""
Expand Down
7 changes: 3 additions & 4 deletions flytekit/annotated/map_task.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Any
from typing import Any, Optional

from flytekit.annotated.base_task import PythonTask
from flytekit.annotated.base_task import PythonTask, TaskMetadata
from flytekit.annotated.interface import transform_interface_to_list_interface
from flytekit.models import task as _task_model


class MapPythonTask(PythonTask):
Expand All @@ -14,7 +13,7 @@ class MapPythonTask(PythonTask):
To do this we might have to give up on supporting lambda functions initially
"""

def __init__(self, tk: PythonTask, metadata: _task_model.TaskMetadata, *args, **kwargs):
def __init__(self, tk: PythonTask, metadata: Optional[TaskMetadata] = None, *args, **kwargs):
collection_interface = transform_interface_to_list_interface(tk.python_interface)
name = "mapper_" + tk.name
self._run_task = tk
Expand Down
1 change: 1 addition & 0 deletions flytekit/annotated/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def with_overrides(self, *args, **kwargs):
self._aliases.append(_workflow_model.Alias(var=k, alias=v))


# TODO we should accept TaskMetadata here and then extract whatever fields we want into NodeMetadata
def create_and_link_node(
ctx: FlyteContext,
entity,
Expand Down
6 changes: 3 additions & 3 deletions flytekit/annotated/python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import re
from typing import Any, Callable, Dict, List, Optional, TypeVar

from flytekit.annotated.base_task import PythonTask
from flytekit.annotated.base_task import PythonTask, TaskMetadata
from flytekit.annotated.context_manager import ImageConfig, RegistrationSettings
from flytekit.annotated.interface import Interface, transform_signature_to_interface
from flytekit.annotated.resources import Resources, ResourceSpec
Expand Down Expand Up @@ -61,9 +61,9 @@ def __init__(
self,
task_config: T,
task_function: Callable,
metadata: _task_model.TaskMetadata,
ignore_input_vars: List[str] = None,
task_type="python-task",
metadata: Optional[TaskMetadata] = None,
ignore_input_vars: List[str] = None,
container_image: str = None,
requests: Resources = None,
limits: Resources = None,
Expand Down
8 changes: 6 additions & 2 deletions flytekit/annotated/reference_task.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import inspect
from typing import Callable, Dict, Optional, Type, Union

from flytekit import TaskMetadata
from flytekit.annotated.interface import transform_signature_to_interface
from flytekit.annotated.python_function_task import PythonFunctionTask
from flytekit.annotated.reference_entity import ReferenceEntity, TaskReference
from flytekit.annotated.task import metadata as get_empty_metadata
from flytekit.common.tasks.task import SdkTask


Expand All @@ -25,7 +25,11 @@ def get_task_structure(self) -> SdkTask:
# settings = FlyteContext.current_context().registration_settings
# This is a dummy sdk task, hopefully when we clean up
tk = SdkTask(
type="ignore", metadata=get_empty_metadata(), interface=self.typed_interface, custom={}, container=None,
type="ignore",
metadata=TaskMetadata().to_taskmetadata_model(),
interface=self.typed_interface,
custom={},
container=None,
)
# Reset id to ensure it matches user input
tk._id = self.id
Expand Down
Loading

0 comments on commit f507049

Please sign in to comment.