Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async type engine #2752

Merged
merged 30 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
8c7767b
wip
wild-endeavor Sep 14, 2024
c51a26e
partly working
wild-endeavor Sep 16, 2024
fc83997
union of dict of list not working because dicttransformer doesn't han…
wild-endeavor Sep 16, 2024
3fbb3d5
clean up async loop detection, make dict async
wild-endeavor Sep 16, 2024
ae48297
fix lints outside of promise
wild-endeavor Sep 16, 2024
5b35dc8
Async/simple te wip (#2753)
wild-endeavor Sep 17, 2024
ce839e8
fix some lint
wild-endeavor Sep 17, 2024
aa00c97
signal is used by asyncio
wild-endeavor Sep 17, 2024
1cbd92e
spell?
wild-endeavor Sep 17, 2024
8c1c8d9
unit tests
wild-endeavor Sep 18, 2024
d7a9fe0
skip two more tests
wild-endeavor Sep 18, 2024
f173719
try cbs
wild-endeavor Sep 18, 2024
c5a5a38
too much callback
wild-endeavor Sep 18, 2024
6a48da2
Flyte loop in FlyteContext & multi-threaded loops (#2759)
wild-endeavor Oct 1, 2024
9c5e6f9
merge master and resolve some of the conflicts
wild-endeavor Oct 2, 2024
faedf6e
merge conflicts resolved
wild-endeavor Oct 3, 2024
41bc372
Merge remote-tracking branch 'origin/master' into async/simple-te
wild-endeavor Oct 4, 2024
27e16ee
migrate over to the new loop manager
wild-endeavor Oct 4, 2024
3617300
update comments and make additional code actually run in parallel
wild-endeavor Oct 4, 2024
5b3a6e0
merge master and resolve conflicts
wild-endeavor Oct 4, 2024
f30f4d5
lint
wild-endeavor Oct 5, 2024
e163a44
Merge remote-tracking branch 'origin/master' into async/simple-te
wild-endeavor Oct 7, 2024
efa2aa2
uncomment test
wild-endeavor Oct 7, 2024
fb021e8
add a paramspec and change a couple things to run_sync
wild-endeavor Oct 7, 2024
f0476ed
run sync missing
wild-endeavor Oct 7, 2024
3c3b5b4
typing extensions
wild-endeavor Oct 8, 2024
686e2c7
debugging (#2794)
wild-endeavor Oct 9, 2024
b952f8d
sort the file list
wild-endeavor Oct 9, 2024
dce9bf6
remove unneeded exception
wild-endeavor Oct 9, 2024
592c64b
lint
wild-endeavor Oct 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from flytekit.tools.module_loader import load_object_from_module
from flytekit.types.pickle import pickle
from flytekit.types.pickle.pickle import FlytePickleTransformer
from flytekit.utils.asyn import loop_manager
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be imported as run_sync and then use that? (I prefer to hide the loop manager as much as possible)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i guess so. i was hoping to one day figure out type hinting, which doesn't work today in either case, but i think will never work with the run sync way. I tried briefly getting functool wraps to work but alas to no avail.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh it was missing a paramspec. it works now. and changed promise.py and base_task.py to just use the run_sync.



class ArrayNodeMapTask(PythonTask):
Expand Down Expand Up @@ -253,7 +254,7 @@ def _literal_map_to_python_input(
v = literal_map.literals[k]
# If the input is offloaded, we need to unwrap it
if v.offloaded_metadata:
v = TypeEngine.unwrap_offloaded_literal(ctx, v)
v = loop_manager.run_sync(TypeEngine.unwrap_offloaded_literal, ctx, v)
if k not in self.bound_inputs:
# assert that v.collection is not None
if not v.collection or not isinstance(v.collection.literals, list):
Expand Down
38 changes: 25 additions & 13 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
from flytekit.models.documentation import Description, Documentation
from flytekit.models.interface import Variable
from flytekit.models.security import SecurityContext
from flytekit.utils.asyn import run_sync

DYNAMIC_PARTITIONS = "_uap"
MODEL_CARD = "_ucm"
Expand Down Expand Up @@ -608,7 +609,7 @@ def _literal_map_to_python_input(
) -> Dict[str, Any]:
return TypeEngine.literal_map_to_kwargs(ctx, literal_map, self.python_interface.inputs)

def _output_to_literal_map(self, native_outputs: Dict[int, Any], ctx: FlyteContext):
async def _output_to_literal_map(self, native_outputs: Dict[int, Any], ctx: FlyteContext):
expected_output_names = list(self._outputs_interface.keys())
if len(expected_output_names) == 1:
# Here we have to handle the fact that the task could've been declared with a typing.NamedTuple of
Expand All @@ -629,27 +630,35 @@ def _output_to_literal_map(self, native_outputs: Dict[int, Any], ctx: FlyteConte
with timeit("Translate the output to literals"):
literals = {}
omt = ctx.output_metadata_tracker
# Here is where we iterate through the outputs, need to call new type engine.
for i, (k, v) in enumerate(native_outputs_as_map.items()):
literal_type = self._outputs_interface[k].type
py_type = self.get_type_for_output_var(k, v)

if isinstance(v, tuple):
raise TypeError(f"Output({k}) in task '{self.name}' received a tuple {v}, instead of {py_type}")
try:
lit = TypeEngine.to_literal(ctx, v, py_type, literal_type)
literals[k] = lit
except Exception as e:
literals[k] = asyncio.create_task(TypeEngine.async_to_literal(ctx, v, py_type, literal_type))

await asyncio.gather(*literals.values(), return_exceptions=True)

for i, (k2, v2) in enumerate(literals.items()):
if v2.exception() is not None:
# only show the name of output key if it's user-defined (by default Flyte names these as "o<n>")
key = k if k != f"o{i}" else i
key = k2 if k2 != f"o{i}" else i
e: BaseException = v2.exception() # type: ignore # we know this is not optional
py_type = self.get_type_for_output_var(k2, native_outputs_as_map[k2])
e.args = (
f"Failed to convert outputs of task '{self.name}' at position {key}.\n"
f"Failed to convert type {type(native_outputs_as_map[expected_output_names[i]])} to type {py_type}.\n"
f"Error Message: {e.args[0]}.",
)
raise
# Now check if there is any output metadata associated with this output variable and attach it to the
# literal
if omt is not None:
raise e
literals[k2] = v2.result()

if omt is not None:
for i, (k, v) in enumerate(native_outputs_as_map.items()):
# Now check if there is any output metadata associated with this output variable and attach it to the
# literal
om = omt.get(v)
if om:
metadata = {}
Expand All @@ -669,7 +678,7 @@ def _output_to_literal_map(self, native_outputs: Dict[int, Any], ctx: FlyteConte
encoded = b64encode(s).decode("utf-8")
metadata[DYNAMIC_PARTITIONS] = encoded
if metadata:
lit.set_metadata(metadata)
literals[k].set_metadata(metadata) # type: ignore # we know these have been resolved

return _literal_models.LiteralMap(literals=literals), native_outputs_as_map

Expand Down Expand Up @@ -697,7 +706,7 @@ def _write_decks(self, native_inputs, native_outputs_as_map, ctx, new_user_param
async def _async_execute(self, native_inputs, native_outputs, ctx, exec_ctx, new_user_params):
native_outputs = await native_outputs
native_outputs = self.post_execute(new_user_params, native_outputs)
literals_map, native_outputs_as_map = self._output_to_literal_map(native_outputs, exec_ctx)
literals_map, native_outputs_as_map = await self._output_to_literal_map(native_outputs, exec_ctx)
self._write_decks(native_inputs, native_outputs_as_map, ctx, new_user_params)
return literals_map

Expand Down Expand Up @@ -787,7 +796,10 @@ def dispatch_execute(
return native_outputs

try:
literals_map, native_outputs_as_map = self._output_to_literal_map(native_outputs, exec_ctx)
with timeit("dispatch execute"):
literals_map, native_outputs_as_map = run_sync(
self._output_to_literal_map, native_outputs, exec_ctx
)
self._write_decks(native_inputs, native_outputs_as_map, ctx, new_user_params)
except (FlyteUploadDataException, FlyteDownloadDataException):
raise
Expand Down
38 changes: 24 additions & 14 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@
from flytekit.models.literals import Binary, Literal, Primitive, Scalar
from flytekit.models.task import Resources
from flytekit.models.types import SimpleType
from flytekit.utils.asyn import loop_manager, run_sync


def translate_inputs_to_literals(
async def _translate_inputs_to_literals(
ctx: FlyteContext,
incoming_values: Dict[str, Any],
flyte_interface_types: Dict[str, _interface_models.Variable],
Expand Down Expand Up @@ -94,16 +95,19 @@ def my_wf(in1: int, in2: int) -> int:
t = native_types[k]
try:
if type(v) is Promise:
v = resolve_attr_path_in_promise(v)
result[k] = TypeEngine.to_literal(ctx, v, t, var.type)
v = await resolve_attr_path_in_promise(v)
result[k] = await TypeEngine.async_to_literal(ctx, v, t, var.type)
except TypeTransformerFailedError as exc:
exc.args = (f"Failed argument '{k}': {exc.args[0]}",)
raise

return result


def resolve_attr_path_in_promise(p: Promise) -> Promise:
translate_inputs_to_literals = loop_manager.synced(_translate_inputs_to_literals)


async def resolve_attr_path_in_promise(p: Promise) -> Promise:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be concerned with backward compatibility with changing a function from sync to async?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this particular function is pretty limited in terms of callees so hopefully that's okay.

"""
resolve_attr_path_in_promise resolves the attribute path in a promise and returns a new promise with the resolved value
This is for local execution only. The remote execution will be resolved in flytepropeller.
Expand Down Expand Up @@ -145,7 +149,9 @@ def resolve_attr_path_in_promise(p: Promise) -> Promise:
new_st = resolve_attr_path_in_pb_struct(st, attr_path=p.attr_path[used:])
literal_type = TypeEngine.to_literal_type(type(new_st))
# Reconstruct the resolved result to flyte literal (because the resolved result might not be struct)
curr_val = TypeEngine.to_literal(FlyteContextManager.current_context(), new_st, type(new_st), literal_type)
curr_val = await TypeEngine.async_to_literal(
FlyteContextManager.current_context(), new_st, type(new_st), literal_type
)
elif type(curr_val.value.value) is Binary:
binary_idl_obj = curr_val.value.value
if binary_idl_obj.tag == _common_constants.MESSAGEPACK:
Expand Down Expand Up @@ -786,7 +792,7 @@ def __rshift__(self, other: Any):
return Output(*promises) # type: ignore


def binding_data_from_python_std(
async def binding_data_from_python_std(
ctx: _flyte_context.FlyteContext,
expected_literal_type: _type_models.LiteralType,
t_value: Any,
Expand Down Expand Up @@ -821,7 +827,8 @@ def binding_data_from_python_std(
# If the value is not a container type, then we can directly convert it to a scalar in the Union case.
# This pushes the handling of the Union types to the type engine.
if not isinstance(t_value, list) and not isinstance(t_value, dict):
scalar = TypeEngine.to_literal(ctx, t_value, t_value_type or type(t_value), expected_literal_type).scalar
lit = await TypeEngine.async_to_literal(ctx, t_value, t_value_type or type(t_value), expected_literal_type)
scalar = lit.scalar
return _literals_models.BindingData(scalar=scalar)

# If it is a container type, then we need to iterate over the variants in the Union type, try each one. This is
Expand All @@ -831,7 +838,7 @@ def binding_data_from_python_std(
try:
lt_type = expected_literal_type.union_type.variants[i]
python_type = get_args(t_value_type)[i] if t_value_type else None
return binding_data_from_python_std(ctx, lt_type, t_value, python_type, nodes)
return await binding_data_from_python_std(ctx, lt_type, t_value, python_type, nodes)
except Exception:
logger.debug(
f"failed to bind data {t_value} with literal type {expected_literal_type.union_type.variants[i]}."
Expand All @@ -844,7 +851,9 @@ def binding_data_from_python_std(
sub_type: Optional[type] = ListTransformer.get_sub_type_or_none(t_value_type)
collection = _literals_models.BindingDataCollection(
bindings=[
binding_data_from_python_std(ctx, expected_literal_type.collection_type, t, sub_type or type(t), nodes)
await binding_data_from_python_std(
ctx, expected_literal_type.collection_type, t, sub_type or type(t), nodes
)
for t in t_value
]
)
Expand All @@ -860,13 +869,13 @@ def binding_data_from_python_std(
f"this should be a Dictionary type and it is not: {type(t_value)} vs {expected_literal_type}"
)
if expected_literal_type.simple == _type_models.SimpleType.STRUCT:
lit = TypeEngine.to_literal(ctx, t_value, type(t_value), expected_literal_type)
lit = await TypeEngine.async_to_literal(ctx, t_value, type(t_value), expected_literal_type)
return _literals_models.BindingData(scalar=lit.scalar)
else:
_, v_type = DictTransformer.extract_types_or_metadata(t_value_type)
m = _literals_models.BindingDataMap(
bindings={
k: binding_data_from_python_std(
k: await binding_data_from_python_std(
ctx, expected_literal_type.map_value_type, v, v_type or type(v), nodes
)
for k, v in t_value.items()
Expand All @@ -883,8 +892,8 @@ def binding_data_from_python_std(
)

# This is the scalar case - e.g. my_task(in1=5)
scalar = TypeEngine.to_literal(ctx, t_value, t_value_type or type(t_value), expected_literal_type).scalar
return _literals_models.BindingData(scalar=scalar)
lit = await TypeEngine.async_to_literal(ctx, t_value, t_value_type or type(t_value), expected_literal_type)
return _literals_models.BindingData(scalar=lit.scalar)


def binding_from_python_std(
Expand All @@ -895,7 +904,8 @@ def binding_from_python_std(
t_value_type: type,
) -> Tuple[_literals_models.Binding, List[Node]]:
nodes: List[Node] = []
binding_data = binding_data_from_python_std(
binding_data = run_sync(
binding_data_from_python_std,
ctx,
expected_literal_type,
t_value,
Expand Down
Loading
Loading