Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: combo fix for internal components I #26718

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_schema/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

from marshmallow import fields, post_dump
from marshmallow import fields, post_dump, INCLUDE, EXCLUDE

from azure.ai.ml._schema import NestedField, StringTransformedEnum, UnionField
from azure.ai.ml._schema.component.component import ComponentSchema
Expand All @@ -15,6 +15,7 @@
InternalInputPortSchema,
InternalOutputPortSchema,
InternalParameterSchema,
InternalPrimitiveOutputSchema,
)


Expand Down Expand Up @@ -42,6 +43,8 @@ def all_values(cls):


class InternalBaseComponentSchema(ComponentSchema):
class Meta:
unknown = INCLUDE
# override name as 1p components allow . in name, which is not allowed in v2 components
name = fields.Str()

Expand All @@ -60,7 +63,16 @@ class InternalBaseComponentSchema(ComponentSchema):
]
),
)
outputs = fields.Dict(keys=fields.Str(), values=NestedField(InternalOutputPortSchema))
# support primitive output for all internal components for now
outputs = fields.Dict(
keys=fields.Str(),
values=UnionField(
[
NestedField(InternalPrimitiveOutputSchema, unknown=EXCLUDE),
NestedField(InternalOutputPortSchema, unknown=EXCLUDE),
]
),
)

