Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support e2e run for if_else node in pipeline #26780

Merged
merged 7 commits into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def PipelineJobsField():
],
}

# Note: the private node types only available when private preview flag opened before init of pipeline job
# schema class.
if is_private_preview_enabled():
pipeline_enable_job_type[ControlFlowType.DO_WHILE] = [NestedField(DoWhileSchema, unknown=INCLUDE)]
pipeline_enable_job_type[ControlFlowType.IF_ELSE] = [NestedField(ConditionNodeSchema, unknown=INCLUDE)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ def _create_schema_for_validation(

return ConditionNodeSchema(context=context)

@classmethod
def _from_rest_object(cls, obj: dict) -> "ConditionNode":
return cls(**obj)

def _to_dict(self) -> Dict:
return self._dump_for_validation()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,6 @@ def _get_validation_error_target(cls) -> ErrorTarget:
"""
return ErrorTarget.PIPELINE

@classmethod
def _from_rest_object(cls, obj: dict, reference_node_list: list) -> "ControlFlowNode":
from azure.ai.ml.entities._job.pipeline._load_component import pipeline_node_factory

node_type = obj.get(CommonYamlFields.TYPE, None)
load_from_rest_obj_func = pipeline_node_factory.get_load_from_rest_object_func(_type=node_type)
return load_from_rest_obj_func(obj, reference_node_list)


class LoopNode(ControlFlowNode, ABC):
"""
Expand Down Expand Up @@ -132,3 +124,11 @@ def _get_data_binding_expression_value(expression, regex):
@staticmethod
def _is_loop_node_dict(obj):
return obj.get(CommonYamlFields.TYPE, None) in [ControlFlowType.DO_WHILE]

@classmethod
def _from_rest_object(cls, obj: dict, reference_node_list: list) -> "ControlFlowNode":
from azure.ai.ml.entities._job.pipeline._load_component import pipeline_node_factory

node_type = obj.get(CommonYamlFields.TYPE, None)
load_from_rest_obj_func = pipeline_node_factory.get_load_from_rest_object_func(_type=node_type)
return load_from_rest_obj_func(obj, reference_node_list)
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from azure.ai.ml.dsl._component_func import to_component_func
from azure.ai.ml.dsl._overrides_definition import OverrideDefinition
from azure.ai.ml.entities._builders import BaseNode, Command, Import, Parallel, Spark, Sweep
from azure.ai.ml.entities._builders.condition_node import ConditionNode
from azure.ai.ml.entities._builders.do_while import DoWhile
from azure.ai.ml.entities._builders.pipeline import Pipeline
from azure.ai.ml.entities._component.component import Component
Expand Down Expand Up @@ -81,6 +82,12 @@ def __init__(self):
load_from_rest_object_func=DoWhile._from_rest_object,
nested_schema=None,
)
self.register_type(
_type=ControlFlowType.IF_ELSE,
create_instance_func=None,
load_from_rest_object_func=ConditionNode._from_rest_object,
nested_schema=None,
)

@classmethod
def _get_func(cls, _type: str, funcs):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from azure.ai.ml.constants._component import ComponentSource
from azure.ai.ml.constants._job.pipeline import ValidationErrorCode
from azure.ai.ml.entities._builders import BaseNode
from azure.ai.ml.entities._builders.condition_node import ConditionNode
from azure.ai.ml.entities._builders.control_flow_node import LoopNode
from azure.ai.ml.entities._builders.import_node import Import
from azure.ai.ml.entities._builders.parallel import Parallel
Expand Down Expand Up @@ -266,11 +267,13 @@ def _customized_validate(self) -> MutableValidationResult:

def _validate_input(self):
validation_result = self._create_empty_validation_result()
# TODO(1979547): refine this logic: not all nodes have `_get_input_binding_dict` method
used_pipeline_inputs = set(
itertools.chain(
*[
self.component._get_input_binding_dict(node if not isinstance(node, LoopNode) else node.body)[0]
for node in self.jobs.values()
for node in self.jobs.values() if not isinstance(node, ConditionNode)
# condition node has no inputs
]
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

from .._utils._experimental import experimental
from .._utils.utils import is_data_binding_expression
from ..entities._builders.condition_node import ConditionNode
from ..entities._component.automl_component import AutoMLComponent
from ..entities._component.pipeline_component import PipelineComponent
from ._code_operations import CodeOperations
Expand Down Expand Up @@ -531,6 +532,8 @@ def resolve_base_node(name, node: BaseNode):
self._job_operations._resolve_arm_id_for_automl_job(job_instance, resolver, inside_pipeline=True)
elif isinstance(job_instance, BaseNode):
resolve_base_node(key, job_instance)
elif isinstance(job_instance, ConditionNode):
pass
else:
msg = f"Non supported job type in Pipeline: {type(job_instance)}"
raise ComponentException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@

from .._utils._experimental import experimental
from ..constants._component import ComponentSource
from ..entities._builders.condition_node import ConditionNode
from ..entities._job.pipeline._io import InputOutputBase, _GroupAttrDict, PipelineInput
from ._component_operations import ComponentOperations
from ._compute_operations import ComputeOperations
Expand Down Expand Up @@ -419,7 +420,7 @@ def _validate(

for node_name, node in job.jobs.items():
try:
if not isinstance(node, DoWhile):
if not isinstance(node, (DoWhile, ConditionNode)):
wangchao1230 marked this conversation as resolved.
Show resolved Hide resolved
node.compute = self._try_get_compute_arm_id(node.compute)
except Exception as e: # pylint: disable=broad-except
validation_result.append_error(yaml_path=f"jobs.{node_name}.compute", message=str(e))
Expand Down
3 changes: 1 addition & 2 deletions sdk/ml/azure-ai-ml/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
from azure.ai.ml import MLClient, load_component, load_job
from azure.ai.ml._restclient.registry_discovery import AzureMachineLearningWorkspaces as ServiceClientRegistryDiscovery
from azure.ai.ml._scope_dependent_operations import OperationConfig, OperationScope
from azure.ai.ml._utils._asset_utils import get_object_hash
from azure.ai.ml._utils.utils import hash_dict
from azure.ai.ml.constants._common import GitProperties
from azure.ai.ml.entities import AzureBlobDatastore, Component
from azure.ai.ml.entities._assets import Data, Model
from azure.ai.ml.entities._component.parallel_component import ParallelComponent
Expand Down Expand Up @@ -526,6 +524,7 @@ def credentialless_datastore(client: MLClient, storage_account_name: str) -> Azu
def enable_pipeline_private_preview_features(mocker: MockFixture):
mocker.patch("azure.ai.ml.entities._job.pipeline.pipeline_job.is_private_preview_enabled", return_value=True)
mocker.patch("azure.ai.ml.dsl._pipeline_component_builder.is_private_preview_enabled", return_value=True)
mocker.patch("azure.ai.ml._schema.pipeline.pipeline_component.is_private_preview_enabled", return_value=True)


@pytest.fixture()
Expand Down
145 changes: 145 additions & 0 deletions sdk/ml/azure-ai-ml/tests/dsl/e2etests/test_dynamic_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import contextlib
import pytest

from azure.ai.ml._schema.pipeline import PipelineJobSchema
from azure.ai.ml.constants._common import AZUREML_PRIVATE_FEATURES_ENV_VAR
from azure.ai.ml.dsl._utils import environment_variable_overwrite
from .._util import _DSL_TIMEOUT_SECOND
from test_utilities.utils import _PYTEST_TIMEOUT_METHOD, omit_with_wildcard
from azure.ai.ml._schema.pipeline.pipeline_component import PipelineJobsField
from devtools_testutils import AzureRecordedTestCase

from azure.ai.ml import MLClient, load_component
from azure.ai.ml.dsl import pipeline
from azure.ai.ml.dsl._condition import condition


@contextlib.contextmanager
def include_private_preview_nodes_in_pipeline():
with environment_variable_overwrite(AZUREML_PRIVATE_FEATURES_ENV_VAR, "True"):
PipelineJobSchema._declared_fields["jobs"] = PipelineJobsField()

try:
yield
finally:
PipelineJobSchema._declared_fields["jobs"] = PipelineJobsField()


@pytest.mark.usefixtures(
"enable_environment_id_arm_expansion",
"enable_pipeline_private_preview_features",
"mock_code_hash",
"mock_component_hash",
"recorded_test",
)
@pytest.mark.timeout(timeout=_DSL_TIMEOUT_SECOND, method=_PYTEST_TIMEOUT_METHOD)
@pytest.mark.e2etest
class TestDynamicPipeline(AzureRecordedTestCase):
def test_dsl_condition_pipeline(self, client: MLClient):
# update jobs field to include private preview nodes

hello_world_component_no_paths = load_component(
source="./tests/test_configs/components/helloworld_component_no_paths.yml"
)
basic_component = load_component(
source="./tests/test_configs/components/component_with_conditional_output/spec.yaml"
)

@pipeline(
name="test_mldesigner_component_with_conditional_output",
compute="cpu-cluster",
wangchao1230 marked this conversation as resolved.
Show resolved Hide resolved
)
def condition_pipeline():
result = basic_component(str_param="abc", int_param=1)

node1 = hello_world_component_no_paths(component_in_number=1)
node2 = hello_world_component_no_paths(component_in_number=2)
condition(condition=result.outputs.output, false_block=node1, true_block=node2)

pipeline_job = condition_pipeline()

# include private preview nodes
with include_private_preview_nodes_in_pipeline():
pipeline_job = client.jobs.create_or_update(pipeline_job)

omit_fields = [
"name",
"properties.display_name",
"properties.jobs.*.componentId",
"properties.settings",
]
dsl_pipeline_job_dict = omit_with_wildcard(pipeline_job._to_rest_object().as_dict(), *omit_fields)
assert dsl_pipeline_job_dict["properties"]["jobs"] == {
"conditionnode": {
"condition": "${{parent.jobs.result.outputs.output}}",
"false_block": "${{parent.jobs.node1}}",
"true_block": "${{parent.jobs.node2}}",
"type": "if_else",
},
"node1": {
"_source": "REMOTE.WORKSPACE.COMPONENT",
"computeId": None,
"display_name": None,
"distribution": None,
"environment_variables": {},
"inputs": {"component_in_number": {"job_input_type": "literal", "value": "1"}},
"limits": None,
"name": "node1",
"outputs": {},
"resources": None,
"tags": {},
"type": "command",
"properties": {},
},
"node2": {
"_source": "REMOTE.WORKSPACE.COMPONENT",
"computeId": None,
"display_name": None,
"distribution": None,
"environment_variables": {},
"inputs": {"component_in_number": {"job_input_type": "literal", "value": "2"}},
"limits": None,
"name": "node2",
"outputs": {},
"resources": None,
"tags": {},
"type": "command",
"properties": {},
},
"result": {
"_source": "REMOTE.WORKSPACE.COMPONENT",
"computeId": None,
"display_name": None,
"distribution": None,
"environment_variables": {},
"inputs": {
"int_param": {"job_input_type": "literal", "value": "1"},
"str_param": {"job_input_type": "literal", "value": "abc"},
},
"limits": None,
"name": "result",
"outputs": {},
"resources": None,
"tags": {},
"type": "command",
"properties": {},
},
}

@pytest.mark.skip(reason="TODO(2027778): Verify after primitive condition is supported.")
def test_dsl_condition_pipeline_with_primitive_input(self, client: MLClient):
hello_world_component_no_paths = load_component(
source="./tests/test_configs/components/helloworld_component_no_paths.yml"
)

@pipeline(
name="test_mldesigner_component_with_conditional_output",
compute="cpu-cluster",
)
def condition_pipeline():
node1 = hello_world_component_no_paths(component_in_number=1)
node2 = hello_world_component_no_paths(component_in_number=2)
condition(condition=True, false_block=node1, true_block=node2)

pipeline_job = condition_pipeline()
client.jobs.create_or_update(pipeline_job)
Loading