Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Add check for dependent packages of tuners before starting restful server #570

Merged
merged 7 commits into from
Jan 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tools/nni_cmd/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,4 @@

COLOR_GREEN_FORMAT = '\033[1;32;32m%s\033[0m'

COLOR_YELLOW_FORMAT = '\033[1;33;33m%s\033[0m'
COLOR_YELLOW_FORMAT = '\033[1;33;33m%s\033[0m'
15 changes: 14 additions & 1 deletion tools/nni_cmd/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,12 @@

import json
import os
import sys
import shutil
import string
from subprocess import Popen, PIPE, call, check_output
from subprocess import Popen, PIPE, call, check_output, check_call
import tempfile
from nni.constants import ModuleName
from nni_annotation import *
from .launcher_utils import validate_all_content
from .rest_utils import rest_put, rest_post, check_rest_server, check_rest_server_quick, check_response
Expand Down Expand Up @@ -272,6 +274,17 @@ def set_experiment(experiment_config, mode, port, config_file_name):
def launch_experiment(args, experiment_config, mode, config_file_name, experiment_id=None):
'''follow steps to start rest server and start experiment'''
nni_config = Config(config_file_name)

# check packages for tuner
if experiment_config.get('tuner') and experiment_config['tuner'].get('builtinTunerName'):
tuner_name = experiment_config['tuner']['builtinTunerName']
module_name = ModuleName[tuner_name]
try:
check_call([sys.executable, '-c', 'import %s'%(module_name)])
except ModuleNotFoundError as e:
print_error('The tuner %s should be installed through nnictl'%(tuner_name))
exit(1)
xuehui1991 marked this conversation as resolved.
Show resolved Hide resolved

# start rest server
rest_process, start_time = start_rest_server(args.port, experiment_config['trainingServicePlatform'], mode, config_file_name, experiment_id)
nni_config.set_config('restServerPid', rest_process.pid)
Expand Down