diff --git a/sdk/python/kfp/components/_structures.py b/sdk/python/kfp/components/_structures.py index aba11fe0bd6..2afc2be5624 100644 --- a/sdk/python/kfp/components/_structures.py +++ b/sdk/python/kfp/components/_structures.py @@ -607,6 +607,8 @@ def _init_outputs(self): task_outputs[output.name] = task_output_arg self.outputs = task_outputs + if len(task_outputs) == 1: + self.output = list(task_outputs.values())[0] class GraphSpec(ModelBase): diff --git a/sdk/python/tests/components/test_components.py b/sdk/python/tests/components/test_components.py index 0e967037438..76f642d303b 100644 --- a/sdk/python/tests/components/test_components.py +++ b/sdk/python/tests/components/test_components.py @@ -14,6 +14,7 @@ import os import sys +import textwrap import unittest from contextlib import contextmanager from pathlib import Path @@ -597,6 +598,53 @@ def test_check_task_spec_outputs_dictionary(self): self.assertEqual(list(task.outputs.keys()), ['out 1', 'out 2']) + def test_check_task_object_no_output_attribute_when_0_outputs(self): + component_text = textwrap.dedent('''\ + implementation: + container: + image: busybox + command: [] + ''', + ) + + op = comp.load_component_from_text(component_text) + task = op() + + self.assertFalse(hasattr(task, 'output')) + + def test_check_task_object_has_output_attribute_when_1_output(self): + component_text = textwrap.dedent('''\ + outputs: + - {name: out 1} + implementation: + container: + image: busybox + command: [touch, {outputPath: out 1}] + ''', + ) + + op = comp.load_component_from_text(component_text) + task = op() + + self.assertEqual(task.output.task_output.output_name, 'out 1') + + def test_check_task_object_no_output_attribute_when_multiple_outputs(self): + component_text = textwrap.dedent('''\ + outputs: + - {name: out 1} + - {name: out 2} + implementation: + container: + image: busybox + command: [touch, {outputPath: out 1}, {outputPath: out 2}] + ''', + ) + + op = comp.load_component_from_text(component_text) + task = op() + + self.assertFalse(hasattr(task, 'output')) + def test_check_type_validation_of_task_spec_outputs(self): producer_component_text = '''\ outputs: