Skip to content

Commit

Permalink
feat: Add NAS in suggestionclient
Browse files Browse the repository at this point in the history
Signed-off-by: Ce Gao <gaoce@caicloud.io>
  • Loading branch information
gaocegege committed Sep 24, 2019
1 parent 87bde83 commit 5eebcd9
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 2 deletions.
45 changes: 45 additions & 0 deletions pkg/controller.v1alpha3/suggestion/suggestionclient/nas.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package suggestionclient

import (
experimentsv1alpha3 "github.com/kubeflow/katib/pkg/apis/controller/experiments/v1alpha3"
suggestionapi "github.com/kubeflow/katib/pkg/apis/manager/v1alpha3"
)

func convertNasConfig(nasConfig *experimentsv1alpha3.NasConfig) *suggestionapi.NasConfig {
res := &suggestionapi.NasConfig{
GraphConfig: convertGraphConfig(nasConfig.GraphConfig),
Operations: convertOperations(nasConfig.Operations),
}
return res
}

func convertGraphConfig(graphConfig experimentsv1alpha3.GraphConfig) *suggestionapi.GraphConfig {
gc := &suggestionapi.GraphConfig{}
if graphConfig.NumLayers != nil {
gc.NumLayers = *graphConfig.NumLayers
}
gc.InputSizes = graphConfig.InputSizes
gc.OutputSizes = graphConfig.OutputSizes
return gc
}

func convertOperations(operations []experimentsv1alpha3.Operation) *suggestionapi.NasConfig_Operations {
ops := &suggestionapi.NasConfig_Operations{
Operation: make([]*suggestionapi.Operation, 0),
}
for _, operation := range operations {
op := &suggestionapi.Operation{
OperationType: operation.OperationType,
ParameterSpecs: convertNasParameterSpecs(operation.Parameters),
}
ops.Operation = append(ops.Operation, op)
}
return ops
}

func convertNasParameterSpecs(parameters []experimentsv1alpha3.ParameterSpec) *suggestionapi.Operation_ParameterSpecs {
ps := &suggestionapi.Operation_ParameterSpecs{
Parameters: convertParameters(parameters),
}
return ps
}
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ func (g *General) ConvertExperiment(e *experimentsv1alpha3.Experiment) *suggesti
ParameterSpecs: &suggestionapi.ExperimentSpec_ParameterSpecs{
Parameters: convertParameters(e.Spec.Parameters),
},
NasConfig: convertNasConfig(e.Spec.NasConfig),
}
return res
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/suggestion/v1alpha3/nasrl_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def print_algorithm_settings(self):

class NasrlService(api_pb2_grpc.SuggestionServicer, HealthServicer):
def __init__(self, logger=None):
super(NasrlService, self).__init__()
if logger == None:
self.logger = getLogger(__name__)
FORMAT = '%(asctime)-15s Experiment %(experiment_name)s %(message)s'
Expand All @@ -155,7 +156,7 @@ def __init__(self, logger=None):

def ValidateAlgorithmSettings(self, request, context):
self.logger.info("Validate Algorithm Settings start")
graph_config = request.experiment_spec.nas_config.graph_config
graph_config = request.experiment.spec.nas_config.graph_config

# Validate GraphConfig
# Check InputSize
Expand All @@ -172,7 +173,7 @@ def ValidateAlgorithmSettings(self, request, context):

# Validate each operation
operations_list = list(
request.experiment_spec.nas_config.operations.operation)
request.experiment.spec.nas_config.operations.operation)
for operation in operations_list:

# Check OperationType
Expand Down

0 comments on commit 5eebcd9

Please sign in to comment.