Skip to content

Commit

Permalink
refactor(core): Enhance return type extraction logic (flyteorg#2598)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <pingsutw@apache.org>
Signed-off-by: mao3267 <chenvincent610@gmail.com>
  • Loading branch information
pingsutw authored and mao3267 committed Aug 2, 2024
1 parent f20b466 commit 2672d24
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
4 changes: 3 additions & 1 deletion flytekit/core/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,9 @@ def t(a: int, b: str) -> Dict[str, int]: ...

# This statement results in true for typing.Namedtuple, single and void return types, so this
# handles Options 1, 2. Even though NamedTuple for us is multi-valued, it's a single value for Python
if isinstance(return_annotation, Type) or isinstance(return_annotation, TypeVar): # type: ignore
if hasattr(return_annotation, "__bases__") and (
isinstance(return_annotation, Type) or isinstance(return_annotation, TypeVar) # type: ignore
):
# isinstance / issubclass does not work for Namedtuple.
# Options 1 and 2
bases = return_annotation.__bases__ # type: ignore
Expand Down
11 changes: 10 additions & 1 deletion tests/flytekit/unit/core/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Dict, List

import pytest
from typing_extensions import Annotated # type: ignore
from typing_extensions import Annotated, TypeVar # type: ignore

from flytekit import map_task, task
from flytekit.core import context_manager
Expand Down Expand Up @@ -96,6 +96,15 @@ def t(a: int, b: str) -> Dict[str, int]:
assert len(return_type) == 1
assert return_type["o0"] == Dict[str, int]

VST = TypeVar("VST")

def t(a: int, b: str) -> VST: # type: ignore
...

return_type = extract_return_annotation(typing.get_type_hints(t).get("return", None))
assert len(return_type) == 1
assert return_type["o0"] == VST


def test_named_tuples():
nt1 = typing.NamedTuple("NT1", x_str=str, y_int=int)
Expand Down

0 comments on commit 2672d24

Please sign in to comment.