diff --git a/docs/en_US/Tuner/HyperbandAdvisor.md b/docs/en_US/Tuner/HyperbandAdvisor.md index a367b06b13..b7787af199 100644 --- a/docs/en_US/Tuner/HyperbandAdvisor.md +++ b/docs/en_US/Tuner/HyperbandAdvisor.md @@ -5,7 +5,7 @@ Hyperband on NNI [Hyperband][1] is a popular automl algorithm. The basic idea of Hyperband is that it creates several buckets, each bucket has `n` randomly generated hyperparameter configurations, each configuration uses `r` resource (e.g., epoch number, batch number). After the `n` configurations is finished, it chooses top `n/eta` configurations and runs them using increased `r*eta` resource. At last, it chooses the best configuration it has found so far. ## 2. Implementation with fully parallelism -Frist, this is an example of how to write an automl algorithm based on MsgDispatcherBase, rather than Tuner and Assessor. Hyperband is implemented in this way because it integrates the functions of both Tuner and Assessor, thus, we call it advisor. +First, this is an example of how to write an automl algorithm based on MsgDispatcherBase, rather than Tuner and Assessor. Hyperband is implemented in this way because it integrates the functions of both Tuner and Assessor, thus, we call it advisor. Second, this implementation fully leverages Hyperband's internal parallelism. More specifically, the next bucket is not started strictly after the current bucket, instead, it starts when there is available resource. diff --git a/src/nni_manager/training_service/pai/paiK8S/paiK8STrainingService.ts b/src/nni_manager/training_service/pai/paiK8S/paiK8STrainingService.ts index fc64d4dbdc..b6b6b4c823 100644 --- a/src/nni_manager/training_service/pai/paiK8S/paiK8STrainingService.ts +++ b/src/nni_manager/training_service/pai/paiK8S/paiK8STrainingService.ts @@ -59,6 +59,10 @@ class PAIK8STrainingService extends PAITrainingService { public async setClusterMetadata(key: string, value: string): Promise { switch (key) { + case TrialConfigMetadataKey.NNI_MANAGER_IP: + this.nniManagerIpConfig = JSON.parse(value); + break; + case TrialConfigMetadataKey.PAI_CLUSTER_CONFIG: this.paiJobRestServer = new PAIJobRestServer(component.get(PAIK8STrainingService)); this.paiClusterConfig = JSON.parse(value); diff --git a/tools/nni_cmd/config_schema.py b/tools/nni_cmd/config_schema.py index 8017946ce9..30db148200 100644 --- a/tools/nni_cmd/config_schema.py +++ b/tools/nni_cmd/config_schema.py @@ -407,15 +407,8 @@ def setPathCheck(key): } machine_list_schema = { - Optional('machineList'):[Or({ - 'ip': setType('ip', str), - Optional('port'): setNumberRange('port', int, 1, 65535), - 'username': setType('username', str), - 'passwd': setType('passwd', str), - Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'), - Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int), - Optional('useActiveGpu'): setType('useActiveGpu', bool) - }, { + Optional('machineList'):[Or( + { 'ip': setType('ip', str), Optional('port'): setNumberRange('port', int, 1, 65535), 'username': setType('username', str), @@ -424,6 +417,15 @@ def setPathCheck(key): Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'), Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int), Optional('useActiveGpu'): setType('useActiveGpu', bool) + }, + { + 'ip': setType('ip', str), + Optional('port'): setNumberRange('port', int, 1, 65535), + 'username': setType('username', str), + 'passwd': setType('passwd', str), + Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'), + Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int), + Optional('useActiveGpu'): setType('useActiveGpu', bool) })] } diff --git a/tools/nni_cmd/nnictl_utils.py b/tools/nni_cmd/nnictl_utils.py index a66197fac9..4866bcdce4 100644 --- a/tools/nni_cmd/nnictl_utils.py +++ b/tools/nni_cmd/nnictl_utils.py @@ -403,11 +403,13 @@ def remote_clean(machine_list, experiment_id=None): userName = machine.get('username') host = machine.get('ip') port = machine.get('port') + sshKeyPath = machine.get('sshKeyPath') + passphrase = machine.get('passphrase') if experiment_id: remote_dir = '/' + '/'.join(['tmp', 'nni', 'experiments', experiment_id]) else: remote_dir = '/' + '/'.join(['tmp', 'nni', 'experiments']) - sftp = create_ssh_sftp_client(host, port, userName, passwd) + sftp = create_ssh_sftp_client(host, port, userName, passwd, sshKeyPath, passphrase) print_normal('removing folder {0}'.format(host + ':' + str(port) + remote_dir)) remove_remote_directory(sftp, remote_dir) diff --git a/tools/nni_cmd/ssh_utils.py b/tools/nni_cmd/ssh_utils.py index 2e68611206..e3f26a8e24 100644 --- a/tools/nni_cmd/ssh_utils.py +++ b/tools/nni_cmd/ssh_utils.py @@ -30,12 +30,16 @@ def copy_remote_directory_to_local(sftp, remote_path, local_path): except Exception: pass -def create_ssh_sftp_client(host_ip, port, username, password): +def create_ssh_sftp_client(host_ip, port, username, password, ssh_key_path, passphrase): '''create ssh client''' try: paramiko = check_environment() conn = paramiko.Transport(host_ip, port) - conn.connect(username=username, password=password) + if ssh_key_path is not None: + ssh_key = paramiko.RSAKey.from_private_key_file(ssh_key_path, password=passphrase) + conn.connect(username=username, pkey=ssh_key) + else: + conn.connect(username=username, password=password) sftp = paramiko.SFTPClient.from_transport(conn) return sftp except Exception as exception: diff --git a/tools/nni_cmd/tensorboard_utils.py b/tools/nni_cmd/tensorboard_utils.py index 8cb0bbfc17..60d589083a 100644 --- a/tools/nni_cmd/tensorboard_utils.py +++ b/tools/nni_cmd/tensorboard_utils.py @@ -37,12 +37,14 @@ def copy_data_from_remote(args, nni_config, trial_content, path_list, host_list, machine_dict = {} local_path_list = [] for machine in machine_list: - machine_dict[machine['ip']] = {'port': machine['port'], 'passwd': machine['passwd'], 'username': machine['username']} + machine_dict[machine['ip']] = {'port': machine['port'], 'passwd': machine['passwd'], 'username': machine['username'], + 'sshKeyPath': machine.get('sshKeyPath'), 'passphrase': machine.get('passphrase')} for index, host in enumerate(host_list): local_path = os.path.join(temp_nni_path, trial_content[index].get('id')) local_path_list.append(local_path) print_normal('Copying log data from %s to %s' % (host + ':' + path_list[index], local_path)) - sftp = create_ssh_sftp_client(host, machine_dict[host]['port'], machine_dict[host]['username'], machine_dict[host]['passwd']) + sftp = create_ssh_sftp_client(host, machine_dict[host]['port'], machine_dict[host]['username'], machine_dict[host]['passwd'], + machine_dict[host]['sshKeyPath'], machine_dict[host]['passphrase']) copy_remote_directory_to_local(sftp, path_list[index], local_path) print_normal('Copy done!') return local_path_list