Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Annotations o0 n0 #307

Merged
merged 2 commits into from
Dec 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flytekit/annotated/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def create_branch_node_promise_var(node_id: str, var: str) -> str:

When building bindings for the branch node, the inputs to the conditions (e.g. (x==5)) need to have variable names
(e.g. x). Because it's currently infeasible to get the name (e.g. x), we resolve to using the referenced node's
output name (e.g. out_0, my_out,... etc.). In order to avoid naming collisions (in cases when, for example, the
output name (e.g. o0, my_out,... etc.). In order to avoid naming collisions (in cases when, for example, the
conditions reference two outputs of two different nodes named the same), we build a variable name composed of the
referenced node name + '.' + the referenced output name. Ideally we use something like
(https://github.com/pwwang/python-varname) to retrieve the assigned variable name (e.g. x). However, because of
Expand Down
2 changes: 1 addition & 1 deletion flytekit/annotated/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class CompilationState(object):
def __init__(self, prefix: str):
"""
:param prefix: This is because we may one day want to be able to have subworkflows inside other workflows. If
users choose to not specify their node names, then we can end up with multiple "node-0"s. This prefix allows
users choose to not specify their node names, then we can end up with multiple "n0"s. This prefix allows
us to give those nested nodes a distinct name, as well as properly identify them in the workflow.
# TODO: Ketan to revisit this whole concept when we re-organize the new structure
"""
Expand Down
2 changes: 1 addition & 1 deletion flytekit/annotated/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def transform_type(x: type, description: str = None) -> _interface_models.Variab


def default_output_name(index: int = 0) -> str:
return f"out_{index}"
return f"o{index}"


def output_name_generator(length: int) -> Generator[str, None, None]:
Expand Down
2 changes: 1 addition & 1 deletion flytekit/annotated/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def create_and_link_node(
# TODO: Clean up NodeOutput dependency on SdkNode, then rename variable
non_sdk_node = Node(
# TODO: Better naming, probably a derivative of the function name.
id=f"{ctx.compilation_state.prefix}node-{len(ctx.compilation_state.nodes)}",
id=f"{ctx.compilation_state.prefix}n{len(ctx.compilation_state.nodes)}",
metadata=node_metadata,
bindings=sorted(bindings, key=lambda b: b.var),
upstream_nodes=upstream_nodes, # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion flytekit/annotated/node_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def sub_wf():
needs to be dereferenced by the output name.

t1_node = create_node(t1)
t2(t1_node.out_0)
t2(t1_node.o0)

"""
if len(args) > 0:
Expand Down
36 changes: 18 additions & 18 deletions tests/flytekit/unit/annotated/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,39 +28,39 @@ def t() -> List[int]:

return_type = extract_return_annotation(inspect.signature(t).return_annotation)
assert len(return_type) == 1
assert return_type["out_0"]._name == "List"
assert return_type["out_0"].__origin__ == list
assert return_type["o0"]._name == "List"
assert return_type["o0"].__origin__ == list

def t() -> Dict[str, int]:
...

return_type = extract_return_annotation(inspect.signature(t).return_annotation)
assert len(return_type) == 1
assert return_type["out_0"]._name == "Dict"
assert return_type["out_0"].__origin__ == dict
assert return_type["o0"]._name == "Dict"
assert return_type["o0"].__origin__ == dict

def t(a: int, b: str) -> typing.Tuple[int, str]:
...

return_type = extract_return_annotation(inspect.signature(t).return_annotation)
assert len(return_type) == 2
assert return_type["out_0"] == int
assert return_type["out_1"] == str
assert return_type["o0"] == int
assert return_type["o1"] == str

def t(a: int, b: str) -> (int, str):
...

return_type = extract_return_annotation(inspect.signature(t).return_annotation)
assert len(return_type) == 2
assert return_type["out_0"] == int
assert return_type["out_1"] == str
assert return_type["o0"] == int
assert return_type["o1"] == str

def t(a: int, b: str) -> str:
...

return_type = extract_return_annotation(inspect.signature(t).return_annotation)
assert len(return_type) == 1
assert return_type["out_0"] == str
assert return_type["o0"] == str

def t(a: int, b: str) -> None:
...
Expand All @@ -73,14 +73,14 @@ def t(a: int, b: str) -> List[int]:

return_type = extract_return_annotation(inspect.signature(t).return_annotation)
assert len(return_type) == 1
assert return_type["out_0"] == List[int]
assert return_type["o0"] == List[int]

def t(a: int, b: str) -> Dict[str, int]:
...

return_type = extract_return_annotation(inspect.signature(t).return_annotation)
assert len(return_type) == 1
assert return_type["out_0"] == Dict[str, int]
assert return_type["o0"] == Dict[str, int]


def test_named_tuples():
Expand All @@ -106,41 +106,41 @@ def z(a: int, b: str) -> typing.Tuple[int, str]:
return 5, "hello world"

result = transform_variable_map(extract_return_annotation(inspect.signature(z).return_annotation))
assert result["out_0"].type.simple == 1
assert result["out_1"].type.simple == 3
assert result["o0"].type.simple == 1
assert result["o1"].type.simple == 3


def test_regular_tuple():
def q(a: int, b: str) -> (int, str):
return 5, "hello world"

result = transform_variable_map(extract_return_annotation(inspect.signature(q).return_annotation))
assert result["out_0"].type.simple == 1
assert result["out_1"].type.simple == 3
assert result["o0"].type.simple == 1
assert result["o1"].type.simple == 3


def test_single_output_new_decorator():
def q(a: int, b: str) -> int:
return a + len(b)

result = transform_variable_map(extract_return_annotation(inspect.signature(q).return_annotation))
assert result["out_0"].type.simple == 1
assert result["o0"].type.simple == 1


def test_sig_files():
def q() -> os.PathLike:
...

result = transform_variable_map(extract_return_annotation(inspect.signature(q).return_annotation))
assert isinstance(result["out_0"].type.blob, _core_types.BlobType)
assert isinstance(result["o0"].type.blob, _core_types.BlobType)


def test_file_types():
def t1() -> FlyteFile["svg"]:
...

return_type = extract_return_annotation(inspect.signature(t1).return_annotation)
assert return_type["out_0"].extension() == FlyteFile["svg"].extension()
assert return_type["o0"].extension() == FlyteFile["svg"].extension()


def test_parameters_and_defaults():
Expand Down
12 changes: 6 additions & 6 deletions tests/flytekit/unit/annotated/test_node_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def t1(a: str) -> str:
@workflow
def my_wf(a: str) -> str:
t1_node = create_node(t1, a=a)
return t1_node.out_0
return t1_node.o0

r = my_wf(a="hello")
assert r == "hello world"
Expand Down Expand Up @@ -66,12 +66,12 @@ def empty_wf2():
)
):
sdk_wf = empty_wf.get_registerable_entity()
assert sdk_wf.nodes[0].upstream_node_ids[0] == "node-1"
assert sdk_wf.nodes[0].id == "node-0"
assert sdk_wf.nodes[0].upstream_node_ids[0] == "n1"
assert sdk_wf.nodes[0].id == "n0"

sdk_wf = empty_wf2.get_registerable_entity()
assert sdk_wf.nodes[0].upstream_node_ids[0] == "node-1"
assert sdk_wf.nodes[0].id == "node-0"
assert sdk_wf.nodes[0].upstream_node_ids[0] == "n1"
assert sdk_wf.nodes[0].id == "n0"


def test_more_normal_task():
Expand All @@ -96,7 +96,7 @@ def my_wf(a: int, b: str) -> (str, str):
t1_node = create_node(t1, a=a)
t1_nt_node = create_node(t1_nt, a=a)
t2_node = create_node(t2, a=[t1_node.t1_str_output, t1_nt_node.t1_str_output, b])
return t1_node.t1_str_output, t2_node.out_0
return t1_node.t1_str_output, t2_node.o0

x = my_wf(a=5, b="hello")
assert x == ("7", "7 7 hello")
4 changes: 2 additions & 2 deletions tests/flytekit/unit/annotated/test_references.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def ref_t1(a: typing.List[str]) -> str:
ss = ref_t1.get_registerable_entity()
assert ss.id == ref_t1.id
assert ss.interface.inputs["a"] is not None
assert ss.interface.outputs["out_0"] is not None
assert ss.interface.outputs["o0"] is not None

registration_settings = context_manager.RegistrationSettings(
project="proj",
Expand Down Expand Up @@ -181,7 +181,7 @@ def test_ref_plain_two_outputs():
assert xx.ref.sdk_node is yy.ref.sdk_node
assert xx.var == "x"
assert yy.var == "y"
assert xx.ref.node_id == "node-0"
assert xx.ref.node_id == "n0"
assert len(xx.ref.sdk_node.bindings) == 2

@task
Expand Down
2 changes: 1 addition & 1 deletion tests/flytekit/unit/annotated/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def my_wf(a: int, b: str) -> (int, str):
with ctx.current_context().new_registration_settings(registration_settings=registration_settings):
wf = my_wf.get_registerable_entity()
assert wf is not None
assert wf.nodes[1].inputs[0].var == "node-0.t1_int_output"
assert wf.nodes[1].inputs[0].var == "n0.t1_int_output"


def test_serialization_branch():
Expand Down
4 changes: 2 additions & 2 deletions tests/flytekit/unit/annotated/test_type_hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,11 @@ def my_wf(a: int, b: str) -> (int, str):
return x, d

assert len(my_wf._nodes) == 2
assert my_wf._nodes[0].id == "node-0"
assert my_wf._nodes[0].id == "n0"
assert my_wf._nodes[1]._upstream_nodes[0] is my_wf._nodes[0]

assert len(my_wf._output_bindings) == 2
assert my_wf._output_bindings[0].var == "out_0"
assert my_wf._output_bindings[0].var == "o0"
assert my_wf._output_bindings[0].binding.promise.var == "t1_int_output"

nt = typing.NamedTuple("SingleNT", t1_int_output=float)
Expand Down
6 changes: 3 additions & 3 deletions tests/flytekit/unit/taskplugins/hive/test_hive_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ def my_wf(in_schema: FlyteSchema, ds: str) -> FlyteSchema:
assert len(sdk_task.interface.outputs) == 1

sdk_wf = my_wf.get_registerable_entity()
assert sdk_wf.interface.outputs["out_0"].type.schema is not None
assert sdk_wf.outputs[0].var == "out_0"
assert sdk_wf.outputs[0].binding.promise.node_id == "node-0"
assert sdk_wf.interface.outputs["o0"].type.schema is not None
assert sdk_wf.outputs[0].var == "o0"
assert sdk_wf.outputs[0].binding.promise.node_id == "n0"
assert sdk_wf.outputs[0].binding.promise.var == "results"


Expand Down
4 changes: 2 additions & 2 deletions tests/flytekit/unit/taskplugins/sagemaker/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def my_custom_trainer(x: int) -> int:
return x

assert my_custom_trainer.python_interface.inputs == {"x": int}
assert my_custom_trainer.python_interface.outputs == {"out_0": int}
assert my_custom_trainer.python_interface.outputs == {"o0": int}

assert my_custom_trainer(x=10) == 10

Expand Down Expand Up @@ -100,7 +100,7 @@ def my_custom_trainer(x: int) -> int:
return x

assert my_custom_trainer.python_interface.inputs == {"x": int}
assert my_custom_trainer.python_interface.outputs == {"out_0": int}
assert my_custom_trainer.python_interface.outputs == {"o0": int}

assert my_custom_trainer(x=10) == 10

Expand Down