-
Notifications
You must be signed in to change notification settings - Fork 301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Async type engine #2752
Async type engine #2752
Changes from all commits
8c7767b
c51a26e
fc83997
3fbb3d5
ae48297
5b35dc8
ce839e8
aa00c97
1cbd92e
8c1c8d9
d7a9fe0
f173719
c5a5a38
6a48da2
9c5e6f9
faedf6e
41bc372
27e16ee
3617300
5b3a6e0
f30f4d5
e163a44
efa2aa2
fb021e8
f0476ed
3c3b5b4
686e2c7
b952f8d
dce9bf6
592c64b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,9 +44,10 @@ | |
from flytekit.models.literals import Binary, Literal, Primitive, Scalar | ||
from flytekit.models.task import Resources | ||
from flytekit.models.types import SimpleType | ||
from flytekit.utils.asyn import loop_manager, run_sync | ||
|
||
|
||
def translate_inputs_to_literals( | ||
async def _translate_inputs_to_literals( | ||
ctx: FlyteContext, | ||
incoming_values: Dict[str, Any], | ||
flyte_interface_types: Dict[str, _interface_models.Variable], | ||
|
@@ -94,16 +95,19 @@ def my_wf(in1: int, in2: int) -> int: | |
t = native_types[k] | ||
try: | ||
if type(v) is Promise: | ||
v = resolve_attr_path_in_promise(v) | ||
result[k] = TypeEngine.to_literal(ctx, v, t, var.type) | ||
v = await resolve_attr_path_in_promise(v) | ||
result[k] = await TypeEngine.async_to_literal(ctx, v, t, var.type) | ||
except TypeTransformerFailedError as exc: | ||
exc.args = (f"Failed argument '{k}': {exc.args[0]}",) | ||
raise | ||
|
||
return result | ||
|
||
|
||
def resolve_attr_path_in_promise(p: Promise) -> Promise: | ||
translate_inputs_to_literals = loop_manager.synced(_translate_inputs_to_literals) | ||
|
||
|
||
async def resolve_attr_path_in_promise(p: Promise) -> Promise: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we be concerned with backward compatibility with changing a function from sync to async? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this particular function is pretty limited in terms of callees so hopefully that's okay. |
||
""" | ||
resolve_attr_path_in_promise resolves the attribute path in a promise and returns a new promise with the resolved value | ||
This is for local execution only. The remote execution will be resolved in flytepropeller. | ||
|
@@ -145,7 +149,9 @@ def resolve_attr_path_in_promise(p: Promise) -> Promise: | |
new_st = resolve_attr_path_in_pb_struct(st, attr_path=p.attr_path[used:]) | ||
literal_type = TypeEngine.to_literal_type(type(new_st)) | ||
# Reconstruct the resolved result to flyte literal (because the resolved result might not be struct) | ||
curr_val = TypeEngine.to_literal(FlyteContextManager.current_context(), new_st, type(new_st), literal_type) | ||
curr_val = await TypeEngine.async_to_literal( | ||
FlyteContextManager.current_context(), new_st, type(new_st), literal_type | ||
) | ||
elif type(curr_val.value.value) is Binary: | ||
binary_idl_obj = curr_val.value.value | ||
if binary_idl_obj.tag == _common_constants.MESSAGEPACK: | ||
|
@@ -786,7 +792,7 @@ def __rshift__(self, other: Any): | |
return Output(*promises) # type: ignore | ||
|
||
|
||
def binding_data_from_python_std( | ||
async def binding_data_from_python_std( | ||
ctx: _flyte_context.FlyteContext, | ||
expected_literal_type: _type_models.LiteralType, | ||
t_value: Any, | ||
|
@@ -821,7 +827,8 @@ def binding_data_from_python_std( | |
# If the value is not a container type, then we can directly convert it to a scalar in the Union case. | ||
# This pushes the handling of the Union types to the type engine. | ||
if not isinstance(t_value, list) and not isinstance(t_value, dict): | ||
scalar = TypeEngine.to_literal(ctx, t_value, t_value_type or type(t_value), expected_literal_type).scalar | ||
lit = await TypeEngine.async_to_literal(ctx, t_value, t_value_type or type(t_value), expected_literal_type) | ||
scalar = lit.scalar | ||
return _literals_models.BindingData(scalar=scalar) | ||
|
||
# If it is a container type, then we need to iterate over the variants in the Union type, try each one. This is | ||
|
@@ -831,7 +838,7 @@ def binding_data_from_python_std( | |
try: | ||
lt_type = expected_literal_type.union_type.variants[i] | ||
python_type = get_args(t_value_type)[i] if t_value_type else None | ||
return binding_data_from_python_std(ctx, lt_type, t_value, python_type, nodes) | ||
return await binding_data_from_python_std(ctx, lt_type, t_value, python_type, nodes) | ||
except Exception: | ||
logger.debug( | ||
f"failed to bind data {t_value} with literal type {expected_literal_type.union_type.variants[i]}." | ||
|
@@ -844,7 +851,9 @@ def binding_data_from_python_std( | |
sub_type: Optional[type] = ListTransformer.get_sub_type_or_none(t_value_type) | ||
collection = _literals_models.BindingDataCollection( | ||
bindings=[ | ||
binding_data_from_python_std(ctx, expected_literal_type.collection_type, t, sub_type or type(t), nodes) | ||
await binding_data_from_python_std( | ||
ctx, expected_literal_type.collection_type, t, sub_type or type(t), nodes | ||
) | ||
for t in t_value | ||
] | ||
) | ||
|
@@ -860,13 +869,13 @@ def binding_data_from_python_std( | |
f"this should be a Dictionary type and it is not: {type(t_value)} vs {expected_literal_type}" | ||
) | ||
if expected_literal_type.simple == _type_models.SimpleType.STRUCT: | ||
lit = TypeEngine.to_literal(ctx, t_value, type(t_value), expected_literal_type) | ||
lit = await TypeEngine.async_to_literal(ctx, t_value, type(t_value), expected_literal_type) | ||
return _literals_models.BindingData(scalar=lit.scalar) | ||
else: | ||
_, v_type = DictTransformer.extract_types_or_metadata(t_value_type) | ||
m = _literals_models.BindingDataMap( | ||
bindings={ | ||
k: binding_data_from_python_std( | ||
k: await binding_data_from_python_std( | ||
ctx, expected_literal_type.map_value_type, v, v_type or type(v), nodes | ||
) | ||
for k, v in t_value.items() | ||
|
@@ -883,8 +892,8 @@ def binding_data_from_python_std( | |
) | ||
|
||
# This is the scalar case - e.g. my_task(in1=5) | ||
scalar = TypeEngine.to_literal(ctx, t_value, t_value_type or type(t_value), expected_literal_type).scalar | ||
return _literals_models.BindingData(scalar=scalar) | ||
lit = await TypeEngine.async_to_literal(ctx, t_value, t_value_type or type(t_value), expected_literal_type) | ||
return _literals_models.BindingData(scalar=lit.scalar) | ||
|
||
|
||
def binding_from_python_std( | ||
|
@@ -895,7 +904,8 @@ def binding_from_python_std( | |
t_value_type: type, | ||
) -> Tuple[_literals_models.Binding, List[Node]]: | ||
nodes: List[Node] = [] | ||
binding_data = binding_data_from_python_std( | ||
binding_data = run_sync( | ||
binding_data_from_python_std, | ||
ctx, | ||
expected_literal_type, | ||
t_value, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be imported as
run_sync
and then use that? (I prefer to hide the loop manager as much as possible)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i guess so. i was hoping to one day figure out type hinting, which doesn't work today in either case, but i think will never work with the run sync way. I tried briefly getting functool wraps to work but alas to no avail.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ohh it was missing a paramspec. it works now. and changed promise.py and base_task.py to just use the run_sync.