Skip to content

Commit

Permalink
Support e2e run for if_else node in pipeline (#26780)
Browse files Browse the repository at this point in the history
* add e2e test

* update test

* fix test

* fix pipeline tests

* add comment

* fix test
  • Loading branch information
D-W- authored Oct 18, 2022
1 parent 7549cb1 commit 7011ee5
Show file tree
Hide file tree
Showing 12 changed files with 1,463 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def PipelineJobsField():
],
}

# Note: the private node types only available when private preview flag opened before init of pipeline job
# schema class.
if is_private_preview_enabled():
pipeline_enable_job_type[ControlFlowType.DO_WHILE] = [NestedField(DoWhileSchema, unknown=INCLUDE)]
pipeline_enable_job_type[ControlFlowType.IF_ELSE] = [NestedField(ConditionNodeSchema, unknown=INCLUDE)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ def _create_schema_for_validation(

return ConditionNodeSchema(context=context)

@classmethod
def _from_rest_object(cls, obj: dict) -> "ConditionNode":
return cls(**obj)

def _to_dict(self) -> Dict:
return self._dump_for_validation()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,6 @@ def _get_validation_error_target(cls) -> ErrorTarget:
"""
return ErrorTarget.PIPELINE

@classmethod
def _from_rest_object(cls, obj: dict, reference_node_list: list) -> "ControlFlowNode":
from azure.ai.ml.entities._job.pipeline._load_component import pipeline_node_factory

node_type = obj.get(CommonYamlFields.TYPE, None)
load_from_rest_obj_func = pipeline_node_factory.get_load_from_rest_object_func(_type=node_type)
return load_from_rest_obj_func(obj, reference_node_list)


class LoopNode(ControlFlowNode, ABC):
"""
Expand Down Expand Up @@ -132,3 +124,11 @@ def _get_data_binding_expression_value(expression, regex):
@staticmethod
def _is_loop_node_dict(obj):
return obj.get(CommonYamlFields.TYPE, None) in [ControlFlowType.DO_WHILE]

@classmethod
def _from_rest_object(cls, obj: dict, reference_node_list: list) -> "ControlFlowNode":
from azure.ai.ml.entities._job.pipeline._load_component import pipeline_node_factory

node_type = obj.get(CommonYamlFields.TYPE, None)
load_from_rest_obj_func = pipeline_node_factory.get_load_from_rest_object_func(_type=node_type)
return load_from_rest_obj_func(obj, reference_node_list)
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from azure.ai.ml.dsl._component_func import to_component_func
from azure.ai.ml.dsl._overrides_definition import OverrideDefinition
from azure.ai.ml.entities._builders import BaseNode, Command, Import, Parallel, Spark, Sweep
from azure.ai.ml.entities._builders.condition_node import ConditionNode
from azure.ai.ml.entities._builders.do_while import DoWhile
from azure.ai.ml.entities._builders.pipeline import Pipeline
from azure.ai.ml.entities._component.component import Component
Expand Down Expand Up @@ -81,6 +82,12 @@ def __init__(self):
load_from_rest_object_func=DoWhile._from_rest_object,
nested_schema=None,
)
self.register_type(
_type=ControlFlowType.IF_ELSE,
create_instance_func=None,
load_from_rest_object_func=ConditionNode._from_rest_object,
nested_schema=None,
)

@classmethod
def _get_func(cls, _type: str, funcs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from azure.ai.ml.constants._component import ComponentSource
from azure.ai.ml.constants._job.pipeline import ValidationErrorCode
from azure.ai.ml.entities._builders import BaseNode
from azure.ai.ml.entities._builders.condition_node import ConditionNode
from azure.ai.ml.entities._builders.control_flow_node import LoopNode
from azure.ai.ml.entities._builders.import_node import Import
from azure.ai.ml.entities._builders.parallel import Parallel
Expand Down Expand Up @@ -266,11 +267,13 @@ def _customized_validate(self) -> MutableValidationResult:

def _validate_input(self):
validation_result = self._create_empty_validation_result()
# TODO(1979547): refine this logic: not all nodes have `_get_input_binding_dict` method
used_pipeline_inputs = set(
itertools.chain(
*[
self.component._get_input_binding_dict(node if not isinstance(node, LoopNode) else node.body)[0]
for node in self.jobs.values()
for node in self.jobs.values() if not isinstance(node, ConditionNode)
# condition node has no inputs
]
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

from .._utils._experimental import experimental
from .._utils.utils import is_data_binding_expression
from ..entities._builders.condition_node import ConditionNode
from ..entities._component.automl_component import AutoMLComponent
from ..entities._component.pipeline_component import PipelineComponent
from ._code_operations import CodeOperations
Expand Down Expand Up @@ -531,6 +532,8 @@ def resolve_base_node(name, node: BaseNode):
self._job_operations._resolve_arm_id_for_automl_job(job_instance, resolver, inside_pipeline=True)
elif isinstance(job_instance, BaseNode):
resolve_base_node(key, job_instance)
elif isinstance(job_instance, ConditionNode):
pass
else:
msg = f"Non supported job type in Pipeline: {type(job_instance)}"
raise ComponentException(
Expand Down
4 changes: 3 additions & 1 deletion sdk/ml/azure-ai-ml/azure/ai/ml/operations/_job_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@

from .._utils._experimental import experimental
from ..constants._component import ComponentSource
from ..entities._builders.condition_node import ConditionNode
from ..entities._job.pipeline._io import InputOutputBase, _GroupAttrDict, PipelineInput
from ._component_operations import ComponentOperations
from ._compute_operations import ComputeOperations
Expand Down Expand Up @@ -419,7 +420,8 @@ def _validate(

for node_name, node in job.jobs.items():
try:
if not isinstance(node, DoWhile):
# TODO(1979547): refactor, not all nodes have compute
if not isinstance(node, (DoWhile, ConditionNode)):
node.compute = self._try_get_compute_arm_id(node.compute)
except Exception as e: # pylint: disable=broad-except
validation_result.append_error(yaml_path=f"jobs.{node_name}.compute", message=str(e))
Expand Down
3 changes: 1 addition & 2 deletions sdk/ml/azure-ai-ml/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
from azure.ai.ml import MLClient, load_component, load_job
from azure.ai.ml._restclient.registry_discovery import AzureMachineLearningWorkspaces as ServiceClientRegistryDiscovery
from azure.ai.ml._scope_dependent_operations import OperationConfig, OperationScope
from azure.ai.ml._utils._asset_utils import get_object_hash
from azure.ai.ml._utils.utils import hash_dict
from azure.ai.ml.constants._common import GitProperties
from azure.ai.ml.entities import AzureBlobDatastore, Component
from azure.ai.ml.entities._assets import Data, Model
from azure.ai.ml.entities._component.parallel_component import ParallelComponent
Expand Down Expand Up @@ -526,6 +524,7 @@ def credentialless_datastore(client: MLClient, storage_account_name: str) -> Azu
def enable_pipeline_private_preview_features(mocker: MockFixture):
mocker.patch("azure.ai.ml.entities._job.pipeline.pipeline_job.is_private_preview_enabled", return_value=True)
mocker.patch("azure.ai.ml.dsl._pipeline_component_builder.is_private_preview_enabled", return_value=True)
mocker.patch("azure.ai.ml._schema.pipeline.pipeline_component.is_private_preview_enabled", return_value=True)


@pytest.fixture()
Expand Down
143 changes: 143 additions & 0 deletions sdk/ml/azure-ai-ml/tests/dsl/e2etests/test_dynamic_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import contextlib
import pytest

from azure.ai.ml._schema.pipeline import PipelineJobSchema
from .._util import _DSL_TIMEOUT_SECOND
from test_utilities.utils import _PYTEST_TIMEOUT_METHOD, omit_with_wildcard
from azure.ai.ml._schema.pipeline.pipeline_component import PipelineJobsField
from devtools_testutils import AzureRecordedTestCase

from azure.ai.ml import MLClient, load_component
from azure.ai.ml.dsl import pipeline
from azure.ai.ml.dsl._condition import condition


@contextlib.contextmanager
def include_private_preview_nodes_in_pipeline():
original_jobs = PipelineJobSchema._declared_fields["jobs"]
PipelineJobSchema._declared_fields["jobs"] = PipelineJobsField()

try:
yield
finally:
PipelineJobSchema._declared_fields["jobs"] = original_jobs


@pytest.mark.usefixtures(
"enable_environment_id_arm_expansion",
"enable_pipeline_private_preview_features",
"mock_code_hash",
"mock_component_hash",
"recorded_test",
)
@pytest.mark.timeout(timeout=_DSL_TIMEOUT_SECOND, method=_PYTEST_TIMEOUT_METHOD)
@pytest.mark.e2etest
class TestDynamicPipeline(AzureRecordedTestCase):
def test_dsl_condition_pipeline(self, client: MLClient):
# update jobs field to include private preview nodes

hello_world_component_no_paths = load_component(
source="./tests/test_configs/components/helloworld_component_no_paths.yml"
)
basic_component = load_component(
source="./tests/test_configs/components/component_with_conditional_output/spec.yaml"
)

@pipeline(
name="test_mldesigner_component_with_conditional_output",
compute="cpu-cluster",
)
def condition_pipeline():
result = basic_component(str_param="abc", int_param=1)

node1 = hello_world_component_no_paths(component_in_number=1)
node2 = hello_world_component_no_paths(component_in_number=2)
condition(condition=result.outputs.output, false_block=node1, true_block=node2)

pipeline_job = condition_pipeline()

# include private preview nodes
with include_private_preview_nodes_in_pipeline():
pipeline_job = client.jobs.create_or_update(pipeline_job)

omit_fields = [
"name",
"properties.display_name",
"properties.jobs.*.componentId",
"properties.settings",
]
dsl_pipeline_job_dict = omit_with_wildcard(pipeline_job._to_rest_object().as_dict(), *omit_fields)
assert dsl_pipeline_job_dict["properties"]["jobs"] == {
"conditionnode": {
"condition": "${{parent.jobs.result.outputs.output}}",
"false_block": "${{parent.jobs.node1}}",
"true_block": "${{parent.jobs.node2}}",
"type": "if_else",
},
"node1": {
"_source": "REMOTE.WORKSPACE.COMPONENT",
"computeId": None,
"display_name": None,
"distribution": None,
"environment_variables": {},
"inputs": {"component_in_number": {"job_input_type": "literal", "value": "1"}},
"limits": None,
"name": "node1",
"outputs": {},
"resources": None,
"tags": {},
"type": "command",
"properties": {},
},
"node2": {
"_source": "REMOTE.WORKSPACE.COMPONENT",
"computeId": None,
"display_name": None,
"distribution": None,
"environment_variables": {},
"inputs": {"component_in_number": {"job_input_type": "literal", "value": "2"}},
"limits": None,
"name": "node2",
"outputs": {},
"resources": None,
"tags": {},
"type": "command",
"properties": {},
},
"result": {
"_source": "REMOTE.WORKSPACE.COMPONENT",
"computeId": None,
"display_name": None,
"distribution": None,
"environment_variables": {},
"inputs": {
"int_param": {"job_input_type": "literal", "value": "1"},
"str_param": {"job_input_type": "literal", "value": "abc"},
},
"limits": None,
"name": "result",
"outputs": {},
"resources": None,
"tags": {},
"type": "command",
"properties": {},
},
}

@pytest.mark.skip(reason="TODO(2027778): Verify after primitive condition is supported.")
def test_dsl_condition_pipeline_with_primitive_input(self, client: MLClient):
hello_world_component_no_paths = load_component(
source="./tests/test_configs/components/helloworld_component_no_paths.yml"
)

@pipeline(
name="test_mldesigner_component_with_conditional_output",
compute="cpu-cluster",
)
def condition_pipeline():
node1 = hello_world_component_no_paths(component_in_number=1)
node2 = hello_world_component_no_paths(component_in_number=2)
condition(condition=True, false_block=node1, true_block=node2)

pipeline_job = condition_pipeline()
client.jobs.create_or_update(pipeline_job)
Loading

0 comments on commit 7011ee5

Please sign in to comment.