diff --git a/test/nni_test/nnitest/run_tests.py b/test/nni_test/nnitest/run_tests.py index 054991c1c5..2b8c4c2e29 100644 --- a/test/nni_test/nnitest/run_tests.py +++ b/test/nni_test/nnitest/run_tests.py @@ -70,7 +70,7 @@ def run_test_case(test_case_config, it_config, args): try: launch_test(new_config_file, args.ts, test_case_config) - invoke_validator(test_case_config, args.nni_source_dir) + invoke_validator(test_case_config, args.nni_source_dir, args.ts) finally: stop_command = get_command(test_case_config, 'stopCommand') print('Stop command:', stop_command, flush=True) @@ -80,7 +80,7 @@ def run_test_case(test_case_config, it_config, args): if os.path.exists(new_config_file): os.remove(new_config_file) -def invoke_validator(test_case_config, nni_source_dir): +def invoke_validator(test_case_config, nni_source_dir, training_service): validator_config = test_case_config.get('validator') if validator_config is None or validator_config.get('class') is None: return @@ -88,7 +88,13 @@ def invoke_validator(test_case_config, nni_source_dir): validator = validators.__dict__[validator_config.get('class')]() kwargs = validator_config.get('kwargs', {}) print('kwargs:', kwargs) - validator(REST_ENDPOINT, get_experiment_dir(EXPERIMENT_URL), nni_source_dir, **kwargs) + experiment_id = get_experiment_id(EXPERIMENT_URL) + try: + validator(REST_ENDPOINT, get_experiment_dir(EXPERIMENT_URL), nni_source_dir, **kwargs) + except: + print_experiment_log(experiment_id=experiment_id) + print_trial_job_log(training_service, TRIAL_JOBS_URL) + raise def get_max_values(config_file): experiment_config = get_yml_content(config_file) @@ -117,7 +123,7 @@ def launch_test(config_file, training_service, test_case_config): proc = subprocess.run(shlex.split(launch_command)) - assert proc.returncode == 0, '`nnictl create` failed with code %d' % proc.returncode + assert proc.returncode == 0, 'launch command failed with code %d' % proc.returncode # set experiment ID into variable exp_var_name = test_case_config.get('setExperimentIdtoVar') @@ -134,24 +140,30 @@ def launch_test(config_file, training_service, test_case_config): bg_time = time.time() print(str(datetime.datetime.now()), ' waiting ...', flush=True) - while True: + try: + # wait restful server to be ready time.sleep(3) - waited_time = time.time() - bg_time - if waited_time > max_duration + 10: - print('waited: {}, max_duration: {}'.format(waited_time, max_duration)) - break - status = get_experiment_status(STATUS_URL) - if status in ['DONE', 'ERROR']: - print('experiment status:', status) - break - num_failed = len(get_failed_trial_jobs(TRIAL_JOBS_URL)) - if num_failed > 0: - print('failed jobs: ', num_failed) - break - + experiment_id = get_experiment_id(EXPERIMENT_URL) + while True: + waited_time = time.time() - bg_time + if waited_time > max_duration + 10: + print('waited: {}, max_duration: {}'.format(waited_time, max_duration)) + break + status = get_experiment_status(STATUS_URL) + if status in ['DONE', 'ERROR']: + print('experiment status:', status) + break + num_failed = len(get_failed_trial_jobs(TRIAL_JOBS_URL)) + if num_failed > 0: + print('failed jobs: ', num_failed) + break + time.sleep(3) + except: + print_experiment_log(experiment_id=experiment_id) + raise print(str(datetime.datetime.now()), ' waiting done', flush=True) if get_experiment_status(STATUS_URL) == 'ERROR': - print_experiment_log(EXPERIMENT_URL) + print_experiment_log(experiment_id=experiment_id) trial_stats = get_trial_stats(TRIAL_JOBS_URL) print(json.dumps(trial_stats, indent=4), flush=True) diff --git a/test/nni_test/nnitest/utils.py b/test/nni_test/nnitest/utils.py index c5a6c05ca2..e362b0de07 100644 --- a/test/nni_test/nnitest/utils.py +++ b/test/nni_test/nnitest/utils.py @@ -10,6 +10,7 @@ import requests import time import ruamel.yaml as yaml +import shlex EXPERIMENT_DONE_SIGNAL = 'Experiment done' @@ -65,14 +66,16 @@ def get_experiment_id(experiment_url): experiment_id = requests.get(experiment_url).json()['id'] return experiment_id -def get_experiment_dir(experiment_url): +def get_experiment_dir(experiment_url=None, experiment_id=None): '''get experiment root directory''' - experiment_id = get_experiment_id(experiment_url) + assert any([experiment_url, experiment_id]) + if experiment_id is None: + experiment_id = get_experiment_id(experiment_url) return os.path.join(os.path.expanduser('~'), 'nni', 'experiments', experiment_id) -def get_nni_log_dir(experiment_url): +def get_nni_log_dir(experiment_url=None, experiment_id=None): '''get nni's log directory from nni's experiment url''' - return os.path.join(get_experiment_dir(experiment_url), 'log') + return os.path.join(get_experiment_dir(experiment_url, experiment_id), 'log') def get_nni_log_path(experiment_url): '''get nni's log path from nni's experiment url''' @@ -125,12 +128,17 @@ def print_trial_job_log(training_service, trial_jobs_url): for log_file in log_files: print_file_content(os.path.join(trial_log_dir, log_file)) -def print_experiment_log(experiment_url): - log_dir = get_nni_log_dir(experiment_url) +def print_experiment_log(experiment_id): + log_dir = get_nni_log_dir(experiment_id=experiment_id) for log_file in ['dispatcher.log', 'nnimanager.log']: filepath = os.path.join(log_dir, log_file) print_file_content(filepath) + print('nnictl log stderr:') + subprocess.run(shlex.split('nnictl log stderr {}'.format(experiment_id))) + print('nnictl log stdout:') + subprocess.run(shlex.split('nnictl log stdout {}'.format(experiment_id))) + def parse_max_duration_time(max_exec_duration): unit = max_exec_duration[-1] time = max_exec_duration[:-1]