From 2672d24fe07528e6393b05b2c22311b3eb08e842 Mon Sep 17 00:00:00 2001 From: Kevin Su Date: Wed, 24 Jul 2024 04:15:04 +0800 Subject: [PATCH] refactor(core): Enhance return type extraction logic (#2598) Signed-off-by: Kevin Su Signed-off-by: mao3267 --- flytekit/core/interface.py | 4 +++- tests/flytekit/unit/core/test_interface.py | 11 ++++++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/flytekit/core/interface.py b/flytekit/core/interface.py index 65fd4fed6a..ebf1921871 100644 --- a/flytekit/core/interface.py +++ b/flytekit/core/interface.py @@ -507,7 +507,9 @@ def t(a: int, b: str) -> Dict[str, int]: ... # This statement results in true for typing.Namedtuple, single and void return types, so this # handles Options 1, 2. Even though NamedTuple for us is multi-valued, it's a single value for Python - if isinstance(return_annotation, Type) or isinstance(return_annotation, TypeVar): # type: ignore + if hasattr(return_annotation, "__bases__") and ( + isinstance(return_annotation, Type) or isinstance(return_annotation, TypeVar) # type: ignore + ): # isinstance / issubclass does not work for Namedtuple. # Options 1 and 2 bases = return_annotation.__bases__ # type: ignore diff --git a/tests/flytekit/unit/core/test_interface.py b/tests/flytekit/unit/core/test_interface.py index fb0d1e6816..d3b994e508 100644 --- a/tests/flytekit/unit/core/test_interface.py +++ b/tests/flytekit/unit/core/test_interface.py @@ -4,7 +4,7 @@ from typing import Dict, List import pytest -from typing_extensions import Annotated # type: ignore +from typing_extensions import Annotated, TypeVar # type: ignore from flytekit import map_task, task from flytekit.core import context_manager @@ -96,6 +96,15 @@ def t(a: int, b: str) -> Dict[str, int]: assert len(return_type) == 1 assert return_type["o0"] == Dict[str, int] + VST = TypeVar("VST") + + def t(a: int, b: str) -> VST: # type: ignore + ... + + return_type = extract_return_annotation(typing.get_type_hints(t).get("return", None)) + assert len(return_type) == 1 + assert return_type["o0"] == VST + def test_named_tuples(): nt1 = typing.NamedTuple("NT1", x_str=str, y_int=int)