Skip to content

Commit

Permalink
remove binding_data_without_python_type and make python type optional…
Browse files Browse the repository at this point in the history
… in binding_data_from_python_std

Signed-off-by: Yee Hing Tong <wild-endeavor@users.noreply.github.com>
  • Loading branch information
wild-endeavor committed Feb 26, 2022
1 parent 1b8c67e commit 9cdaba9
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 66 deletions.
68 changes: 5 additions & 63 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,79 +548,21 @@ def __rshift__(self, other: Any):
return Output(*promises) # type: ignore


# TODO: Can we merge this with the python_std version and just make the Python type optional?
def binding_data_without_python_type(
ctx: _flyte_context.FlyteContext,
expected_literal_type: _type_models.LiteralType,
t_value: typing.Any,
) -> _literals_models.BindingData:
# This handles the case where the given value is the output of another task
if isinstance(t_value, Promise):
if not t_value.is_ready:
return _literals_models.BindingData(promise=t_value.ref)

elif isinstance(t_value, VoidPromise):
raise AssertionError(
f"Cannot pass output from task {t_value.task_name} that produces no outputs to a downstream task"
)

elif isinstance(t_value, list):
if expected_literal_type.collection_type is None:
raise AssertionError(f"this should be a list and it is not: {type(t_value)} vs {expected_literal_type}")
collection = _literals_models.BindingDataCollection(
bindings=[binding_data_without_python_type(ctx, expected_literal_type.collection_type, t) for t in t_value]
)
return _literals_models.BindingData(collection=collection)

elif isinstance(t_value, dict):
if (
expected_literal_type.map_value_type is None
and expected_literal_type.simple != _type_models.SimpleType.STRUCT
):
raise AssertionError(
f"this should be a Dictionary type and it is not: {type(t_value)} vs {expected_literal_type}"
)
if expected_literal_type.simple == _type_models.SimpleType.STRUCT:
lit = TypeEngine.to_literal(ctx, t_value, type(t_value), expected_literal_type)
return _literals_models.BindingData(scalar=lit.scalar)
else:
m = _literals_models.BindingDataMap(
bindings={
k: binding_data_without_python_type(ctx, expected_literal_type.map_value_type, v)
for k, v in t_value.items()
}
)
return _literals_models.BindingData(map=m)

elif isinstance(t_value, tuple):
raise AssertionError(
"Tuples are not a supported type for individual values in Flyte - got a tuple -"
f" {t_value}. If using named tuple in an inner task, please, de-reference the"
"actual attribute that you want to use. For example, in NamedTuple('OP', x=int) then"
"return v.x, instead of v, even if this has a single element"
)

# This is the scalar case - e.g. my_task(in1=5)
# This is the main difference from binding_data_from_python_std, we just take type() of the python value here
scalar = TypeEngine.to_literal(ctx, t_value, type(t_value), expected_literal_type).scalar
return _literals_models.BindingData(scalar=scalar)


def binding_from_flyte_std(
ctx: _flyte_context.FlyteContext,
var_name: str,
expected_literal_type: _type_models.LiteralType,
t_value: typing.Any,
) -> _literals_models.Binding:
binding_data = binding_data_without_python_type(ctx, expected_literal_type, t_value)
binding_data = binding_data_from_python_std(ctx, expected_literal_type, t_value, t_value_type=None)
return _literals_models.Binding(var=var_name, binding=binding_data)


def binding_data_from_python_std(
ctx: _flyte_context.FlyteContext,
expected_literal_type: _type_models.LiteralType,
t_value: typing.Any,
t_value_type: type,
t_value_type: Optional[type] = None,
) -> _literals_models.BindingData:
# This handles the case where the given value is the output of another task
if isinstance(t_value, Promise):
Expand All @@ -636,7 +578,7 @@ def binding_data_from_python_std(
if expected_literal_type.collection_type is None:
raise AssertionError(f"this should be a list and it is not: {type(t_value)} vs {expected_literal_type}")

sub_type = ListTransformer.get_sub_type(t_value_type)
sub_type = ListTransformer.get_sub_type(t_value_type) if t_value_type else None
collection = _literals_models.BindingDataCollection(
bindings=[
binding_data_from_python_std(ctx, expected_literal_type.collection_type, t, sub_type) for t in t_value
Expand All @@ -657,7 +599,7 @@ def binding_data_from_python_std(
lit = TypeEngine.to_literal(ctx, t_value, type(t_value), expected_literal_type)
return _literals_models.BindingData(scalar=lit.scalar)
else:
k_type, v_type = DictTransformer.get_dict_types(t_value_type)
k_type, v_type = DictTransformer.get_dict_types(t_value_type) if t_value_type else None, None
m = _literals_models.BindingDataMap(
bindings={
k: binding_data_from_python_std(ctx, expected_literal_type.map_value_type, v, v_type)
Expand All @@ -675,7 +617,7 @@ def binding_data_from_python_std(
)

# This is the scalar case - e.g. my_task(in1=5)
scalar = TypeEngine.to_literal(ctx, t_value, t_value_type, expected_literal_type).scalar
scalar = TypeEngine.to_literal(ctx, t_value, t_value_type or type(t_value), expected_literal_type).scalar
return _literals_models.BindingData(scalar=scalar)


Expand Down
6 changes: 3 additions & 3 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def get_serializable_workflow(
settings: SerializationSettings,
entity: WorkflowBase,
) -> admin_workflow_models.WorkflowSpec:
# TODO: Try to move up following config refactor
# TODO: Try to move up following config refactor - https://github.com/flyteorg/flyte/issues/2214
from flytekit.remote.workflow import FlyteWorkflow

# Get node models
Expand Down Expand Up @@ -249,7 +249,7 @@ def get_serializable_node(
if entity.flyte_entity is None:
raise Exception(f"Node {entity.id} has no flyte entity")

# TODO: Try to move back up following config refactor
# TODO: Try to move back up following config refactor - https://github.com/flyteorg/flyte/issues/2214
from flytekit.remote.launch_plan import FlyteLaunchPlan
from flytekit.remote.task import FlyteTask
from flytekit.remote.workflow import FlyteWorkflow
Expand Down Expand Up @@ -435,7 +435,7 @@ def get_serializable(
:return: The resulting control plane entity, in addition to being added to the mutable entity_mapping parameter
is also returned.
"""
# TODO: Try to replace following config refactor
# TODO: Try to replace following config refactor - https://github.com/flyteorg/flyte/issues/2214
from flytekit.remote.launch_plan import FlyteLaunchPlan
from flytekit.remote.task import FlyteTask
from flytekit.remote.workflow import FlyteWorkflow
Expand Down

0 comments on commit 9cdaba9

Please sign in to comment.