Skip to content

Commit

Permalink
SDK - Components - Fixed handling collection return values (kubeflow#…
Browse files Browse the repository at this point in the history
…3263)

* SDK - Components - Fixed handling collection return values

Fixes kubeflow#3262

* Fixed the tests
  • Loading branch information
Ark-kun authored and Jeffwan committed Dec 9, 2020
1 parent 24a933c commit e28b239
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 32 deletions.
15 changes: 13 additions & 2 deletions sdk/python/kfp/components/_python_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,17 @@ def get_serializer_and_register_definitions(type_name) -> str:
'_output_files = _parsed_args.pop("_output_paths", [])',
])

# Putting singular return values in a list to be "zipped" with the serializers and output paths
outputs_to_list_code = ''
return_ann = inspect.signature(func).return_annotation
if ( # The return type is singular, not sequence
return_ann is not None
and return_ann != inspect.Parameter.empty
and not isinstance(return_ann, dict)
and not hasattr(return_ann, '_fields') # namedtuple
):
outputs_to_list_code = '_outputs = [_outputs]'

output_serialization_code = ''.join(' {},\n'.format(s) for s in output_serialization_expression_strings)

full_source = \
Expand All @@ -589,8 +600,7 @@ def get_serializer_and_register_definitions(type_name) -> str:
_outputs = {func_name}(**_parsed_args)
if not hasattr(_outputs, '__getitem__') or isinstance(_outputs, str):
_outputs = [_outputs]
{outputs_to_list_code}
_output_serializers = [
{output_serialization_code}
Expand All @@ -611,6 +621,7 @@ def get_serializer_and_register_definitions(type_name) -> str:
extra_code=extra_code,
arg_parse_code='\n'.join(arg_parse_code_lines),
output_serialization_code=output_serialization_code,
outputs_to_list_code=outputs_to_list_code,
)

#Removing consecutive blank lines
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ spec:
_outputs = consume(**_parsed_args)
if not hasattr(_outputs, '__getitem__') or isinstance(_outputs, str):
_outputs = [_outputs]
_output_serializers = [
]
Expand Down Expand Up @@ -76,9 +73,6 @@ spec:
_outputs = consume(**_parsed_args)
if not hasattr(_outputs, '__getitem__') or isinstance(_outputs, str):
_outputs = [_outputs]
_output_serializers = [
]
Expand Down Expand Up @@ -121,9 +115,6 @@ spec:
_outputs = consume(**_parsed_args)
if not hasattr(_outputs, '__getitem__') or isinstance(_outputs, str):
_outputs = [_outputs]
_output_serializers = [
]
Expand Down Expand Up @@ -166,9 +157,6 @@ spec:
_outputs = consume(**_parsed_args)
if not hasattr(_outputs, '__getitem__') or isinstance(_outputs, str):
_outputs = [_outputs]
_output_serializers = [
]
Expand Down Expand Up @@ -211,9 +199,6 @@ spec:
_outputs = consume(**_parsed_args)
if not hasattr(_outputs, '__getitem__') or isinstance(_outputs, str):
_outputs = [_outputs]
_output_serializers = [
]
Expand Down Expand Up @@ -256,9 +241,6 @@ spec:
_outputs = consume(**_parsed_args)
if not hasattr(_outputs, '__getitem__') or isinstance(_outputs, str):
_outputs = [_outputs]
_output_serializers = [
]
Expand Down Expand Up @@ -301,9 +283,6 @@ spec:
_outputs = consume(**_parsed_args)
if not hasattr(_outputs, '__getitem__') or isinstance(_outputs, str):
_outputs = [_outputs]
_output_serializers = [
]
Expand Down Expand Up @@ -507,8 +486,7 @@ spec:
_outputs = produce_list_of_dicts(**_parsed_args)
if not hasattr(_outputs, '__getitem__') or isinstance(_outputs, str):
_outputs = [_outputs]
_outputs = [_outputs]
_output_serializers = [
_serialize_json,
Expand Down Expand Up @@ -570,8 +548,7 @@ spec:
_outputs = produce_list_of_ints(**_parsed_args)
if not hasattr(_outputs, '__getitem__') or isinstance(_outputs, str):
_outputs = [_outputs]
_outputs = [_outputs]
_output_serializers = [
_serialize_json,
Expand Down Expand Up @@ -633,8 +610,7 @@ spec:
_outputs = produce_list_of_strings(**_parsed_args)
if not hasattr(_outputs, '__getitem__') or isinstance(_outputs, str):
_outputs = [_outputs]
_outputs = [_outputs]
_output_serializers = [
_serialize_json,
Expand Down Expand Up @@ -690,8 +666,7 @@ spec:
_outputs = produce_str(**_parsed_args)
if not hasattr(_outputs, '__getitem__') or isinstance(_outputs, str):
_outputs = [_outputs]
_outputs = [_outputs]
_output_serializers = [
_serialize_str,
Expand Down
2 changes: 1 addition & 1 deletion sdk/python/tests/components/test_python_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def assert_values_are_same(

def test_handling_list_dict_output_values(self):
def produce_list() -> list:
return (["string", 1, 2.2, True, False, None, [3, 4], {'s': 5}], )
return ["string", 1, 2.2, True, False, None, [3, 4], {'s': 5}]

# ! JSON map keys are always strings. Python converts all keys to strings without warnings
task_factory = comp.func_to_container_op(produce_list)
Expand Down

0 comments on commit e28b239

Please sign in to comment.