Skip to content

Commit

Permalink
Avoid writing tasks inside functions by raising errors (#328)
Browse files Browse the repository at this point in the history
  • Loading branch information
kumare3 authored Jan 18, 2021
1 parent d11ee29 commit d525cd9
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 1 deletion.
42 changes: 42 additions & 0 deletions flytekit/annotated/python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,38 @@ def get_container(self, settings: SerializationSettings) -> _task_model.Containe
)


def isnested(func) -> bool:
"""
Returns true if a function is local to another function and is not accessible through a module
This would essentially be any function with a `.<local>.` (defined within a function) e.g.
.. code:: python
def foo():
def foo_inner():
pass
pass
In the above example `foo_inner` is the local function or a nested function.
"""
return func.__code__.co_flags & inspect.CO_NESTED != 0


def istestfunction(func) -> bool:
"""
Returns true if the function is defined in a test module. A test module has to have `test_` as the prefix.
False in all other cases
"""
mod = inspect.getmodule(func)
if mod:
mod_name = mod.__name__
if "." in mod_name:
mod_name = mod_name.split(".")[-1]
return mod_name.startswith("test_")
return False


class PythonFunctionTask(PythonAutoContainerTask[T]):
"""
A Python Function task should be used as the base for all extensions that have a python function. It will
Expand Down Expand Up @@ -167,6 +199,12 @@ def __init__(
"""
if task_function is None:
raise ValueError("TaskFunction is a required parameter for PythonFunctionTask")
if not istestfunction(func=task_function) and isnested(func=task_function):
raise ValueError(
"TaskFunction cannot be a nested/inner or local function. "
"It should be accessible at a module level for Flyte to execute it. Test modules with "
"names begining with `test_` are allowed to have nested tasks"
)
self._native_interface = transform_signature_to_interface(inspect.signature(task_function))
mutated_interface = self._native_interface.remove_inputs(ignore_input_vars)
super().__init__(
Expand All @@ -183,6 +221,10 @@ def __init__(
def execution_mode(self) -> ExecutionBehavior:
return self._execution_mode

@property
def task_function(self):
return self._task_function

def execute(self, **kwargs) -> Any:
"""
This method will be invoked to execute the task. If you do decide to override this method you must also
Expand Down
11 changes: 11 additions & 0 deletions tests/flytekit/unit/annotated/tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""
NOTE: This file exists to test failure in case of instantiating a task inside a function in a non test_* module.
Refer to test_python_function_task.py to look at the usage.
"""
from flytekit import task


def tasks():
@task
def foo():
pass
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,35 @@

from flytekit import task
from flytekit.annotated.context_manager import Image, ImageConfig, SerializationSettings
from flytekit.annotated.python_function_task import PythonFunctionTask, get_registerable_container_image
from flytekit.annotated.python_function_task import (
PythonFunctionTask,
get_registerable_container_image,
isnested,
istestfunction,
)
from tests.flytekit.unit.annotated import tasks


def foo():
pass


def test_isnested():
def inner_foo():
pass

assert isnested(foo) is False
assert isnested(inner_foo) is True

# Uses tasks.tasks method
with pytest.raises(ValueError):
tasks.tasks()


def test_istestfunction():
assert istestfunction(foo) is True
assert istestfunction(isnested) is False
assert istestfunction(tasks.tasks) is False


def test_container_image_conversion():
Expand Down

0 comments on commit d525cd9

Please sign in to comment.