Skip to content

Commit

Permalink
vscode decorator for the dynamic task (flyteorg#2994)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <pingsutw@apache.org>
Signed-off-by: Eduardo Apolinario <eapolinario@users.noreply.github.com>
Co-authored-by: Eduardo Apolinario <eapolinario@users.noreply.github.com>
  • Loading branch information
2 people authored and shuyingliang committed Dec 20, 2024
1 parent 92c8ddb commit e8373a8
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 7 deletions.
4 changes: 3 additions & 1 deletion flytekit/core/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,9 @@ def new_compilation_state(self, prefix: str = "") -> CompilationState:
Creates and returns a default compilation state. For most of the code this should be the entrypoint
of compilation, otherwise the code should always uses - with_compilation_state
"""
return CompilationState(prefix=prefix)
from flytekit.core.python_auto_container import default_task_resolver

return CompilationState(prefix=prefix, task_resolver=default_task_resolver)

def new_execution_state(self, working_dir: Optional[os.PathLike] = None) -> ExecutionState:
"""
Expand Down
1 change: 0 additions & 1 deletion flytekit/core/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,6 @@ def extract_task_module(f: Union[Callable, TrackedInstance]) -> Tuple[str, str,
:param f: A task or any other callable
:return: [name to use: str, module_name: str, function_name: str, full_path: str]
"""

if isinstance(f, TrackedInstance):
if hasattr(f, "task_function"):
mod, mod_name, name = _task_module_from_callable(f.task_function)
Expand Down
15 changes: 12 additions & 3 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,9 +696,13 @@ def task_name(self, t: PythonAutoContainerTask) -> str: # type: ignore
return f"{self.name}.{t.__module__}.{t.name}"

def _validate_add_on_failure_handler(self, ctx: FlyteContext, prefix: str, wf_args: Dict[str, Promise]):
# Compare
resolver = (
ctx.compilation_state.task_resolver
if ctx.compilation_state and ctx.compilation_state.task_resolver
else self
)
with FlyteContextManager.with_context(
ctx.with_compilation_state(CompilationState(prefix=prefix, task_resolver=self))
ctx.with_compilation_state(CompilationState(prefix=prefix, task_resolver=resolver))
) as inner_comp_ctx:
# Now lets compile the failure-node if it exists
if self.on_failure:
Expand Down Expand Up @@ -736,9 +740,14 @@ def compile(self, **kwargs):
ctx = FlyteContextManager.current_context()
all_nodes = []
prefix = ctx.compilation_state.prefix if ctx.compilation_state is not None else ""
resolver = (
ctx.compilation_state.task_resolver
if ctx.compilation_state and ctx.compilation_state.task_resolver
else self
)

with FlyteContextManager.with_context(
ctx.with_compilation_state(CompilationState(prefix=prefix, task_resolver=self))
ctx.with_compilation_state(CompilationState(prefix=prefix, task_resolver=resolver))
) as comp_ctx:
# Construct the default input promise bindings, but then override with the provided inputs, if any
input_kwargs = construct_input_promises([k for k in self.interface.inputs.keys()])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

import mock
import pytest

from flytekit.core import context_manager
from flytekit.core.python_auto_container import default_task_resolver
from flytekitplugins.flyteinteractive import (
CODE_TOGETHER_CONFIG,
CODE_TOGETHER_EXTENSION,
Expand All @@ -24,9 +27,9 @@
is_extension_installed,
)

from flytekit import task, workflow
from flytekit import task, workflow, dynamic
from flytekit.configuration import Image, ImageConfig, SerializationSettings
from flytekit.core.context_manager import ExecutionState
from flytekit.core.context_manager import ExecutionState, FlyteContextManager
from flytekit.tools.translator import get_serializable_task


Expand Down Expand Up @@ -402,3 +405,35 @@ def test_get_installed_extensions_failed(mock_run):

expected_extensions = []
assert installed_extensions == expected_extensions


def test_vscode_with_dynamic(vscode_patches):
(
mock_process,
mock_prepare_interactive_python,
mock_exit_handler,
mock_download_vscode,
mock_signal,
mock_prepare_resume_task_python,
mock_prepare_launch_json,
) = vscode_patches

mock_exit_handler.return_value = None

@task()
def train():
print("forward")
print("backward")

@dynamic()
@vscode
def d1():
print("dynamic", flush=True)
train()

ctx = FlyteContextManager.current_context()
with context_manager.FlyteContextManager.with_context(
ctx.with_execution_state(ctx.execution_state.with_params(mode=ExecutionState.Mode.TASK_EXECUTION))
):
d1()
assert d1.task_resolver == default_task_resolver

0 comments on commit e8373a8

Please sign in to comment.