From d525cd9f5bee200203b5effa50fbdcd975e5ea59 Mon Sep 17 00:00:00 2001 From: Ketan Umare <16888709+kumare3@users.noreply.github.com> Date: Mon, 18 Jan 2021 13:51:57 -0800 Subject: [PATCH] Avoid writing tasks inside functions by raising errors (#328) --- flytekit/annotated/python_function_task.py | 42 +++++++++++++++++++ tests/flytekit/unit/annotated/tasks.py | 11 +++++ ...t_task.py => test_python_function_task.py} | 30 ++++++++++++- 3 files changed, 82 insertions(+), 1 deletion(-) create mode 100644 tests/flytekit/unit/annotated/tasks.py rename tests/flytekit/unit/annotated/{test_task.py => test_python_function_task.py} (81%) diff --git a/flytekit/annotated/python_function_task.py b/flytekit/annotated/python_function_task.py index e16a414651..7cdbfc7c78 100644 --- a/flytekit/annotated/python_function_task.py +++ b/flytekit/annotated/python_function_task.py @@ -127,6 +127,38 @@ def get_container(self, settings: SerializationSettings) -> _task_model.Containe ) +def isnested(func) -> bool: + """ + Returns true if a function is local to another function and is not accessible through a module + + This would essentially be any function with a `..` (defined within a function) e.g. + + .. code:: python + + def foo(): + def foo_inner(): + pass + pass + + In the above example `foo_inner` is the local function or a nested function. + """ + return func.__code__.co_flags & inspect.CO_NESTED != 0 + + +def istestfunction(func) -> bool: + """ + Returns true if the function is defined in a test module. A test module has to have `test_` as the prefix. + False in all other cases + """ + mod = inspect.getmodule(func) + if mod: + mod_name = mod.__name__ + if "." in mod_name: + mod_name = mod_name.split(".")[-1] + return mod_name.startswith("test_") + return False + + class PythonFunctionTask(PythonAutoContainerTask[T]): """ A Python Function task should be used as the base for all extensions that have a python function. It will @@ -167,6 +199,12 @@ def __init__( """ if task_function is None: raise ValueError("TaskFunction is a required parameter for PythonFunctionTask") + if not istestfunction(func=task_function) and isnested(func=task_function): + raise ValueError( + "TaskFunction cannot be a nested/inner or local function. " + "It should be accessible at a module level for Flyte to execute it. Test modules with " + "names begining with `test_` are allowed to have nested tasks" + ) self._native_interface = transform_signature_to_interface(inspect.signature(task_function)) mutated_interface = self._native_interface.remove_inputs(ignore_input_vars) super().__init__( @@ -183,6 +221,10 @@ def __init__( def execution_mode(self) -> ExecutionBehavior: return self._execution_mode + @property + def task_function(self): + return self._task_function + def execute(self, **kwargs) -> Any: """ This method will be invoked to execute the task. If you do decide to override this method you must also diff --git a/tests/flytekit/unit/annotated/tasks.py b/tests/flytekit/unit/annotated/tasks.py new file mode 100644 index 0000000000..9385b0f08d --- /dev/null +++ b/tests/flytekit/unit/annotated/tasks.py @@ -0,0 +1,11 @@ +""" +NOTE: This file exists to test failure in case of instantiating a task inside a function in a non test_* module. +Refer to test_python_function_task.py to look at the usage. +""" +from flytekit import task + + +def tasks(): + @task + def foo(): + pass diff --git a/tests/flytekit/unit/annotated/test_task.py b/tests/flytekit/unit/annotated/test_python_function_task.py similarity index 81% rename from tests/flytekit/unit/annotated/test_task.py rename to tests/flytekit/unit/annotated/test_python_function_task.py index 3f8320f52a..155e0bd884 100644 --- a/tests/flytekit/unit/annotated/test_task.py +++ b/tests/flytekit/unit/annotated/test_python_function_task.py @@ -2,7 +2,35 @@ from flytekit import task from flytekit.annotated.context_manager import Image, ImageConfig, SerializationSettings -from flytekit.annotated.python_function_task import PythonFunctionTask, get_registerable_container_image +from flytekit.annotated.python_function_task import ( + PythonFunctionTask, + get_registerable_container_image, + isnested, + istestfunction, +) +from tests.flytekit.unit.annotated import tasks + + +def foo(): + pass + + +def test_isnested(): + def inner_foo(): + pass + + assert isnested(foo) is False + assert isnested(inner_foo) is True + + # Uses tasks.tasks method + with pytest.raises(ValueError): + tasks.tasks() + + +def test_istestfunction(): + assert istestfunction(foo) is True + assert istestfunction(isnested) is False + assert istestfunction(tasks.tasks) is False def test_container_image_conversion():