Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make FlyteFile compatible with Annotated[..., HashMethod] #1544

Merged
merged 6 commits into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from dataclasses_json import config, dataclass_json
from marshmallow import fields
from typing_extensions import Annotated, get_args, get_origin

from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
Expand Down Expand Up @@ -335,6 +336,10 @@ def to_literal(
if python_val is None:
raise TypeTransformerFailedError("None value cannot be converted to a file.")

# Correctly handle `Annotated[FlyteFile, ...]` by extracting the origin type
if get_origin(python_type) is Annotated:
python_type = get_args(python_type)[0]

if not (python_type is os.PathLike or issubclass(python_type, FlyteFile)):
raise ValueError(f"Incorrect type {python_type}, must be either a FlyteFile or os.PathLike")

Expand Down
17 changes: 17 additions & 0 deletions tests/flytekit/unit/core/test_flyte_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
from unittest.mock import MagicMock

import pytest
from typing_extensions import Annotated

import flytekit.configuration
from flytekit.configuration import Config, Image, ImageConfig
from flytekit.core.context_manager import ExecutionState, FlyteContextManager
from flytekit.core.data_persistence import FileAccessProvider, flyte_tmp_dir
from flytekit.core.dynamic_workflow_task import dynamic
from flytekit.core.hash import HashMethod
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.task import task
from flytekit.core.type_engine import TypeEngine
Expand Down Expand Up @@ -433,6 +435,21 @@ def wf(path: str) -> os.PathLike:
assert flyte_tmp_dir in wf(path="s3://somewhere").path


def test_flyte_file_annotated_hashmethod(local_dummy_file):
def calc_hash(ff: FlyteFile) -> str:
return str(ff.path)

@task
def t1(path: str) -> Annotated[FlyteFile, HashMethod(calc_hash)]:
return FlyteFile(path)

@workflow
def wf(path: str) -> None:
t1(path=path)

wf(path=local_dummy_file)


@pytest.mark.sandbox_test
def test_file_open_things():
@task
Expand Down