Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement validations for darts suggestion service #1926

Merged
63 changes: 63 additions & 0 deletions pkg/suggestion/v1beta1/nas/common/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright 2022 The Kubeflow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pkg.apis.manager.v1beta1.python import api_pb2


def validate_operations(operations: list[api_pb2.Operation]) -> (bool, str):

# Validate each operation
for operation in operations:

# Check OperationType
if not operation.operation_type:
return False, "Missing operationType in Operation:\n{}".format(operation)

# Check ParameterConfigs
if not operation.parameter_specs.parameters:
return False, "Missing ParameterConfigs in Operation:\n{}".format(operation)

# Validate each ParameterConfig in Operation
parameters_list = list(operation.parameter_specs.parameters)
for parameter in parameters_list:

# Check Name
if not parameter.name:
return False, "Missing Name in ParameterConfig:\n{}".format(parameter)

# Check ParameterType
if not parameter.parameter_type:
return False, "Missing ParameterType in ParameterConfig:\n{}".format(parameter)

# Check List in Categorical or Discrete Type
if parameter.parameter_type == api_pb2.CATEGORICAL or parameter.parameter_type == api_pb2.DISCRETE:
if not parameter.feasible_space.list:
return False, "Missing List in ParameterConfig.feasibleSpace:\n{}".format(parameter)

# Check Max, Min, Step in Int or Double Type
elif parameter.parameter_type == api_pb2.INT or parameter.parameter_type == api_pb2.DOUBLE:
if not parameter.feasible_space.min and not parameter.feasible_space.max:
tenzen-y marked this conversation as resolved.
Show resolved Hide resolved
return False, "Missing Max and Min in ParameterConfig.feasibleSpace:\n{}".format(parameter)

try:
if (parameter.parameter_type == api_pb2.DOUBLE and
(not parameter.feasible_space.step or float(parameter.feasible_space.step) <= 0)):
return False, \
"Step parameter should be > 0 in ParameterConfig.feasibleSpace:\n{}".format(parameter)
except Exception as e:
return False, \
"failed to validate ParameterConfig.feasibleSpace \n{parameter}):\n{exception}".format(
parameter=parameter, exception=e)

return True, ""
2 changes: 0 additions & 2 deletions pkg/suggestion/v1beta1/nas/darts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,6 @@ Currently, it supports running only on single GPU and second-order approximation

- Integrate E2E test in CI. Create simple example, which can run on CPU.

- Add validation to Suggestion service.

- Support multi GPU training. Add functionality to select GPU for training.

- Support DARTS in Katib UI.
Expand Down
69 changes: 67 additions & 2 deletions pkg/suggestion/v1beta1/nas/darts/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
import logging
from logging import getLogger, StreamHandler, INFO
import json
import grpc

from pkg.suggestion.v1beta1.internal.base_health_service import HealthServicer
from pkg.apis.manager.v1beta1.python import api_pb2
from pkg.apis.manager.v1beta1.python import api_pb2_grpc
from pkg.suggestion.v1beta1.nas.common.validation import validate_operations


class DartsService(api_pb2_grpc.SuggestionServicer, HealthServicer):
Expand All @@ -36,8 +38,12 @@ def __init__(self):
self.logger.addHandler(handler)
self.logger.propagate = False

# TODO: Add validation
tenzen-y marked this conversation as resolved.
Show resolved Hide resolved
def ValidateAlgorithmSettings(self, request, context):
is_valid, message = validate_algorithm_spec(request.experiment.spec)
if not is_valid:
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
context.set_details(message)
self.logger.error(message)
return api_pb2.ValidateAlgorithmSettingsReply()

def GetSuggestions(self, request, context):
Expand Down Expand Up @@ -130,7 +136,66 @@ def get_algorithm_settings(settings_raw):

for setting in settings_raw:
s_name = setting.name
s_value = setting.value
s_value = None if setting.value == "None" else setting.value
algorithm_settings_default[s_name] = s_value

