Skip to content

Commit

Permalink
Async type engine (#2752)
Browse files Browse the repository at this point in the history
Signed-off-by: Yee Hing Tong <wild-endeavor@users.noreply.github.com>
  • Loading branch information
wild-endeavor authored and kumare3 committed Nov 8, 2024
1 parent 857ff50 commit 65b214a
Show file tree
Hide file tree
Showing 23 changed files with 429 additions and 127 deletions.
3 changes: 2 additions & 1 deletion flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from flytekit.tools.module_loader import load_object_from_module
from flytekit.types.pickle import pickle
from flytekit.types.pickle.pickle import FlytePickleTransformer
from flytekit.utils.asyn import loop_manager


class ArrayNodeMapTask(PythonTask):
Expand Down Expand Up @@ -253,7 +254,7 @@ def _literal_map_to_python_input(
v = literal_map.literals[k]
# If the input is offloaded, we need to unwrap it
if v.offloaded_metadata:
v = TypeEngine.unwrap_offloaded_literal(ctx, v)
v = loop_manager.run_sync(TypeEngine.unwrap_offloaded_literal, ctx, v)
if k not in self.bound_inputs:
# assert that v.collection is not None
if not v.collection or not isinstance(v.collection.literals, list):
Expand Down
38 changes: 25 additions & 13 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
from flytekit.models.documentation import Description, Documentation
from flytekit.models.interface import Variable
from flytekit.models.security import SecurityContext
from flytekit.utils.asyn import run_sync

DYNAMIC_PARTITIONS = "_uap"
MODEL_CARD = "_ucm"
Expand Down Expand Up @@ -608,7 +609,7 @@ def _literal_map_to_python_input(
) -> Dict[str, Any]:
return TypeEngine.literal_map_to_kwargs(ctx, literal_map, self.python_interface.inputs)

def _output_to_literal_map(self, native_outputs: Dict[int, Any], ctx: FlyteContext):
async def _output_to_literal_map(self, native_outputs: Dict[int, Any], ctx: FlyteContext):
expected_output_names = list(self._outputs_interface.keys())
if len(expected_output_names) == 1:
# Here we have to handle the fact that the task could've been declared with a typing.NamedTuple of
Expand All @@ -629,27 +630,35 @@ def _output_to_literal_map(self, native_outputs: Dict[int, Any], ctx: FlyteConte
with timeit("Translate the output to literals"):
literals = {}
omt = ctx.output_metadata_tracker
# Here is where we iterate through the outputs, need to call new type engine.
for i, (k, v) in enumerate(native_outputs_as_map.items()):
literal_type = self._outputs_interface[k].type
py_type = self.get_type_for_output_var(k, v)

if isinstance(v, tuple):
raise TypeError(f"Output({k}) in task '{self.name}' received a tuple {v}, instead of {py_type}")
try:
lit = TypeEngine.to_literal(ctx, v, py_type, literal_type)
literals[k] = lit
except Exception as e:
literals[k] = asyncio.create_task(TypeEngine.async_to_literal(ctx, v, py_type, literal_type))

await asyncio.gather(*literals.values(), return_exceptions=True)

for i, (k2, v2) in enumerate(literals.items()):
if v2.exception() is not None:
# only show the name of output key if it's user-defined (by default Flyte names these as "o<n>")
key = k if k != f"o{i}" else i
key = k2 if k2 != f"o{i}" else i
e: BaseException = v2.exception() # type: ignore # we know this is not optional
py_type = self.get_type_for_output_var(k2, native_outputs_as_map[k2])
e.args = (
f"Failed to convert outputs of task '{self.name}' at position {key}.\n"
f"Failed to convert type {type(native_outputs_as_map[expected_output_names[i]])} to type {py_type}.\n"
f"Error Message: {e.args[0]}.",
)
raise
# Now check if there is any output metadata associated with this output variable and attach it to the
# literal
if omt is not None:
raise e
literals[k2] = v2.result()

if omt is not None:
for i, (k, v) in enumerate(native_outputs_as_map.items()):
# Now check if there is any output metadata associated with this output variable and attach it to the
# literal
om = omt.get(v)
if om:
metadata = {}
Expand All @@ -669,7 +678,7 @@ def _output_to_literal_map(self, native_outputs: Dict[int, Any], ctx: FlyteConte
encoded = b64encode(s).decode("utf-8")
metadata[DYNAMIC_PARTITIONS] = encoded
if metadata:
lit.set_metadata(metadata)
literals[k].set_metadata(metadata) # type: ignore # we know these have been resolved

return _literal_models.LiteralMap(literals=literals), native_outputs_as_map

Expand Down Expand Up @@ -697,7 +706,7 @@ def _write_decks(self, native_inputs, native_outputs_as_map, ctx, new_user_param
async def _async_execute(self, native_inputs, native_outputs, ctx, exec_ctx, new_user_params):
native_outputs = await native_outputs
native_outputs = self.post_execute(new_user_params, native_outputs)
literals_map, native_outputs_as_map = self._output_to_literal_map(native_outputs, exec_ctx)
literals_map, native_outputs_as_map = await self._output_to_literal_map(native_outputs, exec_ctx)
self._write_decks(native_inputs, native_outputs_as_map, ctx, new_user_params)
return literals_map

Expand Down Expand Up @@ -787,7 +796,10 @@ def dispatch_execute(
return native_outputs

try:
literals_map, native_outputs_as_map = self._output_to_literal_map(native_outputs, exec_ctx)
with timeit("dispatch execute"):
literals_map, native_outputs_as_map = run_sync(
self._output_to_literal_map, native_outputs, exec_ctx
)
self._write_decks(native_inputs, native_outputs_as_map, ctx, new_user_params)
except (FlyteUploadDataException, FlyteDownloadDataException):
raise
Expand Down
38 changes: 24 additions & 14 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@
from flytekit.models.literals import Binary, Literal, Primitive, Scalar
from flytekit.models.task import Resources
from flytekit.models.types import SimpleType
from flytekit.utils.asyn import loop_manager, run_sync


def translate_inputs_to_literals(
async def _translate_inputs_to_literals(
ctx: FlyteContext,
incoming_values: Dict[str, Any],
flyte_interface_types: Dict[str, _interface_models.Variable],
Expand Down Expand Up @@ -94,16 +95,19 @@ def my_wf(in1: int, in2: int) -> int:
t = native_types[k]
try:
if type(v) is Promise:
v = resolve_attr_path_in_promise(v)
result[k] = TypeEngine.to_literal(ctx, v, t, var.type)
v = await resolve_attr_path_in_promise(v)
result[k] = await TypeEngine.async_to_literal(ctx, v, t, var.type)
except TypeTransformerFailedError as exc:
exc.args = (f"Failed argument '{k}': {exc.args[0]}",)
raise

return result


def resolve_attr_path_in_promise(p: Promise) -> Promise:
translate_inputs_to_literals = loop_manager.synced(_translate_inputs_to_literals)


async def resolve_attr_path_in_promise(p: Promise) -> Promise:
"""
resolve_attr_path_in_promise resolves the attribute path in a promise and returns a new promise with the resolved value
This is for local execution only. The remote execution will be resolved in flytepropeller.
Expand Down Expand Up @@ -145,7 +149,9 @@ def resolve_attr_path_in_promise(p: Promise) -> Promise:
new_st = resolve_attr_path_in_pb_struct(st, attr_path=p.attr_path[used:])
literal_type = TypeEngine.to_literal_type(type(new_st))
# Reconstruct the resolved result to flyte literal (because the resolved result might not be struct)
curr_val = TypeEngine.to_literal(FlyteContextManager.current_context(), new_st, type(new_st), literal_type)
curr_val = await TypeEngine.async_to_literal(
FlyteContextManager.current_context(), new_st, type(new_st), literal_type
)
elif type(curr_val.value.value) is Binary:
binary_idl_obj = curr_val.value.value
if binary_idl_obj.tag == _common_constants.MESSAGEPACK:
Expand Down Expand Up @@ -786,7 +792,7 @@ def __rshift__(self, other: Any):
return Output(*promises) # type: ignore


def binding_data_from_python_std(
async def binding_data_from_python_std(
ctx: _flyte_context.FlyteContext,
expected_literal_type: _type_models.LiteralType,
t_value: Any,
Expand Down Expand Up @@ -821,7 +827,8 @@ def binding_data_from_python_std(
# If the value is not a container type, then we can directly convert it to a scalar in the Union case.
# This pushes the handling of the Union types to the type engine.
if not isinstance(t_value, list) and not isinstance(t_value, dict):
scalar = TypeEngine.to_literal(ctx, t_value, t_value_type or type(t_value), expected_literal_type).scalar
lit = await TypeEngine.async_to_literal(ctx, t_value, t_value_type or type(t_value), expected_literal_type)
scalar = lit.scalar
return _literals_models.BindingData(scalar=scalar)

# If it is a container type, then we need to iterate over the variants in the Union type, try each one. This is
Expand All @@ -831,7 +838,7 @@ def binding_data_from_python_std(
try:
lt_type = expected_literal_type.union_type.variants[i]
python_type = get_args(t_value_type)[i] if t_value_type else None
return binding_data_from_python_std(ctx, lt_type, t_value, python_type, nodes)
return await binding_data_from_python_std(ctx, lt_type, t_value, python_type, nodes)
except Exception:
logger.debug(
f"failed to bind data {t_value} with literal type {expected_literal_type.union_type.variants[i]}."
Expand All @@ -844,7 +851,9 @@ def binding_data_from_python_std(
sub_type: Optional[type] = ListTransformer.get_sub_type_or_none(t_value_type)
collection = _literals_models.BindingDataCollection(
bindings=[
binding_data_from_python_std(ctx, expected_literal_type.collection_type, t, sub_type or type(t), nodes)
await binding_data_from_python_std(
ctx, expected_literal_type.collection_type, t, sub_type or type(t), nodes
)
for t in t_value
]
)
Expand All @@ -860,13 +869,13 @@ def binding_data_from_python_std(
f"this should be a Dictionary type and it is not: {type(t_value)} vs {expected_literal_type}"
)
if expected_literal_type.simple == _type_models.SimpleType.STRUCT:
lit = TypeEngine.to_literal(ctx, t_value, type(t_value), expected_literal_type)
lit = await TypeEngine.async_to_literal(ctx, t_value, type(t_value), expected_literal_type)
return _literals_models.BindingData(scalar=lit.scalar)
else:
_, v_type = DictTransformer.extract_types_or_metadata(t_value_type)
m = _literals_models.BindingDataMap(
bindings={
k: binding_data_from_python_std(
k: await binding_data_from_python_std(
ctx, expected_literal_type.map_value_type, v, v_type or type(v), nodes
)
for k, v in t_value.items()
Expand All @@ -883,8 +892,8 @@ def binding_data_from_python_std(
)

# This is the scalar case - e.g. my_task(in1=5)
scalar = TypeEngine.to_literal(ctx, t_value, t_value_type or type(t_value), expected_literal_type).scalar
return _literals_models.BindingData(scalar=scalar)
lit = await TypeEngine.async_to_literal(ctx, t_value, t_value_type or type(t_value), expected_literal_type)
return _literals_models.BindingData(scalar=lit.scalar)


def binding_from_python_std(
Expand All @@ -895,7 +904,8 @@ def binding_from_python_std(
t_value_type: type,
) -> Tuple[_literals_models.Binding, List[Node]]:
nodes: List[Node] = []
binding_data = binding_data_from_python_std(
binding_data = run_sync(
binding_data_from_python_std,
ctx,
expected_literal_type,
t_value,
Expand Down
Loading

0 comments on commit 65b214a

Please sign in to comment.