From 8c7767b3e5244f8fb5aacfa30c0731d0e79a1ccf Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Fri, 13 Sep 2024 20:21:07 -0700 Subject: [PATCH 01/26] wip Signed-off-by: Yee Hing Tong --- flytekit/core/base_task.py | 2 + flytekit/core/type_engine.py | 78 ++++++++++++++++++-- flytekit/tools/repo.py | 2 +- tests/flytekit/unit/core/test_type_engine.py | 10 +-- 4 files changed, 79 insertions(+), 13 deletions(-) diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 060077a65a..fe5dfe5984 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -624,6 +624,7 @@ def _output_to_literal_map(self, native_outputs: Dict[int, Any], ctx: FlyteConte with timeit("Translate the output to literals"): literals = {} omt = ctx.output_metadata_tracker + # Here is where we iterate through the outputs, need to call new type engine. for i, (k, v) in enumerate(native_outputs_as_map.items()): literal_type = self._outputs_interface[k].type py_type = self.get_type_for_output_var(k, v) @@ -631,6 +632,7 @@ def _output_to_literal_map(self, native_outputs: Dict[int, Any], ctx: FlyteConte if isinstance(v, tuple): raise TypeError(f"Output({k}) in task '{self.name}' received a tuple {v}, instead of {py_type}") try: + # switch this to async version lit = TypeEngine.to_literal(ctx, v, py_type, literal_type) literals[k] = lit except Exception as e: diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index be5cbc6255..4855e3e094 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import collections import copy import dataclasses @@ -166,7 +167,7 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[T]: raise ValueError("By default, transformers do not translate from Flyte types back to Python types") @abstractmethod - def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: + def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Union[Literal, asyncio.Future]: """ Converts a given python_val to a Flyte Literal, assuming the given python_val matches the declared python_type. Implementers should refrain from using type(python_val) instead rely on the passed in python_type. If these @@ -1071,7 +1072,24 @@ def to_literal_type(cls, python_type: Type) -> LiteralType: return res @classmethod - def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type, expected: LiteralType) -> Literal: + def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type, expected: LiteralType) -> Union[Literal, asyncio.Future]: + try: + loop = asyncio.get_running_loop() + coro = cls.a_to_literal(ctx, python_val, python_type, expected) + if loop.is_running(): + fut = loop.create_task(coro) + return fut + + return loop.run_until_complete(coro) + except RuntimeError as e: + if "no running event loop" in str(e): + coro = cls.a_to_literal(ctx, python_val, python_type, expected) + return asyncio.run(coro) + logger.error(f"Unknown RuntimeError {str(e)}") + raise + + @classmethod + async def a_to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type, expected: LiteralType) -> Literal: """ Converts a python value of a given type and expected ``LiteralType`` into a resolved ``Literal`` value. """ @@ -1112,6 +1130,9 @@ def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type break lv = transformer.to_literal(ctx, python_val, python_type, expected) + if isinstance(lv, asyncio.Future): + await lv + lv = lv.result() modify_literal_uris(lv) if hash is not None: lv.hash = hash @@ -1322,7 +1343,7 @@ def is_batchable(t: Type): return True return False - def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: + async def a_to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: if type(python_val) != list: raise TypeTransformerFailedError("Expected a list") @@ -1338,7 +1359,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp break if batch_size > 0: lit_list = [ - TypeEngine.to_literal(ctx, python_val[i : i + batch_size], FlytePickle, expected.collection_type) + TypeEngine.to_literal(ctx, python_val[i: i + batch_size], FlytePickle, expected.collection_type) for i in range(0, len(python_val), batch_size) ] # type: ignore else: @@ -1346,8 +1367,28 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp else: t = self.get_sub_type(python_type) lit_list = [TypeEngine.to_literal(ctx, x, t, expected.collection_type) for x in python_val] # type: ignore + for idx, obj in enumerate(lit_list): + if isinstance(obj, asyncio.Future): + await obj + lit_list[idx] = obj.result() return Literal(collection=LiteralCollection(literals=lit_list)) + def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Union[Literal, asyncio.Future]: + try: + loop = asyncio.get_running_loop() + coro = self.a_to_literal(ctx, python_val, python_type, expected) + if loop.is_running(): + fut = loop.create_task(coro) + return fut + + return loop.run_until_complete(coro) + except RuntimeError as e: + if "no running event loop" in str(e): + coro = self.a_to_literal(ctx, python_val, python_type, expected) + return asyncio.run(coro) + logger.error(f"Unknown RuntimeError {str(e)}") + raise + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> typing.List[typing.Any]: # type: ignore try: lits = lv.collection.literals @@ -1520,7 +1561,24 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]: except Exception as e: raise ValueError(f"Type of Generic Union type is not supported, {e}") - def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: + def to_literal(self, ctx: FlyteContext, python_val: typing.Any, python_type: Type, expected: LiteralType) -> Union[ + Literal, asyncio.Future]: + try: + loop = asyncio.get_running_loop() + coro = self.a_to_literal(ctx, python_val, python_type, expected) + if loop.is_running(): + fut = loop.create_task(coro) + return fut + + return loop.run_until_complete(coro) + except RuntimeError as e: + if "no running event loop" in str(e): + coro = self.a_to_literal(ctx, python_val, python_type, expected) + return asyncio.run(coro) + logger.error(f"Unknown RuntimeError {str(e)}") + raise + + async def a_to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Union[Literal, asyncio.Future]: python_type = get_underlying_type(python_type) found_res = False @@ -1533,12 +1591,17 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp t = get_args(python_type)[i] trans: TypeTransformer[T] = TypeEngine.get_transformer(t) res = trans.to_literal(ctx, python_val, t, expected.union_type.variants[i]) - res_type = _add_tag_to_type(trans.get_literal_type(t), trans.name) + if isinstance(res, asyncio.Future): + res = await res if found_res: + print(f"Current type {get_args(python_type)[i]} old res {res_type}") is_ambiguous = True + res_type = _add_tag_to_type(trans.get_literal_type(t), trans.name) found_res = True except Exception as e: - logger.debug(f"Failed to convert from {python_val} to {t} with error: {e}", exc_info=True) + logger.debug( + f"UnionTransformer failed attempt to convert from {python_val} to {t} error: {e}", + ) continue if is_ambiguous: @@ -1547,6 +1610,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp if found_res: return Literal(scalar=Scalar(union=Union(value=res, stored_type=res_type))) + # breakpoint() raise TypeTransformerFailedError(f"Cannot convert from {python_val} to {python_type}") def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> Optional[typing.Any]: diff --git a/flytekit/tools/repo.py b/flytekit/tools/repo.py index c3d994d1fc..8e08bda404 100644 --- a/flytekit/tools/repo.py +++ b/flytekit/tools/repo.py @@ -335,7 +335,7 @@ def _raw_register(cp_entity: FlyteControlPlaneEntity): secho(og_id, "failed") async def _register(entities: typing.List[task.TaskSpec]): - loop = asyncio.get_event_loop() + loop = asyncio.get_running_loop() tasks = [] for entity in entities: tasks.append(loop.run_in_executor(None, functools.partial(_raw_register, entity))) diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 57f6cddecf..87cdf4516e 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -2012,11 +2012,11 @@ def test_union_of_lists(): assert [x.scalar.primitive.string_value for x in lv.scalar.union.value.collection.literals] == ["hello", "world"] assert v == ["hello", "world"] - lv = TypeEngine.to_literal(ctx, [1, 3], pt, lt) - v = TypeEngine.to_python_value(ctx, lv, pt) - assert lv.scalar.union.stored_type.structure.tag == "Typed List" - assert [x.scalar.primitive.integer for x in lv.scalar.union.value.collection.literals] == [1, 3] - assert v == [1, 3] + # lv = TypeEngine.to_literal(ctx, [1, 3], pt, lt) + # v = TypeEngine.to_python_value(ctx, lv, pt) + # assert lv.scalar.union.stored_type.structure.tag == "Typed List" + # assert [x.scalar.primitive.integer for x in lv.scalar.union.value.collection.literals] == [1, 3] + # assert v == [1, 3] @pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10.") From c51a26e93fd8c198b82c94948cbd155b62acd953 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 16 Sep 2024 00:17:14 -0700 Subject: [PATCH 02/26] partly working Signed-off-by: Yee Hing Tong --- flytekit/core/type_engine.py | 349 ++++++++++++++----- tests/flytekit/unit/core/test_type_engine.py | 10 +- 2 files changed, 264 insertions(+), 95 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 4855e3e094..e95193512e 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -142,6 +142,10 @@ def python_type(self) -> Type[T]: """ return self._t + @property + def is_async(self) -> bool: + return False + @property def type_assertions_enabled(self) -> bool: """ @@ -167,7 +171,7 @@ def guess_python_type(self, literal_type: LiteralType) -> Type[T]: raise ValueError("By default, transformers do not translate from Flyte types back to Python types") @abstractmethod - def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Union[Literal, asyncio.Future]: + def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: """ Converts a given python_val to a Flyte Literal, assuming the given python_val matches the declared python_type. Implementers should refrain from using type(python_val) instead rely on the passed in python_type. If these @@ -205,6 +209,80 @@ def __str__(self): return str(self.__repr__()) +class AsyncTypeTransformer(TypeTransformer[T]): + def __init__(self, name: str, t: Type[T], enable_type_assertions: bool = True): + super().__init__(name, t, enable_type_assertions) + + @property + def is_async(self) -> bool: + return True + + @abstractmethod + async def async_to_literal( + self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType + ) -> Literal: + """ + Converts a given python_val to a Flyte Literal, assuming the given python_val matches the declared python_type. + Implementers should refrain from using type(python_val) instead rely on the passed in python_type. If these + do not match (or are not allowed) the Transformer implementer should raise an AssertionError, clearly stating + what was the mismatch + :param ctx: A FlyteContext, useful in accessing the filesystem and other attributes + :param python_val: The actual value to be transformed + :param python_type: The assumed type of the value (this matches the declared type on the function) + :param expected: Expected Literal Type + """ + + raise NotImplementedError(f"Conversion to Literal for python type {python_type} not implemented") + + @abstractmethod + async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> Optional[T]: + """ + Converts the given Literal to a Python Type. If the conversion cannot be done an AssertionError should be raised + :param ctx: FlyteContext + :param lv: The received literal Value + :param expected_python_type: Expected native python type that should be returned + """ + raise NotImplementedError( + f"Conversion to python value expected type {expected_python_type} from literal not implemented" + ) + + def to_literal( + self, ctx: FlyteContext, python_val: typing.Any, python_type: Type[T], expected: LiteralType + ) -> typing.Union[Literal, asyncio.Future]: + try: + loop = asyncio.get_running_loop() + coro = self.async_to_literal(ctx, python_val, python_type, expected) + if loop.is_running(): + fut = loop.create_task(coro) + return fut + + return loop.run_until_complete(coro) + except RuntimeError as e: + if "no running event loop" in str(e): + coro = self.async_to_literal(ctx, python_val, python_type, expected) + return asyncio.run(coro) + logger.error(f"Unknown RuntimeError {str(e)}") + raise + + def to_python_value( + self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T] + ) -> typing.Union[Optional[T], asyncio.Future]: + try: + loop = asyncio.get_running_loop() + coro = self.async_to_python_value(ctx, lv, expected_python_type) + if loop.is_running(): + fut = loop.create_task(coro) + return fut + + return loop.run_until_complete(coro) + except RuntimeError as e: + if "no running event loop" in str(e): + coro = self.async_to_python_value(ctx, lv, expected_python_type) + return asyncio.run(coro) + logger.error(f"Unknown RuntimeError {str(e)}") + raise + + class SimpleTransformer(TypeTransformer[T]): """ A Simple implementation of a type transformer that uses simple lambdas to transform and reduces boilerplate @@ -227,6 +305,12 @@ def __init__( def get_literal_type(self, t: Optional[Type[T]] = None) -> LiteralType: return LiteralType.from_flyte_idl(self._lt.to_flyte_idl()) + def __getattr__(self, item): + if item == "async_to_literal": + breakpoint() + else: + return super().__getattribute__(item) + def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: if type(python_val) != self._type: raise TypeTransformerFailedError( @@ -875,14 +959,14 @@ def register_restricted_type( cls.register(RestrictedTypeTransformer(name, type)) # type: ignore @classmethod - def register_additional_type(cls, transformer: TypeTransformer, additional_type: Type, override=False): + def register_additional_type(cls, transformer: TypeTransformer[T], additional_type: Type[T], override=False): if additional_type not in cls._REGISTRY or override: cls._REGISTRY[additional_type] = transformer @classmethod def get_transformer(cls, python_type: Type) -> TypeTransformer[T]: """ - The TypeEngine hierarchy for flyteKit. This method looksup and selects the type transformer. The algorithm is + The TypeEngine hierarchy for flyteKit. This method looks up and selects the type transformer. The algorithm is as follows d = dictionary of registered transformers, where is a python `type` @@ -1049,7 +1133,7 @@ def lazy_import_transformers(cls): logger.debug("Transformer for snowflake is already registered.") @classmethod - def to_literal_type(cls, python_type: Type) -> LiteralType: + def to_literal_type(cls, python_type: Type[T]) -> LiteralType: """ Converts a python type into a flyte specific ``LiteralType`` """ @@ -1072,32 +1156,9 @@ def to_literal_type(cls, python_type: Type) -> LiteralType: return res @classmethod - def to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type, expected: LiteralType) -> Union[Literal, asyncio.Future]: - try: - loop = asyncio.get_running_loop() - coro = cls.a_to_literal(ctx, python_val, python_type, expected) - if loop.is_running(): - fut = loop.create_task(coro) - return fut - - return loop.run_until_complete(coro) - except RuntimeError as e: - if "no running event loop" in str(e): - coro = cls.a_to_literal(ctx, python_val, python_type, expected) - return asyncio.run(coro) - logger.error(f"Unknown RuntimeError {str(e)}") - raise - - @classmethod - async def a_to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type, expected: LiteralType) -> Literal: - """ - Converts a python value of a given type and expected ``LiteralType`` into a resolved ``Literal`` value. - """ - from flytekit.core.promise import Promise, VoidPromise + def to_literal_checks(cls, python_val: typing.Any, python_type: Type[T], expected: LiteralType): + from flytekit.core.promise import VoidPromise - if isinstance(python_val, Promise): - # In the example above, this handles the "in2=a" type of argument - return python_val.val if isinstance(python_val, VoidPromise): raise AssertionError( f"Outputs of a non-output producing task {python_val.task_name} cannot be passed to another task." @@ -1111,31 +1172,114 @@ async def a_to_literal(cls, ctx: FlyteContext, python_val: typing.Any, python_ty ) if (python_val is None and python_type != type(None)) and expected and expected.union_type is None: raise TypeTransformerFailedError(f"Python value cannot be None, expected {python_type}/{expected}") - transformer = cls.get_transformer(python_type) - if transformer.type_assertions_enabled: - transformer.assert_type(python_type, python_val) + @classmethod + def calculate_hash(cls, python_val: typing.Any, python_type: Type[T]) -> Optional[str]: # In case the value is an annotated type we inspect the annotations and look for hash-related annotations. - hash = None + hsh = None if is_annotated(python_type): # We are now dealing with one of two cases: # 1. The annotated type is a `HashMethod`, which indicates that we should produce the hash using # the method indicated in the annotation. - # 2. The annotated type is being used for a different purpose other than calculating hash values, in which case - # we should just continue. + # 2. The annotated type is being used for a different purpose other than calculating hash values, + # in which case we should just continue. for annotation in get_args(python_type)[1:]: if not isinstance(annotation, HashMethod): continue - hash = annotation.calculate(python_val) + hsh = annotation.calculate(python_val) break + return hsh + + @classmethod + def to_literal( + cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type[T], expected: LiteralType + ) -> typing.Union[Literal, asyncio.Future]: + """ + The current dance is because we are allowing users to call from an async function, this synchronous + to_literal function, and allowing this to_literal function, to then invoke yet another async functionl, + namely an async tranformer. + + We can remove the need for this function to return a future if we always just asyncio.run(). + That is, if you use this function to call an async transformer, it has to be not within a + running loop. + """ + from flytekit.core.promise import Promise + + cls.to_literal_checks(python_val, python_type, expected) + if isinstance(python_val, Promise): + # In the example above, this handles the "in2=a" type of argument + return python_val.val + + transformer = cls.get_transformer(python_type) + # possible options are: + # a) in a loop (somewhere above, someone called async) + # 1) transformer is async - just await the async function. + # 2) transformer is not async - since the expectation is for async behavior + # run it in an executor. + # b) not in a loop (async has never been called) + # 1) transformer is async - create loop and run it. + # 2) transformer is not async - just invoke normally as a blocking function + + if transformer.type_assertions_enabled: + transformer.assert_type(python_type, python_val) + + loop = None + try: + loop = asyncio.get_running_loop() + except RuntimeError as e: + # handle outside of try/catch to avoid nested exceptions + if "no running event loop" not in str(e): + logger.error(f"Unknown RuntimeError {str(e)}") + raise + + if loop: # get_running_loop didn't raise, return a Future + if transformer.is_async: + coro = transformer.async_to_literal(ctx, python_val, python_type, expected) + fut = loop.create_task(coro) + else: + fut = loop.run_in_executor(None, transformer.to_literal, ctx, python_val, python_type, expected) + return fut + else: # get_running_loop raised + if transformer.is_async: + coro = transformer.async_to_literal(ctx, python_val, python_type, expected) + lv = asyncio.run(coro) + print("hi") + else: + lv = transformer.to_literal(ctx, python_val, python_type, expected) + + modify_literal_uris(lv) + lv.hash = cls.calculate_hash(python_val, python_type) + return lv + + @classmethod + async def async_to_literal( + cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type[T], expected: LiteralType + ) -> Literal: + """ + Converts a python value of a given type and expected ``LiteralType`` into a resolved ``Literal`` value. + """ + from flytekit.core.promise import Promise + + cls.to_literal_checks(python_val, python_type, expected) + + if isinstance(python_val, Promise): + # In the example above, this handles the "in2=a" type of argument + return python_val.val + + transformer = cls.get_transformer(python_type) + if transformer.type_assertions_enabled: + transformer.assert_type(python_type, python_val) + + if transformer.is_async: + lv = await transformer.async_to_literal(ctx, python_val, python_type, expected) + else: + loop = asyncio.get_running_loop() + fut = loop.run_in_executor(None, transformer.to_literal, ctx, python_val, python_type, expected) + lv = await fut - lv = transformer.to_literal(ctx, python_val, python_type, expected) - if isinstance(lv, asyncio.Future): - await lv - lv = lv.result() modify_literal_uris(lv) - if hash is not None: - lv.hash = hash + lv.hash = cls.calculate_hash(python_val, python_type) + return lv @classmethod @@ -1144,7 +1288,41 @@ def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: T Converts a Literal value with an expected python type into a python value. """ transformer = cls.get_transformer(expected_python_type) - return transformer.to_python_value(ctx, lv, expected_python_type) + + loop = None + try: + loop = asyncio.get_running_loop() + except RuntimeError as e: + # handle outside of try/catch to avoid nested exceptions + if "no running event loop" not in str(e): + logger.error(f"Unknown RuntimeError {str(e)}") + raise + + if loop: # get_running_loop didn't raise, return a Future + if transformer.is_async: + coro = transformer.async_to_python_value(ctx, lv, expected_python_type) + fut = loop.create_task(coro) + else: + fut = loop.run_in_executor(None, transformer.to_python_value, ctx, lv, expected_python_type) + return fut + else: # get_running_loop raised + if transformer.is_async: + coro = transformer.async_to_python_value(ctx, lv, expected_python_type) + pv = asyncio.run(coro) + else: + pv = transformer.to_python_value(ctx, lv, expected_python_type) + return pv + + @classmethod + async def async_to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: Type) -> typing.Any: + transformer = cls.get_transformer(expected_python_type) + if transformer.is_async: + pv = await transformer.async_to_python_value(ctx, lv, expected_python_type) + else: + loop = asyncio.get_running_loop() + fut = loop.run_in_executor(None, transformer.to_python_value, ctx, lv, expected_python_type) + pv = await fut + return pv @classmethod def to_html(cls, ctx: FlyteContext, python_val: typing.Any, expected_python_type: Type[typing.Any]) -> str: @@ -1283,7 +1461,7 @@ def guess_python_type(cls, flyte_type: LiteralType) -> type: raise ValueError(f"No transformers could reverse Flyte literal type {flyte_type}") -class ListTransformer(TypeTransformer[T]): +class ListTransformer(AsyncTypeTransformer[T]): """ Transformer that handles a univariate typing.List[T] """ @@ -1343,7 +1521,9 @@ def is_batchable(t: Type): return True return False - async def a_to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: + async def async_to_literal( + self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType + ) -> Literal: if type(python_val) != list: raise TypeTransformerFailedError("Expected a list") @@ -1359,7 +1539,7 @@ async def a_to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type break if batch_size > 0: lit_list = [ - TypeEngine.to_literal(ctx, python_val[i: i + batch_size], FlytePickle, expected.collection_type) + TypeEngine.to_literal(ctx, python_val[i : i + batch_size], FlytePickle, expected.collection_type) for i in range(0, len(python_val), batch_size) ] # type: ignore else: @@ -1373,23 +1553,9 @@ async def a_to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type lit_list[idx] = obj.result() return Literal(collection=LiteralCollection(literals=lit_list)) - def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Union[Literal, asyncio.Future]: - try: - loop = asyncio.get_running_loop() - coro = self.a_to_literal(ctx, python_val, python_type, expected) - if loop.is_running(): - fut = loop.create_task(coro) - return fut - - return loop.run_until_complete(coro) - except RuntimeError as e: - if "no running event loop" in str(e): - coro = self.a_to_literal(ctx, python_val, python_type, expected) - return asyncio.run(coro) - logger.error(f"Unknown RuntimeError {str(e)}") - raise - - def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> typing.List[typing.Any]: # type: ignore + async def async_to_python_value( + self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T] + ) -> typing.List[typing.Any]: # type: ignore try: lits = lv.collection.literals except AttributeError: @@ -1410,7 +1576,12 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: return batch_list else: st = self.get_sub_type(expected_python_type) - return [TypeEngine.to_python_value(ctx, x, st) for x in lits] + result = [TypeEngine.to_python_value(ctx, x, st) for x in lits] + for idx, r in enumerate(result): + if isinstance(r, asyncio.Future): + await r + result[idx] = r.result() + return result def guess_python_type(self, literal_type: LiteralType) -> list: # type: ignore if literal_type.collection_type: @@ -1527,7 +1698,7 @@ def _is_union_type(t): return t is typing.Union or get_origin(t) is typing.Union or UnionType and isinstance(t, UnionType) -class UnionTransformer(TypeTransformer[T]): +class UnionTransformer(AsyncTypeTransformer[T]): """ Transformer that handles a typing.Union[T1, T2, ...] """ @@ -1561,24 +1732,9 @@ def get_literal_type(self, t: Type[T]) -> Optional[LiteralType]: except Exception as e: raise ValueError(f"Type of Generic Union type is not supported, {e}") - def to_literal(self, ctx: FlyteContext, python_val: typing.Any, python_type: Type, expected: LiteralType) -> Union[ - Literal, asyncio.Future]: - try: - loop = asyncio.get_running_loop() - coro = self.a_to_literal(ctx, python_val, python_type, expected) - if loop.is_running(): - fut = loop.create_task(coro) - return fut - - return loop.run_until_complete(coro) - except RuntimeError as e: - if "no running event loop" in str(e): - coro = self.a_to_literal(ctx, python_val, python_type, expected) - return asyncio.run(coro) - logger.error(f"Unknown RuntimeError {str(e)}") - raise - - async def a_to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Union[Literal, asyncio.Future]: + async def async_to_literal( + self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType + ) -> typing.Union[Literal, asyncio.Future]: python_type = get_underlying_type(python_type) found_res = False @@ -1590,16 +1746,18 @@ async def a_to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type try: t = get_args(python_type)[i] trans: TypeTransformer[T] = TypeEngine.get_transformer(t) - res = trans.to_literal(ctx, python_val, t, expected.union_type.variants[i]) - if isinstance(res, asyncio.Future): - res = await res + if isinstance(trans, AsyncTypeTransformer): + attempt = trans.async_to_literal(ctx, python_val, t, expected.union_type.variants[i]) + res = await attempt + else: + res = trans.to_literal(ctx, python_val, t, expected.union_type.variants[i]) if found_res: print(f"Current type {get_args(python_type)[i]} old res {res_type}") is_ambiguous = True res_type = _add_tag_to_type(trans.get_literal_type(t), trans.name) found_res = True except Exception as e: - logger.debug( + logger.warning( f"UnionTransformer failed attempt to convert from {python_val} to {t} error: {e}", ) continue @@ -1613,7 +1771,9 @@ async def a_to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type # breakpoint() raise TypeTransformerFailedError(f"Cannot convert from {python_val} to {python_type}") - def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> Optional[typing.Any]: + async def async_to_python_value( + self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T] + ) -> Optional[typing.Any]: expected_python_type = get_underlying_type(expected_python_type) union_tag = None @@ -1642,13 +1802,22 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: assert lv.scalar is not None # type checker assert lv.scalar.union is not None # type checker - res = trans.to_python_value(ctx, lv.scalar.union.value, v) + if isinstance(trans, AsyncTypeTransformer): + res = await trans.async_to_python_value(ctx, lv.scalar.union.value, v) + else: + res = trans.to_python_value(ctx, lv.scalar.union.value, v) + if isinstance(res, asyncio.Future): + res = await res + if found_res: is_ambiguous = True cur_transformer = trans.name break else: - res = trans.to_python_value(ctx, lv, v) + if isinstance(trans, AsyncTypeTransformer): + res = await trans.async_to_python_value(ctx, lv, v) + else: + res = trans.to_python_value(ctx, lv, v) if found_res: is_ambiguous = True cur_transformer = trans.name diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 87cdf4516e..57f6cddecf 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -2012,11 +2012,11 @@ def test_union_of_lists(): assert [x.scalar.primitive.string_value for x in lv.scalar.union.value.collection.literals] == ["hello", "world"] assert v == ["hello", "world"] - # lv = TypeEngine.to_literal(ctx, [1, 3], pt, lt) - # v = TypeEngine.to_python_value(ctx, lv, pt) - # assert lv.scalar.union.stored_type.structure.tag == "Typed List" - # assert [x.scalar.primitive.integer for x in lv.scalar.union.value.collection.literals] == [1, 3] - # assert v == [1, 3] + lv = TypeEngine.to_literal(ctx, [1, 3], pt, lt) + v = TypeEngine.to_python_value(ctx, lv, pt) + assert lv.scalar.union.stored_type.structure.tag == "Typed List" + assert [x.scalar.primitive.integer for x in lv.scalar.union.value.collection.literals] == [1, 3] + assert v == [1, 3] @pytest.mark.skipif(sys.version_info < (3, 10), reason="PEP604 requires >=3.10.") From fc839978b42e8d9f4568f8f5eee751e06a1763c7 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 16 Sep 2024 00:35:52 -0700 Subject: [PATCH 03/26] union of dict of list not working because dicttransformer doesn't handle the fact that it itself can be called from an async function, thus enabling the TypeEngine to return a future, which it currently doesn't handle Signed-off-by: Yee Hing Tong --- flytekit/core/type_engine.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index e95193512e..8fc01ff575 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1243,7 +1243,6 @@ def to_literal( if transformer.is_async: coro = transformer.async_to_literal(ctx, python_val, python_type, expected) lv = asyncio.run(coro) - print("hi") else: lv = transformer.to_literal(ctx, python_val, python_type, expected) From 3fbb3d50ed6846b73efd6781c29c76c12272444f Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 16 Sep 2024 00:41:37 -0700 Subject: [PATCH 04/26] clean up async loop detection, make dict async Signed-off-by: Yee Hing Tong --- flytekit/core/type_engine.py | 61 +++++----- tests/flytekit/unit/core/test_type_engine.py | 110 +++++++++---------- 2 files changed, 90 insertions(+), 81 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 8fc01ff575..23beb2a2ad 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -249,38 +249,40 @@ async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_p def to_literal( self, ctx: FlyteContext, python_val: typing.Any, python_type: Type[T], expected: LiteralType ) -> typing.Union[Literal, asyncio.Future]: + loop = None try: loop = asyncio.get_running_loop() - coro = self.async_to_literal(ctx, python_val, python_type, expected) - if loop.is_running(): - fut = loop.create_task(coro) - return fut - - return loop.run_until_complete(coro) except RuntimeError as e: - if "no running event loop" in str(e): - coro = self.async_to_literal(ctx, python_val, python_type, expected) - return asyncio.run(coro) - logger.error(f"Unknown RuntimeError {str(e)}") - raise + if "no running event loop" not in str(e): + logger.error(f"Unknown RuntimeError {str(e)}") + raise + + if loop: + coro = self.async_to_literal(ctx, python_val, python_type, expected) + fut = loop.create_task(coro) + return fut + else: + coro = self.async_to_literal(ctx, python_val, python_type, expected) + return asyncio.run(coro) def to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T] ) -> typing.Union[Optional[T], asyncio.Future]: + loop = None try: loop = asyncio.get_running_loop() - coro = self.async_to_python_value(ctx, lv, expected_python_type) - if loop.is_running(): - fut = loop.create_task(coro) - return fut - - return loop.run_until_complete(coro) except RuntimeError as e: - if "no running event loop" in str(e): - coro = self.async_to_python_value(ctx, lv, expected_python_type) - return asyncio.run(coro) - logger.error(f"Unknown RuntimeError {str(e)}") - raise + if "no running event loop" not in str(e): + logger.error(f"Unknown RuntimeError {str(e)}") + raise + + if loop: + coro = self.async_to_python_value(ctx, lv, expected_python_type) + fut = loop.create_task(coro) + return fut + else: + coro = self.async_to_python_value(ctx, lv, expected_python_type) + return asyncio.run(coro) class SimpleTransformer(TypeTransformer[T]): @@ -1844,7 +1846,7 @@ def guess_python_type(self, literal_type: LiteralType) -> type: raise ValueError(f"Union transformer cannot reverse {literal_type}") -class DictTransformer(TypeTransformer[dict]): +class DictTransformer(AsyncTypeTransformer[dict]): """ Transformer that transforms a univariate dictionary Dict[str, T] to a Literal Map or transforms a untyped dictionary to a JSON (struct/Generic) @@ -1930,7 +1932,7 @@ def get_literal_type(self, t: Type[dict]) -> LiteralType: raise ValueError(f"Type of Generic List type is not supported, {e}") return _type_models.LiteralType(simple=_type_models.SimpleType.STRUCT) - def to_literal( + async def async_to_literal( self, ctx: FlyteContext, python_val: typing.Any, python_type: Type[dict], expected: LiteralType ) -> Literal: if type(python_val) != dict: @@ -1957,9 +1959,12 @@ def to_literal( _, v_type = self.extract_types_or_metadata(python_type) lit_map[k] = TypeEngine.to_literal(ctx, v, cast(type, v_type), expected.map_value_type) + for result_k, result_v in lit_map.items(): + if isinstance(result_v, asyncio.Future): + lit_map[result_k] = await result_v return Literal(map=LiteralMap(literals=lit_map)) - def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[dict]) -> dict: + async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[dict]) -> dict: if lv and lv.map and lv.map.literals is not None: tp = self.dict_types(expected_python_type) @@ -1973,7 +1978,11 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: raise TypeError("TypeMismatch. Destination dictionary does not accept 'str' key") py_map = {} for k, v in lv.map.literals.items(): - py_map[k] = TypeEngine.to_python_value(ctx, v, cast(Type, tp[1])) + item = TypeEngine.to_python_value(ctx, v, cast(Type, tp[1])) + if isinstance(item, asyncio.Future): + py_map[k] = await item + else: + py_map[k] = item return py_map # for empty generic we have to explicitly test for lv.scalar.generic is not None as empty dict diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 57f6cddecf..1ef657ac0a 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -2735,61 +2735,61 @@ def test_is_batchable(): ) -@pytest.mark.parametrize( - "python_val, python_type, expected_list_length", - [ - # Case 1: List of FlytePickle objects with default batch size. - # (By default, the batch_size is set to the length of the whole list.) - # After converting to literal, the result will be [batched_FlytePickle(5 items)]. - # Therefore, the expected list length is [1]. - ([{"foo"}] * 5, typing.List[FlytePickle], [1]), - # Case 2: List of FlytePickle objects with batch size 2. - # After converting to literal, the result will be - # [batched_FlytePickle(2 items), batched_FlytePickle(2 items), batched_FlytePickle(1 item)]. - # Therefore, the expected list length is [3]. - ( - ["foo"] * 5, - Annotated[typing.List[FlytePickle], HashMethod(function=str), BatchSize(2)], - [3], - ), - # Case 3: Nested list of FlytePickle objects with batch size 2. - # After converting to literal, the result will be - # [[batched_FlytePickle(3 items)], [batched_FlytePickle(3 items)]] - # Therefore, the expected list length is [2, 1] (the length of the outer list remains the same, the inner list is batched). - ( - [["foo", "foo", "foo"]] * 2, - typing.List[Annotated[typing.List[FlytePickle], BatchSize(3)]], - [2, 1], - ), - # Case 4: Empty list - ([[], typing.List[FlytePickle], []]), - ], -) -def test_batch_pickle_list(python_val, python_type, expected_list_length): - ctx = FlyteContext.current_context() - expected = TypeEngine.to_literal_type(python_type) - lv = TypeEngine.to_literal(ctx, python_val, python_type, expected) - - tmp_lv = lv - for length in expected_list_length: - # Check that after converting to literal, the length of the literal list is equal to: - # - the length of the original list divided by the batch size if not nested - # - the length of the original list if it contains a nested list - assert len(tmp_lv.collection.literals) == length - tmp_lv = tmp_lv.collection.literals[0] - - pv = TypeEngine.to_python_value(ctx, lv, python_type) - # Check that after converting literal to Python value, the result is equal to the original python values. - assert pv == python_val - if get_origin(python_type) is Annotated: - pv = TypeEngine.to_python_value(ctx, lv, get_args(python_type)[0]) - # Remove the annotation and check that after converting to Python value, the result is equal - # to the original input values. This is used to simulate the following case: - # @workflow - # def wf(): - # data = task0() # task0() -> Annotated[typing.List[FlytePickle], BatchSize(2)] - # task1(data=data) # task1(data: typing.List[FlytePickle]) - assert pv == python_val +# @pytest.mark.parametrize( +# "python_val, python_type, expected_list_length", +# [ +# # Case 1: List of FlytePickle objects with default batch size. +# # (By default, the batch_size is set to the length of the whole list.) +# # After converting to literal, the result will be [batched_FlytePickle(5 items)]. +# # Therefore, the expected list length is [1]. +# ([{"foo"}] * 5, typing.List[FlytePickle], [1]), +# # Case 2: List of FlytePickle objects with batch size 2. +# # After converting to literal, the result will be +# # [batched_FlytePickle(2 items), batched_FlytePickle(2 items), batched_FlytePickle(1 item)]. +# # Therefore, the expected list length is [3]. +# ( +# ["foo"] * 5, +# Annotated[typing.List[FlytePickle], HashMethod(function=str), BatchSize(2)], +# [3], +# ), +# # Case 3: Nested list of FlytePickle objects with batch size 2. +# # After converting to literal, the result will be +# # [[batched_FlytePickle(3 items)], [batched_FlytePickle(3 items)]] +# # Therefore, the expected list length is [2, 1] (the length of the outer list remains the same, the inner list is batched). +# ( +# [["foo", "foo", "foo"]] * 2, +# typing.List[Annotated[typing.List[FlytePickle], BatchSize(3)]], +# [2, 1], +# ), +# # Case 4: Empty list +# ([[], typing.List[FlytePickle], []]), +# ], +# ) +# def test_batch_pickle_list(python_val, python_type, expected_list_length): +# ctx = FlyteContext.current_context() +# expected = TypeEngine.to_literal_type(python_type) +# lv = TypeEngine.to_literal(ctx, python_val, python_type, expected) +# +# tmp_lv = lv +# for length in expected_list_length: +# # Check that after converting to literal, the length of the literal list is equal to: +# # - the length of the original list divided by the batch size if not nested +# # - the length of the original list if it contains a nested list +# assert len(tmp_lv.collection.literals) == length +# tmp_lv = tmp_lv.collection.literals[0] +# +# pv = TypeEngine.to_python_value(ctx, lv, python_type) +# # Check that after converting literal to Python value, the result is equal to the original python values. +# assert pv == python_val +# if get_origin(python_type) is Annotated: +# pv = TypeEngine.to_python_value(ctx, lv, get_args(python_type)[0]) +# # Remove the annotation and check that after converting to Python value, the result is equal +# # to the original input values. This is used to simulate the following case: +# # @workflow +# # def wf(): +# # data = task0() # task0() -> Annotated[typing.List[FlytePickle], BatchSize(2)] +# # task1(data=data) # task1(data: typing.List[FlytePickle]) +# assert pv == python_val @pytest.mark.parametrize( From ae48297741a7f79a9d080e95d797db90d8ec026a Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 16 Sep 2024 01:39:35 -0700 Subject: [PATCH 05/26] fix lints outside of promise Signed-off-by: Yee Hing Tong --- flytekit/core/base_task.py | 3 ++- flytekit/core/promise.py | 9 +++++++-- flytekit/core/type_engine.py | 30 ++++++++++++------------------ 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index fe5dfe5984..dbca411092 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -632,8 +632,9 @@ def _output_to_literal_map(self, native_outputs: Dict[int, Any], ctx: FlyteConte if isinstance(v, tuple): raise TypeError(f"Output({k}) in task '{self.name}' received a tuple {v}, instead of {py_type}") try: - # switch this to async version + # switch this to async version and remove assert lit = TypeEngine.to_literal(ctx, v, py_type, literal_type) + assert not isinstance(lit, asyncio.Future) literals[k] = lit except Exception as e: # only show the name of output key if it's user-defined (by default Flyte names these as "o") diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 9a8a853981..279805973d 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import collections import datetime import inspect @@ -236,9 +237,13 @@ def __init__(self, lhs: Union["Promise", Any], op: ComparisonOps, rhs: Union["Pr else: raise ValueError("Only primitive values can be used in comparison") if self._lhs is None: - self._lhs = type_engine.TypeEngine.to_literal(FlyteContextManager.current_context(), lhs, type(lhs), None) + lhs_lit = type_engine.TypeEngine.to_literal(FlyteContextManager.current_context(), lhs, type(lhs), None) + assert not isinstance(lhs_lit, asyncio.Future) + self._lhs = lhs_lit if self._rhs is None: - self._rhs = type_engine.TypeEngine.to_literal(FlyteContextManager.current_context(), rhs, type(rhs), None) + rhs_lit = type_engine.TypeEngine.to_literal(FlyteContextManager.current_context(), rhs, type(rhs), None) + assert not isinstance(rhs_lit, asyncio.Future) + self._rhs = rhs_lit @property def rhs(self) -> Union["Promise", _literals_models.Literal]: diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 23beb2a2ad..68f9878aba 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -265,7 +265,7 @@ def to_literal( coro = self.async_to_literal(ctx, python_val, python_type, expected) return asyncio.run(coro) - def to_python_value( + def to_python_value( # type: ignore[override] self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T] ) -> typing.Union[Optional[T], asyncio.Future]: loop = None @@ -307,12 +307,6 @@ def __init__( def get_literal_type(self, t: Optional[Type[T]] = None) -> LiteralType: return LiteralType.from_flyte_idl(self._lt.to_flyte_idl()) - def __getattr__(self, item): - if item == "async_to_literal": - breakpoint() - else: - return super().__getattribute__(item) - def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal: if type(python_val) != self._type: raise TypeTransformerFailedError( @@ -1235,14 +1229,14 @@ def to_literal( raise if loop: # get_running_loop didn't raise, return a Future - if transformer.is_async: + if isinstance(transformer, AsyncTypeTransformer): coro = transformer.async_to_literal(ctx, python_val, python_type, expected) fut = loop.create_task(coro) else: - fut = loop.run_in_executor(None, transformer.to_literal, ctx, python_val, python_type, expected) + fut = loop.run_in_executor(None, transformer.to_literal, ctx, python_val, python_type, expected) # type: ignore[assignment] return fut else: # get_running_loop raised - if transformer.is_async: + if isinstance(transformer, AsyncTypeTransformer): coro = transformer.async_to_literal(ctx, python_val, python_type, expected) lv = asyncio.run(coro) else: @@ -1271,7 +1265,7 @@ async def async_to_literal( if transformer.type_assertions_enabled: transformer.assert_type(python_type, python_val) - if transformer.is_async: + if isinstance(transformer, AsyncTypeTransformer): lv = await transformer.async_to_literal(ctx, python_val, python_type, expected) else: loop = asyncio.get_running_loop() @@ -1284,7 +1278,7 @@ async def async_to_literal( return lv @classmethod - def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: Type) -> typing.Any: + def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> typing.Any: """ Converts a Literal value with an expected python type into a python value. """ @@ -1300,14 +1294,14 @@ def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: T raise if loop: # get_running_loop didn't raise, return a Future - if transformer.is_async: + if isinstance(transformer, AsyncTypeTransformer): coro = transformer.async_to_python_value(ctx, lv, expected_python_type) fut = loop.create_task(coro) else: - fut = loop.run_in_executor(None, transformer.to_python_value, ctx, lv, expected_python_type) + fut = loop.run_in_executor(None, transformer.to_python_value, ctx, lv, expected_python_type) # type: ignore[assignment] return fut else: # get_running_loop raised - if transformer.is_async: + if isinstance(transformer, AsyncTypeTransformer): coro = transformer.async_to_python_value(ctx, lv, expected_python_type) pv = asyncio.run(coro) else: @@ -1317,7 +1311,7 @@ def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: T @classmethod async def async_to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: Type) -> typing.Any: transformer = cls.get_transformer(expected_python_type) - if transformer.is_async: + if isinstance(transformer, AsyncTypeTransformer): pv = await transformer.async_to_python_value(ctx, lv, expected_python_type) else: loop = asyncio.get_running_loop() @@ -1554,9 +1548,9 @@ async def async_to_literal( lit_list[idx] = obj.result() return Literal(collection=LiteralCollection(literals=lit_list)) - async def async_to_python_value( + async def async_to_python_value( # type: ignore self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T] - ) -> typing.List[typing.Any]: # type: ignore + ) -> typing.Optional[typing.List[T]]: try: lits = lv.collection.literals except AttributeError: From 5b35dc833d72eb04d3108fe4d2013cf3b8f3cbf5 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 17 Sep 2024 00:13:20 -0700 Subject: [PATCH 06/26] Async/simple te wip (#2753) Change some functions in base_task and promise to async Add some helper functions Use a contextvars aware executor Signed-off-by: Yee Hing Tong --- flytekit/core/base_task.py | 13 +++-- flytekit/core/promise.py | 57 ++++++++++++------ flytekit/core/type_engine.py | 29 ++++++---- flytekit/extend/backend/agent_service.py | 4 +- flytekit/extend/backend/base_agent.py | 2 +- flytekit/utils/__init__.py | 0 flytekit/utils/async_utils.py | 64 +++++++++++++++++++++ tests/flytekit/unit/core/test_flyte_file.py | 6 ++ tests/flytekit/unit/core/test_promise.py | 18 +++--- 9 files changed, 147 insertions(+), 46 deletions(-) create mode 100644 flytekit/utils/__init__.py create mode 100644 flytekit/utils/async_utils.py diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index dbca411092..4b1e147f2e 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -81,6 +81,7 @@ from flytekit.models.documentation import Description, Documentation from flytekit.models.interface import Variable from flytekit.models.security import SecurityContext +from flytekit.utils.async_utils import ensure_no_loop DYNAMIC_PARTITIONS = "_uap" MODEL_CARD = "_ucm" @@ -603,7 +604,7 @@ def _literal_map_to_python_input( ) -> Dict[str, Any]: return TypeEngine.literal_map_to_kwargs(ctx, literal_map, self.python_interface.inputs) - def _output_to_literal_map(self, native_outputs: Dict[int, Any], ctx: FlyteContext): + async def _output_to_literal_map(self, native_outputs: Dict[int, Any], ctx: FlyteContext): expected_output_names = list(self._outputs_interface.keys()) if len(expected_output_names) == 1: # Here we have to handle the fact that the task could've been declared with a typing.NamedTuple of @@ -632,9 +633,7 @@ def _output_to_literal_map(self, native_outputs: Dict[int, Any], ctx: FlyteConte if isinstance(v, tuple): raise TypeError(f"Output({k}) in task '{self.name}' received a tuple {v}, instead of {py_type}") try: - # switch this to async version and remove assert - lit = TypeEngine.to_literal(ctx, v, py_type, literal_type) - assert not isinstance(lit, asyncio.Future) + lit = await TypeEngine.async_to_literal(ctx, v, py_type, literal_type) literals[k] = lit except Exception as e: # only show the name of output key if it's user-defined (by default Flyte names these as "o") @@ -696,7 +695,7 @@ def _write_decks(self, native_inputs, native_outputs_as_map, ctx, new_user_param async def _async_execute(self, native_inputs, native_outputs, ctx, exec_ctx, new_user_params): native_outputs = await native_outputs native_outputs = self.post_execute(new_user_params, native_outputs) - literals_map, native_outputs_as_map = self._output_to_literal_map(native_outputs, exec_ctx) + literals_map, native_outputs_as_map = await self._output_to_literal_map(native_outputs, exec_ctx) self._write_decks(native_inputs, native_outputs_as_map, ctx, new_user_params) return literals_map @@ -782,8 +781,10 @@ def dispatch_execute( ): return native_outputs - literals_map, native_outputs_as_map = self._output_to_literal_map(native_outputs, exec_ctx) + ensure_no_loop("Cannot run PythonTask.dispatch_execute from within a loop") + literals_map, native_outputs_as_map = asyncio.run(self._output_to_literal_map(native_outputs, exec_ctx)) self._write_decks(native_inputs, native_outputs_as_map, ctx, new_user_params) + # After the execute has been successfully completed return literals_map diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 279805973d..a9b6d59947 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -45,9 +45,10 @@ from flytekit.models.literals import Primitive from flytekit.models.task import Resources from flytekit.models.types import SimpleType +from flytekit.utils.async_utils import top_level_sync_wrapper -def translate_inputs_to_literals( +async def _translate_inputs_to_literals( ctx: FlyteContext, incoming_values: Dict[str, Any], flyte_interface_types: Dict[str, _interface_models.Variable], @@ -95,8 +96,8 @@ def my_wf(in1: int, in2: int) -> int: t = native_types[k] try: if type(v) is Promise: - v = resolve_attr_path_in_promise(v) - result[k] = TypeEngine.to_literal(ctx, v, t, var.type) + v = await resolve_attr_path_in_promise(v) + result[k] = await TypeEngine.async_to_literal(ctx, v, t, var.type) except TypeTransformerFailedError as exc: exc.args = (f"Failed argument '{k}': {exc.args[0]}",) raise @@ -104,7 +105,10 @@ def my_wf(in1: int, in2: int) -> int: return result -def resolve_attr_path_in_promise(p: Promise) -> Promise: +translate_inputs_to_literals = top_level_sync_wrapper(_translate_inputs_to_literals) + + +async def resolve_attr_path_in_promise(p: Promise) -> Promise: """ resolve_attr_path_in_promise resolves the attribute path in a promise and returns a new promise with the resolved value This is for local execution only. The remote execution will be resolved in flytepropeller. @@ -148,7 +152,9 @@ def resolve_attr_path_in_promise(p: Promise) -> Promise: new_st = resolve_attr_path_in_pb_struct(st, attr_path=p.attr_path[used:]) literal_type = TypeEngine.to_literal_type(type(new_st)) # Reconstruct the resolved result to flyte literal (because the resolved result might not be struct) - curr_val = TypeEngine.to_literal(FlyteContextManager.current_context(), new_st, type(new_st), literal_type) + curr_val = await TypeEngine.async_to_literal( + FlyteContextManager.current_context(), new_st, type(new_st), literal_type + ) p._val = curr_val return p @@ -760,7 +766,7 @@ def __rshift__(self, other: Any): return Output(*promises) # type: ignore -def binding_data_from_python_std( +async def binding_data_from_python_std( ctx: _flyte_context.FlyteContext, expected_literal_type: _type_models.LiteralType, t_value: Any, @@ -795,7 +801,8 @@ def binding_data_from_python_std( # If the value is not a container type, then we can directly convert it to a scalar in the Union case. # This pushes the handling of the Union types to the type engine. if not isinstance(t_value, list) and not isinstance(t_value, dict): - scalar = TypeEngine.to_literal(ctx, t_value, t_value_type or type(t_value), expected_literal_type).scalar + lit = await TypeEngine.async_to_literal(ctx, t_value, t_value_type or type(t_value), expected_literal_type) + scalar = lit.scalar return _literals_models.BindingData(scalar=scalar) # If it is a container type, then we need to iterate over the variants in the Union type, try each one. This is @@ -805,7 +812,7 @@ def binding_data_from_python_std( try: lt_type = expected_literal_type.union_type.variants[i] python_type = get_args(t_value_type)[i] if t_value_type else None - return binding_data_from_python_std(ctx, lt_type, t_value, python_type, nodes) + return await binding_data_from_python_std(ctx, lt_type, t_value, python_type, nodes) except Exception: logger.debug( f"failed to bind data {t_value} with literal type {expected_literal_type.union_type.variants[i]}." @@ -818,7 +825,9 @@ def binding_data_from_python_std( sub_type: Optional[type] = ListTransformer.get_sub_type_or_none(t_value_type) collection = _literals_models.BindingDataCollection( bindings=[ - binding_data_from_python_std(ctx, expected_literal_type.collection_type, t, sub_type or type(t), nodes) + await binding_data_from_python_std( + ctx, expected_literal_type.collection_type, t, sub_type or type(t), nodes + ) for t in t_value ] ) @@ -834,13 +843,13 @@ def binding_data_from_python_std( f"this should be a Dictionary type and it is not: {type(t_value)} vs {expected_literal_type}" ) if expected_literal_type.simple == _type_models.SimpleType.STRUCT: - lit = TypeEngine.to_literal(ctx, t_value, type(t_value), expected_literal_type) + lit = await TypeEngine.async_to_literal(ctx, t_value, type(t_value), expected_literal_type) return _literals_models.BindingData(scalar=lit.scalar) else: _, v_type = DictTransformer.extract_types_or_metadata(t_value_type) m = _literals_models.BindingDataMap( bindings={ - k: binding_data_from_python_std( + k: await binding_data_from_python_std( ctx, expected_literal_type.map_value_type, v, v_type or type(v), nodes ) for k, v in t_value.items() @@ -857,10 +866,11 @@ def binding_data_from_python_std( ) # This is the scalar case - e.g. my_task(in1=5) - scalar = TypeEngine.to_literal(ctx, t_value, t_value_type or type(t_value), expected_literal_type).scalar - return _literals_models.BindingData(scalar=scalar) + lit = await TypeEngine.async_to_literal(ctx, t_value, t_value_type or type(t_value), expected_literal_type) + return _literals_models.BindingData(scalar=lit.scalar) +# This function cannot be called from an async call stack def binding_from_python_std( ctx: _flyte_context.FlyteContext, var_name: str, @@ -869,12 +879,21 @@ def binding_from_python_std( t_value_type: type, ) -> Tuple[_literals_models.Binding, List[Node]]: nodes: List[Node] = [] - binding_data = binding_data_from_python_std( - ctx, - expected_literal_type, - t_value, - t_value_type, - nodes, + try: + asyncio.get_running_loop() + raise AssertionError("binding_from_python_std cannot be run from within an async call stack") + except RuntimeError as e: + if "no running event loop" not in str(e): + logger.error(f"Unknown RuntimeError {str(e)}") + raise + binding_data = asyncio.run( + binding_data_from_python_std( + ctx, + expected_literal_type, + t_value, + t_value_type, + nodes, + ) ) return _literals_models.Binding(var=var_name, binding=binding_data), nodes diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 68f9878aba..cc60f55aa0 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -53,6 +53,7 @@ Void, ) from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType +from flytekit.utils.async_utils import ContextExecutor, top_level_sync_wrapper T = typing.TypeVar("T") DEFINITIONS = "definitions" @@ -1233,7 +1234,8 @@ def to_literal( coro = transformer.async_to_literal(ctx, python_val, python_type, expected) fut = loop.create_task(coro) else: - fut = loop.run_in_executor(None, transformer.to_literal, ctx, python_val, python_type, expected) # type: ignore[assignment] + executor = ContextExecutor() + fut = loop.run_in_executor(executor, transformer.to_literal, ctx, python_val, python_type, expected) # type: ignore[assignment] return fut else: # get_running_loop raised if isinstance(transformer, AsyncTypeTransformer): @@ -1269,7 +1271,8 @@ async def async_to_literal( lv = await transformer.async_to_literal(ctx, python_val, python_type, expected) else: loop = asyncio.get_running_loop() - fut = loop.run_in_executor(None, transformer.to_literal, ctx, python_val, python_type, expected) + executor = ContextExecutor() + fut = loop.run_in_executor(executor, transformer.to_literal, ctx, python_val, python_type, expected) lv = await fut modify_literal_uris(lv) @@ -1298,7 +1301,8 @@ def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: T coro = transformer.async_to_python_value(ctx, lv, expected_python_type) fut = loop.create_task(coro) else: - fut = loop.run_in_executor(None, transformer.to_python_value, ctx, lv, expected_python_type) # type: ignore[assignment] + executor = ContextExecutor() + fut = loop.run_in_executor(executor, transformer.to_python_value, ctx, lv, expected_python_type) # type: ignore[assignment] return fut else: # get_running_loop raised if isinstance(transformer, AsyncTypeTransformer): @@ -1315,7 +1319,8 @@ async def async_to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_py pv = await transformer.async_to_python_value(ctx, lv, expected_python_type) else: loop = asyncio.get_running_loop() - fut = loop.run_in_executor(None, transformer.to_python_value, ctx, lv, expected_python_type) + executor = ContextExecutor() + fut = loop.run_in_executor(executor, transformer.to_python_value, ctx, lv, expected_python_type) pv = await fut return pv @@ -1344,7 +1349,7 @@ def named_tuple_to_variable_map(cls, t: typing.NamedTuple) -> _interface_models. @classmethod @timeit("Translate literal to python value") - def literal_map_to_kwargs( + async def _literal_map_to_kwargs( cls, ctx: FlyteContext, lm: LiteralMap, @@ -1372,14 +1377,14 @@ def literal_map_to_kwargs( kwargs = {} for i, k in enumerate(lm.literals): try: - kwargs[k] = TypeEngine.to_python_value(ctx, lm.literals[k], python_interface_inputs[k]) + kwargs[k] = await TypeEngine.async_to_python_value(ctx, lm.literals[k], python_interface_inputs[k]) except TypeTransformerFailedError as exc: exc.args = (f"Error converting input '{k}' at position {i}:\n {exc.args[0]}",) raise return kwargs @classmethod - def dict_to_literal_map( + async def _dict_to_literal_map( cls, ctx: FlyteContext, d: typing.Dict[str, typing.Any], @@ -1397,7 +1402,7 @@ def dict_to_literal_map( # `list` and `dict`. python_type = type_hints.get(k, type(v)) try: - literal_map[k] = TypeEngine.to_literal( + literal_map[k] = await TypeEngine.async_to_literal( ctx=ctx, python_val=v, python_type=python_type, @@ -1408,13 +1413,13 @@ def dict_to_literal_map( return LiteralMap(literal_map) @classmethod - def dict_to_literal_map_pb( + async def dict_to_literal_map_pb( cls, ctx: FlyteContext, d: typing.Dict[str, typing.Any], type_hints: Optional[typing.Dict[str, type]] = None, ) -> Optional[literals_pb2.LiteralMap]: - literal_map = cls.dict_to_literal_map(ctx, d, type_hints) + literal_map = await cls._dict_to_literal_map(ctx, d, type_hints) return literal_map.to_flyte_idl() @classmethod @@ -1456,6 +1461,10 @@ def guess_python_type(cls, flyte_type: LiteralType) -> type: raise ValueError(f"No transformers could reverse Flyte literal type {flyte_type}") +TypeEngine.literal_map_to_kwargs = top_level_sync_wrapper(TypeEngine._literal_map_to_kwargs) +TypeEngine.dict_to_literal_map = top_level_sync_wrapper(TypeEngine._dict_to_literal_map) + + class ListTransformer(AsyncTypeTransformer[T]): """ Transformer that handles a univariate typing.List[T] diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index a92cef8e36..e9cd74522d 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -142,7 +142,7 @@ async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) outputs = res.outputs.to_flyte_idl() else: ctx = FlyteContext.current_context() - outputs = TypeEngine.dict_to_literal_map_pb(ctx, res.outputs) + outputs = await TypeEngine.dict_to_literal_map_pb(ctx, res.outputs) return GetTaskResponse( resource=Resource(phase=res.phase, log_links=res.log_links, message=res.message, outputs=outputs) ) @@ -184,7 +184,7 @@ async def ExecuteTaskSync( outputs = res.outputs.to_flyte_idl() else: ctx = FlyteContext.current_context() - outputs = TypeEngine.dict_to_literal_map_pb(ctx, res.outputs) + outputs = await TypeEngine.dict_to_literal_map_pb(ctx, res.outputs) header = ExecuteTaskSyncResponseHeader( resource=Resource(phase=res.phase, log_links=res.log_links, message=res.message, outputs=outputs) diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 9f155da321..3828b0f67b 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -313,7 +313,7 @@ def execute(self: PythonTask, **kwargs) -> LiteralMap: return LiteralMap.from_flyte_idl(output_proto) if resource.outputs and not isinstance(resource.outputs, LiteralMap): - return TypeEngine.dict_to_literal_map(ctx, resource.outputs) + return TypeEngine.dict_to_literal_map(ctx, resource.outputs) # type: ignore return resource.outputs diff --git a/flytekit/utils/__init__.py b/flytekit/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/flytekit/utils/async_utils.py b/flytekit/utils/async_utils.py new file mode 100644 index 0000000000..6e9961476f --- /dev/null +++ b/flytekit/utils/async_utils.py @@ -0,0 +1,64 @@ +import asyncio +import contextvars +import functools +from concurrent.futures import ThreadPoolExecutor +from types import CoroutineType + +from typing_extensions import Any, Callable + +from flytekit.loggers import logger + +AsyncFuncType = Callable[[Any], CoroutineType] +Synced = Callable[[Any], Any] + + +def ensure_no_loop(error_msg: str): + try: + asyncio.get_running_loop() + raise AssertionError(error_msg) + except RuntimeError as e: + if "no running event loop" not in str(e): + logger.error(f"Unknown RuntimeError {str(e)}") + raise + + +def ensure_and_get_running_loop() -> asyncio.AbstractEventLoop: + try: + return asyncio.get_running_loop() + except RuntimeError as e: + if "no running event loop" not in str(e): + logger.error(f"Unknown RuntimeError {str(e)}") + raise + + +def top_level_sync(func: AsyncFuncType, *args, **kwargs): + """ + Make sure there is no current loop. Then run the func in a new loop + """ + ensure_no_loop(f"Calling {func.__name__} when event loop active not allowed") + coro = func(*args, **kwargs) + return asyncio.run(coro) + + +def top_level_sync_wrapper(func: AsyncFuncType) -> Synced: + """Given a function, make so can be called in async or blocking contexts + + Leave obj=None if defining within a class. Pass the instance if attaching + as an attribute of the instance. + """ + + @functools.wraps(func) + def wrapper(*args, **kwargs): + return top_level_sync(func, *args, **kwargs) + + return wrapper + + +class ContextExecutor(ThreadPoolExecutor): + def __init__(self): + self.context = contextvars.copy_context() + super().__init__(initializer=self._set_child_context) + + def _set_child_context(self): + for var, value in self.context.items(): + var.set(value) diff --git a/tests/flytekit/unit/core/test_flyte_file.py b/tests/flytekit/unit/core/test_flyte_file.py index d17464c1e9..255c326aeb 100644 --- a/tests/flytekit/unit/core/test_flyte_file.py +++ b/tests/flytekit/unit/core/test_flyte_file.py @@ -105,20 +105,25 @@ def my_wf(path: FlyteFile[typing.TypeVar("txt")]) -> FlyteFile: with open(res, "r") as fh: assert fh.read() == "Hello World" + def test_flytefile_in_dataclass(local_dummy_txt_file): TxtFile = FlyteFile[typing.TypeVar("txt")] + @dataclass class DC: f: TxtFile + @task def t1(path: TxtFile) -> DC: return DC(f=path) + @workflow def my_wf(path: TxtFile) -> DC: dc = t1(path=path) return dc txt_file = TxtFile(local_dummy_txt_file) + my_wf.compile() dc1 = my_wf(path=txt_file) with open(dc1.f, "r") as fh: assert fh.read() == "Hello World" @@ -126,6 +131,7 @@ def my_wf(path: TxtFile) -> DC: dc2 = DC(f=txt_file) assert dc1 == dc2 + @pytest.mark.skipif(not can_import("magic"), reason="Libmagic is not installed") def test_mismatching_file_types(local_dummy_txt_file): @task diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index bd24d47bb8..99f6920155 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -215,8 +215,9 @@ def my_workflow() -> (str, str, str): assert o3 == "b" +@pytest.mark.asyncio @pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.") -def test_resolve_attr_path_in_promise(): +async def test_resolve_attr_path_in_promise(): @dataclass_json @dataclass class Foo: @@ -224,7 +225,7 @@ class Foo: src = {"a": [Foo(b="foo")]} - src_lit = TypeEngine.to_literal( + src_lit = await TypeEngine.to_literal( FlyteContextManager.current_context(), src, Dict[str, List[Foo]], @@ -233,15 +234,16 @@ class Foo: src_promise = Promise("val1", src_lit) # happy path - tgt_promise = resolve_attr_path_in_promise(src_promise["a"][0]["b"]) - assert "foo" == TypeEngine.to_python_value(FlyteContextManager.current_context(), tgt_promise.val, str) + tgt_promise = await resolve_attr_path_in_promise(src_promise["a"][0]["b"]) + assert "foo" == await TypeEngine.to_python_value(FlyteContextManager.current_context(), tgt_promise.val, str) # exception with pytest.raises(FlytePromiseAttributeResolveException): - tgt_promise = resolve_attr_path_in_promise(src_promise["c"]) + await resolve_attr_path_in_promise(src_promise["c"]) -def test_prom_with_union_literals(): +@pytest.mark.asyncio +async def test_prom_with_union_literals(): ctx = FlyteContextManager.current_context() pt = typing.Union[str, int] lt = TypeEngine.to_literal_type(pt) @@ -250,7 +252,7 @@ def test_prom_with_union_literals(): LiteralType(simple=SimpleType.INTEGER, structure=TypeStructure(tag="int")), ] - bd = binding_data_from_python_std(ctx, lt, 3, pt, []) + bd = await binding_data_from_python_std(ctx, lt, 3, pt, []) assert bd.scalar.union.stored_type.structure.tag == "int" - bd = binding_data_from_python_std(ctx, lt, "hello", pt, []) + bd = await binding_data_from_python_std(ctx, lt, "hello", pt, []) assert bd.scalar.union.stored_type.structure.tag == "str" From ce839e853dac3685f314fb433e94d3fba32ae8fc Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 17 Sep 2024 00:44:24 -0700 Subject: [PATCH 07/26] fix some lint Signed-off-by: Yee Hing Tong --- flytekit/core/type_engine.py | 25 +++++++++++++++++++-- tests/flytekit/unit/core/test_flyte_file.py | 2 +- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index cc60f55aa0..e7e8887da3 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1347,6 +1347,17 @@ def named_tuple_to_variable_map(cls, t: typing.NamedTuple) -> _interface_models. variables[var_name] = _interface_models.Variable(type=literal_type, description=f"{idx}") return _interface_models.VariableMap(variables=variables) + # Declare empty function to get linting to work. Monkeypatched below. + @classmethod + def literal_map_to_kwargs( + cls, + ctx: FlyteContext, + lm: LiteralMap, + python_types: typing.Optional[typing.Dict[str, type]] = None, + literal_types: typing.Optional[typing.Dict[str, _interface_models.Variable]] = None, + ) -> typing.Dict[str, typing.Any]: + raise NotImplementedError + @classmethod @timeit("Translate literal to python value") async def _literal_map_to_kwargs( @@ -1383,6 +1394,16 @@ async def _literal_map_to_kwargs( raise return kwargs + # Declare empty function to get linting to work. Monkeypatched below. + @classmethod + def dict_to_literal_map( + cls, + ctx: FlyteContext, + d: typing.Dict[str, typing.Any], + type_hints: Optional[typing.Dict[str, type]] = None, + ) -> LiteralMap: + raise NotImplementedError + @classmethod async def _dict_to_literal_map( cls, @@ -1461,8 +1482,8 @@ def guess_python_type(cls, flyte_type: LiteralType) -> type: raise ValueError(f"No transformers could reverse Flyte literal type {flyte_type}") -TypeEngine.literal_map_to_kwargs = top_level_sync_wrapper(TypeEngine._literal_map_to_kwargs) -TypeEngine.dict_to_literal_map = top_level_sync_wrapper(TypeEngine._dict_to_literal_map) +TypeEngine.literal_map_to_kwargs = top_level_sync_wrapper(TypeEngine._literal_map_to_kwargs) # type: ignore[method-assign] +TypeEngine.dict_to_literal_map = top_level_sync_wrapper(TypeEngine._dict_to_literal_map) # type: ignore[method-assign] class ListTransformer(AsyncTypeTransformer[T]): diff --git a/tests/flytekit/unit/core/test_flyte_file.py b/tests/flytekit/unit/core/test_flyte_file.py index 255c326aeb..319e9b0890 100644 --- a/tests/flytekit/unit/core/test_flyte_file.py +++ b/tests/flytekit/unit/core/test_flyte_file.py @@ -116,7 +116,7 @@ class DC: @task def t1(path: TxtFile) -> DC: return DC(f=path) - + @workflow def my_wf(path: TxtFile) -> DC: dc = t1(path=path) From aa00c97d838b20d9b60b96b9f7639738663fd280 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 17 Sep 2024 01:13:36 -0700 Subject: [PATCH 08/26] signal is used by asyncio Signed-off-by: Yee Hing Tong --- flytekit/core/type_engine.py | 1 - .../unit/interactive/test_flyteinteractive_vscode.py | 9 ++++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index e7e8887da3..958ce9b016 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1793,7 +1793,6 @@ async def async_to_literal( if found_res: return Literal(scalar=Scalar(union=Union(value=res, stored_type=res_type))) - # breakpoint() raise TypeTransformerFailedError(f"Cannot convert from {python_val} to {python_type}") async def async_to_python_value( diff --git a/tests/flytekit/unit/interactive/test_flyteinteractive_vscode.py b/tests/flytekit/unit/interactive/test_flyteinteractive_vscode.py index 20f1c7c10c..4babf6ae8a 100644 --- a/tests/flytekit/unit/interactive/test_flyteinteractive_vscode.py +++ b/tests/flytekit/unit/interactive/test_flyteinteractive_vscode.py @@ -90,7 +90,7 @@ def wf(): mock_process.assert_called_once() mock_exit_handler.assert_called_once() mock_prepare_interactive_python.assert_called_once() - mock_signal.assert_called_once() + assert mock_signal.call_count >=1 mock_prepare_resume_task_python.assert_called_once() mock_prepare_launch_json.assert_called_once() @@ -120,7 +120,7 @@ def wf(): mock_process.assert_not_called() mock_exit_handler.assert_not_called() mock_prepare_interactive_python.assert_not_called() - mock_signal.assert_not_called() + assert mock_signal.call_count >= 1 mock_prepare_resume_task_python.assert_not_called() mock_prepare_launch_json.assert_not_called() @@ -150,7 +150,6 @@ def wf(): mock_process.assert_not_called() mock_exit_handler.assert_not_called() mock_prepare_interactive_python.assert_not_called() - mock_signal.assert_not_called() mock_prepare_resume_task_python.assert_not_called() mock_prepare_launch_json.assert_not_called() @@ -196,7 +195,7 @@ def wf(a: int, b: int): mock_process.assert_called_once() mock_exit_handler.assert_called_once() mock_prepare_interactive_python.assert_called_once() - mock_signal.assert_called_once() + assert mock_signal.call_count >= 1 mock_prepare_resume_task_python.assert_called_once() mock_prepare_launch_json.assert_called_once() @@ -248,7 +247,7 @@ def wf(): mock_process.assert_called_once() mock_exit_handler.assert_called_once() mock_prepare_interactive_python.assert_called_once() - mock_signal.assert_called_once() + assert mock_signal.call_count >= 1 mock_prepare_resume_task_python.assert_called_once() mock_prepare_launch_json.assert_called_once() From 1cbd92e6c5491cfecaef71f1feebf296363499e6 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 17 Sep 2024 01:16:46 -0700 Subject: [PATCH 09/26] spell? Signed-off-by: Yee Hing Tong --- flytekit/core/type_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 958ce9b016..b1249cb7ea 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1194,7 +1194,7 @@ def to_literal( """ The current dance is because we are allowing users to call from an async function, this synchronous to_literal function, and allowing this to_literal function, to then invoke yet another async functionl, - namely an async tranformer. + namely an async transformer. We can remove the need for this function to return a future if we always just asyncio.run(). That is, if you use this function to call an async transformer, it has to be not within a From 8c1c8d99ac3a47e1ecf96499350b4254bbce6f6b Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 17 Sep 2024 17:36:45 -0700 Subject: [PATCH 10/26] unit tests Signed-off-by: Yee Hing Tong --- flytekit/sensor/sensor_engine.py | 2 +- tests/flytekit/unit/interactive/test_flyteinteractive_vscode.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/flytekit/sensor/sensor_engine.py b/flytekit/sensor/sensor_engine.py index ac718abe35..36e4a4aaa0 100644 --- a/flytekit/sensor/sensor_engine.py +++ b/flytekit/sensor/sensor_engine.py @@ -25,7 +25,7 @@ async def create(self, task_template: TaskTemplate, inputs: Optional[LiteralMap] python_interface_inputs = { name: TypeEngine.guess_python_type(lt.type) for name, lt in task_template.interface.inputs.items() } - native_inputs = TypeEngine.literal_map_to_kwargs(ctx, inputs, python_interface_inputs) + native_inputs = await TypeEngine._literal_map_to_kwargs(ctx, inputs, python_interface_inputs) sensor_metadata.inputs = native_inputs return sensor_metadata diff --git a/tests/flytekit/unit/interactive/test_flyteinteractive_vscode.py b/tests/flytekit/unit/interactive/test_flyteinteractive_vscode.py index 4babf6ae8a..360ec76a40 100644 --- a/tests/flytekit/unit/interactive/test_flyteinteractive_vscode.py +++ b/tests/flytekit/unit/interactive/test_flyteinteractive_vscode.py @@ -2,6 +2,7 @@ import mock import pytest +import sys from flytekit.interactive import ( DEFAULT_CODE_SERVER_DIR_NAMES, DEFAULT_CODE_SERVER_EXTENSIONS, @@ -169,6 +170,7 @@ def wf(a: int, b: int) -> int: assert res == 15 +@pytest.mark.skipif(sys.version_info < (3, 10), reason="asyncio signal behavior diff") def test_vscode_run_task_first_fail(vscode_patches, mock_remote_execution): ( mock_process, From d7a9fe023e485472f383a3d2837cc6583095e020 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 17 Sep 2024 17:56:47 -0700 Subject: [PATCH 11/26] skip two more tests Signed-off-by: Yee Hing Tong --- tests/flytekit/unit/interactive/test_flyteinteractive_vscode.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/flytekit/unit/interactive/test_flyteinteractive_vscode.py b/tests/flytekit/unit/interactive/test_flyteinteractive_vscode.py index 360ec76a40..8555027021 100644 --- a/tests/flytekit/unit/interactive/test_flyteinteractive_vscode.py +++ b/tests/flytekit/unit/interactive/test_flyteinteractive_vscode.py @@ -96,6 +96,7 @@ def wf(): mock_prepare_launch_json.assert_called_once() +@pytest.mark.skipif(sys.version_info < (3, 10), reason="asyncio signal behavior diff") def test_vscode_remote_execution_but_disable(vscode_patches, mock_remote_execution): ( mock_process, @@ -223,6 +224,7 @@ def test_vscode_config(): assert config.extension_remote_paths == DEFAULT_CODE_SERVER_EXTENSIONS +@pytest.mark.skipif(sys.version_info < (3, 10), reason="asyncio signal behavior diff") def test_vscode_with_args(vscode_patches, mock_remote_execution): ( mock_process, From f173719ef081f0bfd1283c9545e4464f895f6163 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 17 Sep 2024 19:17:38 -0700 Subject: [PATCH 12/26] try cbs Signed-off-by: Yee Hing Tong --- flytekit/core/local_cache.py | 3 ++- flytekit/core/type_engine.py | 24 ++++++++++++++++++++++++ flytekit/extend/backend/base_agent.py | 4 ++-- 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/flytekit/core/local_cache.py b/flytekit/core/local_cache.py index 7cd87e2a49..4fd795808b 100644 --- a/flytekit/core/local_cache.py +++ b/flytekit/core/local_cache.py @@ -70,8 +70,9 @@ def get( ) -> Optional[LiteralMap]: if not LocalTaskCache._initialized: LocalTaskCache.initialize() + kk = _calculate_cache_key(task_name, cache_version, input_literal_map, cache_ignore_input_vars) return LocalTaskCache._cache.get( - _calculate_cache_key(task_name, cache_version, input_literal_map, cache_ignore_input_vars) + kk ) @staticmethod diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index b1249cb7ea..3bd73c23c3 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1236,6 +1236,18 @@ def to_literal( else: executor = ContextExecutor() fut = loop.run_in_executor(executor, transformer.to_literal, ctx, python_val, python_type, expected) # type: ignore[assignment] + + def cb(cb_fut: asyncio.Future): + res = None + try: + res = cb_fut.result() + except Exception as e: + logger.debug(f"Skipping callback for: {cb_fut}") + if res: + modify_literal_uris(res) + res.hash = cls.calculate_hash(python_val, python_type) + + fut.add_done_callback(cb) return fut else: # get_running_loop raised if isinstance(transformer, AsyncTypeTransformer): @@ -1303,6 +1315,18 @@ def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: T else: executor = ContextExecutor() fut = loop.run_in_executor(executor, transformer.to_python_value, ctx, lv, expected_python_type) # type: ignore[assignment] + + def cb(cb_fut: asyncio.Future): + res = None + try: + res = cb_fut.result() + except Exception as e: + logger.debug(f"Skipping callback for: {cb_fut}") + if res: + modify_literal_uris(res) + res.hash = cls.calculate_hash(python_val, python_type) + + fut.add_done_callback(cb) return fut else: # get_running_loop raised if isinstance(transformer, AsyncTypeTransformer): diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 3828b0f67b..8cc73ea721 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -267,7 +267,7 @@ async def _do( ) -> Resource: try: ctx = FlyteContext.current_context() - literal_map = TypeEngine.dict_to_literal_map(ctx, inputs or {}, self.get_input_types()) + literal_map = await TypeEngine._dict_to_literal_map(ctx, inputs or {}, self.get_input_types()) return await mirror_async_methods( agent.do, task_template=template, inputs=literal_map, output_prefix=output_prefix ) @@ -327,7 +327,7 @@ async def _create( with FlyteContextManager.with_context(cb) as ctx: # Write the inputs to a remote file, so that the remote task can read the inputs from this file. - literal_map = TypeEngine.dict_to_literal_map(ctx, inputs or {}, self.get_input_types()) + literal_map = await TypeEngine._dict_to_literal_map(ctx, inputs or {}, self.get_input_types()) path = ctx.file_access.get_random_local_path() utils.write_proto_to_file(literal_map.to_flyte_idl(), path) ctx.file_access.put_data(path, f"{output_prefix}/inputs.pb") From c5a5a38403bc6ea6494f742381fda1fede955636 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 18 Sep 2024 00:10:54 -0700 Subject: [PATCH 13/26] too much callback Signed-off-by: Yee Hing Tong --- flytekit/core/local_cache.py | 3 +-- flytekit/core/type_engine.py | 13 +------------ 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/flytekit/core/local_cache.py b/flytekit/core/local_cache.py index 4fd795808b..7cd87e2a49 100644 --- a/flytekit/core/local_cache.py +++ b/flytekit/core/local_cache.py @@ -70,9 +70,8 @@ def get( ) -> Optional[LiteralMap]: if not LocalTaskCache._initialized: LocalTaskCache.initialize() - kk = _calculate_cache_key(task_name, cache_version, input_literal_map, cache_ignore_input_vars) return LocalTaskCache._cache.get( - kk + _calculate_cache_key(task_name, cache_version, input_literal_map, cache_ignore_input_vars) ) @staticmethod diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 3bd73c23c3..dd47e7ef4f 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1241,7 +1241,7 @@ def cb(cb_fut: asyncio.Future): res = None try: res = cb_fut.result() - except Exception as e: + except Exception: logger.debug(f"Skipping callback for: {cb_fut}") if res: modify_literal_uris(res) @@ -1316,17 +1316,6 @@ def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: T executor = ContextExecutor() fut = loop.run_in_executor(executor, transformer.to_python_value, ctx, lv, expected_python_type) # type: ignore[assignment] - def cb(cb_fut: asyncio.Future): - res = None - try: - res = cb_fut.result() - except Exception as e: - logger.debug(f"Skipping callback for: {cb_fut}") - if res: - modify_literal_uris(res) - res.hash = cls.calculate_hash(python_val, python_type) - - fut.add_done_callback(cb) return fut else: # get_running_loop raised if isinstance(transformer, AsyncTypeTransformer): From 6a48da26afcd7674f8394af5f480a65f3b3f0730 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 1 Oct 2024 11:11:56 -0700 Subject: [PATCH 14/26] Flyte loop in FlyteContext & multi-threaded loops (#2759) Signed-off-by: Yee Hing Tong --- flytekit/core/base_task.py | 7 +- flytekit/core/context_manager.py | 33 ++- flytekit/core/promise.py | 4 +- flytekit/core/type_engine.py | 208 +++++++----------- flytekit/utils/async_utils.py | 116 +++++++--- .../tests/test_flyteinteractive_vscode.py | 2 +- tests/flytekit/unit/core/test_data.py | 51 ++++- tests/flytekit/unit/core/test_promise.py | 4 +- tests/flytekit/unit/core/test_type_hints.py | 5 +- .../test_flyteinteractive_vscode.py | 2 +- 10 files changed, 268 insertions(+), 164 deletions(-) diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 4b1e147f2e..683b55c1b3 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -81,7 +81,7 @@ from flytekit.models.documentation import Description, Documentation from flytekit.models.interface import Variable from flytekit.models.security import SecurityContext -from flytekit.utils.async_utils import ensure_no_loop +from flytekit.utils.async_utils import run_sync_new_thread DYNAMIC_PARTITIONS = "_uap" MODEL_CARD = "_ucm" @@ -781,8 +781,9 @@ def dispatch_execute( ): return native_outputs - ensure_no_loop("Cannot run PythonTask.dispatch_execute from within a loop") - literals_map, native_outputs_as_map = asyncio.run(self._output_to_literal_map(native_outputs, exec_ctx)) + with timeit("dispatch execute"): + synced = run_sync_new_thread(self._output_to_literal_map) + literals_map, native_outputs_as_map = synced(native_outputs, exec_ctx) self._write_decks(native_inputs, native_outputs_as_map, ctx, new_user_params) # After the execute has been successfully completed diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index 13691162d5..e56fc3f1e9 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -13,6 +13,7 @@ from __future__ import annotations +import asyncio import logging as _logging import os import pathlib @@ -35,6 +36,7 @@ from flytekit.interfaces.stats import taggable from flytekit.loggers import developer_logger, user_space_logger from flytekit.models.core import identifier as _identifier +from flytekit.utils.async_utils import get_or_create_loop if typing.TYPE_CHECKING: from flytekit import Deck @@ -642,6 +644,15 @@ class FlyteContext(object): in_a_condition: bool = False origin_stackframe: Optional[traceback.FrameSummary] = None output_metadata_tracker: Optional[OutputMetadataTracker] = None + _loop: Optional[asyncio.AbstractEventLoop] = None + + @property + def loop(self) -> asyncio.AbstractEventLoop: + """ + Can remove this property in the future + """ + assert self._loop is not None + return self._loop @property def user_space_params(self) -> Optional[ExecutionParameters]: @@ -668,6 +679,7 @@ def new_builder(self) -> Builder: execution_state=self.execution_state, in_a_condition=self.in_a_condition, output_metadata_tracker=self.output_metadata_tracker, + loop=self._loop, ) def enter_conditional_section(self) -> Builder: @@ -692,6 +704,9 @@ def with_serialization_settings(self, ss: SerializationSettings) -> Builder: def with_output_metadata_tracker(self, t: OutputMetadataTracker) -> Builder: return self.new_builder().with_output_metadata_tracker(t) + def with_ensure_loop(self) -> Builder: + return self.new_builder().with_ensure_loop() + def new_compilation_state(self, prefix: str = "") -> CompilationState: """ Creates and returns a default compilation state. For most of the code this should be the entrypoint @@ -753,6 +768,7 @@ class Builder(object): serialization_settings: Optional[SerializationSettings] = None in_a_condition: bool = False output_metadata_tracker: Optional[OutputMetadataTracker] = None + loop: Optional[asyncio.AbstractEventLoop] = None def build(self) -> FlyteContext: return FlyteContext( @@ -764,6 +780,7 @@ def build(self) -> FlyteContext: serialization_settings=self.serialization_settings, in_a_condition=self.in_a_condition, output_metadata_tracker=self.output_metadata_tracker, + _loop=self.loop, ) def enter_conditional_section(self) -> FlyteContext.Builder: @@ -812,6 +829,12 @@ def with_output_metadata_tracker(self, t: OutputMetadataTracker) -> FlyteContext self.output_metadata_tracker = t return self + def with_ensure_loop(self, use_windows: bool = False) -> FlyteContext.Builder: + if not self.loop: + # Currently this will use a running system loop. + self.loop = get_or_create_loop(use_windows=use_windows) + return self + def new_compilation_state(self, prefix: str = "") -> CompilationState: """ Creates and returns a default compilation state. For most of the code this should be the entrypoint @@ -947,9 +970,13 @@ def initialize(): decks=[], ) - default_context = default_context.with_execution_state( - default_context.new_execution_state().with_params(user_space_params=default_user_space_params) - ).build() + default_context = ( + default_context.with_execution_state( + default_context.new_execution_state().with_params(user_space_params=default_user_space_params) + ) + .with_ensure_loop() + .build() + ) default_context.set_stackframe(s=FlyteContextManager.get_origin_stackframe()) flyte_context_Var.set([default_context]) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index a9b6d59947..40556a599e 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -45,7 +45,7 @@ from flytekit.models.literals import Primitive from flytekit.models.task import Resources from flytekit.models.types import SimpleType -from flytekit.utils.async_utils import top_level_sync_wrapper +from flytekit.utils.async_utils import run_sync_new_thread async def _translate_inputs_to_literals( @@ -105,7 +105,7 @@ def my_wf(in1: int, in2: int) -> int: return result -translate_inputs_to_literals = top_level_sync_wrapper(_translate_inputs_to_literals) +translate_inputs_to_literals = run_sync_new_thread(_translate_inputs_to_literals) async def resolve_attr_path_in_promise(p: Promise) -> Promise: diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index dd47e7ef4f..e19d4c2166 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -53,7 +53,7 @@ Void, ) from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType -from flytekit.utils.async_utils import ContextExecutor, top_level_sync_wrapper +from flytekit.utils.async_utils import ContextExecutor, get_running_loop_if_exists, run_sync_new_thread T = typing.TypeVar("T") DEFINITIONS = "definitions" @@ -249,41 +249,23 @@ async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_p def to_literal( self, ctx: FlyteContext, python_val: typing.Any, python_type: Type[T], expected: LiteralType - ) -> typing.Union[Literal, asyncio.Future]: - loop = None - try: - loop = asyncio.get_running_loop() - except RuntimeError as e: - if "no running event loop" not in str(e): - logger.error(f"Unknown RuntimeError {str(e)}") - raise - - if loop: - coro = self.async_to_literal(ctx, python_val, python_type, expected) - fut = loop.create_task(coro) - return fut + ) -> Literal: + if ctx.loop.is_running(): + synced = run_sync_new_thread(self.async_to_literal) + result = synced(ctx, python_val, python_type, expected) + return result else: coro = self.async_to_literal(ctx, python_val, python_type, expected) - return asyncio.run(coro) + return ctx.loop.run_until_complete(coro) - def to_python_value( # type: ignore[override] - self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T] - ) -> typing.Union[Optional[T], asyncio.Future]: - loop = None - try: - loop = asyncio.get_running_loop() - except RuntimeError as e: - if "no running event loop" not in str(e): - logger.error(f"Unknown RuntimeError {str(e)}") - raise - - if loop: - coro = self.async_to_python_value(ctx, lv, expected_python_type) - fut = loop.create_task(coro) - return fut + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> Optional[T]: + if ctx.loop.is_running(): + synced = run_sync_new_thread(self.async_to_python_value) + result = synced(ctx, lv, expected_python_type) + return result else: coro = self.async_to_python_value(ctx, lv, expected_python_type) - return asyncio.run(coro) + return ctx.loop.run_until_complete(coro) class SimpleTransformer(TypeTransformer[T]): @@ -1190,15 +1172,11 @@ def calculate_hash(cls, python_val: typing.Any, python_type: Type[T]) -> Optiona @classmethod def to_literal( cls, ctx: FlyteContext, python_val: typing.Any, python_type: Type[T], expected: LiteralType - ) -> typing.Union[Literal, asyncio.Future]: + ) -> Literal: """ The current dance is because we are allowing users to call from an async function, this synchronous to_literal function, and allowing this to_literal function, to then invoke yet another async functionl, namely an async transformer. - - We can remove the need for this function to return a future if we always just asyncio.run(). - That is, if you use this function to call an async transformer, it has to be not within a - running loop. """ from flytekit.core.promise import Promise @@ -1220,45 +1198,26 @@ def to_literal( if transformer.type_assertions_enabled: transformer.assert_type(python_type, python_val) - loop = None - try: - loop = asyncio.get_running_loop() - except RuntimeError as e: - # handle outside of try/catch to avoid nested exceptions - if "no running event loop" not in str(e): - logger.error(f"Unknown RuntimeError {str(e)}") - raise + running_loop = get_running_loop_if_exists() - if loop: # get_running_loop didn't raise, return a Future + # can't have a main loop running either, maybe this is one of the downsides + # of not calling set_event_loop. + if not ctx.loop.is_running() and not running_loop: if isinstance(transformer, AsyncTypeTransformer): coro = transformer.async_to_literal(ctx, python_val, python_type, expected) - fut = loop.create_task(coro) + lv = ctx.loop.run_until_complete(coro) else: - executor = ContextExecutor() - fut = loop.run_in_executor(executor, transformer.to_literal, ctx, python_val, python_type, expected) # type: ignore[assignment] - - def cb(cb_fut: asyncio.Future): - res = None - try: - res = cb_fut.result() - except Exception: - logger.debug(f"Skipping callback for: {cb_fut}") - if res: - modify_literal_uris(res) - res.hash = cls.calculate_hash(python_val, python_type) - - fut.add_done_callback(cb) - return fut - else: # get_running_loop raised + lv = transformer.to_literal(ctx, python_val, python_type, expected) + else: if isinstance(transformer, AsyncTypeTransformer): - coro = transformer.async_to_literal(ctx, python_val, python_type, expected) - lv = asyncio.run(coro) + synced = run_sync_new_thread(transformer.async_to_literal) + lv = synced(ctx, python_val, python_type, expected) else: lv = transformer.to_literal(ctx, python_val, python_type, expected) - modify_literal_uris(lv) - lv.hash = cls.calculate_hash(python_val, python_type) - return lv + modify_literal_uris(lv) + lv.hash = cls.calculate_hash(python_val, python_type) + return lv @classmethod async def async_to_literal( @@ -1282,6 +1241,8 @@ async def async_to_literal( if isinstance(transformer, AsyncTypeTransformer): lv = await transformer.async_to_literal(ctx, python_val, python_type, expected) else: + # Testing just blocking call + # lv = transformer.to_literal(ctx, python_val, python_type, expected) loop = asyncio.get_running_loop() executor = ContextExecutor() fut = loop.run_in_executor(executor, transformer.to_literal, ctx, python_val, python_type, expected) @@ -1299,31 +1260,23 @@ def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: T """ transformer = cls.get_transformer(expected_python_type) - loop = None - try: - loop = asyncio.get_running_loop() - except RuntimeError as e: - # handle outside of try/catch to avoid nested exceptions - if "no running event loop" not in str(e): - logger.error(f"Unknown RuntimeError {str(e)}") - raise - - if loop: # get_running_loop didn't raise, return a Future - if isinstance(transformer, AsyncTypeTransformer): - coro = transformer.async_to_python_value(ctx, lv, expected_python_type) - fut = loop.create_task(coro) - else: - executor = ContextExecutor() - fut = loop.run_in_executor(executor, transformer.to_python_value, ctx, lv, expected_python_type) # type: ignore[assignment] + # see note in to_literal. + running_loop = get_running_loop_if_exists() - return fut - else: # get_running_loop raised + if not ctx.loop.is_running() and not running_loop: if isinstance(transformer, AsyncTypeTransformer): coro = transformer.async_to_python_value(ctx, lv, expected_python_type) - pv = asyncio.run(coro) + pv = ctx.loop.run_until_complete(coro) else: pv = transformer.to_python_value(ctx, lv, expected_python_type) return pv + else: + if isinstance(transformer, AsyncTypeTransformer): + synced = run_sync_new_thread(transformer.async_to_python_value) + return synced(ctx, lv, expected_python_type) + else: + res = transformer.to_python_value(ctx, lv, expected_python_type) + return res @classmethod async def async_to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: Type) -> typing.Any: @@ -1360,8 +1313,8 @@ def named_tuple_to_variable_map(cls, t: typing.NamedTuple) -> _interface_models. variables[var_name] = _interface_models.Variable(type=literal_type, description=f"{idx}") return _interface_models.VariableMap(variables=variables) - # Declare empty function to get linting to work. Monkeypatched below. @classmethod + @timeit("Translate literal to python value") def literal_map_to_kwargs( cls, ctx: FlyteContext, @@ -1369,10 +1322,11 @@ def literal_map_to_kwargs( python_types: typing.Optional[typing.Dict[str, type]] = None, literal_types: typing.Optional[typing.Dict[str, _interface_models.Variable]] = None, ) -> typing.Dict[str, typing.Any]: - raise NotImplementedError + synced = run_sync_new_thread(cls._literal_map_to_kwargs) + return synced(ctx, lm, python_types, literal_types) @classmethod - @timeit("Translate literal to python value") + @timeit("AsyncTranslate literal to python value") async def _literal_map_to_kwargs( cls, ctx: FlyteContext, @@ -1401,13 +1355,17 @@ async def _literal_map_to_kwargs( kwargs = {} for i, k in enumerate(lm.literals): try: - kwargs[k] = await TypeEngine.async_to_python_value(ctx, lm.literals[k], python_interface_inputs[k]) + kwargs[k] = asyncio.create_task( + TypeEngine.async_to_python_value(ctx, lm.literals[k], python_interface_inputs[k]) + ) + await asyncio.gather(*kwargs.values()) except TypeTransformerFailedError as exc: exc.args = (f"Error converting input '{k}' at position {i}:\n {exc.args[0]}",) raise + + kwargs = {k: v.result() for k, v in kwargs.items() if v is not None} return kwargs - # Declare empty function to get linting to work. Monkeypatched below. @classmethod def dict_to_literal_map( cls, @@ -1415,7 +1373,8 @@ def dict_to_literal_map( d: typing.Dict[str, typing.Any], type_hints: Optional[typing.Dict[str, type]] = None, ) -> LiteralMap: - raise NotImplementedError + synced = run_sync_new_thread(cls._dict_to_literal_map) + return synced(ctx, d, type_hints) @classmethod async def _dict_to_literal_map( @@ -1436,14 +1395,19 @@ async def _dict_to_literal_map( # `list` and `dict`. python_type = type_hints.get(k, type(v)) try: - literal_map[k] = await TypeEngine.async_to_literal( - ctx=ctx, - python_val=v, - python_type=python_type, - expected=TypeEngine.to_literal_type(python_type), + literal_map[k] = asyncio.create_task( + TypeEngine.async_to_literal( + ctx=ctx, + python_val=v, + python_type=python_type, + expected=TypeEngine.to_literal_type(python_type), + ) ) + await asyncio.gather(*literal_map.values()) except TypeError: raise user_exceptions.FlyteTypeException(type(v), python_type, received_value=v) + + literal_map = {k: v.result() for k, v in literal_map.items()} return LiteralMap(literal_map) @classmethod @@ -1495,10 +1459,6 @@ def guess_python_type(cls, flyte_type: LiteralType) -> type: raise ValueError(f"No transformers could reverse Flyte literal type {flyte_type}") -TypeEngine.literal_map_to_kwargs = top_level_sync_wrapper(TypeEngine._literal_map_to_kwargs) # type: ignore[method-assign] -TypeEngine.dict_to_literal_map = top_level_sync_wrapper(TypeEngine._dict_to_literal_map) # type: ignore[method-assign] - - class ListTransformer(AsyncTypeTransformer[T]): """ Transformer that handles a univariate typing.List[T] @@ -1584,11 +1544,9 @@ async def async_to_literal( lit_list = [] else: t = self.get_sub_type(python_type) - lit_list = [TypeEngine.to_literal(ctx, x, t, expected.collection_type) for x in python_val] # type: ignore - for idx, obj in enumerate(lit_list): - if isinstance(obj, asyncio.Future): - await obj - lit_list[idx] = obj.result() + lit_list = [TypeEngine.async_to_literal(ctx, x, t, expected.collection_type) for x in python_val] + lit_list = await asyncio.gather(*lit_list) + return Literal(collection=LiteralCollection(literals=lit_list)) async def async_to_python_value( # type: ignore @@ -1608,18 +1566,15 @@ async def async_to_python_value( # type: ignore batch_list = [TypeEngine.to_python_value(ctx, batch, FlytePickle) for batch in lits] if len(batch_list) > 0 and type(batch_list[0]) is list: - # Make it have backward compatibility. The upstream task may use old version of Flytekit that - # won't merge the elements in the list. Therefore, we should check if the batch_list[0] is the list first. + # Make it have backward compatibility. The upstream task may use old version of Flytekit that won't + # merge the elements in the list. Therefore, we should check if the batch_list[0] is the list first. return [item for batch in batch_list for item in batch] return batch_list else: st = self.get_sub_type(expected_python_type) - result = [TypeEngine.to_python_value(ctx, x, st) for x in lits] - for idx, r in enumerate(result): - if isinstance(r, asyncio.Future): - await r - result[idx] = r.result() - return result + result = [TypeEngine.async_to_python_value(ctx, x, st) for x in lits] + result = await asyncio.gather(*result) + return result # type: ignore # should be a list, thinks its a tuple def guess_python_type(self, literal_type: LiteralType) -> list: # type: ignore if literal_type.collection_type: @@ -1825,6 +1780,7 @@ async def async_to_python_value( cur_transformer = "" res = None res_tag = None + # This is serial not really async but should be okay. for v in get_args(expected_python_type): try: trans: TypeTransformer[T] = TypeEngine.get_transformer(v) @@ -1994,10 +1950,14 @@ async def async_to_literal( else: _, v_type = self.extract_types_or_metadata(python_type) - lit_map[k] = TypeEngine.to_literal(ctx, v, cast(type, v_type), expected.map_value_type) - for result_k, result_v in lit_map.items(): - if isinstance(result_v, asyncio.Future): - lit_map[result_k] = await result_v + lit_map[k] = asyncio.create_task( + TypeEngine.async_to_literal(ctx, v, cast(type, v_type), expected.map_value_type) + ) + + await asyncio.gather(*lit_map.values()) + for k, v in lit_map.items(): + lit_map[k] = v.result() + return Literal(map=LiteralMap(literals=lit_map)) async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[dict]) -> dict: @@ -2014,11 +1974,13 @@ async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_p raise TypeError("TypeMismatch. Destination dictionary does not accept 'str' key") py_map = {} for k, v in lv.map.literals.items(): - item = TypeEngine.to_python_value(ctx, v, cast(Type, tp[1])) - if isinstance(item, asyncio.Future): - py_map[k] = await item - else: - py_map[k] = item + fut = asyncio.create_task(TypeEngine.async_to_python_value(ctx, v, cast(Type, tp[1]))) + py_map[k] = fut + + await asyncio.gather(*py_map.values()) + for k, v in py_map.items(): + py_map[k] = v.result() + return py_map # for empty generic we have to explicitly test for lv.scalar.generic is not None as empty dict diff --git a/flytekit/utils/async_utils.py b/flytekit/utils/async_utils.py index 6e9961476f..1c3977e723 100644 --- a/flytekit/utils/async_utils.py +++ b/flytekit/utils/async_utils.py @@ -1,62 +1,128 @@ import asyncio -import contextvars -import functools +import atexit +import sys +import threading from concurrent.futures import ThreadPoolExecutor +from contextvars import copy_context from types import CoroutineType - -from typing_extensions import Any, Callable +from typing import Any, Awaitable, Callable, Optional, TypeVar from flytekit.loggers import logger AsyncFuncType = Callable[[Any], CoroutineType] Synced = Callable[[Any], Any] +T = TypeVar("T") -def ensure_no_loop(error_msg: str): +def get_running_loop_if_exists() -> Optional[asyncio.AbstractEventLoop]: try: - asyncio.get_running_loop() - raise AssertionError(error_msg) + loop = asyncio.get_running_loop() + return loop except RuntimeError as e: if "no running event loop" not in str(e): logger.error(f"Unknown RuntimeError {str(e)}") raise + return None -def ensure_and_get_running_loop() -> asyncio.AbstractEventLoop: +def get_or_create_loop(use_windows: bool = False) -> asyncio.AbstractEventLoop: + # todo: what happens if we remove this? never rely on the running loop + # something to test. import flytekit, inside an async function, what happens? + # import flytekit, inside a jupyter notebook (which sets its own loop) try: - return asyncio.get_running_loop() + running_loop = asyncio.get_running_loop() + return running_loop except RuntimeError as e: if "no running event loop" not in str(e): - logger.error(f"Unknown RuntimeError {str(e)}") + logger.error(f"Unknown RuntimeError when getting loop {str(e)}") raise + if sys.platform == "win32" and use_windows: + loop = asyncio.WindowsSelectorEventLoopPolicy().new_event_loop() + else: + loop = asyncio.new_event_loop() + # Intentionally not calling asyncio.set_event_loop(loop) + # Unclear what the downside of this is. But should be better in the Jupyter case where it seems to + # co-opt set_event_loop somehow + + # maybe add signal handlers in the future + + return loop -def top_level_sync(func: AsyncFuncType, *args, **kwargs): + +class _CoroRunner: """ - Make sure there is no current loop. Then run the func in a new loop + Runs a coroutine and a loop for it on a background thread, in a blocking manner """ - ensure_no_loop(f"Calling {func.__name__} when event loop active not allowed") - coro = func(*args, **kwargs) - return asyncio.run(coro) - -def top_level_sync_wrapper(func: AsyncFuncType) -> Synced: - """Given a function, make so can be called in async or blocking contexts + def __init__(self) -> None: + self.__io_loop: asyncio.AbstractEventLoop | None = None + self.__runner_thread: threading.Thread | None = None + self.__lock = threading.Lock() + atexit.register(self._close) + + def _close(self) -> None: + if self.__io_loop: + self.__io_loop.stop() + + def _runner(self) -> None: + loop = self.__io_loop + assert loop is not None + try: + loop.run_forever() + finally: + loop.close() + + def run(self, coro: Any) -> Any: + """ + This is a blocking function. + Synchronously runs the coroutine on a background thread. + """ + name = f"{threading.current_thread().name} - runner" + with self.__lock: + # remove before merging + if f"{threading.current_thread().name} - runner" != name: + raise AssertionError + if self.__io_loop is None: + self.__io_loop = asyncio.new_event_loop() + self.__runner_thread = threading.Thread(target=self._runner, daemon=True, name=name) + self.__runner_thread.start() + logger.debug(f"Runner thread started {name}") + fut = asyncio.run_coroutine_threadsafe(coro, self.__io_loop) + res = fut.result(None) + + return res + + +_runner_map: dict[str, _CoroRunner] = {} + + +def run_sync_new_thread(coro_function: Callable[..., Awaitable[T]]) -> Callable[..., T]: + """ + Decorator to run a coroutine function with a loop that runs in a different thread. + Always run in a new thread, even if no current thread is running. - Leave obj=None if defining within a class. Pass the instance if attaching - as an attribute of the instance. + :param coro_function: A coroutine function """ - @functools.wraps(func) - def wrapper(*args, **kwargs): - return top_level_sync(func, *args, **kwargs) + # if not inspect.iscoroutinefunction(coro_function): + # raise AssertionError + + def wrapped(*args: Any, **kwargs: Any) -> Any: + name = threading.current_thread().name + logger.debug(f"Invoking coro_f synchronously in thread: {threading.current_thread().name}") + inner = coro_function(*args, **kwargs) + if name not in _runner_map: + _runner_map[name] = _CoroRunner() + return _runner_map[name].run(inner) - return wrapper + wrapped.__doc__ = coro_function.__doc__ + return wrapped class ContextExecutor(ThreadPoolExecutor): def __init__(self): - self.context = contextvars.copy_context() + self.context = copy_context() super().__init__(initializer=self._set_child_context) def _set_child_context(self): diff --git a/plugins/flytekit-flyteinteractive/tests/test_flyteinteractive_vscode.py b/plugins/flytekit-flyteinteractive/tests/test_flyteinteractive_vscode.py index 54aba7a63a..de1a386a8b 100644 --- a/plugins/flytekit-flyteinteractive/tests/test_flyteinteractive_vscode.py +++ b/plugins/flytekit-flyteinteractive/tests/test_flyteinteractive_vscode.py @@ -202,7 +202,7 @@ def wf(a: int, b: int): mock_process.assert_called_once() mock_exit_handler.assert_called_once() mock_prepare_interactive_python.assert_called_once() - mock_signal.assert_called_once() + assert mock_signal.call_count >= 1 mock_prepare_resume_task_python.assert_called_once() mock_prepare_launch_json.assert_called_once() diff --git a/tests/flytekit/unit/core/test_data.py b/tests/flytekit/unit/core/test_data.py index aa308d7929..93ea739355 100644 --- a/tests/flytekit/unit/core/test_data.py +++ b/tests/flytekit/unit/core/test_data.py @@ -3,18 +3,20 @@ import shutil import tempfile from uuid import UUID - +import typing import fsspec import mock import pytest from s3fs import S3FileSystem from flytekit.configuration import Config, DataConfig, S3Config -from flytekit.core.context_manager import FlyteContextManager +from flytekit.core.context_manager import FlyteContextManager, FlyteContext from flytekit.core.data_persistence import FileAccessProvider, get_fsspec_storage_options, s3_setup_args from flytekit.core.type_engine import TypeEngine from flytekit.types.directory.types import FlyteDirectory from flytekit.types.file import FlyteFile +from flytekit.utils.async_utils import run_sync_new_thread +from flytekit.models.literals import Literal local = fsspec.filesystem("file") root = os.path.abspath(os.sep) @@ -446,3 +448,48 @@ def test_s3_metadata(): res = [(x, y) for x, y in res] files = [os.path.join(x, y) for x, y in res] assert len(files) == 2 + + +async def dummy_output_to_literal_map(ctx: FlyteContext, ff: typing.List[FlyteFile]) -> Literal: + lt = TypeEngine.to_literal_type(typing.List[FlyteFile]) + lit = await TypeEngine.async_to_literal(ctx, ff, typing.List[FlyteFile], lt) + return lit + + +@pytest.mark.sandbox_test +def test_async_local_copy_to_s3(): + import time + import datetime + + f1 = "/Users/ytong/go/src/github.com/unionai/debugyt/user/ytong/src/yt_dbg/fr/rand.file" + f2 = "/Users/ytong/go/src/github.com/unionai/debugyt/user/ytong/src/yt_dbg/fr/rand2.file" + f3 = "/Users/ytong/go/src/github.com/unionai/debugyt/user/ytong/src/yt_dbg/fr/rand3.file" + + ff1 = FlyteFile(path=f1) + ff2 = FlyteFile(path=f2) + ff3 = FlyteFile(path=f3) + ff = [ff1, ff2, ff3] + + ctx = FlyteContextManager.current_context() + dc = Config.for_sandbox().data_config + random_folder = UUID(int=random.getrandbits(64)).hex + raw_output = f"s3://my-s3-bucket/testing/upload_test/{random_folder}" + print(f"Uploading to {raw_output}") + provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output, data_config=dc) + + start_time = datetime.datetime.now(datetime.timezone.utc) + start_wall_time = time.perf_counter() + start_process_time = time.process_time() + + with FlyteContextManager.with_context(ctx.with_file_access(provider)) as ctx: + synced = run_sync_new_thread(dummy_output_to_literal_map) + lit = synced(ctx, ff) + print(lit) + + end_time = datetime.datetime.now(datetime.timezone.utc) + end_wall_time = time.perf_counter() + end_process_time = time.process_time() + + print(f"Time taken: {end_time - start_time}") + print(f"Wall time taken: {end_wall_time - start_wall_time}") + print(f"Process time taken: {end_process_time - start_process_time}") diff --git a/tests/flytekit/unit/core/test_promise.py b/tests/flytekit/unit/core/test_promise.py index 99f6920155..b5676f0c23 100644 --- a/tests/flytekit/unit/core/test_promise.py +++ b/tests/flytekit/unit/core/test_promise.py @@ -225,7 +225,7 @@ class Foo: src = {"a": [Foo(b="foo")]} - src_lit = await TypeEngine.to_literal( + src_lit = TypeEngine.to_literal( FlyteContextManager.current_context(), src, Dict[str, List[Foo]], @@ -235,7 +235,7 @@ class Foo: # happy path tgt_promise = await resolve_attr_path_in_promise(src_promise["a"][0]["b"]) - assert "foo" == await TypeEngine.to_python_value(FlyteContextManager.current_context(), tgt_promise.val, str) + assert "foo" == TypeEngine.to_python_value(FlyteContextManager.current_context(), tgt_promise.val, str) # exception with pytest.raises(FlytePromiseAttributeResolveException): diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index 0e7b88bd08..aa13a3b753 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -736,8 +736,9 @@ def print_expr(expr): def test_comparison_lits(): - px = Promise("x", TypeEngine.to_literal(None, 5, int, None)) - py = Promise("y", TypeEngine.to_literal(None, 8, int, None)) + ctx = context_manager.FlyteContextManager.current_context() + px = Promise("x", TypeEngine.to_literal(ctx, 5, int, None)) + py = Promise("y", TypeEngine.to_literal(ctx, 8, int, None)) def eval_expr(expr, expected: bool): print(f"{expr} evals to {expr.eval()}") diff --git a/tests/flytekit/unit/interactive/test_flyteinteractive_vscode.py b/tests/flytekit/unit/interactive/test_flyteinteractive_vscode.py index 8555027021..4102ee84b5 100644 --- a/tests/flytekit/unit/interactive/test_flyteinteractive_vscode.py +++ b/tests/flytekit/unit/interactive/test_flyteinteractive_vscode.py @@ -122,7 +122,7 @@ def wf(): mock_process.assert_not_called() mock_exit_handler.assert_not_called() mock_prepare_interactive_python.assert_not_called() - assert mock_signal.call_count >= 1 + assert mock_signal.call_count == 0 mock_prepare_resume_task_python.assert_not_called() mock_prepare_launch_json.assert_not_called() From faedf6e5601ac455ae40b4591e5f1cc8c6b10f15 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Thu, 3 Oct 2024 16:49:14 -0700 Subject: [PATCH 15/26] merge conflicts resolved Signed-off-by: Yee Hing Tong --- flytekit/core/array_node_map_task.py | 4 +++- flytekit/core/type_engine.py | 19 +++++++++++-------- flytekit/extend/backend/agent_service.py | 6 ++++-- flytekit/extend/backend/base_agent.py | 10 ++++++---- tests/flytekit/unit/core/test_array_node.py | 9 ++++++--- tests/flytekit/unit/extend/test_agent.py | 8 +++++--- 6 files changed, 35 insertions(+), 21 deletions(-) diff --git a/flytekit/core/array_node_map_task.py b/flytekit/core/array_node_map_task.py index 94454f417b..2f71df9b34 100644 --- a/flytekit/core/array_node_map_task.py +++ b/flytekit/core/array_node_map_task.py @@ -29,6 +29,7 @@ from flytekit.tools.module_loader import load_object_from_module from flytekit.types.pickle import pickle from flytekit.types.pickle.pickle import FlytePickleTransformer +from flytekit.utils.async_utils import run_sync_new_thread class ArrayNodeMapTask(PythonTask): @@ -253,7 +254,8 @@ def _literal_map_to_python_input( v = literal_map.literals[k] # If the input is offloaded, we need to unwrap it if v.offloaded_metadata: - v = TypeEngine.unwrap_offloaded_literal(ctx, v) + sync_f = run_sync_new_thread(TypeEngine.unwrap_offloaded_literal) + v = sync_f(ctx, v) if k not in self.bound_inputs: # assert that v.collection is not None if not v.collection or not isinstance(v.collection.literals, list): diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 584d70bd2f..e0bbe4e9e2 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1287,7 +1287,7 @@ async def async_to_literal( return lv @classmethod - def unwrap_offloaded_literal(cls, ctx: FlyteContext, lv: Literal) -> Literal: + async def unwrap_offloaded_literal(cls, ctx: FlyteContext, lv: Literal) -> Literal: if not lv.offloaded_metadata: return lv @@ -1304,7 +1304,8 @@ def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: T """ # Initiate the process of loading the offloaded literal if offloaded_metadata is set if lv.offloaded_metadata: - lv = cls.unwrap_offloaded_literal(ctx, lv) + synced = run_sync_new_thread(cls.unwrap_offloaded_literal) + lv = synced(ctx, lv) transformer = cls.get_transformer(expected_python_type) # see note in to_literal. @@ -1327,6 +1328,8 @@ def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: T @classmethod async def async_to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: Type) -> typing.Any: + if lv.offloaded_metadata: + lv = await cls.unwrap_offloaded_literal(ctx, lv) transformer = cls.get_transformer(expected_python_type) if isinstance(transformer, AsyncTypeTransformer): pv = await transformer.async_to_python_value(ctx, lv, expected_python_type) @@ -1400,15 +1403,15 @@ async def _literal_map_to_kwargs( f" than allowed by the input spec {len(python_interface_inputs)}" ) kwargs = {} - for i, k in enumerate(lm.literals): - try: + try: + for i, k in enumerate(lm.literals): kwargs[k] = asyncio.create_task( TypeEngine.async_to_python_value(ctx, lm.literals[k], python_interface_inputs[k]) ) - await asyncio.gather(*kwargs.values()) - except TypeTransformerFailedError as exc: - exc.args = (f"Error converting input '{k}' at position {i}:\n {exc.args[0]}",) - raise + await asyncio.gather(*kwargs.values()) + except TypeTransformerFailedError as exc: + exc.args = (f"Error converting input '{k}' at position {i}:\n {exc.args[0]}",) + raise kwargs = {k: v.result() for k, v in kwargs.items() if v is not None} return kwargs diff --git a/flytekit/extend/backend/agent_service.py b/flytekit/extend/backend/agent_service.py index a046ccb973..5919ab0611 100644 --- a/flytekit/extend/backend/agent_service.py +++ b/flytekit/extend/backend/agent_service.py @@ -133,7 +133,8 @@ async def GetTask(self, request: GetTaskRequest, context: grpc.ServicerContext) agent = AgentRegistry.get_agent(request.task_type) logger.info(f"{agent.name} start checking the status of the job") res = await mirror_async_methods(agent.get, resource_meta=agent.metadata_type.decode(request.resource_meta)) - return GetTaskResponse(resource=res.to_flyte_idl()) + resource = await res.to_flyte_idl() + return GetTaskResponse(resource=resource) @record_agent_metrics async def DeleteTask(self, request: DeleteTaskRequest, context: grpc.ServicerContext) -> DeleteTaskResponse: @@ -166,7 +167,8 @@ async def ExecuteTaskSync( agent.do, task_template=template, inputs=literal_map, output_prefix=output_prefix ) - header = ExecuteTaskSyncResponseHeader(resource=res.to_flyte_idl()) + resource = await res.to_flyte_idl() + header = ExecuteTaskSyncResponseHeader(resource=resource) yield ExecuteTaskSyncResponse(header=header) request_success_count.labels(task_type=task_type, operation=do_operation).inc() except Exception as e: diff --git a/flytekit/extend/backend/base_agent.py b/flytekit/extend/backend/base_agent.py index 415ab70811..2f973e94f0 100644 --- a/flytekit/extend/backend/base_agent.py +++ b/flytekit/extend/backend/base_agent.py @@ -31,7 +31,6 @@ from flytekit.exceptions.user import FlyteUserException from flytekit.extend.backend.utils import is_terminal_phase, mirror_async_methods, render_task_template from flytekit.loggers import set_flytekit_log_properties -from flytekit.models import common from flytekit.models.literals import LiteralMap from flytekit.models.task import TaskExecutionMetadata, TaskTemplate @@ -79,7 +78,6 @@ def decode(cls, data: bytes) -> "ResourceMeta": return dataclass_from_dict(cls, json.loads(data.decode("utf-8"))) -# Merge todo: move some logic back out of the to/from idl functions. @dataclass class Resource: """ @@ -98,7 +96,11 @@ class Resource: outputs: Optional[Union[LiteralMap, typing.Dict[str, Any]]] = None custom_info: Optional[typing.Dict[str, Any]] = None - def to_flyte_idl(self) -> _Resource: + async def to_flyte_idl(self) -> _Resource: + """ + This function is async to call the async type engine functions. This is okay to do because this is not a + normal model class that inherits from FlyteIdlEntity + """ if self.outputs is None: outputs = None elif isinstance(self.outputs, LiteralMap): @@ -106,7 +108,7 @@ def to_flyte_idl(self) -> _Resource: else: ctx = FlyteContext.current_context() - outputs = TypeEngine.dict_to_literal_map_pb(ctx, self.outputs) + outputs = await TypeEngine.dict_to_literal_map_pb(ctx, self.outputs) return _Resource( phase=self.phase, diff --git a/tests/flytekit/unit/core/test_array_node.py b/tests/flytekit/unit/core/test_array_node.py index c861d39268..2f4c3145ba 100644 --- a/tests/flytekit/unit/core/test_array_node.py +++ b/tests/flytekit/unit/core/test_array_node.py @@ -3,7 +3,8 @@ import pytest -from flytekit import LaunchPlan, current_context, task, workflow +from flytekit import LaunchPlan, task, workflow +from flytekit.core.context_manager import FlyteContextManager from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core.array_node import array_node from flytekit.core.array_node_map_task import map_task @@ -35,7 +36,8 @@ def parent_wf(a: int, b: typing.Union[int, str], c: int = 2) -> int: return multiply(val=a, val1=b, val2=c) -lp = LaunchPlan.get_default_launch_plan(current_context(), parent_wf) +ctx = FlyteContextManager.current_context() +lp = LaunchPlan.get_default_launch_plan(ctx, parent_wf) @workflow @@ -103,7 +105,8 @@ def ex_task(val: int) -> int: def ex_wf(val: int) -> int: return ex_task(val=val) - ex_lp = LaunchPlan.get_default_launch_plan(current_context(), ex_wf) + ctx = FlyteContextManager.current_context() + ex_lp = LaunchPlan.get_default_launch_plan(ctx, ex_wf) @workflow def grandparent_ex_wf() -> typing.List[typing.Optional[int]]: diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index 946bf3a778..1638617575 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -21,7 +21,6 @@ from flyteidl.core.identifier_pb2 import ResourceType from flytekit import PythonFunctionTask, task -from flytekit.clis.sdk_in_container.serve import print_agents_metadata from flytekit.configuration import ( FastSerializationSettings, Image, @@ -59,6 +58,7 @@ from flytekit.models.security import Identity from flytekit.models.task import TaskExecutionMetadata, TaskTemplate from flytekit.tools.translator import get_serializable +from flytekit.utils.async_utils import run_sync_new_thread dummy_id = "dummy_id" @@ -440,7 +440,8 @@ def test_resource_type(): o = Resource( phase=TaskExecution.SUCCEEDED, ) - v = o.to_flyte_idl() + synced = run_sync_new_thread(o.to_flyte_idl) + v = synced() assert v assert v.phase == TaskExecution.SUCCEEDED assert len(v.log_links) == 0 @@ -458,7 +459,8 @@ def test_resource_type(): outputs={"o0": 1}, custom_info={"custom": "info", "num": 1}, ) - v = o.to_flyte_idl() + synced = run_sync_new_thread(o.to_flyte_idl) + v = synced() assert v assert v.phase == TaskExecution.SUCCEEDED assert v.log_links[0].name == "console" From 27e16ee8e11b38450ccd484ac04f5cf199d01b6d Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Fri, 4 Oct 2024 11:12:15 -0700 Subject: [PATCH 16/26] migrate over to the new loop manager Signed-off-by: Yee Hing Tong --- flytekit/core/array_node_map_task.py | 5 +- flytekit/core/base_task.py | 4 +- flytekit/core/context_manager.py | 33 +---- flytekit/core/promise.py | 4 +- flytekit/core/type_engine.py | 82 +++++------ flytekit/remote/remote.py | 2 +- flytekit/{tools => utils}/asyn.py | 12 ++ flytekit/utils/async_utils.py | 130 ------------------ tests/flytekit/unit/core/test_data.py | 5 +- tests/flytekit/unit/extend/test_agent.py | 8 +- .../unit/{tools => utils}/test_asyn.py | 2 +- 11 files changed, 62 insertions(+), 225 deletions(-) rename flytekit/{tools => utils}/asyn.py (89%) delete mode 100644 flytekit/utils/async_utils.py rename tests/flytekit/unit/{tools => utils}/test_asyn.py (98%) diff --git a/flytekit/core/array_node_map_task.py b/flytekit/core/array_node_map_task.py index 2f71df9b34..82bb0cbe68 100644 --- a/flytekit/core/array_node_map_task.py +++ b/flytekit/core/array_node_map_task.py @@ -29,7 +29,7 @@ from flytekit.tools.module_loader import load_object_from_module from flytekit.types.pickle import pickle from flytekit.types.pickle.pickle import FlytePickleTransformer -from flytekit.utils.async_utils import run_sync_new_thread +from flytekit.utils.asyn import loop_manager class ArrayNodeMapTask(PythonTask): @@ -254,8 +254,7 @@ def _literal_map_to_python_input( v = literal_map.literals[k] # If the input is offloaded, we need to unwrap it if v.offloaded_metadata: - sync_f = run_sync_new_thread(TypeEngine.unwrap_offloaded_literal) - v = sync_f(ctx, v) + v = loop_manager.run_sync(TypeEngine.unwrap_offloaded_literal, ctx, v) if k not in self.bound_inputs: # assert that v.collection is not None if not v.collection or not isinstance(v.collection.literals, list): diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index d35e73a438..24bc9fc90a 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -86,7 +86,7 @@ from flytekit.models.documentation import Description, Documentation from flytekit.models.interface import Variable from flytekit.models.security import SecurityContext -from flytekit.utils.async_utils import run_sync_new_thread +from flytekit.utils.asyn import loop_manager DYNAMIC_PARTITIONS = "_uap" MODEL_CARD = "_ucm" @@ -790,7 +790,7 @@ def dispatch_execute( try: with timeit("dispatch execute"): - synced = run_sync_new_thread(self._output_to_literal_map) + synced = loop_manager.synced(self._output_to_literal_map) literals_map, native_outputs_as_map = synced(native_outputs, exec_ctx) self._write_decks(native_inputs, native_outputs_as_map, ctx, new_user_params) except (FlyteUploadDataException, FlyteDownloadDataException): diff --git a/flytekit/core/context_manager.py b/flytekit/core/context_manager.py index e56fc3f1e9..13691162d5 100644 --- a/flytekit/core/context_manager.py +++ b/flytekit/core/context_manager.py @@ -13,7 +13,6 @@ from __future__ import annotations -import asyncio import logging as _logging import os import pathlib @@ -36,7 +35,6 @@ from flytekit.interfaces.stats import taggable from flytekit.loggers import developer_logger, user_space_logger from flytekit.models.core import identifier as _identifier -from flytekit.utils.async_utils import get_or_create_loop if typing.TYPE_CHECKING: from flytekit import Deck @@ -644,15 +642,6 @@ class FlyteContext(object): in_a_condition: bool = False origin_stackframe: Optional[traceback.FrameSummary] = None output_metadata_tracker: Optional[OutputMetadataTracker] = None - _loop: Optional[asyncio.AbstractEventLoop] = None - - @property - def loop(self) -> asyncio.AbstractEventLoop: - """ - Can remove this property in the future - """ - assert self._loop is not None - return self._loop @property def user_space_params(self) -> Optional[ExecutionParameters]: @@ -679,7 +668,6 @@ def new_builder(self) -> Builder: execution_state=self.execution_state, in_a_condition=self.in_a_condition, output_metadata_tracker=self.output_metadata_tracker, - loop=self._loop, ) def enter_conditional_section(self) -> Builder: @@ -704,9 +692,6 @@ def with_serialization_settings(self, ss: SerializationSettings) -> Builder: def with_output_metadata_tracker(self, t: OutputMetadataTracker) -> Builder: return self.new_builder().with_output_metadata_tracker(t) - def with_ensure_loop(self) -> Builder: - return self.new_builder().with_ensure_loop() - def new_compilation_state(self, prefix: str = "") -> CompilationState: """ Creates and returns a default compilation state. For most of the code this should be the entrypoint @@ -768,7 +753,6 @@ class Builder(object): serialization_settings: Optional[SerializationSettings] = None in_a_condition: bool = False output_metadata_tracker: Optional[OutputMetadataTracker] = None - loop: Optional[asyncio.AbstractEventLoop] = None def build(self) -> FlyteContext: return FlyteContext( @@ -780,7 +764,6 @@ def build(self) -> FlyteContext: serialization_settings=self.serialization_settings, in_a_condition=self.in_a_condition, output_metadata_tracker=self.output_metadata_tracker, - _loop=self.loop, ) def enter_conditional_section(self) -> FlyteContext.Builder: @@ -829,12 +812,6 @@ def with_output_metadata_tracker(self, t: OutputMetadataTracker) -> FlyteContext self.output_metadata_tracker = t return self - def with_ensure_loop(self, use_windows: bool = False) -> FlyteContext.Builder: - if not self.loop: - # Currently this will use a running system loop. - self.loop = get_or_create_loop(use_windows=use_windows) - return self - def new_compilation_state(self, prefix: str = "") -> CompilationState: """ Creates and returns a default compilation state. For most of the code this should be the entrypoint @@ -970,13 +947,9 @@ def initialize(): decks=[], ) - default_context = ( - default_context.with_execution_state( - default_context.new_execution_state().with_params(user_space_params=default_user_space_params) - ) - .with_ensure_loop() - .build() - ) + default_context = default_context.with_execution_state( + default_context.new_execution_state().with_params(user_space_params=default_user_space_params) + ).build() default_context.set_stackframe(s=FlyteContextManager.get_origin_stackframe()) flyte_context_Var.set([default_context]) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index c39a632b66..50c5fd2a70 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -45,7 +45,7 @@ from flytekit.models.literals import Primitive from flytekit.models.task import Resources from flytekit.models.types import SimpleType -from flytekit.utils.async_utils import run_sync_new_thread +from flytekit.utils.asyn import loop_manager async def _translate_inputs_to_literals( @@ -105,7 +105,7 @@ def my_wf(in1: int, in2: int) -> int: return result -translate_inputs_to_literals = run_sync_new_thread(_translate_inputs_to_literals) +translate_inputs_to_literals = loop_manager.synced(_translate_inputs_to_literals) async def resolve_attr_path_in_promise(p: Promise) -> Promise: diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index e0bbe4e9e2..93beac9162 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -53,7 +53,7 @@ Void, ) from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType -from flytekit.utils.async_utils import ContextExecutor, get_running_loop_if_exists, run_sync_new_thread +from flytekit.utils.asyn import ContextExecutor, loop_manager T = typing.TypeVar("T") DEFINITIONS = "definitions" @@ -280,22 +280,14 @@ async def async_to_python_value(self, ctx: FlyteContext, lv: Literal, expected_p def to_literal( self, ctx: FlyteContext, python_val: typing.Any, python_type: Type[T], expected: LiteralType ) -> Literal: - if ctx.loop.is_running(): - synced = run_sync_new_thread(self.async_to_literal) - result = synced(ctx, python_val, python_type, expected) - return result - else: - coro = self.async_to_literal(ctx, python_val, python_type, expected) - return ctx.loop.run_until_complete(coro) + synced = loop_manager.synced(self.async_to_literal) + result = synced(ctx, python_val, python_type, expected) + return result def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> Optional[T]: - if ctx.loop.is_running(): - synced = run_sync_new_thread(self.async_to_python_value) - result = synced(ctx, lv, expected_python_type) - return result - else: - coro = self.async_to_python_value(ctx, lv, expected_python_type) - return ctx.loop.run_until_complete(coro) + synced = loop_manager.synced(self.async_to_python_value) + result = synced(ctx, lv, expected_python_type) + return result class SimpleTransformer(TypeTransformer[T]): @@ -1231,22 +1223,20 @@ def to_literal( if transformer.type_assertions_enabled: transformer.assert_type(python_type, python_val) - running_loop = get_running_loop_if_exists() - # can't have a main loop running either, maybe this is one of the downsides # of not calling set_event_loop. - if not ctx.loop.is_running() and not running_loop: - if isinstance(transformer, AsyncTypeTransformer): - coro = transformer.async_to_literal(ctx, python_val, python_type, expected) - lv = ctx.loop.run_until_complete(coro) - else: - lv = transformer.to_literal(ctx, python_val, python_type, expected) + if isinstance(transformer, AsyncTypeTransformer): + synced = loop_manager.synced(transformer.async_to_literal) + lv = synced(ctx, python_val, python_type, expected) else: - if isinstance(transformer, AsyncTypeTransformer): - synced = run_sync_new_thread(transformer.async_to_literal) - lv = synced(ctx, python_val, python_type, expected) - else: - lv = transformer.to_literal(ctx, python_val, python_type, expected) + lv = transformer.to_literal(ctx, python_val, python_type, expected) + # if not ctx.loop.is_running() and not running_loop: + # if isinstance(transformer, AsyncTypeTransformer): + # coro = transformer.async_to_literal(ctx, python_val, python_type, expected) + # lv = ctx.loop.run_until_complete(coro) + # else: + # lv = transformer.to_literal(ctx, python_val, python_type, expected) + # else: modify_literal_uris(lv) lv.hash = cls.calculate_hash(python_val, python_type) @@ -1275,7 +1265,6 @@ async def async_to_literal( lv = await transformer.async_to_literal(ctx, python_val, python_type, expected) else: # Testing just blocking call - # lv = transformer.to_literal(ctx, python_val, python_type, expected) loop = asyncio.get_running_loop() executor = ContextExecutor() fut = loop.run_in_executor(executor, transformer.to_literal, ctx, python_val, python_type, expected) @@ -1304,27 +1293,24 @@ def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: T """ # Initiate the process of loading the offloaded literal if offloaded_metadata is set if lv.offloaded_metadata: - synced = run_sync_new_thread(cls.unwrap_offloaded_literal) + synced = loop_manager.synced(cls.unwrap_offloaded_literal) lv = synced(ctx, lv) transformer = cls.get_transformer(expected_python_type) - # see note in to_literal. - running_loop = get_running_loop_if_exists() - - if not ctx.loop.is_running() and not running_loop: - if isinstance(transformer, AsyncTypeTransformer): - coro = transformer.async_to_python_value(ctx, lv, expected_python_type) - pv = ctx.loop.run_until_complete(coro) - else: - pv = transformer.to_python_value(ctx, lv, expected_python_type) - return pv + if isinstance(transformer, AsyncTypeTransformer): + synced = loop_manager.synced(transformer.async_to_python_value) + return synced(ctx, lv, expected_python_type) else: - if isinstance(transformer, AsyncTypeTransformer): - synced = run_sync_new_thread(transformer.async_to_python_value) - return synced(ctx, lv, expected_python_type) - else: - res = transformer.to_python_value(ctx, lv, expected_python_type) - return res + res = transformer.to_python_value(ctx, lv, expected_python_type) + return res + # if not ctx.loop.is_running() and not running_loop: + # if isinstance(transformer, AsyncTypeTransformer): + # coro = transformer.async_to_python_value(ctx, lv, expected_python_type) + # pv = ctx.loop.run_until_complete(coro) + # else: + # pv = transformer.to_python_value(ctx, lv, expected_python_type) + # return pv + # else: @classmethod async def async_to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: Type) -> typing.Any: @@ -1372,7 +1358,7 @@ def literal_map_to_kwargs( python_types: typing.Optional[typing.Dict[str, type]] = None, literal_types: typing.Optional[typing.Dict[str, _interface_models.Variable]] = None, ) -> typing.Dict[str, typing.Any]: - synced = run_sync_new_thread(cls._literal_map_to_kwargs) + synced = loop_manager.synced(cls._literal_map_to_kwargs) return synced(ctx, lm, python_types, literal_types) @classmethod @@ -1423,7 +1409,7 @@ def dict_to_literal_map( d: typing.Dict[str, typing.Any], type_hints: Optional[typing.Dict[str, type]] = None, ) -> LiteralMap: - synced = run_sync_new_thread(cls._dict_to_literal_map) + synced = loop_manager.synced(cls._dict_to_literal_map) return synced(ctx, d, type_hints) @classmethod diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index 379943d36d..3fcdee46e3 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -87,7 +87,6 @@ from flytekit.remote.lazy_entity import LazyEntity from flytekit.remote.remote_callable import RemoteEntity from flytekit.remote.remote_fs import get_flyte_fs -from flytekit.tools.asyn import run_sync from flytekit.tools.fast_registration import FastPackageOptions, fast_package from flytekit.tools.interactive import ipython_check from flytekit.tools.script_mode import _find_project_root, compress_scripts, get_all_modules, hash_file @@ -98,6 +97,7 @@ get_serializable, get_serializable_launch_plan, ) +from flytekit.utils.asyn import run_sync if typing.TYPE_CHECKING: try: diff --git a/flytekit/tools/asyn.py b/flytekit/utils/asyn.py similarity index 89% rename from flytekit/tools/asyn.py rename to flytekit/utils/asyn.py index a126405672..0853c1670c 100644 --- a/flytekit/tools/asyn.py +++ b/flytekit/utils/asyn.py @@ -13,7 +13,9 @@ async def async_add(a: int, b: int) -> int: import atexit import os import threading +from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager +from contextvars import copy_context from typing import Any, Awaitable, Callable, TypeVar from flytekit.loggers import logger @@ -99,3 +101,13 @@ def wrapped(*args: Any, **kwargs: Any) -> T: loop_manager = _AsyncLoopManager() run_sync = loop_manager.run_sync + + +class ContextExecutor(ThreadPoolExecutor): + def __init__(self): + self.context = copy_context() + super().__init__(initializer=self._set_child_context) + + def _set_child_context(self): + for var, value in self.context.items(): + var.set(value) diff --git a/flytekit/utils/async_utils.py b/flytekit/utils/async_utils.py deleted file mode 100644 index 1c3977e723..0000000000 --- a/flytekit/utils/async_utils.py +++ /dev/null @@ -1,130 +0,0 @@ -import asyncio -import atexit -import sys -import threading -from concurrent.futures import ThreadPoolExecutor -from contextvars import copy_context -from types import CoroutineType -from typing import Any, Awaitable, Callable, Optional, TypeVar - -from flytekit.loggers import logger - -AsyncFuncType = Callable[[Any], CoroutineType] -Synced = Callable[[Any], Any] -T = TypeVar("T") - - -def get_running_loop_if_exists() -> Optional[asyncio.AbstractEventLoop]: - try: - loop = asyncio.get_running_loop() - return loop - except RuntimeError as e: - if "no running event loop" not in str(e): - logger.error(f"Unknown RuntimeError {str(e)}") - raise - return None - - -def get_or_create_loop(use_windows: bool = False) -> asyncio.AbstractEventLoop: - # todo: what happens if we remove this? never rely on the running loop - # something to test. import flytekit, inside an async function, what happens? - # import flytekit, inside a jupyter notebook (which sets its own loop) - try: - running_loop = asyncio.get_running_loop() - return running_loop - except RuntimeError as e: - if "no running event loop" not in str(e): - logger.error(f"Unknown RuntimeError when getting loop {str(e)}") - raise - - if sys.platform == "win32" and use_windows: - loop = asyncio.WindowsSelectorEventLoopPolicy().new_event_loop() - else: - loop = asyncio.new_event_loop() - # Intentionally not calling asyncio.set_event_loop(loop) - # Unclear what the downside of this is. But should be better in the Jupyter case where it seems to - # co-opt set_event_loop somehow - - # maybe add signal handlers in the future - - return loop - - -class _CoroRunner: - """ - Runs a coroutine and a loop for it on a background thread, in a blocking manner - """ - - def __init__(self) -> None: - self.__io_loop: asyncio.AbstractEventLoop | None = None - self.__runner_thread: threading.Thread | None = None - self.__lock = threading.Lock() - atexit.register(self._close) - - def _close(self) -> None: - if self.__io_loop: - self.__io_loop.stop() - - def _runner(self) -> None: - loop = self.__io_loop - assert loop is not None - try: - loop.run_forever() - finally: - loop.close() - - def run(self, coro: Any) -> Any: - """ - This is a blocking function. - Synchronously runs the coroutine on a background thread. - """ - name = f"{threading.current_thread().name} - runner" - with self.__lock: - # remove before merging - if f"{threading.current_thread().name} - runner" != name: - raise AssertionError - if self.__io_loop is None: - self.__io_loop = asyncio.new_event_loop() - self.__runner_thread = threading.Thread(target=self._runner, daemon=True, name=name) - self.__runner_thread.start() - logger.debug(f"Runner thread started {name}") - fut = asyncio.run_coroutine_threadsafe(coro, self.__io_loop) - res = fut.result(None) - - return res - - -_runner_map: dict[str, _CoroRunner] = {} - - -def run_sync_new_thread(coro_function: Callable[..., Awaitable[T]]) -> Callable[..., T]: - """ - Decorator to run a coroutine function with a loop that runs in a different thread. - Always run in a new thread, even if no current thread is running. - - :param coro_function: A coroutine function - """ - - # if not inspect.iscoroutinefunction(coro_function): - # raise AssertionError - - def wrapped(*args: Any, **kwargs: Any) -> Any: - name = threading.current_thread().name - logger.debug(f"Invoking coro_f synchronously in thread: {threading.current_thread().name}") - inner = coro_function(*args, **kwargs) - if name not in _runner_map: - _runner_map[name] = _CoroRunner() - return _runner_map[name].run(inner) - - wrapped.__doc__ = coro_function.__doc__ - return wrapped - - -class ContextExecutor(ThreadPoolExecutor): - def __init__(self): - self.context = copy_context() - super().__init__(initializer=self._set_child_context) - - def _set_child_context(self): - for var, value in self.context.items(): - var.set(value) diff --git a/tests/flytekit/unit/core/test_data.py b/tests/flytekit/unit/core/test_data.py index 93ea739355..2de6e8c196 100644 --- a/tests/flytekit/unit/core/test_data.py +++ b/tests/flytekit/unit/core/test_data.py @@ -15,7 +15,7 @@ from flytekit.core.type_engine import TypeEngine from flytekit.types.directory.types import FlyteDirectory from flytekit.types.file import FlyteFile -from flytekit.utils.async_utils import run_sync_new_thread +from flytekit.utils.asyn import loop_manager from flytekit.models.literals import Literal local = fsspec.filesystem("file") @@ -482,8 +482,7 @@ def test_async_local_copy_to_s3(): start_process_time = time.process_time() with FlyteContextManager.with_context(ctx.with_file_access(provider)) as ctx: - synced = run_sync_new_thread(dummy_output_to_literal_map) - lit = synced(ctx, ff) + lit = loop_manager.run_sync(dummy_output_to_literal_map, ctx, ff) print(lit) end_time = datetime.datetime.now(datetime.timezone.utc) diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index 1638617575..f16369dae6 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -58,7 +58,7 @@ from flytekit.models.security import Identity from flytekit.models.task import TaskExecutionMetadata, TaskTemplate from flytekit.tools.translator import get_serializable -from flytekit.utils.async_utils import run_sync_new_thread +from flytekit.utils.asyn import loop_manager dummy_id = "dummy_id" @@ -440,8 +440,7 @@ def test_resource_type(): o = Resource( phase=TaskExecution.SUCCEEDED, ) - synced = run_sync_new_thread(o.to_flyte_idl) - v = synced() + v = loop_manager.run_sync(o.to_flyte_idl) assert v assert v.phase == TaskExecution.SUCCEEDED assert len(v.log_links) == 0 @@ -459,8 +458,7 @@ def test_resource_type(): outputs={"o0": 1}, custom_info={"custom": "info", "num": 1}, ) - synced = run_sync_new_thread(o.to_flyte_idl) - v = synced() + v = loop_manager.run_sync(o.to_flyte_idl) assert v assert v.phase == TaskExecution.SUCCEEDED assert v.log_links[0].name == "console" diff --git a/tests/flytekit/unit/tools/test_asyn.py b/tests/flytekit/unit/utils/test_asyn.py similarity index 98% rename from tests/flytekit/unit/tools/test_asyn.py rename to tests/flytekit/unit/utils/test_asyn.py index 0a3ffeb2d3..db74ac6f53 100644 --- a/tests/flytekit/unit/tools/test_asyn.py +++ b/tests/flytekit/unit/utils/test_asyn.py @@ -4,7 +4,7 @@ from typing import List, Dict, Optional from asyncio import get_running_loop from functools import partial -from flytekit.tools.asyn import run_sync, loop_manager +from flytekit.utils.asyn import run_sync, loop_manager from contextvars import ContextVar From 361730047df400d49ed3a431dd3dbcdf401d1a34 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Fri, 4 Oct 2024 16:02:19 -0700 Subject: [PATCH 17/26] update comments and make additional code actually run in parallel Signed-off-by: Yee Hing Tong --- flytekit/core/base_task.py | 25 ++++++++++------ flytekit/core/promise.py | 8 ++---- flytekit/core/type_engine.py | 56 ++++++++++++------------------------ 3 files changed, 36 insertions(+), 53 deletions(-) diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 24bc9fc90a..1d80d6dcd1 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -637,21 +637,28 @@ async def _output_to_literal_map(self, native_outputs: Dict[int, Any], ctx: Flyt if isinstance(v, tuple): raise TypeError(f"Output({k}) in task '{self.name}' received a tuple {v}, instead of {py_type}") - try: - lit = await TypeEngine.async_to_literal(ctx, v, py_type, literal_type) - literals[k] = lit - except Exception as e: + literals[k] = asyncio.create_task(TypeEngine.async_to_literal(ctx, v, py_type, literal_type)) + + await asyncio.gather(*literals.values(), return_exceptions=True) + + for i, (k, v) in enumerate(literals.items()): + if v.exception() is not None: # only show the name of output key if it's user-defined (by default Flyte names these as "o") key = k if k != f"o{i}" else i + e = v.exception() + py_type = self.get_type_for_output_var(k, native_outputs_as_map[k]) e.args = ( f"Failed to convert outputs of task '{self.name}' at position {key}.\n" f"Failed to convert type {type(native_outputs_as_map[expected_output_names[i]])} to type {py_type}.\n" f"Error Message: {e.args[0]}.", ) - raise - # Now check if there is any output metadata associated with this output variable and attach it to the - # literal - if omt is not None: + raise e + literals[k] = v.result() + + if omt is not None: + for i, (k, v) in enumerate(native_outputs_as_map.items()): + # Now check if there is any output metadata associated with this output variable and attach it to the + # literal om = omt.get(v) if om: metadata = {} @@ -671,7 +678,7 @@ async def _output_to_literal_map(self, native_outputs: Dict[int, Any], ctx: Flyt encoded = b64encode(s).decode("utf-8") metadata[DYNAMIC_PARTITIONS] = encoded if metadata: - lit.set_metadata(metadata) + literals[k].set_metadata(metadata) return _literal_models.LiteralMap(literals=literals), native_outputs_as_map diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 50c5fd2a70..c475d5154a 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -243,13 +243,9 @@ def __init__(self, lhs: Union["Promise", Any], op: ComparisonOps, rhs: Union["Pr else: raise ValueError("Only primitive values can be used in comparison") if self._lhs is None: - lhs_lit = type_engine.TypeEngine.to_literal(FlyteContextManager.current_context(), lhs, type(lhs), None) - assert not isinstance(lhs_lit, asyncio.Future) - self._lhs = lhs_lit + self._lhs = type_engine.TypeEngine.to_literal(FlyteContextManager.current_context(), lhs, type(lhs), None) if self._rhs is None: - rhs_lit = type_engine.TypeEngine.to_literal(FlyteContextManager.current_context(), rhs, type(rhs), None) - assert not isinstance(rhs_lit, asyncio.Future) - self._rhs = rhs_lit + self._rhs = type_engine.TypeEngine.to_literal(FlyteContextManager.current_context(), rhs, type(rhs), None) @property def rhs(self) -> Union["Promise", _literals_models.Literal]: diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 93beac9162..c77447e64d 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1211,32 +1211,15 @@ def to_literal( return python_val.val transformer = cls.get_transformer(python_type) - # possible options are: - # a) in a loop (somewhere above, someone called async) - # 1) transformer is async - just await the async function. - # 2) transformer is not async - since the expectation is for async behavior - # run it in an executor. - # b) not in a loop (async has never been called) - # 1) transformer is async - create loop and run it. - # 2) transformer is not async - just invoke normally as a blocking function if transformer.type_assertions_enabled: transformer.assert_type(python_type, python_val) - # can't have a main loop running either, maybe this is one of the downsides - # of not calling set_event_loop. if isinstance(transformer, AsyncTypeTransformer): synced = loop_manager.synced(transformer.async_to_literal) lv = synced(ctx, python_val, python_type, expected) else: lv = transformer.to_literal(ctx, python_val, python_type, expected) - # if not ctx.loop.is_running() and not running_loop: - # if isinstance(transformer, AsyncTypeTransformer): - # coro = transformer.async_to_literal(ctx, python_val, python_type, expected) - # lv = ctx.loop.run_until_complete(coro) - # else: - # lv = transformer.to_literal(ctx, python_val, python_type, expected) - # else: modify_literal_uris(lv) lv.hash = cls.calculate_hash(python_val, python_type) @@ -1303,14 +1286,6 @@ def to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: T else: res = transformer.to_python_value(ctx, lv, expected_python_type) return res - # if not ctx.loop.is_running() and not running_loop: - # if isinstance(transformer, AsyncTypeTransformer): - # coro = transformer.async_to_python_value(ctx, lv, expected_python_type) - # pv = ctx.loop.run_until_complete(coro) - # else: - # pv = transformer.to_python_value(ctx, lv, expected_python_type) - # return pv - # else: @classmethod async def async_to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_python_type: Type) -> typing.Any: @@ -1430,20 +1405,25 @@ async def _dict_to_literal_map( # to account for the type erasure that happens in the case of built-in collection containers, such as # `list` and `dict`. python_type = type_hints.get(k, type(v)) - try: - literal_map[k] = asyncio.create_task( - TypeEngine.async_to_literal( - ctx=ctx, - python_val=v, - python_type=python_type, - expected=TypeEngine.to_literal_type(python_type), - ) + literal_map[k] = asyncio.create_task( + TypeEngine.async_to_literal( + ctx=ctx, + python_val=v, + python_type=python_type, + expected=TypeEngine.to_literal_type(python_type), ) - await asyncio.gather(*literal_map.values()) - except TypeError: - raise user_exceptions.FlyteTypeException(type(v), python_type, received_value=v) + ) + await asyncio.gather(*literal_map.values(), return_exceptions=True) + for idx, (k, v) in enumerate(literal_map.items()): + if literal_map[k].exception() is not None: + python_type = type_hints.get(k, type(d[k])) + e = literal_map[k].exception() + if isinstance(e, TypeError): + raise user_exceptions.FlyteTypeException(type(v), python_type, received_value=v) + else: + raise e + literal_map[k] = v.result() - literal_map = {k: v.result() for k, v in literal_map.items()} return LiteralMap(literal_map) @classmethod @@ -1832,7 +1812,7 @@ async def async_to_python_value( cur_transformer = "" res = None res_tag = None - # This is serial not really async but should be okay. + # This is serial, not actually async, but should be okay since it's more reasonable for Unions. for v in get_args(expected_python_type): try: trans: TypeTransformer[T] = TypeEngine.get_transformer(v) From f30f4d54abc672cbab5a6082a7e8ddd72a955392 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Fri, 4 Oct 2024 17:54:01 -0700 Subject: [PATCH 18/26] lint Signed-off-by: Yee Hing Tong --- flytekit/core/base_task.py | 14 +++++++------- flytekit/core/promise.py | 4 +++- flytekit/core/type_engine.py | 2 +- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 1d80d6dcd1..861152b51f 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -641,19 +641,19 @@ async def _output_to_literal_map(self, native_outputs: Dict[int, Any], ctx: Flyt await asyncio.gather(*literals.values(), return_exceptions=True) - for i, (k, v) in enumerate(literals.items()): - if v.exception() is not None: + for i, (k2, v2) in enumerate(literals.items()): + if v2.exception() is not None: # only show the name of output key if it's user-defined (by default Flyte names these as "o") - key = k if k != f"o{i}" else i - e = v.exception() - py_type = self.get_type_for_output_var(k, native_outputs_as_map[k]) + key = k2 if k2 != f"o{i}" else i + e: BaseException = v2.exception() # type: ignore # we know this is not optional + py_type = self.get_type_for_output_var(k2, native_outputs_as_map[k2]) e.args = ( f"Failed to convert outputs of task '{self.name}' at position {key}.\n" f"Failed to convert type {type(native_outputs_as_map[expected_output_names[i]])} to type {py_type}.\n" f"Error Message: {e.args[0]}.", ) raise e - literals[k] = v.result() + literals[k2] = v2.result() if omt is not None: for i, (k, v) in enumerate(native_outputs_as_map.items()): @@ -678,7 +678,7 @@ async def _output_to_literal_map(self, native_outputs: Dict[int, Any], ctx: Flyt encoded = b64encode(s).decode("utf-8") metadata[DYNAMIC_PARTITIONS] = encoded if metadata: - literals[k].set_metadata(metadata) + literals[k].set_metadata(metadata) # type: ignore # we know these have been resolved return _literal_models.LiteralMap(literals=literals), native_outputs_as_map diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 79e00c898c..ddb14dccfe 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -150,7 +150,9 @@ async def resolve_attr_path_in_promise(p: Promise) -> Promise: new_st = resolve_attr_path_in_pb_struct(st, attr_path=p.attr_path[used:]) literal_type = TypeEngine.to_literal_type(type(new_st)) # Reconstruct the resolved result to flyte literal (because the resolved result might not be struct) - curr_val = await TypeEngine.async_to_literal(FlyteContextManager.current_context(), new_st, type(new_st), literal_type) + curr_val = await TypeEngine.async_to_literal( + FlyteContextManager.current_context(), new_st, type(new_st), literal_type + ) elif type(curr_val.value.value) is Binary: binary_idl_obj = curr_val.value.value if binary_idl_obj.tag == _common_constants.MESSAGEPACK: diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 72cca662a1..652c965330 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -1459,7 +1459,7 @@ async def _dict_to_literal_map( for idx, (k, v) in enumerate(literal_map.items()): if literal_map[k].exception() is not None: python_type = type_hints.get(k, type(d[k])) - e = literal_map[k].exception() + e: BaseException = literal_map[k].exception() # type: ignore if isinstance(e, TypeError): raise user_exceptions.FlyteTypeException(type(v), python_type, received_value=v) else: From efa2aa2c1418b575f0d5851c7536552528f3f615 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 7 Oct 2024 14:28:31 -0700 Subject: [PATCH 19/26] uncomment test Signed-off-by: Yee Hing Tong --- flytekit/core/promise.py | 15 ++- tests/flytekit/unit/core/test_type_engine.py | 110 +++++++++---------- 2 files changed, 62 insertions(+), 63 deletions(-) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index ddb14dccfe..d12f8ae69b 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -913,14 +913,13 @@ def binding_from_python_std( if "no running event loop" not in str(e): logger.error(f"Unknown RuntimeError {str(e)}") raise - binding_data = asyncio.run( - binding_data_from_python_std( - ctx, - expected_literal_type, - t_value, - t_value_type, - nodes, - ) + synced = loop_manager.synced(binding_data_from_python_std) + binding_data = synced( + ctx, + expected_literal_type, + t_value, + t_value_type, + nodes, ) return _literals_models.Binding(var=var_name, binding=binding_data), nodes diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index 3485aa180c..a354869c30 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -2734,61 +2734,61 @@ def test_is_batchable(): ) -# @pytest.mark.parametrize( -# "python_val, python_type, expected_list_length", -# [ -# # Case 1: List of FlytePickle objects with default batch size. -# # (By default, the batch_size is set to the length of the whole list.) -# # After converting to literal, the result will be [batched_FlytePickle(5 items)]. -# # Therefore, the expected list length is [1]. -# ([{"foo"}] * 5, typing.List[FlytePickle], [1]), -# # Case 2: List of FlytePickle objects with batch size 2. -# # After converting to literal, the result will be -# # [batched_FlytePickle(2 items), batched_FlytePickle(2 items), batched_FlytePickle(1 item)]. -# # Therefore, the expected list length is [3]. -# ( -# ["foo"] * 5, -# Annotated[typing.List[FlytePickle], HashMethod(function=str), BatchSize(2)], -# [3], -# ), -# # Case 3: Nested list of FlytePickle objects with batch size 2. -# # After converting to literal, the result will be -# # [[batched_FlytePickle(3 items)], [batched_FlytePickle(3 items)]] -# # Therefore, the expected list length is [2, 1] (the length of the outer list remains the same, the inner list is batched). -# ( -# [["foo", "foo", "foo"]] * 2, -# typing.List[Annotated[typing.List[FlytePickle], BatchSize(3)]], -# [2, 1], -# ), -# # Case 4: Empty list -# ([[], typing.List[FlytePickle], []]), -# ], -# ) -# def test_batch_pickle_list(python_val, python_type, expected_list_length): -# ctx = FlyteContext.current_context() -# expected = TypeEngine.to_literal_type(python_type) -# lv = TypeEngine.to_literal(ctx, python_val, python_type, expected) -# -# tmp_lv = lv -# for length in expected_list_length: -# # Check that after converting to literal, the length of the literal list is equal to: -# # - the length of the original list divided by the batch size if not nested -# # - the length of the original list if it contains a nested list -# assert len(tmp_lv.collection.literals) == length -# tmp_lv = tmp_lv.collection.literals[0] -# -# pv = TypeEngine.to_python_value(ctx, lv, python_type) -# # Check that after converting literal to Python value, the result is equal to the original python values. -# assert pv == python_val -# if get_origin(python_type) is Annotated: -# pv = TypeEngine.to_python_value(ctx, lv, get_args(python_type)[0]) -# # Remove the annotation and check that after converting to Python value, the result is equal -# # to the original input values. This is used to simulate the following case: -# # @workflow -# # def wf(): -# # data = task0() # task0() -> Annotated[typing.List[FlytePickle], BatchSize(2)] -# # task1(data=data) # task1(data: typing.List[FlytePickle]) -# assert pv == python_val +@pytest.mark.parametrize( + "python_val, python_type, expected_list_length", + [ + # Case 1: List of FlytePickle objects with default batch size. + # (By default, the batch_size is set to the length of the whole list.) + # After converting to literal, the result will be [batched_FlytePickle(5 items)]. + # Therefore, the expected list length is [1]. + ([{"foo"}] * 5, typing.List[FlytePickle], [1]), + # Case 2: List of FlytePickle objects with batch size 2. + # After converting to literal, the result will be + # [batched_FlytePickle(2 items), batched_FlytePickle(2 items), batched_FlytePickle(1 item)]. + # Therefore, the expected list length is [3]. + ( + ["foo"] * 5, + Annotated[typing.List[FlytePickle], HashMethod(function=str), BatchSize(2)], + [3], + ), + # Case 3: Nested list of FlytePickle objects with batch size 2. + # After converting to literal, the result will be + # [[batched_FlytePickle(3 items)], [batched_FlytePickle(3 items)]] + # Therefore, the expected list length is [2, 1] (the length of the outer list remains the same, the inner list is batched). + ( + [["foo", "foo", "foo"]] * 2, + typing.List[Annotated[typing.List[FlytePickle], BatchSize(3)]], + [2, 1], + ), + # Case 4: Empty list + ([[], typing.List[FlytePickle], []]), + ], +) +def test_batch_pickle_list(python_val, python_type, expected_list_length): + ctx = FlyteContext.current_context() + expected = TypeEngine.to_literal_type(python_type) + lv = TypeEngine.to_literal(ctx, python_val, python_type, expected) + + tmp_lv = lv + for length in expected_list_length: + # Check that after converting to literal, the length of the literal list is equal to: + # - the length of the original list divided by the batch size if not nested + # - the length of the original list if it contains a nested list + assert len(tmp_lv.collection.literals) == length + tmp_lv = tmp_lv.collection.literals[0] + + pv = TypeEngine.to_python_value(ctx, lv, python_type) + # Check that after converting literal to Python value, the result is equal to the original python values. + assert pv == python_val + if get_origin(python_type) is Annotated: + pv = TypeEngine.to_python_value(ctx, lv, get_args(python_type)[0]) + # Remove the annotation and check that after converting to Python value, the result is equal + # to the original input values. This is used to simulate the following case: + # @workflow + # def wf(): + # data = task0() # task0() -> Annotated[typing.List[FlytePickle], BatchSize(2)] + # task1(data=data) # task1(data: typing.List[FlytePickle]) + assert pv == python_val @pytest.mark.parametrize( From fb021e85a6bda18ea68dee9ff25dd8eae54c336c Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 7 Oct 2024 16:35:28 -0700 Subject: [PATCH 20/26] add a paramspec and change a couple things to run_sync Signed-off-by: Yee Hing Tong --- flytekit/core/base_task.py | 6 +++--- flytekit/core/promise.py | 6 +++--- flytekit/utils/asyn.py | 9 ++++++--- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 861152b51f..c79d505c0b 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -86,7 +86,6 @@ from flytekit.models.documentation import Description, Documentation from flytekit.models.interface import Variable from flytekit.models.security import SecurityContext -from flytekit.utils.asyn import loop_manager DYNAMIC_PARTITIONS = "_uap" MODEL_CARD = "_ucm" @@ -797,8 +796,9 @@ def dispatch_execute( try: with timeit("dispatch execute"): - synced = loop_manager.synced(self._output_to_literal_map) - literals_map, native_outputs_as_map = synced(native_outputs, exec_ctx) + literals_map, native_outputs_as_map = run_sync( + self._output_to_literal_map, native_outputs, exec_ctx + ) self._write_decks(native_inputs, native_outputs_as_map, ctx, new_user_params) except (FlyteUploadDataException, FlyteDownloadDataException): raise diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index d12f8ae69b..bd99b7b600 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -45,7 +45,7 @@ from flytekit.models.literals import Binary, Literal, Primitive, Scalar from flytekit.models.task import Resources from flytekit.models.types import SimpleType -from flytekit.utils.asyn import loop_manager +from flytekit.utils.asyn import loop_manager, run_sync async def _translate_inputs_to_literals( @@ -913,8 +913,8 @@ def binding_from_python_std( if "no running event loop" not in str(e): logger.error(f"Unknown RuntimeError {str(e)}") raise - synced = loop_manager.synced(binding_data_from_python_std) - binding_data = synced( + binding_data = run_sync( + binding_data_from_python_std, ctx, expected_literal_type, t_value, diff --git a/flytekit/utils/asyn.py b/flytekit/utils/asyn.py index 0853c1670c..4524ca1bb7 100644 --- a/flytekit/utils/asyn.py +++ b/flytekit/utils/asyn.py @@ -11,17 +11,20 @@ async def async_add(a: int, b: int) -> int: import asyncio import atexit +import functools import os import threading from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager from contextvars import copy_context -from typing import Any, Awaitable, Callable, TypeVar +from typing import Any, Awaitable, Callable, ParamSpec, TypeVar from flytekit.loggers import logger T = TypeVar("T") +P = ParamSpec("P") + @contextmanager def _selector_policy(): @@ -89,13 +92,13 @@ def run_sync(self, coro_func: Callable[..., Awaitable[T]], *args, **kwargs) -> T self._runner_map[name] = _TaskRunner() return self._runner_map[name].run(coro) - def synced(self, coro_func: Callable[..., Awaitable[T]]) -> Callable[..., T]: + def synced(self, coro_func: Callable[P, Awaitable[T]]) -> Callable[P, T]: """Make loop run coroutine until it returns. Runs in other thread""" + @functools.wraps(coro_func) def wrapped(*args: Any, **kwargs: Any) -> T: return self.run_sync(coro_func, *args, **kwargs) - wrapped.__doc__ = coro_func.__doc__ return wrapped From f0476edc0653d007d676a03ca1351564eb93c97f Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Mon, 7 Oct 2024 16:57:05 -0700 Subject: [PATCH 21/26] run sync missing Signed-off-by: Yee Hing Tong --- flytekit/core/base_task.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index c79d505c0b..b284a4c6e4 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -86,6 +86,7 @@ from flytekit.models.documentation import Description, Documentation from flytekit.models.interface import Variable from flytekit.models.security import SecurityContext +from flytekit.utils.asyn import run_sync DYNAMIC_PARTITIONS = "_uap" MODEL_CARD = "_ucm" From 3c3b5b4adce24eed3d5fb4a53970f70508688a18 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 8 Oct 2024 10:30:17 -0700 Subject: [PATCH 22/26] typing extensions Signed-off-by: Yee Hing Tong --- flytekit/utils/asyn.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/flytekit/utils/asyn.py b/flytekit/utils/asyn.py index 4524ca1bb7..aac46981d2 100644 --- a/flytekit/utils/asyn.py +++ b/flytekit/utils/asyn.py @@ -17,7 +17,9 @@ async def async_add(a: int, b: int) -> int: from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager from contextvars import copy_context -from typing import Any, Awaitable, Callable, ParamSpec, TypeVar +from typing import Any, Awaitable, Callable, TypeVar + +from typing_extensions import ParamSpec from flytekit.loggers import logger From 686e2c7a0badd7f0976b641267eb91a96ce94d50 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 8 Oct 2024 17:19:34 -0700 Subject: [PATCH 23/26] debugging (#2794) Signed-off-by: Yee Hing Tong --- flytekit/core/type_engine.py | 13 +++---------- flytekit/types/file/file.py | 13 +++++++++---- flytekit/utils/asyn.py | 12 ------------ 3 files changed, 12 insertions(+), 26 deletions(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 652c965330..94e824a30b 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -48,7 +48,7 @@ from flytekit.models.core import types as _core_types from flytekit.models.literals import Binary, Literal, LiteralCollection, LiteralMap, Primitive, Scalar, Union, Void from flytekit.models.types import LiteralType, SimpleType, TypeStructure, UnionType -from flytekit.utils.asyn import ContextExecutor, loop_manager +from flytekit.utils.asyn import loop_manager T = typing.TypeVar("T") DEFINITIONS = "definitions" @@ -1289,11 +1289,7 @@ async def async_to_literal( if isinstance(transformer, AsyncTypeTransformer): lv = await transformer.async_to_literal(ctx, python_val, python_type, expected) else: - # Testing just blocking call - loop = asyncio.get_running_loop() - executor = ContextExecutor() - fut = loop.run_in_executor(executor, transformer.to_literal, ctx, python_val, python_type, expected) - lv = await fut + lv = transformer.to_literal(ctx, python_val, python_type, expected) modify_literal_uris(lv) lv.hash = cls.calculate_hash(python_val, python_type) @@ -1337,10 +1333,7 @@ async def async_to_python_value(cls, ctx: FlyteContext, lv: Literal, expected_py if isinstance(transformer, AsyncTypeTransformer): pv = await transformer.async_to_python_value(ctx, lv, expected_python_type) else: - loop = asyncio.get_running_loop() - executor = ContextExecutor() - fut = loop.run_in_executor(executor, transformer.to_python_value, ctx, lv, expected_python_type) - pv = await fut + pv = transformer.to_python_value(ctx, lv, expected_python_type) return pv @classmethod diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index bacde7089f..602f5bc12e 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -17,7 +17,12 @@ from flytekit.core.constants import MESSAGEPACK from flytekit.core.context_manager import FlyteContext, FlyteContextManager -from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError, get_underlying_type +from flytekit.core.type_engine import ( + AsyncTypeTransformer, + TypeEngine, + TypeTransformerFailedError, + get_underlying_type, +) from flytekit.exceptions.user import FlyteAssertion from flytekit.loggers import logger from flytekit.models.core import types as _core_types @@ -350,7 +355,7 @@ def __str__(self): return self.path -class FlyteFilePathTransformer(TypeTransformer[FlyteFile]): +class FlyteFilePathTransformer(AsyncTypeTransformer[FlyteFile]): def __init__(self): super().__init__(name="FlyteFilePath", t=FlyteFile) @@ -428,7 +433,7 @@ def validate_file_type( if real_type not in expected_type: raise ValueError(f"Incorrect file type, expected {expected_type}, got {real_type}") - def to_literal( + async def async_to_literal( self, ctx: FlyteContext, python_val: typing.Union[FlyteFile, os.PathLike, str], @@ -549,7 +554,7 @@ def from_binary_idl( else: raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`") - def to_python_value( + async def async_to_python_value( self, ctx: FlyteContext, lv: Literal, expected_python_type: typing.Union[typing.Type[FlyteFile], os.PathLike] ) -> FlyteFile: # Handle dataclass attribute access diff --git a/flytekit/utils/asyn.py b/flytekit/utils/asyn.py index aac46981d2..c447db052f 100644 --- a/flytekit/utils/asyn.py +++ b/flytekit/utils/asyn.py @@ -14,9 +14,7 @@ async def async_add(a: int, b: int) -> int: import functools import os import threading -from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager -from contextvars import copy_context from typing import Any, Awaitable, Callable, TypeVar from typing_extensions import ParamSpec @@ -106,13 +104,3 @@ def wrapped(*args: Any, **kwargs: Any) -> T: loop_manager = _AsyncLoopManager() run_sync = loop_manager.run_sync - - -class ContextExecutor(ThreadPoolExecutor): - def __init__(self): - self.context = copy_context() - super().__init__(initializer=self._set_child_context) - - def _set_child_context(self): - for var, value in self.context.items(): - var.set(value) From b952f8db6431d51dac8df95c42f6da3ac39f6611 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 8 Oct 2024 17:41:43 -0700 Subject: [PATCH 24/26] sort the file list Signed-off-by: Yee Hing Tong --- flytekit/tools/script_mode.py | 1 + tests/flytekit/unit/cli/pyflyte/test_script_mode.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/flytekit/tools/script_mode.py b/flytekit/tools/script_mode.py index 7188b5b90d..fa8634361d 100644 --- a/flytekit/tools/script_mode.py +++ b/flytekit/tools/script_mode.py @@ -118,6 +118,7 @@ def ls_files( else: all_files = list_all_files(source_path, deref_symlinks, ignore_group) + all_files.sort() hasher = hashlib.md5() for abspath in all_files: relpath = os.path.relpath(abspath, source_path) diff --git a/tests/flytekit/unit/cli/pyflyte/test_script_mode.py b/tests/flytekit/unit/cli/pyflyte/test_script_mode.py index 74d8aeab73..b063091075 100644 --- a/tests/flytekit/unit/cli/pyflyte/test_script_mode.py +++ b/tests/flytekit/unit/cli/pyflyte/test_script_mode.py @@ -6,6 +6,7 @@ from flytekit.tools.script_mode import ls_files from flytekit.constants import CopyFileDetection + # a pytest fixture that creates a tmp directory and creates # a small file structure in it @pytest.fixture @@ -39,7 +40,7 @@ def test_list_dir(dummy_dir_structure): files, d = ls_files(str(dummy_dir_structure), CopyFileDetection.ALL) assert len(files) == 5 if os.name != "nt": - assert d == "c092f1b85f7c6b2a71881a946c00a855" + assert d == "b6907fd823a45e26c780a4ba62111243" def test_list_filtered_on_modules(dummy_dir_structure): From dce9bf6214dc71f86cd29330a37cf9586142bb74 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Tue, 8 Oct 2024 19:01:35 -0700 Subject: [PATCH 25/26] remove unneeded exception Signed-off-by: Yee Hing Tong --- flytekit/core/promise.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index bd99b7b600..15d4881ac6 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -897,7 +897,6 @@ async def binding_data_from_python_std( return _literals_models.BindingData(scalar=lit.scalar) -# This function cannot be called from an async call stack def binding_from_python_std( ctx: _flyte_context.FlyteContext, var_name: str, @@ -906,13 +905,6 @@ def binding_from_python_std( t_value_type: type, ) -> Tuple[_literals_models.Binding, List[Node]]: nodes: List[Node] = [] - try: - asyncio.get_running_loop() - raise AssertionError("binding_from_python_std cannot be run from within an async call stack") - except RuntimeError as e: - if "no running event loop" not in str(e): - logger.error(f"Unknown RuntimeError {str(e)}") - raise binding_data = run_sync( binding_data_from_python_std, ctx, From 592c64b04dfcb5f378f696225cf97751a39d44a5 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Wed, 9 Oct 2024 09:22:32 -0700 Subject: [PATCH 26/26] lint Signed-off-by: Yee Hing Tong --- flytekit/core/promise.py | 1 - 1 file changed, 1 deletion(-) diff --git a/flytekit/core/promise.py b/flytekit/core/promise.py index 15d4881ac6..afd51cd7cc 100644 --- a/flytekit/core/promise.py +++ b/flytekit/core/promise.py @@ -1,6 +1,5 @@ from __future__ import annotations -import asyncio import collections import datetime import inspect