From 08b75222f11345c8071427837bb559d56f23d291 Mon Sep 17 00:00:00 2001 From: Chi-Sheng Liu Date: Fri, 7 Jun 2024 01:12:18 +0800 Subject: [PATCH] feat(bindings): Task arguments default value binding (#2401) flyteorg/flyte#5321 if the key is not in `kwargs` but in `interface.inputs_with_defaults`, add the value in `interface.inputs_with_defaults` to `kwargs`. Signed-off-by: Chi-Sheng Liu Signed-off-by: Jan Fiedler --- flytekit/core/promise.py | 56 +-- flytekit/core/type_engine.py | 2 +- tests/flytekit/unit/core/test_composition.py | 14 - tests/flytekit/unit/core/test_dynamic.py | 37 ++ tests/flytekit/unit/core/test_promise.py | 2 +- .../flytekit/unit/core/test_serialization.py | 450 +++++++++++++++++- .../unit/types/pickle/test_flyte_pickle.py | 52 +- .../test_structured_dataset.py | 40 +- 8 files changed, 607 insertions(+), 46 deletions(-) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 931d036d02..afd3c069f6 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -4,10 +4,10 @@ import inspect from copy import deepcopy from enum import Enum -from typing import Any, Coroutine, Dict, List, Optional, Set, Tuple, Union, cast +from typing import Any, Coroutine, Dict, Hashable, List, Optional, Set, Tuple, Union, cast, get_args from google.protobuf import struct_pb2 as _struct -from typing_extensions import Protocol, get_args +from typing_extensions import Protocol from flytekit.core import constants as _common_constants from flytekit.core import context_manager as _flyte_context @@ -23,7 +23,13 @@ ) from flytekit.core.interface import Interface from flytekit.core.node import Node -from flytekit.core.type_engine import DictTransformer, ListTransformer, TypeEngine, TypeTransformerFailedError +from flytekit.core.type_engine import ( + DictTransformer, + ListTransformer, + TypeEngine, + TypeTransformerFailedError, + UnionTransformer, +) from flytekit.exceptions import user as _user_exceptions from flytekit.exceptions.user import FlytePromiseAttributeResolveException from flytekit.loggers import logger @@ -774,7 +780,13 @@ def binding_from_python_std( t_value_type: type, ) -> Tuple[_literals_models.Binding, List[Node]]: nodes: List[Node] = [] - binding_data = binding_data_from_python_std(ctx, expected_literal_type, t_value, t_value_type, nodes) + binding_data = binding_data_from_python_std( + ctx, + expected_literal_type, + t_value, + t_value_type, + nodes, + ) return _literals_models.Binding(var=var_name, binding=binding_data), nodes @@ -1060,32 +1072,22 @@ def create_and_link_node( for k in sorted(interface.inputs): var = typed_interface.inputs[k] + if var.type.simple == SimpleType.NONE: + raise TypeError("Arguments do not have type annotation") if k not in kwargs: - is_optional = False - if var.type.union_type: - for variant in var.type.union_type.variants: - if variant.simple == SimpleType.NONE: - val, _default = interface.inputs_with_defaults[k] - if _default is not None: - raise ValueError( - f"The default value for the optional type must be None, but got {_default}" - ) - is_optional = True - if not is_optional: - from flytekit.core.base_task import Task - + # interface.inputs_with_defaults[k][0] is the type of the default argument + # interface.inputs_with_defaults[k][1] is the value of the default argument + if k in interface.inputs_with_defaults and ( + interface.inputs_with_defaults[k][1] is not None + or UnionTransformer.is_optional_type(interface.inputs_with_defaults[k][0]) + ): + default_val = interface.inputs_with_defaults[k][1] + if not isinstance(default_val, Hashable): + raise _user_exceptions.FlyteAssertion("Cannot use non-hashable object as default argument") + kwargs[k] = default_val + else: error_msg = f"Input {k} of type {interface.inputs[k]} was not specified for function {entity.name}" - - _, _default = interface.inputs_with_defaults[k] - if isinstance(entity, Task) and _default is not None: - error_msg += ( - ". Flyte workflow syntax is a domain-specific language (DSL) for building execution graphs which " - "supports a subset of Python’s semantics. When calling tasks, all kwargs have to be provided." - ) - raise _user_exceptions.FlyteAssertion(error_msg) - else: - continue v = kwargs[k] # This check ensures that tuples are not passed into a function, as tuples are not supported by Flyte # Usually a Tuple will indicate that multiple outputs from a previous task were accidentally passed diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index f5beb53f52..55d6368e43 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1554,7 +1554,7 @@ def __init__(self): super().__init__("Typed Union", typing.Union) @staticmethod - def is_optional_type(t: Type[T]) -> bool: + def is_optional_type(t: Type) -> bool: """Return True if `t` is a Union or Optional type.""" return _is_union_type(t) or type(None) in get_args(t) diff --git a/tests/flytekit/unit/core/test_composition.py b/tests/flytekit/unit/core/test_composition.py index 6fe2b01e61..0073baec53 100644 --- a/tests/flytekit/unit/core/test_composition.py +++ b/tests/flytekit/unit/core/test_composition.py @@ -1,7 +1,5 @@ from typing import Dict, List, NamedTuple, Optional, Union -import pytest - from flytekit.core import launch_plan from flytekit.core.task import task from flytekit.core.workflow import workflow @@ -186,15 +184,3 @@ def wf(a: Optional[int] = 1) -> Optional[int]: return t2(a=a) assert wf() is None - - with pytest.raises(ValueError, match="The default value for the optional type must be None, but got 3"): - - @task() - def t3(c: Optional[int] = 3) -> Optional[int]: - ... - - @workflow - def wf(): - return t3() - - wf() diff --git a/tests/flytekit/unit/core/test_dynamic.py b/tests/flytekit/unit/core/test_dynamic.py index 7964548674..d3a7237391 100644 --- a/tests/flytekit/unit/core/test_dynamic.py +++ b/tests/flytekit/unit/core/test_dynamic.py @@ -95,6 +95,43 @@ def ranged_int_to_str(a: int) -> typing.List[str]: assert res == ["fast-2", "fast-3", "fast-4", "fast-5", "fast-6"] +@pytest.mark.parametrize( + "input_val,output_val", + [ + (4, 0), + (5, 5), + ], +) +def test_dynamic_local_default_args_task(input_val, output_val): + @task + def t1(a: int = 0) -> int: + return a + + @dynamic + def dt(a: int) -> int: + if a % 2 == 0: + return t1() + return t1(a=a) + + assert dt(a=input_val) == output_val + + with context_manager.FlyteContextManager.with_context( + context_manager.FlyteContextManager.current_context().with_serialization_settings(settings) + ) as ctx: + with context_manager.FlyteContextManager.with_context( + ctx.with_execution_state( + ctx.execution_state.with_params( + mode=ExecutionState.Mode.TASK_EXECUTION, + ) + ) + ) as ctx: + input_literal_map = TypeEngine.dict_to_literal_map(ctx, {"a": input_val}) + dynamic_job_spec = dt.dispatch_execute(ctx, input_literal_map) + assert len(dynamic_job_spec.nodes) == 1 + assert len(dynamic_job_spec.tasks) == 1 + assert dynamic_job_spec.nodes[0].inputs[0].binding.scalar.primitive is not None + + def test_nested_dynamic_local(): @task def t1(a: int) -> str: diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index e022c875e0..a6d223f21a 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -46,7 +46,7 @@ def t2(a: typing.Optional[int] = None) -> typing.Optional[int]: p = create_and_link_node(ctx, t2) assert p.ref.var == "o0" - assert len(p.ref.node.bindings) == 0 + assert len(p.ref.node.bindings) == 1 def test_create_and_link_node_from_remote(): diff --git a/tests/flytekit/unit/core/test_serialization.py b/tests/flytekit/unit/core/test_serialization.py index 88297f43f4..2fcf8bbd94 100644 --- a/tests/flytekit/unit/core/test_serialization.py +++ b/tests/flytekit/unit/core/test_serialization.py @@ -12,9 +12,20 @@ from flytekit.core.python_auto_container import get_registerable_container_image from flytekit.core.task import task from flytekit.core.workflow import workflow +from flytekit.exceptions.user import FlyteAssertion from flytekit.image_spec.image_spec import ImageBuildEngine, _calculate_deduped_hash_from_image_spec from flytekit.models.admin.workflow import WorkflowSpec -from flytekit.models.types import SimpleType +from flytekit.models.literals import ( + BindingData, + BindingDataCollection, + BindingDataMap, + Literal, + Primitive, + Scalar, + Union, + Void, +) +from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType from flytekit.tools.translator import get_serializable from flytekit.types.error.error import FlyteError @@ -495,3 +506,440 @@ def z(a: int, b: str) -> typing.Tuple[int, str]: assert task_spec.template.interface.inputs["a"].description == "foo" assert task_spec.template.interface.inputs["b"].description == "bar" assert task_spec.template.interface.outputs["o0"].description == "ramen" + + +def test_default_args_task_int_type(): + default_val = 0 + input_val = 100 + + @task + def t1(a: int = default_val) -> int: + return a + + @workflow + def wf_no_input() -> int: + return t1() + + @workflow + def wf_with_input() -> int: + return t1(a=input_val) + + @workflow + def wf_with_sub_wf() -> tuple[int, int]: + return (wf_no_input(), wf_with_input()) + + wf_no_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_no_input) + wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input) + + assert wf_no_input_spec.template.nodes[0].inputs[0].binding.value == Scalar( + primitive=Primitive(integer=default_val) + ) + assert wf_with_input_spec.template.nodes[0].inputs[0].binding.value == Scalar( + primitive=Primitive(integer=input_val) + ) + + output_type = LiteralType(simple=SimpleType.INTEGER) + assert wf_no_input_spec.template.interface.outputs["o0"].type == output_type + assert wf_with_input_spec.template.interface.outputs["o0"].type == output_type + + assert wf_no_input() == default_val + assert wf_with_input() == input_val + assert wf_with_sub_wf() == (default_val, input_val) + + +def test_default_args_task_str_type(): + default_val = "" + input_val = "foo" + + @task + def t1(a: str = default_val) -> str: + return a + + @workflow + def wf_no_input() -> str: + return t1() + + @workflow + def wf_with_input() -> str: + return t1(a=input_val) + + @workflow + def wf_with_sub_wf() -> tuple[str, str]: + return (wf_no_input(), wf_with_input()) + + wf_no_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_no_input) + wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input) + + assert wf_no_input_spec.template.nodes[0].inputs[0].binding.value == Scalar( + primitive=Primitive(string_value=default_val) + ) + assert wf_with_input_spec.template.nodes[0].inputs[0].binding.value == Scalar( + primitive=Primitive(string_value=input_val) + ) + + output_type = LiteralType(simple=SimpleType.STRING) + assert wf_no_input_spec.template.interface.outputs["o0"].type == output_type + assert wf_with_input_spec.template.interface.outputs["o0"].type == output_type + + assert wf_no_input() == default_val + assert wf_with_input() == input_val + assert wf_with_sub_wf() == (default_val, input_val) + + +def test_default_args_task_optional_int_type_default_none(): + default_val = None + input_val = 100 + + @task + def t1(a: typing.Optional[int] = default_val) -> typing.Optional[int]: + return a + + @workflow + def wf_no_input() -> typing.Optional[int]: + return t1() + + @workflow + def wf_with_input() -> typing.Optional[int]: + return t1(a=input_val) + + @workflow + def wf_with_sub_wf() -> tuple[typing.Optional[int], typing.Optional[int]]: + return (wf_no_input(), wf_with_input()) + + wf_no_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_no_input) + wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input) + + assert wf_no_input_spec.template.nodes[0].inputs[0].binding.value == Scalar( + union=Union( + value=Literal( + scalar=Scalar( + none_type=Void(), + ), + ), + stored_type=LiteralType( + simple=SimpleType.NONE, + structure=TypeStructure(tag="none"), + ), + ), + ) + assert wf_with_input_spec.template.nodes[0].inputs[0].binding.value == Scalar( + union=Union( + value=Literal( + scalar=Scalar(primitive=Primitive(integer=input_val)), + ), + stored_type=LiteralType( + simple=SimpleType.INTEGER, + structure=TypeStructure(tag="int"), + ), + ), + ) + + output_type = LiteralType( + union_type=UnionType( + [ + LiteralType( + simple=SimpleType.INTEGER, + structure=TypeStructure(tag="int"), + ), + LiteralType( + simple=SimpleType.NONE, + structure=TypeStructure(tag="none"), + ), + ] + ) + ) + assert wf_no_input_spec.template.interface.outputs["o0"].type == output_type + assert wf_with_input_spec.template.interface.outputs["o0"].type == output_type + + assert wf_no_input() == default_val + assert wf_with_input() == input_val + assert wf_with_sub_wf() == (default_val, input_val) + + +def test_default_args_task_optional_int_type_default_int(): + default_val = 10 + input_val = 100 + + @task + def t1(a: typing.Optional[int] = default_val) -> typing.Optional[int]: + return a + + @workflow + def wf_no_input() -> typing.Optional[int]: + return t1() + + @workflow + def wf_with_input() -> typing.Optional[int]: + return t1(a=input_val) + + @workflow + def wf_with_sub_wf() -> tuple[typing.Optional[int], typing.Optional[int]]: + return (wf_no_input(), wf_with_input()) + + wf_no_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_no_input) + wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input) + + assert wf_no_input_spec.template.nodes[0].inputs[0].binding.value == Scalar( + union=Union( + value=Literal( + scalar=Scalar( + primitive=Primitive(integer=default_val), + ), + ), + stored_type=LiteralType( + simple=SimpleType.INTEGER, + structure=TypeStructure(tag="int"), + ), + ), + ) + assert wf_with_input_spec.template.nodes[0].inputs[0].binding.value == Scalar( + union=Union( + value=Literal( + scalar=Scalar(primitive=Primitive(integer=input_val)), + ), + stored_type=LiteralType( + simple=SimpleType.INTEGER, + structure=TypeStructure(tag="int"), + ), + ), + ) + + output_type = LiteralType( + union_type=UnionType( + [ + LiteralType( + simple=SimpleType.INTEGER, + structure=TypeStructure(tag="int"), + ), + LiteralType( + simple=SimpleType.NONE, + structure=TypeStructure(tag="none"), + ), + ] + ) + ) + assert wf_no_input_spec.template.interface.outputs["o0"].type == output_type + assert wf_with_input_spec.template.interface.outputs["o0"].type == output_type + + assert wf_no_input() == default_val + assert wf_with_input() == input_val + assert wf_with_sub_wf() == (default_val, input_val) + + +def test_default_args_task_no_type_hint(): + @task + def t1(a=0) -> int: + return a + + @workflow + def wf_no_input() -> int: + return t1() + + @workflow + def wf_with_input() -> int: + return t1(a=100) + + with pytest.raises(TypeError, match="Arguments do not have type annotation"): + get_serializable(OrderedDict(), serialization_settings, wf_no_input) + with pytest.raises(TypeError, match="Arguments do not have type annotation"): + get_serializable(OrderedDict(), serialization_settings, wf_with_input) + + +def test_default_args_task_mismatch_type(): + @task + def t1(a: int = "foo") -> int: # type: ignore + return a + + @workflow + def wf_no_input() -> int: + return t1() + + @workflow + def wf_with_input() -> int: + return t1(a="bar") + + with pytest.raises(AssertionError, match="Failed to Bind variable"): + get_serializable(OrderedDict(), serialization_settings, wf_no_input) + with pytest.raises(AssertionError, match="Failed to Bind variable"): + get_serializable(OrderedDict(), serialization_settings, wf_with_input) + + +def test_default_args_task_list_type(): + input_val = [1, 2, 3] + + @task + def t1(a: list[int] = []) -> list[int]: + return a + + @workflow + def wf_no_input() -> list[int]: + return t1() + + @workflow + def wf_with_input() -> list[int]: + return t1(a=input_val) + + with pytest.raises(FlyteAssertion, match="Cannot use non-hashable object as default argument"): + get_serializable(OrderedDict(), serialization_settings, wf_no_input) + + wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input) + + assert wf_with_input_spec.template.nodes[0].inputs[0].binding.value == BindingDataCollection( + bindings=[ + BindingData(scalar=Scalar(primitive=Primitive(integer=1))), + BindingData(scalar=Scalar(primitive=Primitive(integer=2))), + BindingData(scalar=Scalar(primitive=Primitive(integer=3))), + ] + ) + + assert wf_with_input_spec.template.interface.outputs["o0"].type == LiteralType( + collection_type=LiteralType(simple=SimpleType.INTEGER) + ) + + assert wf_with_input() == input_val + + +def test_default_args_task_dict_type(): + input_val = {"a": 1, "b": 2} + + @task + def t1(a: dict[str, int] = {}) -> dict[str, int]: + return a + + @workflow + def wf_no_input() -> dict[str, int]: + return t1() + + @workflow + def wf_with_input() -> dict[str, int]: + return t1(a=input_val) + + with pytest.raises(FlyteAssertion, match="Cannot use non-hashable object as default argument"): + get_serializable(OrderedDict(), serialization_settings, wf_no_input) + + wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input) + + assert wf_with_input_spec.template.nodes[0].inputs[0].binding.value == BindingDataMap( + bindings={ + "a": BindingData(scalar=Scalar(primitive=Primitive(integer=1))), + "b": BindingData(scalar=Scalar(primitive=Primitive(integer=2))), + } + ) + + assert wf_with_input_spec.template.interface.outputs["o0"].type == LiteralType( + map_value_type=LiteralType(simple=SimpleType.INTEGER) + ) + + assert wf_with_input() == input_val + + +def test_default_args_task_optional_list_type_default_none(): + default_val = None + input_val = [1, 2, 3] + + @task + def t1(a: typing.Optional[typing.List[int]] = default_val) -> typing.Optional[typing.List[int]]: + return a + + @workflow + def wf_no_input() -> typing.Optional[typing.List[int]]: + return t1() + + @workflow + def wf_with_input() -> typing.Optional[typing.List[int]]: + return t1(a=input_val) + + @workflow + def wf_with_sub_wf() -> tuple[typing.Optional[typing.List[int]], typing.Optional[typing.List[int]]]: + return (wf_no_input(), wf_with_input()) + + wf_no_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_no_input) + wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input) + + assert wf_no_input_spec.template.nodes[0].inputs[0].binding.value == Scalar( + union=Union( + value=Literal( + scalar=Scalar( + none_type=Void(), + ), + ), + stored_type=LiteralType( + simple=SimpleType.NONE, + structure=TypeStructure(tag="none"), + ), + ), + ) + assert wf_with_input_spec.template.nodes[0].inputs[0].binding.value == BindingDataCollection( + bindings=[ + BindingData(scalar=Scalar(primitive=Primitive(integer=1))), + BindingData(scalar=Scalar(primitive=Primitive(integer=2))), + BindingData(scalar=Scalar(primitive=Primitive(integer=3))), + ] + ) + + output_type = LiteralType( + union_type=UnionType( + [ + LiteralType( + collection_type=LiteralType(simple=SimpleType.INTEGER), + structure=TypeStructure(tag="Typed List"), + ), + LiteralType( + simple=SimpleType.NONE, + structure=TypeStructure(tag="none"), + ), + ] + ) + ) + assert wf_no_input_spec.template.interface.outputs["o0"].type == output_type + assert wf_with_input_spec.template.interface.outputs["o0"].type == output_type + + assert wf_no_input() == default_val + assert wf_with_input() == input_val + assert wf_with_sub_wf() == (default_val, input_val) + + +def test_default_args_task_optional_list_type_default_list(): + input_val = [1, 2, 3] + + @task + def t1(a: typing.Optional[typing.List[int]] = []) -> typing.Optional[typing.List[int]]: + return a + + @workflow + def wf_no_input() -> typing.Optional[typing.List[int]]: + return t1() + + @workflow + def wf_with_input() -> typing.Optional[typing.List[int]]: + return t1(a=input_val) + + with pytest.raises(FlyteAssertion, match="Cannot use non-hashable object as default argument"): + get_serializable(OrderedDict(), serialization_settings, wf_no_input) + + wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input) + + assert wf_with_input_spec.template.nodes[0].inputs[0].binding.value == BindingDataCollection( + bindings=[ + BindingData(scalar=Scalar(primitive=Primitive(integer=1))), + BindingData(scalar=Scalar(primitive=Primitive(integer=2))), + BindingData(scalar=Scalar(primitive=Primitive(integer=3))), + ] + ) + + assert wf_with_input_spec.template.interface.outputs["o0"].type == LiteralType( + union_type=UnionType( + [ + LiteralType( + collection_type=LiteralType(simple=SimpleType.INTEGER), + structure=TypeStructure(tag="Typed List"), + ), + LiteralType( + simple=SimpleType.NONE, + structure=TypeStructure(tag="none"), + ), + ] + ) + ) + + assert wf_with_input() == input_val diff --git a/tests/flytekit/unit/types/pickle/test_flyte_pickle.py b/tests/flytekit/unit/types/pickle/test_flyte_pickle.py index 7c2da727c1..53cdc7dc20 100644 --- a/tests/flytekit/unit/types/pickle/test_flyte_pickle.py +++ b/tests/flytekit/unit/types/pickle/test_flyte_pickle.py @@ -1,7 +1,7 @@ import sys from collections import OrderedDict from collections.abc import Sequence -from typing import Dict, List, Union +from typing import Any, Dict, List, Union import numpy as np import pytest @@ -11,6 +11,7 @@ from flytekit.configuration import Image, ImageConfig from flytekit.core import context_manager from flytekit.core.task import task +from flytekit.core.workflow import workflow from flytekit.models.core.types import BlobType from flytekit.models.literals import BlobMetadata from flytekit.models.types import LiteralType @@ -126,3 +127,52 @@ def t1(a: int) -> Annotated[Foo, a1(a="bar")]: task_spec = get_serializable(OrderedDict(), serialization_settings, t1) md = task_spec.template.interface.outputs["o0"].type.metadata["python_class_name"] assert "0x" not in str(md) + + +def test_default_args_task(): + default_val = 123 + input_val = "foo" + + @task + def t1(a: Any = default_val) -> Any: + return a + + @workflow + def wf_no_input() -> Any: + return t1() + + @workflow + def wf_with_input() -> Any: + return t1(a=input_val) + + @workflow + def wf_with_sub_wf() -> tuple[Any, Any]: + return (wf_no_input(), wf_with_input()) + + wf_no_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_no_input) + wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input) + + metadata = BlobMetadata( + type=BlobType( + format="PythonPickle", + dimensionality=BlobType.BlobDimensionality.SINGLE, + ), + ) + assert wf_no_input_spec.template.nodes[0].inputs[0].binding.value.blob.metadata == metadata + assert wf_with_input_spec.template.nodes[0].inputs[0].binding.value.blob.metadata == metadata + + output_type = LiteralType( + blob=BlobType( + format="PythonPickle", + dimensionality=BlobType.BlobDimensionality.SINGLE, + ), + metadata={ + "python_class_name": "typing.Any", + }, + ) + assert wf_no_input_spec.template.interface.outputs["o0"].type == output_type + assert wf_with_input_spec.template.interface.outputs["o0"].type == output_type + + assert wf_no_input() == default_val + assert wf_with_input() == input_val + assert wf_with_sub_wf() == (default_val, input_val) diff --git a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py index 9a5628af0f..18c3ce82db 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py +++ b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py @@ -1,6 +1,7 @@ import os import tempfile import typing +from collections import OrderedDict import google.cloud.bigquery import pyarrow as pa @@ -16,10 +17,12 @@ from flytekit.core.task import task from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow +from flytekit.exceptions.user import FlyteAssertion from flytekit.lazy_import.lazy_module import is_imported from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata -from flytekit.models.types import SchemaType, SimpleType, StructuredDatasetType +from flytekit.models.types import LiteralType, SchemaType, SimpleType, StructuredDatasetType +from flytekit.tools.translator import get_serializable from flytekit.types.structured.structured_dataset import ( PARQUET, StructuredDataset, @@ -545,3 +548,38 @@ def test_reregister_encoder(): df_literal_type = TypeEngine.to_literal_type(pd.DataFrame) TypeEngine.to_literal(ctx, sd, python_type=pd.DataFrame, expected=df_literal_type) + + +def test_default_args_task(): + input_val = generate_pandas() + + @task + def t1(a: pd.DataFrame = pd.DataFrame()) -> pd.DataFrame: + return a + + @workflow + def wf_no_input() -> pd.DataFrame: + return t1() + + @workflow + def wf_with_input() -> pd.DataFrame: + return t1(a=input_val) + + with pytest.raises(FlyteAssertion, match="Cannot use non-hashable object as default argument"): + get_serializable(OrderedDict(), serialization_settings, wf_no_input) + + wf_with_input_spec = get_serializable(OrderedDict(), serialization_settings, wf_with_input) + + assert wf_with_input_spec.template.nodes[0].inputs[ + 0 + ].binding.value.structured_dataset.metadata == StructuredDatasetMetadata( + structured_dataset_type=StructuredDatasetType( + format="parquet", + ), + ) + + assert wf_with_input_spec.template.interface.outputs["o0"].type == LiteralType( + structured_dataset_type=StructuredDatasetType() + ) + + pd.testing.assert_frame_equal(wf_with_input(), input_val)