diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_pipeline_component_builder.py b/sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_pipeline_component_builder.py index ea9aafd2cffe..256dd63cd301 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_pipeline_component_builder.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_pipeline_component_builder.py @@ -145,6 +145,7 @@ def __init__( default_datastore=None, tags=None, source_path=None, + non_pipeline_inputs=None ): self.func = func name = name if name else func.__name__ @@ -155,6 +156,7 @@ def __init__( name = func.__name__ # List of nodes, order by it's creation order in pipeline. self.nodes = [] + self.non_pipeline_parameter_names = non_pipeline_inputs or [] # A dict of inputs name to InputDefinition. # TODO: infer pipeline component input meta from assignment self.inputs = self._build_inputs(func) @@ -181,10 +183,10 @@ def add_node(self, node: Union[BaseNode, AutoMLJob]): """ self.nodes.append(node) - def build(self) -> PipelineComponent: + def build(self, non_pipeline_params_dict=None) -> PipelineComponent: # Clear nodes as we may call build multiple times. self.nodes = [] - kwargs = _build_pipeline_parameter(self.func, self._get_group_parameter_defaults()) + kwargs = _build_pipeline_parameter(self.func, self._get_group_parameter_defaults(), non_pipeline_params_dict) # We use this stack to store the dsl pipeline definition hierarchy _definition_builder_stack.push(self) @@ -218,7 +220,7 @@ def build(self) -> PipelineComponent: return pipeline_component def _build_inputs(self, func): - inputs = _get_param_with_standard_annotation(func, is_func=True) + inputs = _get_param_with_standard_annotation(func, is_func=True, skip_params=self.non_pipeline_parameter_names) for k, v in inputs.items(): # add arg description if k in self._args_description: @@ -379,7 +381,7 @@ def _get_name_or_component_name(node: Union[BaseNode, AutoMLJob]): return result -def _build_pipeline_parameter(func, kwargs=None): +def _build_pipeline_parameter(func, kwargs=None, non_pipeline_parameter_dict=None): # Pass group defaults into kwargs to support group.item can be used even if no default on function. # example: # @parameter_group @@ -391,9 +393,14 @@ def _build_pipeline_parameter(func, kwargs=None): # component_func(input=param.key) <--- param.key should be val. # transform kwargs - transformed_kwargs = {} + transformed_kwargs = non_pipeline_parameter_dict or {} if kwargs: - transformed_kwargs.update({key: _wrap_pipeline_parameter(key, value) for key, value in kwargs.items()}) + transformed_kwargs.update( + { + key: _wrap_pipeline_parameter(key, value) for key, value in kwargs.items() + if key not in non_pipeline_parameter_dict + } + ) def all_params(parameters): for value in parameters.values(): diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_pipeline_decorator.py b/sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_pipeline_decorator.py index d43292178706..6babfa111166 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_pipeline_decorator.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/dsl/_pipeline_decorator.py @@ -10,7 +10,7 @@ from functools import wraps from inspect import Parameter, signature from pathlib import Path -from typing import Any, Callable, Dict, TypeVar +from typing import Any, Callable, Dict, TypeVar, List from azure.ai.ml.entities import Data, PipelineJob, PipelineJobSettings from azure.ai.ml.entities._builders.pipeline import Pipeline @@ -24,6 +24,8 @@ UnexpectedKeywordError, UnsupportedParameterKindError, UserErrorException, + NonExistParamValueError, + UnExpectedNonPipelineParameterTypeError, ) from ._pipeline_component_builder import PipelineComponentBuilder, _is_inside_dsl_pipeline_func @@ -112,6 +114,7 @@ def pipeline_decorator(func: _TFunc) -> _TFunc: if not isinstance(func, Callable): # pylint: disable=isinstance-second-argument-not-valid-type raise UserErrorException(f"Dsl pipeline decorator accept only function type, got {type(func)}.") + non_pipeline_inputs = kwargs.get("non_pipeline_inputs", []) or kwargs.get("non_pipeline_parameters", []) # compute variable names changed from default_compute_targe -> compute -> default_compute -> none # to support legacy usage, we support them with priority. compute = kwargs.get("compute", None) @@ -149,6 +152,7 @@ def pipeline_decorator(func: _TFunc) -> _TFunc: default_datastore=default_datastore, tags=tags, source_path=str(func_entry_path), + non_pipeline_inputs=non_pipeline_inputs, ) @wraps(func) @@ -159,12 +163,13 @@ def wrapper(*args, **kwargs) -> PipelineJob: # Because we only want to enable dsl settings on top level pipeline _dsl_settings_stack.push() # use this stack to track on_init/on_finalize settings try: - provided_positional_args = _validate_args(func, args, kwargs) + provided_positional_args = _validate_args(func, args, kwargs, non_pipeline_inputs) # Convert args to kwargs kwargs.update(provided_positional_args) + non_pipeline_params_dict = {k: v for k, v in kwargs.items() if k in non_pipeline_inputs} # TODO: cache built pipeline component - pipeline_component = pipeline_builder.build() + pipeline_component = pipeline_builder.build(non_pipeline_params_dict=non_pipeline_params_dict) finally: # use `finally` to ensure pop operation from the stack dsl_settings = _dsl_settings_stack.pop() @@ -215,8 +220,11 @@ def wrapper(*args, **kwargs) -> PipelineJob: return pipeline_decorator -def _validate_args(func, args, kwargs): +def _validate_args(func, args, kwargs, non_pipeline_inputs): """Validate customer function args and convert them to kwargs.""" + if not isinstance(non_pipeline_inputs, List) or \ + any(not isinstance(param, str) for param in non_pipeline_inputs): + raise UnExpectedNonPipelineParameterTypeError() # Positional arguments validate all_parameters = [param for _, param in signature(func).parameters.items()] # Implicit parameter are *args and **kwargs @@ -224,6 +232,11 @@ def _validate_args(func, args, kwargs): raise UnsupportedParameterKindError(func.__name__) all_parameter_keys = [param.name for param in all_parameters] + non_pipeline_inputs = non_pipeline_inputs or [] + unexpected_non_pipeline_inputs = [param for param in non_pipeline_inputs if param not in all_parameter_keys] + if unexpected_non_pipeline_inputs: + raise NonExistParamValueError(func.__name__, unexpected_non_pipeline_inputs) + empty_parameters = {param.name: param for param in all_parameters if param.default is Parameter.empty} min_num = len(empty_parameters) max_num = len(all_parameters) @@ -250,7 +263,8 @@ def _is_supported_data_type(_data): for pipeline_input_name in provided_args: data = provided_args[pipeline_input_name] - if data is not None and not _is_supported_data_type(data): + if data is not None and not _is_supported_data_type(data) and \ + pipeline_input_name not in non_pipeline_inputs: msg = ( "Pipeline input expected an azure.ai.ml.Input or primitive types (str, bool, int or float), " "but got type {}." diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_inputs_outputs/utils.py b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_inputs_outputs/utils.py index 6e1def679c85..1d51a818f513 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_inputs_outputs/utils.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_inputs_outputs/utils.py @@ -59,7 +59,7 @@ def _get_annotation_cls_by_type(t: type, raise_error=False, optional=None): # pylint: disable=too-many-statements -def _get_param_with_standard_annotation(cls_or_func, is_func=False): +def _get_param_with_standard_annotation(cls_or_func, is_func=False, skip_params=None): """Standardize function parameters or class fields with dsl.types annotation.""" # TODO: we'd better remove this potential recursive import @@ -207,6 +207,7 @@ def _split(_fields): } ) + skip_params = skip_params or [] inherited_fields = _get_inherited_fields() # From annotations get field with type annotations = getattr(cls_or_func, "__annotations__", {}) @@ -215,10 +216,16 @@ def _split(_fields): # Update fields use class field with defaults from class dict or signature(func).paramters if not is_func: # Only consider public fields in class dict - defaults_dict = {key: val for key, val in cls_or_func.__dict__.items() if not key.startswith("_")} + defaults_dict = { + key: val for key, val in cls_or_func.__dict__.items() + if not key.startswith("_") and key not in skip_params + } else: # Infer parameter type from value if is function - defaults_dict = {key: val.default for key, val in signature(cls_or_func).parameters.items()} + defaults_dict = { + key: val.default for key, val in signature(cls_or_func).parameters.items() + if key not in skip_params + } fields = _update_fields_with_default(annotation_fields, defaults_dict) all_fields = _merge_and_reorder(inherited_fields, fields) return all_fields diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/exceptions.py b/sdk/ml/azure-ai-ml/azure/ai/ml/exceptions.py index b01a3a3a537a..5631168bc98b 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/exceptions.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/exceptions.py @@ -571,6 +571,23 @@ def __init__(self, func_name, keyword): super().__init__(message=message, no_personal_data_message=message) +class NonExistParamValueError(KeywordError): + """Exception raised when items in non_pipeline_inputs not in keyword parameters in + dynamic functions.""" + + def __init__(self, func_name, keywords): + message = "%s() got unexpected params in non_pipeline_inputs %r." % (func_name, keywords) + super().__init__(message=message, no_personal_data_message=message) + + +class UnExpectedNonPipelineParameterTypeError(UserErrorException): + """Exception raised when non_pipeline_parameter type is not List[str].""" + + def __init__(self): + message = "Type of 'non_pipeline_parameter' in dsl.pipeline should be a list of string" + super().__init__(message=message, no_personal_data_message=message) + + class UnsupportedOperationError(UserErrorException): """Exception raised when specified operation is not supported.""" diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_component_operations.py b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_component_operations.py index d527365f3072..2d8fdff3a43c 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_component_operations.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/operations/_component_operations.py @@ -577,9 +577,21 @@ def check_parameter_type(f): error_category=ErrorCategory.USER_ERROR, ) + def check_non_pipeline_inputs(f): + """Check whether non_pipeline_inputs exist in pipeline builder.""" + if f._pipeline_builder.non_pipeline_parameter_names: + msg = "Cannot register pipeline component {!r} with non_pipeline_inputs." + raise ValidationException( + message=msg.format(f.__name__), + no_personal_data_message=msg.format(""), + target=ErrorTarget.COMPONENT, + error_category=ErrorCategory.USER_ERROR, + ) + if hasattr(component_func, "_is_mldesigner_component") and component_func._is_mldesigner_component: return component_func.component if hasattr(component_func, "_is_dsl_func") and component_func._is_dsl_func: + check_non_pipeline_inputs(component_func) check_parameter_type(component_func) if component_func._job_settings: module_logger.warning( diff --git a/sdk/ml/azure-ai-ml/tests/dsl/e2etests/test_dsl_pipeline.py b/sdk/ml/azure-ai-ml/tests/dsl/e2etests/test_dsl_pipeline.py index e2e8c0107c30..24afdeaa0df4 100644 --- a/sdk/ml/azure-ai-ml/tests/dsl/e2etests/test_dsl_pipeline.py +++ b/sdk/ml/azure-ai-ml/tests/dsl/e2etests/test_dsl_pipeline.py @@ -1061,6 +1061,21 @@ def pipeline_missing_type( in e.value.message ) + @dsl.pipeline(non_pipeline_inputs=['param']) + def pipeline_with_non_pipeline_inputs( + required_input: Input, + required_param: str, + param: str, + ): + default_optional_func( + required_input=required_input, + required_param=required_param, + ) + + with pytest.raises(ValidationException) as e: + client.components.create_or_update(pipeline_with_non_pipeline_inputs) + assert "Cannot register pipeline component 'pipeline_with_non_pipeline_inputs' with non_pipeline_inputs." in e.value.message + def test_create_pipeline_component_by_dsl(self, caplog, client: MLClient): default_optional_func = load_component(source=str(components_dir / "default_optional_component.yml")) diff --git a/sdk/ml/azure-ai-ml/tests/dsl/unittests/test_dsl_pipeline.py b/sdk/ml/azure-ai-ml/tests/dsl/unittests/test_dsl_pipeline.py index ccd54924d550..f3e1c8d149e7 100644 --- a/sdk/ml/azure-ai-ml/tests/dsl/unittests/test_dsl_pipeline.py +++ b/sdk/ml/azure-ai-ml/tests/dsl/unittests/test_dsl_pipeline.py @@ -28,7 +28,7 @@ from azure.ai.ml.entities._builders import Command from azure.ai.ml.entities._job.pipeline._io import PipelineInput from azure.ai.ml.entities._job.pipeline._load_component import _generate_component_function -from azure.ai.ml.exceptions import UserErrorException, ValidationException +from azure.ai.ml.exceptions import UserErrorException, ValidationException, NonExistParamValueError, UnExpectedNonPipelineParameterTypeError from .._util import _DSL_TIMEOUT_SECOND @@ -1940,3 +1940,69 @@ def pipeline_func(component_in_path): pipeline_job = pipeline_func(component_in_path=Data(name="test", version="1", type=AssetTypes.MLTABLE)) result = pipeline_job._validate() assert result._to_dict() == {"result": "Succeeded"} + + def test_pipeline_with_non_pipeline_inputs(self): + component_yaml = components_dir / "helloworld_component.yml" + component_func1 = load_component(source=component_yaml, params_override=[{"name": "component_name_1"}]) + component_func2 = load_component(source=component_yaml, params_override=[{"name": "component_name_2"}]) + + @dsl.pipeline(non_pipeline_inputs=["other_params", "is_add_component"]) + def pipeline_func(job_in_number, job_in_path, other_params, is_add_component): + component_func1(component_in_number=job_in_number, component_in_path=job_in_path) + component_func2(component_in_number=other_params, component_in_path=job_in_path) + if is_add_component: + component_func2(component_in_number=other_params, component_in_path=job_in_path) + + pipeline = pipeline_func(10, Input(path="/a/path/on/ds"), 15, False) + assert len(pipeline.jobs) == 2 + assert "other_params" not in pipeline.inputs + assert isinstance(pipeline.jobs[component_func1.name].inputs["component_in_number"]._data, PipelineInput) + assert pipeline.jobs[component_func2.name].inputs["component_in_number"]._data == 15 + + pipeline = pipeline_func(10, Input(path="/a/path/on/ds"), 15, True) + assert len(pipeline.jobs) == 3 + + @dsl.pipeline(non_pipeline_parameters=["other_params", "is_add_component"]) + def pipeline_func(job_in_number, job_in_path, other_params, is_add_component): + component_func1(component_in_number=job_in_number, component_in_path=job_in_path) + component_func2(component_in_number=other_params, component_in_path=job_in_path) + if is_add_component: + component_func2(component_in_number=other_params, component_in_path=job_in_path) + + pipeline = pipeline_func(10, Input(path="/a/path/on/ds"), 15, True) + assert len(pipeline.jobs) == 3 + + def test_pipeline_with_invalid_non_pipeline_inputs(self): + + @dsl.pipeline(non_pipeline_inputs=[123]) + def pipeline_func(): + pass + + with pytest.raises(UnExpectedNonPipelineParameterTypeError) as error_info: + pipeline_func() + assert "Type of 'non_pipeline_parameter' in dsl.pipeline should be a list of string" in str(error_info) + + @dsl.pipeline(non_pipeline_inputs=["non_exist_param1", "non_exist_param2"]) + def pipeline_func(): + pass + + with pytest.raises(NonExistParamValueError) as error_info: + pipeline_func() + assert "pipeline_func() got unexpected params in non_pipeline_inputs ['non_exist_param1', 'non_exist_param2']" in str(error_info) + + def test_component_func_as_non_pipeline_inputs(self): + component_yaml = components_dir / "helloworld_component.yml" + component_func1 = load_component(source=component_yaml, params_override=[{"name": "component_name_1"}]) + component_func2 = load_component(source=component_yaml, params_override=[{"name": "component_name_2"}]) + + @dsl.pipeline(non_pipeline_inputs=["component_func"]) + def pipeline_func(job_in_number, job_in_path, component_func): + component_func1(component_in_number=job_in_number, component_in_path=job_in_path) + component_func(component_in_number=job_in_number, component_in_path=job_in_path) + + pipeline = pipeline_func( + job_in_number=10, + job_in_path=Input(path="/a/path/on/ds"), + component_func=component_func2) + assert len(pipeline.jobs) == 2 + assert component_func2.name in pipeline.jobs