From 5dde0d8d9e8a4012934d69129e0e965945946186 Mon Sep 17 00:00:00 2001 From: fishead Date: Mon, 3 Feb 2020 10:27:04 +0800 Subject: [PATCH 1/4] support create ssh connection by using sshkey (issue #1950) (#1957) --- tools/nni_cmd/nnictl_utils.py | 4 +++- tools/nni_cmd/ssh_utils.py | 8 ++++++-- tools/nni_cmd/tensorboard_utils.py | 6 ++++-- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/tools/nni_cmd/nnictl_utils.py b/tools/nni_cmd/nnictl_utils.py index a66197fac9..4866bcdce4 100644 --- a/tools/nni_cmd/nnictl_utils.py +++ b/tools/nni_cmd/nnictl_utils.py @@ -403,11 +403,13 @@ def remote_clean(machine_list, experiment_id=None): userName = machine.get('username') host = machine.get('ip') port = machine.get('port') + sshKeyPath = machine.get('sshKeyPath') + passphrase = machine.get('passphrase') if experiment_id: remote_dir = '/' + '/'.join(['tmp', 'nni', 'experiments', experiment_id]) else: remote_dir = '/' + '/'.join(['tmp', 'nni', 'experiments']) - sftp = create_ssh_sftp_client(host, port, userName, passwd) + sftp = create_ssh_sftp_client(host, port, userName, passwd, sshKeyPath, passphrase) print_normal('removing folder {0}'.format(host + ':' + str(port) + remote_dir)) remove_remote_directory(sftp, remote_dir) diff --git a/tools/nni_cmd/ssh_utils.py b/tools/nni_cmd/ssh_utils.py index 2e68611206..e3f26a8e24 100644 --- a/tools/nni_cmd/ssh_utils.py +++ b/tools/nni_cmd/ssh_utils.py @@ -30,12 +30,16 @@ def copy_remote_directory_to_local(sftp, remote_path, local_path): except Exception: pass -def create_ssh_sftp_client(host_ip, port, username, password): +def create_ssh_sftp_client(host_ip, port, username, password, ssh_key_path, passphrase): '''create ssh client''' try: paramiko = check_environment() conn = paramiko.Transport(host_ip, port) - conn.connect(username=username, password=password) + if ssh_key_path is not None: + ssh_key = paramiko.RSAKey.from_private_key_file(ssh_key_path, password=passphrase) + conn.connect(username=username, pkey=ssh_key) + else: + conn.connect(username=username, password=password) sftp = paramiko.SFTPClient.from_transport(conn) return sftp except Exception as exception: diff --git a/tools/nni_cmd/tensorboard_utils.py b/tools/nni_cmd/tensorboard_utils.py index 8cb0bbfc17..60d589083a 100644 --- a/tools/nni_cmd/tensorboard_utils.py +++ b/tools/nni_cmd/tensorboard_utils.py @@ -37,12 +37,14 @@ def copy_data_from_remote(args, nni_config, trial_content, path_list, host_list, machine_dict = {} local_path_list = [] for machine in machine_list: - machine_dict[machine['ip']] = {'port': machine['port'], 'passwd': machine['passwd'], 'username': machine['username']} + machine_dict[machine['ip']] = {'port': machine['port'], 'passwd': machine['passwd'], 'username': machine['username'], + 'sshKeyPath': machine.get('sshKeyPath'), 'passphrase': machine.get('passphrase')} for index, host in enumerate(host_list): local_path = os.path.join(temp_nni_path, trial_content[index].get('id')) local_path_list.append(local_path) print_normal('Copying log data from %s to %s' % (host + ':' + path_list[index], local_path)) - sftp = create_ssh_sftp_client(host, machine_dict[host]['port'], machine_dict[host]['username'], machine_dict[host]['passwd']) + sftp = create_ssh_sftp_client(host, machine_dict[host]['port'], machine_dict[host]['username'], machine_dict[host]['passwd'], + machine_dict[host]['sshKeyPath'], machine_dict[host]['passphrase']) copy_remote_directory_to_local(sftp, path_list[index], local_path) print_normal('Copy done!') return local_path_list From c55c5f4656e17cd107ce7fce042ab8659c55e9bc Mon Sep 17 00:00:00 2001 From: SparkSnail Date: Mon, 3 Feb 2020 16:03:56 +0800 Subject: [PATCH 2/4] fix support setting nniManagerIp in PAI (#1987) --- .../training_service/pai/paiK8S/paiK8STrainingService.ts | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/nni_manager/training_service/pai/paiK8S/paiK8STrainingService.ts b/src/nni_manager/training_service/pai/paiK8S/paiK8STrainingService.ts index fc64d4dbdc..b6b6b4c823 100644 --- a/src/nni_manager/training_service/pai/paiK8S/paiK8STrainingService.ts +++ b/src/nni_manager/training_service/pai/paiK8S/paiK8STrainingService.ts @@ -59,6 +59,10 @@ class PAIK8STrainingService extends PAITrainingService { public async setClusterMetadata(key: string, value: string): Promise { switch (key) { + case TrialConfigMetadataKey.NNI_MANAGER_IP: + this.nniManagerIpConfig = JSON.parse(value); + break; + case TrialConfigMetadataKey.PAI_CLUSTER_CONFIG: this.paiJobRestServer = new PAIJobRestServer(component.get(PAIK8STrainingService)); this.paiClusterConfig = JSON.parse(value); From 1c54b40ce56731f3ff6053a45a4f45af28c40bb9 Mon Sep 17 00:00:00 2001 From: SparkSnail Date: Tue, 4 Feb 2020 10:31:57 +0800 Subject: [PATCH 3/4] Change validation order in machineList (#1966) --- tools/nni_cmd/config_schema.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/tools/nni_cmd/config_schema.py b/tools/nni_cmd/config_schema.py index 8017946ce9..30db148200 100644 --- a/tools/nni_cmd/config_schema.py +++ b/tools/nni_cmd/config_schema.py @@ -407,15 +407,8 @@ def setPathCheck(key): } machine_list_schema = { - Optional('machineList'):[Or({ - 'ip': setType('ip', str), - Optional('port'): setNumberRange('port', int, 1, 65535), - 'username': setType('username', str), - 'passwd': setType('passwd', str), - Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'), - Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int), - Optional('useActiveGpu'): setType('useActiveGpu', bool) - }, { + Optional('machineList'):[Or( + { 'ip': setType('ip', str), Optional('port'): setNumberRange('port', int, 1, 65535), 'username': setType('username', str), @@ -424,6 +417,15 @@ def setPathCheck(key): Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'), Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int), Optional('useActiveGpu'): setType('useActiveGpu', bool) + }, + { + 'ip': setType('ip', str), + Optional('port'): setNumberRange('port', int, 1, 65535), + 'username': setType('username', str), + 'passwd': setType('passwd', str), + Optional('gpuIndices'): Or(int, And(str, lambda x: len([int(i) for i in x.split(',')]) > 0), error='gpuIndex format error!'), + Optional('maxTrialNumPerGpu'): setType('maxTrialNumPerGpu', int), + Optional('useActiveGpu'): setType('useActiveGpu', bool) })] } From 8e953fce608c26d112bd6c4220d730e9a6cdae4b Mon Sep 17 00:00:00 2001 From: Crissman Loomis Date: Tue, 4 Feb 2020 12:20:26 +0900 Subject: [PATCH 4/4] Typo (#1992) --- docs/en_US/Tuner/HyperbandAdvisor.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/en_US/Tuner/HyperbandAdvisor.md b/docs/en_US/Tuner/HyperbandAdvisor.md index a367b06b13..b7787af199 100644 --- a/docs/en_US/Tuner/HyperbandAdvisor.md +++ b/docs/en_US/Tuner/HyperbandAdvisor.md @@ -5,7 +5,7 @@ Hyperband on NNI [Hyperband][1] is a popular automl algorithm. The basic idea of Hyperband is that it creates several buckets, each bucket has `n` randomly generated hyperparameter configurations, each configuration uses `r` resource (e.g., epoch number, batch number). After the `n` configurations is finished, it chooses top `n/eta` configurations and runs them using increased `r*eta` resource. At last, it chooses the best configuration it has found so far. ## 2. Implementation with fully parallelism -Frist, this is an example of how to write an automl algorithm based on MsgDispatcherBase, rather than Tuner and Assessor. Hyperband is implemented in this way because it integrates the functions of both Tuner and Assessor, thus, we call it advisor. +First, this is an example of how to write an automl algorithm based on MsgDispatcherBase, rather than Tuner and Assessor. Hyperband is implemented in this way because it integrates the functions of both Tuner and Assessor, thus, we call it advisor. Second, this implementation fully leverages Hyperband's internal parallelism. More specifically, the next bucket is not started strictly after the current bucket, instead, it starts when there is available resource.