diff --git a/Dockerfile.external-plugin-service b/Dockerfile.external-plugin-service index 8a79a31720..2194f5de23 100644 --- a/Dockerfile.external-plugin-service +++ b/Dockerfile.external-plugin-service @@ -4,7 +4,6 @@ MAINTAINER Flyte Team LABEL org.opencontainers.image.source=https://github.com/flyteorg/flytekit ARG VERSION -RUN pip install -U flytekit==$VERSION \ - flytekitplugins-bigquery==$VERSION \ +RUN pip install -U flytekit==$VERSION flytekitplugins-bigquery==$VERSION CMD pyflyte serve --port 8000 diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 0ea1412729..2a3687c06c 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -21,6 +21,7 @@ from flytekit.core.node import Node from flytekit.core.type_engine import DictTransformer, ListTransformer, TypeEngine, TypeTransformerFailedError from flytekit.exceptions import user as _user_exceptions +from flytekit.loggers import logger from flytekit.models import interface as _interface_models from flytekit.models import literals as _literal_models from flytekit.models import literals as _literals_models @@ -618,10 +619,21 @@ def binding_data_from_python_std( f"Cannot pass output from task {t_value.task_name} that produces no outputs to a downstream task" ) - elif isinstance(t_value, list): - if expected_literal_type.collection_type is None: - raise AssertionError(f"this should be a list and it is not: {type(t_value)} vs {expected_literal_type}") + elif expected_literal_type.union_type is not None: + for i in range(len(expected_literal_type.union_type.variants)): + try: + lt_type = expected_literal_type.union_type.variants[i] + python_type = get_args(t_value_type)[i] if t_value_type else None + return binding_data_from_python_std(ctx, lt_type, t_value, python_type) + except Exception: + logger.debug( + f"failed to bind data {t_value} with literal type {expected_literal_type.union_type.variants[i]}." + ) + raise AssertionError( + f"Failed to bind data {t_value} with literal type {expected_literal_type.union_type.variants}." + ) + elif isinstance(t_value, list): sub_type: Optional[type] = ListTransformer.get_sub_type(t_value_type) if t_value_type else None collection = _literals_models.BindingDataCollection( bindings=[ diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index 465d74cd66..3212966601 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -1,10 +1,13 @@ from __future__ import annotations +import typing from dataclasses import dataclass from enum import Enum from functools import update_wrapper from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast +from typing_extensions import get_args + from flytekit.core import constants as _common_constants from flytekit.core.base_task import PythonTask from flytekit.core.class_based_resolver import ClassStorageTaskResolver @@ -32,14 +35,16 @@ from flytekit.core.python_auto_container import PythonAutoContainerTask from flytekit.core.reference_entity import ReferenceEntity, WorkflowReference from flytekit.core.tracker import extract_task_module -from flytekit.core.type_engine import TypeEngine, TypeTransformerFailedError +from flytekit.core.type_engine import TypeEngine, TypeTransformerFailedError, UnionTransformer from flytekit.exceptions import scopes as exception_scopes from flytekit.exceptions.user import FlyteValidationException, FlyteValueException from flytekit.loggers import logger from flytekit.models import interface as _interface_models from flytekit.models import literals as _literal_models +from flytekit.models import types as type_models from flytekit.models.core import workflow as _workflow_model from flytekit.models.documentation import Description, Documentation +from flytekit.models.types import TypeStructure GLOBAL_START_NODE = Node( id=_common_constants.GLOBAL_INPUT_NODE_ID, @@ -49,6 +54,8 @@ flyte_entity=None, ) +T = typing.TypeVar("T") + class WorkflowFailurePolicy(Enum): """ @@ -272,24 +279,63 @@ def execute(self, **kwargs): def compile(self, **kwargs): pass - def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]: - # This is done to support the invariant that Workflow local executions always work with Promise objects - # holding Flyte literal values. Even in a wf, a user can call a sub-workflow with a Python native value. - for k, v in kwargs.items(): - if not isinstance(v, Promise): - t = self.python_interface.inputs[k] + def ensure_literal( + self, ctx, py_type: Type[T], input_type: type_models.LiteralType, python_value: Any + ) -> _literal_models.Literal: + """ + This function will attempt to convert a python value to a literal. If the python value is a promise, it will + return the promise's value. + """ + if input_type.union_type is not None: + if python_value is None and UnionTransformer.is_optional_type(py_type): + return _literal_models.Literal(scalar=_literal_models.Scalar(none_type=_literal_models.Void())) + for i in range(len(input_type.union_type.variants)): + lt_type = input_type.union_type.variants[i] + python_type = get_args(py_type)[i] try: - kwargs[k] = Promise(var=k, val=TypeEngine.to_literal(ctx, v, t, self.interface.inputs[k].type)) + final_lt = self.ensure_literal(ctx, python_type, lt_type, python_value) + lt_type._structure = TypeStructure(tag=TypeEngine.get_transformer(python_type).name) + return _literal_models.Literal( + scalar=_literal_models.Scalar(union=_literal_models.Union(value=final_lt, stored_type=lt_type)) + ) + except Exception as e: + logger.debug(f"Failed to convert {python_value} to {lt_type} with error {e}") + raise TypeError(f"Failed to convert {python_value} to {input_type}") + if isinstance(python_value, list) and input_type.collection_type: + collection_lit_type = input_type.collection_type + collection_py_type = get_args(py_type)[0] + xx = [self.ensure_literal(ctx, collection_py_type, collection_lit_type, pv) for pv in python_value] + return _literal_models.Literal(collection=_literal_models.LiteralCollection(literals=xx)) + elif isinstance(python_value, dict) and input_type.map_value_type: + mapped_lit_type = input_type.map_value_type + mapped_py_type = get_args(py_type)[1] + xx = {k: self.ensure_literal(ctx, mapped_py_type, mapped_lit_type, v) for k, v in python_value.items()} # type: ignore + return _literal_models.Literal(map=_literal_models.LiteralMap(literals=xx)) + # It is a scalar, convert to Promise if necessary. + else: + if isinstance(python_value, Promise): + return python_value.val + if not isinstance(python_value, Promise): + try: + res = TypeEngine.to_literal(ctx, python_value, py_type, input_type) + return res except TypeTransformerFailedError as exc: raise TypeError( - f"Failed to convert input argument '{k}' of workflow '{self.name}':\n {exc}" + f"Failed to convert input '{python_value}' of workflow '{self.name}':\n {exc}" ) from exc - # The output of this will always be a combination of Python native values and Promises containing Flyte - # Literals. + def local_execute(self, ctx: FlyteContext, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromise, None]: + # This is done to support the invariant that Workflow local executions always work with Promise objects + # holding Flyte literal values. Even in a wf, a user can call a sub-workflow with a Python native value. + for k, v in kwargs.items(): + py_type = self.python_interface.inputs[k] + lit_type = self.interface.inputs[k].type + kwargs[k] = Promise(var=k, val=self.ensure_literal(ctx, py_type, lit_type, v)) + + # The output of this will always be a combination of Python native values and Promises containing Flyte + # Literals. self.compile() function_outputs = self.execute(**kwargs) - # First handle the empty return case. # A workflow function may return a task that doesn't return anything # def wf(): diff --git a/tests/flytekit/unit/core/test_type_conversion_errors.py b/tests/flytekit/unit/core/test_type_conversion_errors.py index 807bbcad22..dda19dd126 100644 --- a/tests/flytekit/unit/core/test_type_conversion_errors.py +++ b/tests/flytekit/unit/core/test_type_conversion_errors.py @@ -91,11 +91,9 @@ def test_workflow_with_task_error(correct_input): def test_workflow_with_input_error(incorrect_input): with pytest.raises( TypeError, - match=( - r"Encountered error while executing workflow '{}':\n" - r" Failed to convert input argument 'a' of workflow '.+':\n" - r" Expected value of type \ but got .+ of type" - ).format(wf_with_output_error.name), + match=(r"Encountered error while executing workflow '{}':\n" r" Failed to convert input").format( + wf_with_output_error.name + ), ): wf_with_output_error(a=incorrect_input) diff --git a/tests/flytekit/unit/core/test_workflows.py b/tests/flytekit/unit/core/test_workflows.py index 90a8c712e6..7bcbcb8ea3 100644 --- a/tests/flytekit/unit/core/test_workflows.py +++ b/tests/flytekit/unit/core/test_workflows.py @@ -136,6 +136,89 @@ def wf(b: int) -> nt: assert x == (7, 7) +def test_sub_wf_varying_types(): + @task + def t1l( + a: typing.List[typing.Dict[str, typing.List[int]]], + b: typing.Dict[str, typing.List[int]], + c: typing.Union[typing.List[typing.Dict[str, typing.List[int]]], typing.Dict[str, typing.List[int]], int], + d: int, + ) -> str: + xx = ",".join([f"{k}:{v}" for d in a for k, v in d.items()]) + yy = ",".join([f"{k}: {i}" for k, v in b.items() for i in v]) + if isinstance(c, list): + zz = ",".join([f"{k}:{v}" for d in c for k, v in d.items()]) + elif isinstance(c, dict): + zz = ",".join([f"{k}: {i}" for k, v in c.items() for i in v]) + else: + zz = str(c) + return f"First: {xx} Second: {yy} Third: {zz} Int: {d}" + + @task + def get_int() -> int: + return 1 + + @workflow + def subwf( + a: typing.List[typing.Dict[str, typing.List[int]]], + b: typing.Dict[str, typing.List[int]], + c: typing.Union[typing.List[typing.Dict[str, typing.List[int]]], typing.Dict[str, typing.List[int]]], + d: int, + ) -> str: + return t1l(a=a, b=b, c=c, d=d) + + @workflow + def wf() -> str: + ds = [ + {"first_map_a": [42], "first_map_b": [get_int(), 2]}, + { + "second_map_c": [33], + "second_map_d": [9, 99], + }, + ] + ll = { + "ll_1": [get_int(), get_int(), get_int()], + "ll_2": [4, 5, 6], + } + out = subwf(a=ds, b=ll, c=ds, d=get_int()) + return out + + wf.compile() + x = wf() + expected = ( + "First: first_map_a:[42],first_map_b:[1, 2],second_map_c:[33],second_map_d:[9, 99] " + "Second: ll_1: 1,ll_1: 1,ll_1: 1,ll_2: 4,ll_2: 5,ll_2: 6 " + "Third: first_map_a:[42],first_map_b:[1, 2],second_map_c:[33],second_map_d:[9, 99] " + "Int: 1" + ) + assert x == expected + + @workflow + def wf() -> str: + ds = [ + {"first_map_a": [42], "first_map_b": [get_int(), 2]}, + { + "second_map_c": [33], + "second_map_d": [9, 99], + }, + ] + ll = { + "ll_1": [get_int(), get_int(), get_int()], + "ll_2": [4, 5, 6], + } + out = subwf(a=ds, b=ll, c=ll, d=get_int()) + return out + + x = wf() + expected = ( + "First: first_map_a:[42],first_map_b:[1, 2],second_map_c:[33],second_map_d:[9, 99] " + "Second: ll_1: 1,ll_1: 1,ll_1: 1,ll_2: 4,ll_2: 5,ll_2: 6 " + "Third: ll_1: 1,ll_1: 1,ll_1: 1,ll_2: 4,ll_2: 5,ll_2: 6 " + "Int: 1" + ) + assert x == expected + + def test_unexpected_outputs(): @task def t1(a: int) -> int: