Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ML][Pipelines] dsl.pipeline support pass non_pipeline_parameters #26920

Merged
merged 12 commits into from
Oct 24, 2022
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def __init__(
default_datastore=None,
tags=None,
source_path=None,
non_pipeline_inputs=None
):
self.func = func
name = name if name else func.__name__
Expand All @@ -155,6 +156,7 @@ def __init__(
name = func.__name__
# List of nodes, order by it's creation order in pipeline.
self.nodes = []
self.non_pipeline_parameter_names = non_pipeline_inputs or []
# A dict of inputs name to InputDefinition.
# TODO: infer pipeline component input meta from assignment
self.inputs = self._build_inputs(func)
Expand All @@ -181,10 +183,10 @@ def add_node(self, node: Union[BaseNode, AutoMLJob]):
"""
self.nodes.append(node)

def build(self) -> PipelineComponent:
def build(self, non_pipeline_params_dict=None) -> PipelineComponent:
# Clear nodes as we may call build multiple times.
self.nodes = []
kwargs = _build_pipeline_parameter(self.func, self._get_group_parameter_defaults())
kwargs = _build_pipeline_parameter(self.func, self._get_group_parameter_defaults(), non_pipeline_params_dict)
# We use this stack to store the dsl pipeline definition hierarchy
_definition_builder_stack.push(self)

Expand Down Expand Up @@ -218,7 +220,7 @@ def build(self) -> PipelineComponent:
return pipeline_component

def _build_inputs(self, func):
inputs = _get_param_with_standard_annotation(func, is_func=True)
inputs = _get_param_with_standard_annotation(func, is_func=True, skip_params=self.non_pipeline_parameter_names)
for k, v in inputs.items():
# add arg description
if k in self._args_description:
Expand Down Expand Up @@ -379,7 +381,7 @@ def _get_name_or_component_name(node: Union[BaseNode, AutoMLJob]):
return result


def _build_pipeline_parameter(func, kwargs=None):
def _build_pipeline_parameter(func, kwargs=None, non_pipeline_parameter_dict=None):
# Pass group defaults into kwargs to support group.item can be used even if no default on function.
# example:
# @parameter_group
Expand All @@ -391,9 +393,14 @@ def _build_pipeline_parameter(func, kwargs=None):
# component_func(input=param.key) <--- param.key should be val.

# transform kwargs
transformed_kwargs = {}
transformed_kwargs = non_pipeline_parameter_dict or {}
if kwargs:
transformed_kwargs.update({key: _wrap_pipeline_parameter(key, value) for key, value in kwargs.items()})
transformed_kwargs.update(
{
key: _wrap_pipeline_parameter(key, value) for key, value in kwargs.items()
if key not in non_pipeline_parameter_dict
}
)

def all_params(parameters):
for value in parameters.values():
Expand Down
24 changes: 19 additions & 5 deletions sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_pipeline_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from functools import wraps
from inspect import Parameter, signature
from pathlib import Path
from typing import Any, Callable, Dict, TypeVar
from typing import Any, Callable, Dict, TypeVar, List

from azure.ai.ml.entities import Data, PipelineJob, PipelineJobSettings
from azure.ai.ml.entities._builders.pipeline import Pipeline
Expand All @@ -24,6 +24,8 @@
UnexpectedKeywordError,
UnsupportedParameterKindError,
UserErrorException,
NonExistParamValueError,
UnExpectedNonPipelineParameterTypeError,
)

from ._pipeline_component_builder import PipelineComponentBuilder, _is_inside_dsl_pipeline_func
Expand Down Expand Up @@ -112,6 +114,7 @@ def pipeline_decorator(func: _TFunc) -> _TFunc:
if not isinstance(func, Callable): # pylint: disable=isinstance-second-argument-not-valid-type
raise UserErrorException(f"Dsl pipeline decorator accept only function type, got {type(func)}.")

non_pipeline_inputs = kwargs.get("non_pipeline_inputs", []) or kwargs.get("non_pipeline_parameters", [])
# compute variable names changed from default_compute_targe -> compute -> default_compute -> none
# to support legacy usage, we support them with priority.
compute = kwargs.get("compute", None)
Expand Down Expand Up @@ -149,6 +152,7 @@ def pipeline_decorator(func: _TFunc) -> _TFunc:
default_datastore=default_datastore,
tags=tags,
source_path=str(func_entry_path),
non_pipeline_inputs=non_pipeline_inputs,
)

@wraps(func)
Expand All @@ -159,12 +163,13 @@ def wrapper(*args, **kwargs) -> PipelineJob:
# Because we only want to enable dsl settings on top level pipeline
_dsl_settings_stack.push() # use this stack to track on_init/on_finalize settings
try:
provided_positional_args = _validate_args(func, args, kwargs)
provided_positional_args = _validate_args(func, args, kwargs, non_pipeline_inputs)
# Convert args to kwargs
kwargs.update(provided_positional_args)
non_pipeline_params_dict = {k: v for k, v in kwargs.items() if k in non_pipeline_inputs}

# TODO: cache built pipeline component
pipeline_component = pipeline_builder.build()
pipeline_component = pipeline_builder.build(non_pipeline_params_dict=non_pipeline_params_dict)
finally:
# use `finally` to ensure pop operation from the stack
dsl_settings = _dsl_settings_stack.pop()
Expand Down Expand Up @@ -215,15 +220,23 @@ def wrapper(*args, **kwargs) -> PipelineJob:
return pipeline_decorator


def _validate_args(func, args, kwargs):
def _validate_args(func, args, kwargs, non_pipeline_inputs):
"""Validate customer function args and convert them to kwargs."""
if not isinstance(non_pipeline_inputs, List) or \
any(not isinstance(param, str) for param in non_pipeline_inputs):
raise UnExpectedNonPipelineParameterTypeError()
# Positional arguments validate
all_parameters = [param for _, param in signature(func).parameters.items()]
# Implicit parameter are *args and **kwargs
if any(param.kind in {param.VAR_KEYWORD, param.VAR_POSITIONAL} for param in all_parameters):
raise UnsupportedParameterKindError(func.__name__)

all_parameter_keys = [param.name for param in all_parameters]
non_pipeline_inputs = non_pipeline_inputs or []
unexpected_non_pipeline_inputs = [param for param in non_pipeline_inputs if param not in all_parameter_keys]
if unexpected_non_pipeline_inputs:
raise NonExistParamValueError(func.__name__, unexpected_non_pipeline_inputs)

empty_parameters = {param.name: param for param in all_parameters if param.default is Parameter.empty}
min_num = len(empty_parameters)
max_num = len(all_parameters)
Expand All @@ -250,7 +263,8 @@ def _is_supported_data_type(_data):

for pipeline_input_name in provided_args:
data = provided_args[pipeline_input_name]
if data is not None and not _is_supported_data_type(data):
if data is not None and not _is_supported_data_type(data) and \
pipeline_input_name not in non_pipeline_inputs:
msg = (
"Pipeline input expected an azure.ai.ml.Input or primitive types (str, bool, int or float), "
"but got type {}."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def _get_annotation_cls_by_type(t: type, raise_error=False, optional=None):


# pylint: disable=too-many-statements
def _get_param_with_standard_annotation(cls_or_func, is_func=False):
def _get_param_with_standard_annotation(cls_or_func, is_func=False, skip_params=None):
"""Standardize function parameters or class fields with dsl.types
annotation."""
# TODO: we'd better remove this potential recursive import
Expand Down Expand Up @@ -207,6 +207,7 @@ def _split(_fields):
}
)

skip_params = skip_params or []
inherited_fields = _get_inherited_fields()
# From annotations get field with type
annotations = getattr(cls_or_func, "__annotations__", {})
Expand All @@ -215,10 +216,16 @@ def _split(_fields):
# Update fields use class field with defaults from class dict or signature(func).paramters
if not is_func:
# Only consider public fields in class dict
defaults_dict = {key: val for key, val in cls_or_func.__dict__.items() if not key.startswith("_")}
defaults_dict = {
key: val for key, val in cls_or_func.__dict__.items()
if not key.startswith("_") and key not in skip_params
}
else:
# Infer parameter type from value if is function
defaults_dict = {key: val.default for key, val in signature(cls_or_func).parameters.items()}
defaults_dict = {
key: val.default for key, val in signature(cls_or_func).parameters.items()
if key not in skip_params
}
fields = _update_fields_with_default(annotation_fields, defaults_dict)
all_fields = _merge_and_reorder(inherited_fields, fields)
return all_fields
Expand Down
17 changes: 17 additions & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,23 @@ def __init__(self, func_name, keyword):
super().__init__(message=message, no_personal_data_message=message)


class NonExistParamValueError(KeywordError):
"""Exception raised when items in non_pipeline_inputs not in keyword parameters in
dynamic functions."""

def __init__(self, func_name, keywords):
message = "%s() got unexpected params in non_pipeline_inputs %r." % (func_name, keywords)
super().__init__(message=message, no_personal_data_message=message)


class UnExpectedNonPipelineParameterTypeError(UserErrorException):
"""Exception raised when non_pipeline_parameter type is not List[str]."""

def __init__(self):
message = "Type of 'non_pipeline_parameter' in dsl.pipeline should be a list of string"
super().__init__(message=message, no_personal_data_message=message)


class UnsupportedOperationError(UserErrorException):
"""Exception raised when specified operation is not supported."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -577,9 +577,21 @@ def check_parameter_type(f):
error_category=ErrorCategory.USER_ERROR,
)

def check_non_pipeline_inputs(f):
"""Check whether non_pipeline_inputs exist in pipeline builder."""
if f._pipeline_builder.non_pipeline_parameter_names:
msg = "Cannot register pipeline component {!r} with non_pipeline_inputs."
raise ValidationException(
message=msg.format(f.__name__),
no_personal_data_message=msg.format(""),
target=ErrorTarget.COMPONENT,
error_category=ErrorCategory.USER_ERROR,
)

if hasattr(component_func, "_is_mldesigner_component") and component_func._is_mldesigner_component:
return component_func.component
if hasattr(component_func, "_is_dsl_func") and component_func._is_dsl_func:
check_non_pipeline_inputs(component_func)
check_parameter_type(component_func)
if component_func._job_settings:
module_logger.warning(
Expand Down
15 changes: 15 additions & 0 deletions sdk/ml/azure-ai-ml/tests/dsl/e2etests/test_dsl_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,6 +1061,21 @@ def pipeline_missing_type(
in e.value.message
)

@dsl.pipeline(non_pipeline_inputs=['param'])
def pipeline_with_non_pipeline_inputs(
required_input: Input,
required_param: str,
param: str,
):
default_optional_func(
required_input=required_input,
required_param=required_param,
)

with pytest.raises(ValidationException) as e:
client.components.create_or_update(pipeline_with_non_pipeline_inputs)
assert "Cannot register pipeline component 'pipeline_with_non_pipeline_inputs' with non_pipeline_inputs." in e.value.message

def test_create_pipeline_component_by_dsl(self, caplog, client: MLClient):
default_optional_func = load_component(source=str(components_dir / "default_optional_component.yml"))

Expand Down
68 changes: 67 additions & 1 deletion sdk/ml/azure-ai-ml/tests/dsl/unittests/test_dsl_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from azure.ai.ml.entities._builders import Command
from azure.ai.ml.entities._job.pipeline._io import PipelineInput
from azure.ai.ml.entities._job.pipeline._load_component import _generate_component_function
from azure.ai.ml.exceptions import UserErrorException, ValidationException
from azure.ai.ml.exceptions import UserErrorException, ValidationException, NonExistParamValueError, UnExpectedNonPipelineParameterTypeError

from .._util import _DSL_TIMEOUT_SECOND

Expand Down Expand Up @@ -1940,3 +1940,69 @@ def pipeline_func(component_in_path):
pipeline_job = pipeline_func(component_in_path=Data(name="test", version="1", type=AssetTypes.MLTABLE))
result = pipeline_job._validate()
assert result._to_dict() == {"result": "Succeeded"}

def test_pipeline_with_non_pipeline_inputs(self):
component_yaml = components_dir / "helloworld_component.yml"
component_func1 = load_component(source=component_yaml, params_override=[{"name": "component_name_1"}])
component_func2 = load_component(source=component_yaml, params_override=[{"name": "component_name_2"}])

@dsl.pipeline(non_pipeline_inputs=["other_params", "is_add_component"])
def pipeline_func(job_in_number, job_in_path, other_params, is_add_component):
component_func1(component_in_number=job_in_number, component_in_path=job_in_path)
lalala123123 marked this conversation as resolved.
Show resolved Hide resolved
component_func2(component_in_number=other_params, component_in_path=job_in_path)
if is_add_component:
component_func2(component_in_number=other_params, component_in_path=job_in_path)

pipeline = pipeline_func(10, Input(path="/a/path/on/ds"), 15, False)
assert len(pipeline.jobs) == 2
assert "other_params" not in pipeline.inputs
assert isinstance(pipeline.jobs[component_func1.name].inputs["component_in_number"]._data, PipelineInput)
assert pipeline.jobs[component_func2.name].inputs["component_in_number"]._data == 15

pipeline = pipeline_func(10, Input(path="/a/path/on/ds"), 15, True)
assert len(pipeline.jobs) == 3

@dsl.pipeline(non_pipeline_parameters=["other_params", "is_add_component"])
def pipeline_func(job_in_number, job_in_path, other_params, is_add_component):
component_func1(component_in_number=job_in_number, component_in_path=job_in_path)
component_func2(component_in_number=other_params, component_in_path=job_in_path)
if is_add_component:
component_func2(component_in_number=other_params, component_in_path=job_in_path)

pipeline = pipeline_func(10, Input(path="/a/path/on/ds"), 15, True)
assert len(pipeline.jobs) == 3

def test_pipeline_with_invalid_non_pipeline_inputs(self):

@dsl.pipeline(non_pipeline_inputs=[123])
def pipeline_func():
pass

with pytest.raises(UnExpectedNonPipelineParameterTypeError) as error_info:
pipeline_func()
assert "Type of 'non_pipeline_parameter' in dsl.pipeline should be a list of string" in str(error_info)

@dsl.pipeline(non_pipeline_inputs=["non_exist_param1", "non_exist_param2"])
def pipeline_func():
pass

with pytest.raises(NonExistParamValueError) as error_info:
pipeline_func()
assert "pipeline_func() got unexpected params in non_pipeline_inputs ['non_exist_param1', 'non_exist_param2']" in str(error_info)

def test_component_func_as_non_pipeline_inputs(self):
component_yaml = components_dir / "helloworld_component.yml"
component_func1 = load_component(source=component_yaml, params_override=[{"name": "component_name_1"}])
component_func2 = load_component(source=component_yaml, params_override=[{"name": "component_name_2"}])

@dsl.pipeline(non_pipeline_inputs=["component_func"])
def pipeline_func(job_in_number, job_in_path, component_func):
component_func1(component_in_number=job_in_number, component_in_path=job_in_path)
component_func(component_in_number=job_in_number, component_in_path=job_in_path)

pipeline = pipeline_func(
job_in_number=10,
job_in_path=Input(path="/a/path/on/ds"),
component_func=component_func2)
assert len(pipeline.jobs) == 2
assert component_func2.name in pipeline.jobs