Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error message for missing return #2551

Merged
merged 32 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added flytekit/_ast/__init__.py
Empty file.
25 changes: 25 additions & 0 deletions flytekit/_ast/parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import ast
import inspect
import typing


def get_function_param_location(func: typing.Callable, param_name: str) -> (int, int):
"""
Get the line and column number of the parameter in the source code of the function definition.
"""
# Get source code of the function
source_lines, start_line = inspect.getsourcelines(func)
source_code = "".join(source_lines)

# Parse the source code into an AST
module = ast.parse(source_code)

# Traverse the AST to find the function definition
for node in ast.walk(module):
if isinstance(node, ast.FunctionDef) and node.name == func.__name__:
for i, arg in enumerate(node.args.args):
if arg.arg == param_name:
# Calculate the line and column number of the parameter
line_number = start_line + node.lineno - 1
column_offset = arg.col_offset
return line_number, column_offset
27 changes: 12 additions & 15 deletions flytekit/clis/sdk_in_container/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

from flytekit.core.constants import SOURCE_CODE
from flytekit.exceptions.base import FlyteException
from flytekit.exceptions.user import FlyteInvalidInputException
from flytekit.exceptions.user import FlyteCompilationException, FlyteInvalidInputException
from flytekit.exceptions.utils import annotate_exception_with_code
from flytekit.loggers import get_level_from_cli_verbosity, logger

project_option = click.Option(
Expand Down Expand Up @@ -130,12 +131,14 @@ def pretty_print_traceback(e: Exception, verbosity: int = 1):
else:
raise ValueError(f"Verbosity level must be between 0 and 2. Got {verbosity}")

if hasattr(e, SOURCE_CODE):
# TODO: Use other way to check if the background is light or dark
theme = "emacs" if "LIGHT_BACKGROUND" in os.environ else "monokai"
syntax = Syntax(getattr(e, SOURCE_CODE), "python", theme=theme, background_color="default")
panel = Panel(syntax, border_style="red", title=type(e).__name__, title_align="left")
console.print(panel, no_wrap=False)
if isinstance(e, FlyteCompilationException):
e = annotate_exception_with_code(e, e.fn, e.param_name)
if hasattr(e, SOURCE_CODE):
# TODO: Use other way to check if the background is light or dark
theme = "emacs" if "LIGHT_BACKGROUND" in os.environ else "monokai"
syntax = Syntax(getattr(e, SOURCE_CODE), "python", theme=theme, background_color="default")
panel = Panel(syntax, border_style="red", title=e._ERROR_CODE, title_align="left")
console.print(panel, no_wrap=False)


def pretty_print_exception(e: Exception, verbosity: int = 1):
Expand All @@ -161,20 +164,14 @@ def pretty_print_exception(e: Exception, verbosity: int = 1):
pretty_print_grpc_error(cause)
else:
pretty_print_traceback(e, verbosity)
else:
pretty_print_traceback(e, verbosity)
return

if isinstance(e, grpc.RpcError):
pretty_print_grpc_error(e)
return

if isinstance(e, AssertionError):
click.secho(f"Assertion Error: {e}", fg="red")
return

if isinstance(e, ValueError):
click.secho(f"Value Error: {e}", fg="red")
return

pretty_print_traceback(e, verbosity)


Expand Down
24 changes: 20 additions & 4 deletions flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,16 @@

from flytekit.core import context_manager
from flytekit.core.artifact import Artifact, ArtifactIDSpecification, ArtifactQuery
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.docstring import Docstring
from flytekit.core.sentinel import DYNAMIC_INPUT_BINDING
from flytekit.core.type_engine import TypeEngine, UnionTransformer
from flytekit.exceptions.user import FlyteValidationException
from flytekit.exceptions.utils import annotate_exception_with_code
from flytekit.core.utils import has_return_statement
from flytekit.exceptions.user import (
FlyteMissingReturnValueException,
FlyteMissingTypeException,
FlyteValidationException,
)
from flytekit.loggers import developer_logger, logger
from flytekit.models import interface as _interface_models
from flytekit.models.literals import Literal, Scalar, Void
Expand Down Expand Up @@ -375,15 +380,26 @@ def transform_function_to_interface(fn: typing.Callable, docstring: Optional[Doc
signature = inspect.signature(fn)
return_annotation = type_hints.get("return", None)

ctx = FlyteContextManager.current_context()
# Only check if the task/workflow has a return statement at compile time locally.
if (
ctx.execution_state
and ctx.execution_state.mode is None
and return_annotation
and type(None) not in get_args(return_annotation)
and return_annotation is not type(None)
and has_return_statement(fn) is False
):
raise FlyteMissingReturnValueException(fn=fn)

outputs = extract_return_annotation(return_annotation)
for k, v in outputs.items():
outputs[k] = v # type: ignore
inputs: Dict[str, Tuple[Type, Any]] = OrderedDict()
for k, v in signature.parameters.items(): # type: ignore
annotation = type_hints.get(k, None)
if annotation is None:
err_msg = f"'{k}' has no type. Please add a type annotation to the input parameter."
raise annotate_exception_with_code(TypeError(err_msg), fn, k)
raise FlyteMissingTypeException(fn=fn, param_name=k)
default = v.default if v.default is not inspect.Parameter.empty else None
# Inputs with default values are currently ignored, we may want to look into that in the future
inputs[k] = (annotation, default) # type: ignore
Expand Down
12 changes: 12 additions & 0 deletions flytekit/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import datetime
import inspect
import os
import shutil
import tempfile
import time
import typing
from abc import ABC, abstractmethod
from functools import wraps
from hashlib import sha224 as _sha224
Expand Down Expand Up @@ -381,3 +383,13 @@ def get_extra_config(self):
Get the config of the decorator.
"""
pass


def has_return_statement(func: typing.Callable) -> bool:
source_lines = inspect.getsourcelines(func)[0]
for line in source_lines:
if "return" in line.strip():
return True
if "yield" in line.strip():
return True
return False
22 changes: 22 additions & 0 deletions flytekit/exceptions/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,3 +97,25 @@ def __init__(self, request: typing.Any):

class FlytePromiseAttributeResolveException(FlyteAssertion):
_ERROR_CODE = "USER:PromiseAttributeResolveError"


class FlyteCompilationException(FlyteUserException):
_ERROR_CODE = "USER:CompileError"

def __init__(self, fn: typing.Callable, param_name: typing.Optional[str] = None):
self.fn = fn
self.param_name = param_name


class FlyteMissingTypeException(FlyteCompilationException):
_ERROR_CODE = "USER:MissingTypeError"

def __str__(self):
return f"'{self.param_name}' has no type. Please add a type annotation to the input parameter."


class FlyteMissingReturnValueException(FlyteCompilationException):
_ERROR_CODE = "USER:MissingReturnValueError"

def __str__(self):
return f"{self.fn.__name__} function must return a value. Please add a return statement at the end of the function."
34 changes: 9 additions & 25 deletions flytekit/exceptions/utils.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,28 @@
import ast
import inspect
import typing

from flytekit._ast.parser import get_function_param_location
from flytekit.core.constants import SOURCE_CODE
from flytekit.exceptions.user import FlyteUserException


def get_function_param_location(func: typing.Callable, param_name: str) -> (int, int):
"""
Get the line and column number of the parameter in the source code of the function definition.
"""
# Get source code of the function
source_lines, start_line = inspect.getsourcelines(func)
source_code = "".join(source_lines)

# Parse the source code into an AST
module = ast.parse(source_code)

# Traverse the AST to find the function definition
for node in ast.walk(module):
if isinstance(node, ast.FunctionDef) and node.name == func.__name__:
for i, arg in enumerate(node.args.args):
if arg.arg == param_name:
# Calculate the line and column number of the parameter
line_number = start_line + node.lineno - 1
column_offset = arg.col_offset
return line_number, column_offset


def get_source_code_from_fn(fn: typing.Callable, param_name: str) -> (str, int):
def get_source_code_from_fn(fn: typing.Callable, param_name: typing.Optional[str] = None) -> (str, int):
"""
Get the source code of the function and the column offset of the parameter defined in the input signature.
"""
lines, start_line = inspect.getsourcelines(fn)
if param_name is None:
return "".join(f"{start_line + i} {lines[i]}" for i in range(len(lines))), 0

target_line_no, column_offset = get_function_param_location(fn, param_name)
line_index = target_line_no - start_line
source_code = "".join(f"{start_line + i} {lines[i]}" for i in range(line_index + 1))
return source_code, column_offset


def annotate_exception_with_code(exception: Exception, fn: typing.Callable, param_name: str) -> Exception:
def annotate_exception_with_code(
exception: FlyteUserException, fn: typing.Callable, param_name: typing.Optional[str] = None
) -> FlyteUserException:
"""
Annotate the exception with the source code, and will be printed in the rich panel.
@param exception: The exception to be annotated.
Expand Down
10 changes: 6 additions & 4 deletions tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,14 +390,16 @@ def test_fetch_not_exist_launch_plan(register):


def test_execute_reference_task(register):
nt = typing.NamedTuple("OutputsBC", [("t1_int_output", int), ("c", str)])

@reference_task(
project=PROJECT,
domain=DOMAIN,
name="basic.basic_workflow.t1",
version=VERSION,
)
def t1(a: int) -> typing.NamedTuple("OutputsBC", t1_int_output=int, c=str):
...
def t1(a: int) -> nt:
return nt(t1_int_output=a + 2, c="world")

remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
execution = remote.execute(
Expand All @@ -424,7 +426,7 @@ def test_execute_reference_workflow(register):
version=VERSION,
)
def my_wf(a: int, b: str) -> (int, str):
...
return a + 2, b + "world"

remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
execution = remote.execute(
Expand All @@ -451,7 +453,7 @@ def test_execute_reference_launchplan(register):
version=VERSION,
)
def my_wf(a: int, b: str) -> (int, str):
...
return 3, "world"

remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
execution = remote.execute(
Expand Down
2 changes: 2 additions & 0 deletions tests/flytekit/unit/bin/test_python_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def test_dispatch_execute_user_error_non_recov(mock_write_to_file, mock_upload_d
def t1(a: int) -> str:
# Should be interpreted as a non-recoverable user error
raise ValueError(f"some exception {a}")
return "hello"

ctx = context_manager.FlyteContext.current_context()
with context_manager.FlyteContextManager.with_context(
Expand Down Expand Up @@ -242,6 +243,7 @@ def t1(a: int) -> str:
def my_subwf(a: int) -> typing.List[str]:
# This also tests the dynamic/compile path
raise user_exceptions.FlyteRecoverableException(f"recoverable {a}")
return ["1", "2"]

ctx = context_manager.FlyteContext.current_context()
with context_manager.FlyteContextManager.with_context(
Expand Down
4 changes: 2 additions & 2 deletions tests/flytekit/unit/core/flyte_functools/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,10 @@ def test_unwrapped_task():
error = completed_process.stderr
error_str = ""
for line in error.strip().split("\n"):
if line.startswith("TypeError"):
if line.startswith("FlyteMissingTypeException"):
error_str += line
assert error_str != ""
assert error_str.startswith("TypeError: 'args' has no type. Please add a type annotation to the input")
assert "'args' has no type. Please add a type annotation" in error_str


@pytest.mark.parametrize("script", ["nested_function.py", "nested_wrapped_function.py"])
Expand Down
2 changes: 1 addition & 1 deletion tests/flytekit/unit/core/test_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def test_basic_option_a3():

@task
def t3(b_value: str) -> Annotated[pd.DataFrame, a3]:
...
return pd.DataFrame({"a": [1, 2, 3], "b": [b_value, b_value, b_value]})

entities = OrderedDict()
t3_s = get_serializable(entities, serialization_settings, t3)
Expand Down
4 changes: 2 additions & 2 deletions tests/flytekit/unit/core/test_imperative.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def ref_t1(
dataframe: pd.DataFrame,
imputation_method: str = "median",
) -> pd.DataFrame:
...
return dataframe

@reference_task(
project="flytesnacks",
Expand All @@ -340,7 +340,7 @@ def ref_t2(
split_mask: int,
num_features: int,
) -> pd.DataFrame:
...
return dataframe

wb = ImperativeWorkflow(name="core.feature_engineering.workflow.fe_wf")
wb.add_workflow_input("sqlite_archive", FlyteFile[typing.TypeVar("sqlite")])
Expand Down
Loading
Loading