diff --git a/docs/en_US/Tutorial/Nnictl.md b/docs/en_US/Tutorial/Nnictl.md index 2c92514b4b..3c71ccf8a5 100644 --- a/docs/en_US/Tutorial/Nnictl.md +++ b/docs/en_US/Tutorial/Nnictl.md @@ -305,12 +305,14 @@ Debug mode will disable version check function in Trialkeeper. * Description - You can use this command to show trial's information. + You can use this command to show trial's information. Note that if `head` or `tail` is set, only complete trials will be listed. * Usage ```bash nnictl trial ls + nnictl trial ls --head 10 + nnictl trial ls --tail 10 ``` * Options @@ -318,6 +320,8 @@ Debug mode will disable version check function in Trialkeeper. |Name, shorthand|Required|Default|Description| |------|------|------ |------| |id| False| |ID of the experiment you want to set| + |--head|False||the number of items to be listed with the highest default metric| + |--tail|False||the number of items to be listed with the lowest default metric| * __nnictl trial kill__ diff --git a/tools/nni_cmd/nnictl.py b/tools/nni_cmd/nnictl.py index 21fdf13f6d..213554d5e8 100644 --- a/tools/nni_cmd/nnictl.py +++ b/tools/nni_cmd/nnictl.py @@ -103,6 +103,8 @@ def parse_args(): parser_trial_subparsers = parser_trial.add_subparsers() parser_trial_ls = parser_trial_subparsers.add_parser('ls', help='list trial jobs') parser_trial_ls.add_argument('id', nargs='?', help='the id of experiment') + parser_trial_ls.add_argument('--head', type=int, help='list the highest experiments on the default metric') + parser_trial_ls.add_argument('--tail', type=int, help='list the lowest experiments on the default metric') parser_trial_ls.set_defaults(func=trial_ls) parser_trial_kill = parser_trial_subparsers.add_parser('kill', help='kill trial jobs') parser_trial_kill.add_argument('id', nargs='?', help='the id of experiment') diff --git a/tools/nni_cmd/nnictl_utils.py b/tools/nni_cmd/nnictl_utils.py index 3aa5a6629a..3fad64b1fc 100644 --- a/tools/nni_cmd/nnictl_utils.py +++ b/tools/nni_cmd/nnictl_utils.py @@ -9,6 +9,7 @@ import re import shutil import subprocess +from functools import cmp_to_key from datetime import datetime, timezone from pathlib import Path from subprocess import Popen @@ -248,6 +249,20 @@ def stop_experiment(args): def trial_ls(args): '''List trial''' + def final_metric_data_cmp(lhs, rhs): + metric_l = json.loads(json.loads(lhs['finalMetricData'][0]['data'])) + metric_r = json.loads(json.loads(rhs['finalMetricData'][0]['data'])) + if isinstance(metric_l, float): + return metric_l - metric_r + elif isinstance(metric_l, dict): + return metric_l['default'] - metric_r['default'] + else: + print_error('Unexpected data format. Please check your data.') + raise ValueError + + if args.head and args.tail: + print_error('Head and tail cannot be set at the same time.') + return nni_config = Config(get_config_filename(args)) rest_port = nni_config.get_config('restServerPort') rest_pid = nni_config.get_config('restServerPid') @@ -259,6 +274,14 @@ def trial_ls(args): response = rest_get(trial_jobs_url(rest_port), REST_TIME_OUT) if response and check_response(response): content = json.loads(response.text) + if args.head: + assert args.head > 0, 'The number of requested data must be greater than 0.' + content = sorted(filter(lambda x: 'finalMetricData' in x, content), + key=cmp_to_key(final_metric_data_cmp), reverse=True)[:args.head] + elif args.tail: + assert args.tail > 0, 'The number of requested data must be greater than 0.' + content = sorted(filter(lambda x: 'finalMetricData' in x, content), + key=cmp_to_key(final_metric_data_cmp))[:args.tail] for index, value in enumerate(content): content[index] = convert_time_stamp_to_date(value) print(json.dumps(content, indent=4, sort_keys=True, separators=(',', ':'))) diff --git a/tools/nni_cmd/url_utils.py b/tools/nni_cmd/url_utils.py index 6d1f7694e1..59a28837a6 100644 --- a/tools/nni_cmd/url_utils.py +++ b/tools/nni_cmd/url_utils.py @@ -28,7 +28,6 @@ def metric_data_url(port): '''get metric_data url''' return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, METRIC_DATA_API) - def check_status_url(port): '''get check_status url''' return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, CHECK_STATUS_API)