# type field is required for registration
type = StringTransformedEnum(
Expand Down
43 changes: 27 additions & 16 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_schema/input_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,22 @@

from marshmallow import fields, post_dump, post_load

from azure.ai.ml._schema import StringTransformedEnum, UnionField
from azure.ai.ml._schema.component.input_output import InputPortSchema, OutputPortSchema, ParameterSchema
from azure.ai.ml._schema.core.fields import DumpableFloatField, DumpableIntegerField
from azure.ai.ml._schema import StringTransformedEnum, UnionField, PatchedSchemaMeta
from azure.ai.ml._schema.component.input_output import InputPortSchema, ParameterSchema
from azure.ai.ml._schema.core.fields import DumpableFloatField, DumpableIntegerField, DumpableEnumField


SUPPORTED_INTERNAL_PARAM_TYPES = [
"integer",
"Integer",
"boolean",
"Boolean",
"string",
"String",
"float",
"Float",
]

class InternalInputPortSchema(InputPortSchema):
# skip client-side validate for type enum & support list
type = UnionField(
Expand All @@ -29,29 +40,29 @@ def resolve_list_type(self, data, original_data, **kwargs): # pylint: disable=u
return data


class InternalOutputPortSchema(OutputPortSchema):
class InternalOutputPortSchema(metaclass=PatchedSchemaMeta):
# skip client-side validate for type enum
type = fields.Str(
required=True,
data_key="type",
)
description = fields.Str()
is_link_mode = fields.Bool()
datastore_mode = fields.Str()


class InternalPrimitiveOutputSchema(metaclass=PatchedSchemaMeta):
type = DumpableEnumField(
allowed_values=SUPPORTED_INTERNAL_PARAM_TYPES,
required=True,
)
description = fields.Str()
is_control = fields.Bool()
elliotzh marked this conversation as resolved.
Show resolved Hide resolved


class InternalParameterSchema(ParameterSchema):
type = StringTransformedEnum(
allowed_values=[
"integer",
"Integer",
"boolean",
"Boolean",
"string",
"String",
"float",
"Float",
],
casing_transform=lambda x: x,
type = DumpableEnumField(
allowed_values=SUPPORTED_INTERNAL_PARAM_TYPES,
required=True,
data_key="type",
)
Expand Down
3 changes: 3 additions & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_schema/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@


class InternalBaseNodeSchema(BaseNodeSchema):
class Meta:
unknown = INCLUDE
component = UnionField(
[
# for registry type assets
Expand Down Expand Up @@ -56,6 +58,7 @@ class ScopeSchema(InternalBaseNodeSchema):
class HDInsightSchema(InternalBaseNodeSchema):
type = StringTransformedEnum(allowed_values=[NodeType.HDI], casing_transform=lambda x: x)

compute_name = fields.Str()
queue = fields.Str()
driver_memory = fields.Str()
driver_cores = fields.Int()
Expand Down
6 changes: 6 additions & 0 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_internal/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
InternalComponent,
Parallel,
Scope,
DataTransfer,
Hemera,
Starlite,
)
from azure.ai.ml._schema import NestedField
from azure.ai.ml.entities._component.component_factory import component_factory
Expand Down Expand Up @@ -57,6 +60,9 @@ def enable_internal_components_in_pipeline():
_register_node(_type, InternalBaseNode, InternalBaseNodeSchema)

# redo the registration for those with specific runsettings
_register_node(NodeType.DATA_TRANSFER, DataTransfer, InternalBaseNodeSchema)
_register_node(NodeType.HEMERA, Hemera, InternalBaseNodeSchema)
_register_node(NodeType.STARLITE, Starlite, InternalBaseNodeSchema)
_register_node(NodeType.COMMAND, Command, CommandSchema)
_register_node(NodeType.DISTRIBUTED, Distributed, DistributedSchema)
_register_node(NodeType.SCOPE, Scope, ScopeSchema)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,34 @@

class _AdditionalIncludes:
def __init__(self, code_path: Union[None, str], yaml_path: str):
self._yaml_path = Path(yaml_path)
self._yaml_name = self._yaml_path.name
self._code_path = self._yaml_path.parent
if code_path is not None:
self._code_path = (self._code_path / code_path).resolve()
self.__yaml_path = Path(yaml_path)
self.__code_path = code_path

self._tmp_code_path = None
self._additional_includes_file_path = self._yaml_path.with_suffix(f".{ADDITIONAL_INCLUDES_SUFFIX}")
self._includes = None
if self._additional_includes_file_path.is_file():
with open(self._additional_includes_file_path, "r") as f:
lines = f.readlines()
self._includes = [line.strip() for line in lines if len(line.strip()) > 0]

@property
def _yaml_path(self) -> Path:
return self.__yaml_path

@property
def _code_path(self) -> Path:
if self.__code_path is not None:
return (self._yaml_path.parent / self.__code_path).resolve()
return self._yaml_path.parent

@property
def _yaml_name(self) -> str:
return self._yaml_path.name

@property
def _additional_includes_file_path(self) -> Path:
return self._yaml_path.with_suffix(f".{ADDITIONAL_INCLUDES_SUFFIX}")

@property
def code(self) -> Path:
return self._tmp_code_path if self._tmp_code_path else self._code_path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@ def compute(self) -> str:
@compute.setter
def compute(self, value: str):
"""Set the compute definition for the command."""
if value is not None and not isinstance(value, str):
raise ValueError(f"Failed in setting compute: only string is supported in DPv2 but got {type(value)}")
self._compute = value

@property
Expand All @@ -54,8 +52,6 @@ def environment(self) -> str:
@environment.setter
def environment(self, value: str):
"""Set the environment definition for the command."""
if value is not None and not isinstance(value, str):
raise ValueError(f"Failed in setting environment: only string is supported in DPv2 but got {type(value)}")
self._environment = value

@property
Expand Down
27 changes: 16 additions & 11 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_internal/entities/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# disable redefined-builtin to use id/type as argument name
from contextlib import contextmanager
from typing import Dict, Union
import os

from marshmallow import INCLUDE, Schema

Expand Down Expand Up @@ -176,6 +177,21 @@ def _additional_includes(self):
def _create_schema_for_validation(cls, context) -> Union[PathAwareSchema, Schema]:
return InternalBaseComponentSchema(context=context)

def _validate(self, raise_error=False) -> MutableValidationResult:
if self._additional_includes is not None and self._additional_includes._validate().passed:
# update source path in case dependency file is in additional_includes
with self._resolve_local_code() as tmp_base_path:
origin_base_path, origin_source_path = self._base_path, self._source_path

try:
self._base_path, self._source_path = \
elliotzh marked this conversation as resolved.
Show resolved Hide resolved
tmp_base_path, tmp_base_path / os.path.basename(self._source_path)
return super()._validate(raise_error=raise_error)
finally:
self._base_path, self._source_path = origin_base_path, origin_source_path

return super()._validate(raise_error=raise_error)

def _customized_validate(self) -> MutableValidationResult:
validation_result = super(InternalComponent, self)._customized_validate()
if isinstance(self.environment, InternalEnvironment):
Expand Down Expand Up @@ -228,14 +244,3 @@ def _resolve_local_code(self):

def __call__(self, *args, **kwargs) -> InternalBaseNode: # pylint: disable=useless-super-delegation
return super(InternalComponent, self).__call__(*args, **kwargs)

def _schema_validate(self) -> MutableValidationResult:
"""Validate the resource with the schema.

return type: ValidationResult
"""
result = super(InternalComponent, self)._schema_validate()
# skip unknown field warnings for internal components
# TODO: move this logic into base class
result._warnings = list(filter(lambda x: x.message != "Unknown field.", result._warnings))
return result
23 changes: 11 additions & 12 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_internal/entities/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from azure.ai.ml.entities._job.pipeline._io import NodeInput, NodeOutput, PipelineInput
from azure.ai.ml.entities._util import convert_ordered_dict_to_dict

from ...entities._validation import MutableValidationResult
from .._schema.component import NodeType
from ._input_outputs import InternalInput

Expand Down Expand Up @@ -95,17 +94,6 @@ def _to_job(self) -> Job:
def _load_from_dict(cls, data: Dict, context: Dict, additional_message: str, **kwargs) -> "Job":
raise RuntimeError("Internal components doesn't support load from dict")

def _schema_validate(self) -> MutableValidationResult:
"""Validate the resource with the schema.

return type: ValidationResult
"""
result = super(InternalBaseNode, self)._schema_validate()
# skip unknown field warnings for internal components
# TODO: move this logic into base class?
result._warnings = list(filter(lambda x: x.message != "Unknown field.", result._warnings))
return result

@classmethod
def _create_schema_for_validation(cls, context) -> Union[PathAwareSchema, Schema]:
from .._schema.node import InternalBaseNodeSchema
Expand Down Expand Up @@ -176,6 +164,7 @@ def __init__(self, **kwargs):
kwargs.pop("type", None)
super(HDInsight, self).__init__(type=NodeType.HDI, **kwargs)
self._init = True
self._compute_name: str = kwargs.pop("compute_name", None)
self._queue: str = kwargs.pop("queue", None)
self._driver_memory: str = kwargs.pop("driver_memory", None)
self._driver_cores: int = kwargs.pop("driver_cores", None)
Expand All @@ -186,6 +175,15 @@ def __init__(self, **kwargs):
self._hdinsight_spark_job_name: str = kwargs.pop("hdinsight_spark_job_name", None)
self._init = False

@property
def compute_name(self) -> str:
"""Name of the compute to be used."""
return self._compute_name

@compute_name.setter
def compute_name(self, value: str):
self._compute_name = value
elliotzh marked this conversation as resolved.
Show resolved Hide resolved

@property
def queue(self) -> str:
"""The name of the YARN queue to which submitted."""
Expand Down Expand Up @@ -267,6 +265,7 @@ def hdinsight_spark_job_name(self, value: str):
@classmethod
def _picked_fields_from_dict_to_rest_object(cls) -> List[str]:
return [
"compute_name",
"queue",
"driver_cores",
"executor_memory",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class Meta:
],
metadata={"description": "Provides the configuration for a distributed run."},
)
# primitive output is only supported for command component
# primitive output is only supported for command component & pipeline component
outputs = fields.Dict(
keys=fields.Str(),
values=UnionField(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,12 @@ class OutputPortSchema(metaclass=PatchedSchemaMeta):
)


class PrimitiveOutputSchema(metaclass=PatchedSchemaMeta):
class PrimitiveOutputSchema(OutputPortSchema):
type = DumpableEnumField(
allowed_values=SUPPORTED_PARAM_TYPES,
required=True,
)
description = fields.Str()
is_control = fields.Bool()
mode = DumpableEnumField(
allowed_values=SUPPORTED_INPUT_OUTPUT_MODES,
)


class ParameterSchema(metaclass=PatchedSchemaMeta):
Expand Down
6 changes: 4 additions & 2 deletions sdk/ml/azure-ai-ml/azure/ai/ml/_schema/core/schema_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ def __new__(cls, name, bases, dct):
if meta is None:
dct["Meta"] = PatchedMeta
else:
dct["Meta"].unknown = RAISE
dct["Meta"].ordered = True
if not hasattr(meta, "unknown"):
dct["Meta"].unknown = RAISE
if not hasattr(meta, "ordered"):
dct["Meta"].ordered = True

bases = bases + (PatchedBaseSchema,)
klass = super().__new__(cls, name, bases, dct)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ class PipelineComponentSchema(ComponentSchema):
type = StringTransformedEnum(allowed_values=[NodeType.PIPELINE])
jobs = PipelineJobsField()

# primitive output is only supported for command component
# primitive output is only supported for command component & pipeline component
outputs = fields.Dict(
keys=fields.Str(),
values=UnionField(
Expand Down
5 changes: 2 additions & 3 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/_load_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def load_common(
return _load_common_raising_marshmallow_error(cls, yaml_dict, relative_origin, params_override, **kwargs)
except ValidationError as e:
if issubclass(cls, SchemaValidatableMixin):
validation_result = _ValidationResultBuilder.from_validation_error(e, relative_origin)
validation_result = _ValidationResultBuilder.from_validation_error(e, source_path=relative_origin)
validation_result.try_raise(
# pylint: disable=protected-access
error_target=cls._get_validation_error_target(),
Expand All @@ -102,8 +102,7 @@ def load_common(
f"of type {type_str}, please specify the correct "
f"type in the 'type' property.",
)
else:
raise e
raise e


def _try_load_yaml_dict(source: Union[str, PathLike, IO[AnyStr]]) -> dict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def validate_common(cls, path, validate_func, params_override=None) -> Validatio
except ValidationException as err:
return _ValidationResultBuilder.from_single_message(err.message)
except ValidationError as err:
return _ValidationResultBuilder.from_validation_error(err, path)
return _ValidationResultBuilder.from_validation_error(err, source_path=path)


def validate_component(path, ml_client=None, params_override=None) -> ValidationResult:
Expand Down
Loading