return algorithm_settings_default


def validate_algorithm_spec(spec: api_pb2.ExperimentSpec) -> (bool, str):
# Validate Operations
is_valid, message = validate_operations(spec.nas_config.operations.operation)
if not is_valid:
return False, message

# Validate AlgorithmSettings
is_valid, message = validate_algorithm_settings(spec.algorithm.algorithm_settings)
if not is_valid:
return False, message

return True, ""


# validate_algorithm_settings is implemented based on quark0/darts and pt.darts.
# quark0/darts: https://github.com/quark0/darts
# pt.darts: https://github.com/khanrc/pt.darts
def validate_algorithm_settings(algorithm_settings: list[api_pb2.AlgorithmSetting]) -> (bool, str):
for s in algorithm_settings:
try:
if s.name == "num_epochs":
if not int(s.value) > 0:
return False, "{} should be greater than zero".format(s.name)

# Validate learning rate
if s.name in {"w_lr", "w_lr_min", "alpha_lr"}:
if not float(s.value) >= 0.0:
return False, "{} should be greater than or equal to zero".format(s.name)

# Validate weight decay
if s.name in {"w_weight_decay", "alpha_weight_decay"}:
if not float(s.value) >= 0.0:
return False, "{} should be greater than or equal to zero".format(s.name)

# Validate w_momentum and w_grad_clip
if s.name in {"w_momentum", "w_grad_clip"}:
if not float(s.value) >= 0.0:
return False, "{} should be greater than or equal to zero".format(s.name)

if s.name == "batch_size":
if s.value != "None" and not int(s.value) >= 1:
return False, "batch_size should be greater than or equal to one"

if s.name == "num_workers":
if not int(s.value) >= 0:
return False, "num_workers should be greater than or equal to zero"

# Validate "init_channels", "print_step", "num_nodes" and "stem_multiplier"
if s.name in {"init_channels", "print_step", "num_nodes", "stem_multiplier"}:
if not int(s.value) >= 1:
return False, "{} should be greater than or equal to one".format(s.name)

except Exception as e:
return False, "failed to validate {name}({value}): {exception}".format(name=s.name, value=s.value,
exception=e)

return True, ""
81 changes: 23 additions & 58 deletions pkg/suggestion/v1beta1/nas/enas/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pkg.suggestion.v1beta1.nas.enas.AlgorithmSettings import (
parseAlgorithmSettings, algorithmSettingsValidator, enableNoneSettingsList)
from pkg.suggestion.v1beta1.internal.base_health_service import HealthServicer
from pkg.suggestion.v1beta1.nas.common.validation import validate_operations


class EnasExperiment:
Expand Down Expand Up @@ -161,67 +162,29 @@ def __init__(self, logger=None):

def ValidateAlgorithmSettings(self, request, context):
self.logger.info("Validate Algorithm Settings start")
graph_config = request.experiment.spec.nas_config.graph_config
nas_config = request.experiment.spec.nas_config
graph_config = nas_config.graph_config

# Validate GraphConfig
# Check InputSize
if not graph_config.input_sizes:
return self.SetValidateContextError(context, "Missing InputSizes in GraphConfig:\n{}".format(graph_config))
return self.set_validate_context_error(context,
"Missing InputSizes in GraphConfig:\n{}".format(graph_config))

# Check OutputSize
if not graph_config.output_sizes:
return self.SetValidateContextError(context, "Missing OutputSizes in GraphConfig:\n{}".format(graph_config))
return self.set_validate_context_error(context,
"Missing OutputSizes in GraphConfig:\n{}".format(graph_config))

# Check NumLayers
if not graph_config.num_layers:
return self.SetValidateContextError(context, "Missing NumLayers in GraphConfig:\n{}".format(graph_config))

# Validate each operation
operations_list = list(
request.experiment.spec.nas_config.operations.operation)
for operation in operations_list:

# Check OperationType
if not operation.operation_type:
return self.SetValidateContextError(context, "Missing operationType in Operation:\n{}".format(
operation))

