Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

add nnictl command to list trial results with highest/lowest metric #2747

Merged
merged 15 commits into from
Aug 12, 2020
2 changes: 2 additions & 0 deletions docs/en_US/Tutorial/Nnictl.md
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@ Debug mode will disable version check function in Trialkeeper.
|Name, shorthand|Required|Default|Description|
|------|------|------ |------|
|id| False| |ID of the experiment you want to set|
|--head|False||the number of items to be listed with the highest default metric|
|--tail|False||the number of items to be listed with the lowest default metric|

* __nnictl trial kill__

Expand Down
2 changes: 2 additions & 0 deletions tools/nni_cmd/nnictl.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def parse_args():
parser_trial_subparsers = parser_trial.add_subparsers()
parser_trial_ls = parser_trial_subparsers.add_parser('ls', help='list trial jobs')
parser_trial_ls.add_argument('id', nargs='?', help='the id of experiment')
parser_trial_ls.add_argument('--head', type=int)
parser_trial_ls.add_argument('--tail', type=int)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please follow other commands' coding style

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

parser_trial_ls.set_defaults(func=trial_ls)
parser_trial_kill = parser_trial_subparsers.add_parser('kill', help='kill trial jobs')
parser_trial_kill.add_argument('id', nargs='?', help='the id of experiment')
Expand Down
25 changes: 25 additions & 0 deletions tools/nni_cmd/nnictl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import re
import shutil
import subprocess
from functools import cmp_to_key
from datetime import datetime, timezone
from pathlib import Path
from subprocess import Popen
Expand Down Expand Up @@ -248,6 +249,22 @@ def stop_experiment(args):

def trial_ls(args):
'''List trial'''
def final_metric_data_cmp(lhs, rhs):
# The first json.loads handles the serialized data and the second
# reconstructs data structure to handle dict metric.
metric_l = json.loads(json.loads(lhs['finalMetricData'][0]['data']))
metric_r = json.loads(json.loads(rhs['finalMetricData'][0]['data']))
if isinstance(metric_l, float):
return metric_l - metric_r
elif isinstance(metric_l, dict):
return metric_l['default'] - metric_r['default']
else:
print_error('Unexpected data format. Please check your data.')
raise ValueError

if args.head and args.tail:
print_error('Head and tail cannot be set at the same time.')
return
nni_config = Config(get_config_filename(args))
rest_port = nni_config.get_config('restServerPort')
rest_pid = nni_config.get_config('restServerPid')
Expand All @@ -259,6 +276,14 @@ def trial_ls(args):
response = rest_get(trial_jobs_url(rest_port), REST_TIME_OUT)
if response and check_response(response):
content = json.loads(response.text)
if args.head:
assert int(args.head) > 0, 'The number of requested data must be greater than 0.'
args.head = min(int(args.head), len(content))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why use int(args.head)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't know if specify the arg's type as int, it will be converted to int.

content = sorted(content, key=cmp_to_key(final_metric_data_cmp), reverse=True)[:args.head]
elif args.tail:
assert int(args.tail) > 0, 'The number of requested data must be greater than 0.'
args.tail = min(int(args.tail), len(content))
content = sorted(content, key=cmp_to_key(final_metric_data_cmp))[:args.tail]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no filter here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

for index, value in enumerate(content):
content[index] = convert_time_stamp_to_date(value)
print(json.dumps(content, indent=4, sort_keys=True, separators=(',', ':')))
Expand Down
1 change: 0 additions & 1 deletion tools/nni_cmd/url_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

TENSORBOARD_API = '/tensorboard'


def check_status_url(port):
'''get check_status url'''
return '{0}:{1}{2}{3}'.format(BASE_URL, port, API_ROOT_URL, CHECK_STATUS_API)
Expand Down