# Check ParameterConfigs
if not operation.parameter_specs.parameters:
return self.SetValidateContextError(context, "Missing ParameterConfigs in Operation:\n{}".format(
operation))

# Validate each ParameterConfig in Operation
parameters_list = list(operation.parameter_specs.parameters)
for parameter in parameters_list:

# Check Name
if not parameter.name:
return self.SetValidateContextError(context, "Missing Name in ParameterConfig:\n{}".format(
parameter))

# Check ParameterType
if not parameter.parameter_type:
return self.SetValidateContextError(context, "Missing ParameterType in ParameterConfig:\n{}".format(
parameter))

# Check List in Categorical or Discrete Type
if parameter.parameter_type == api_pb2.CATEGORICAL or parameter.parameter_type == api_pb2.DISCRETE:
if not parameter.feasible_space.list:
return self.SetValidateContextError(
context, "Missing List in ParameterConfig.feasibleSpace:\n{}".format(parameter))

# Check Max, Min, Step in Int or Double Type
elif parameter.parameter_type == api_pb2.INT or parameter.parameter_type == api_pb2.DOUBLE:
if not parameter.feasible_space.min and not parameter.feasible_space.max:
return self.SetValidateContextError(
context, "Missing Max and Min in ParameterConfig.feasibleSpace:\n{}".format(parameter))

if (parameter.parameter_type == api_pb2.DOUBLE and
(not parameter.feasible_space.step or float(parameter.feasible_space.step) <= 0)):
return self.SetValidateContextError(
context, "Step parameter should be > 0 in ParameterConfig.feasibleSpace:\n{}".format(
parameter))
return self.set_validate_context_error(context,
"Missing NumLayers in GraphConfig:\n{}".format(graph_config))

# Validate Operations
is_valid, message = validate_operations(nas_config.operations.operation)
if not is_valid:
return self.set_validate_context_error(context, message)

# Validate Algorithm Settings
settings_raw = request.experiment.spec.algorithm.algorithm_settings
Expand All @@ -233,14 +196,15 @@ def ValidateAlgorithmSettings(self, request, context):
setting_range = algorithmSettingsValidator[setting.name][1]
try:
converted_value = setting_type(setting.value)
except Exception:
return self.SetValidateContextError(context, "Algorithm Setting {} must be {} type".format(
setting.name, setting_type.__name__))
except Exception as e:
return self.set_validate_context_error(context,
"Algorithm Setting {} must be {} type: exception {}".format(
setting.name, setting_type.__name__, e))

if setting_type == float:
if (converted_value <= setting_range[0] or
(setting_range[1] != 'inf' and converted_value > setting_range[1])):
return self.SetValidateContextError(
return self.set_validate_context_error(
context, "Algorithm Setting {}: {} with {} type must be in range ({}, {}]".format(
setting.name,
converted_value,
Expand All @@ -250,7 +214,7 @@ def ValidateAlgorithmSettings(self, request, context):
)

elif converted_value < setting_range[0]:
return self.SetValidateContextError(
return self.set_validate_context_error(
context, "Algorithm Setting {}: {} with {} type must be in range [{}, {})".format(
setting.name,
converted_value,
Expand All @@ -259,12 +223,13 @@ def ValidateAlgorithmSettings(self, request, context):
setting_range[1])
)
else:
return self.SetValidateContextError(context, "Unknown Algorithm Setting name: {}".format(setting.name))
return self.set_validate_context_error(context,
"Unknown Algorithm Setting name: {}".format(setting.name))

self.logger.info("All Experiment Settings are Valid")
return api_pb2.ValidateAlgorithmSettingsReply()

def SetValidateContextError(self, context, error_message):
def set_validate_context_error(self, context, error_message):
context.set_code(grpc.StatusCode.INVALID_ARGUMENT)
context.set_details(error_message)
self.logger.info(error_message)
Expand Down
Loading