From cef1bc704e7a62736ba643a8df2b21577e7e4f24 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Thu, 29 Jul 2021 21:46:51 -0700 Subject: [PATCH 1/6] Add "final_extra_valid_opt_filepath" to train_model.py Also a way to configure saving outputs to json. Tested with a test. (Not super attached to names or to the default values I've got set. I tried using 'Optional[str]' for the cmdline arg, but the test was not happy with that...) --- parlai/scripts/train_model.py | 58 ++++++++++++++++++++++++++++++++--- tests/test_train_model.py | 57 ++++++++++++++++++++++++++++++++++ 2 files changed, 111 insertions(+), 4 deletions(-) diff --git a/parlai/scripts/train_model.py b/parlai/scripts/train_model.py index 64a3d8c92b3..117050eb65a 100644 --- a/parlai/scripts/train_model.py +++ b/parlai/scripts/train_model.py @@ -23,7 +23,7 @@ # TODO List: # * More logging (e.g. to files), make things prettier. - +import copy import json import numpy as np import signal @@ -87,6 +87,18 @@ def setup_args(parser=None) -> ParlaiParser: '--evaltask', help='task to use for valid/test (defaults to the one used for training)', ) + train.add_argument( + '--final-extra-valid-opt-filepath', + type=str, + default='', + help="A '.opt' file that is used for final eval. Useful for setting skip-generation to false. 'datatype' must be included as part of the opt.", + ) + train.add_argument( + '--write-log-as-json', + type=bool, + default=False, + help="Write metrics log as json rather than a pretty print of the report.", + ) train.add_argument( '--eval-batchsize', type=int, @@ -602,7 +614,9 @@ def _run_single_eval(self, opt, valid_world, max_exs): return valid_report - def _run_eval(self, valid_worlds, opt, datatype, max_exs=-1, write_log=False): + def _run_eval( + self, valid_worlds, opt, datatype, max_exs=-1, write_log=False, log_suffix="" + ): """ Eval on validation/test data. @@ -642,8 +656,16 @@ def _run_eval(self, valid_worlds, opt, datatype, max_exs=-1, write_log=False): # write to file if write_log and opt.get('model_file') and is_primary_worker(): # Write out metrics - with PathManager.open(opt['model_file'] + '.' + datatype, 'a') as f: - f.write(f'{metrics}\n') + if opt["write_log_as_json"]: + with PathManager.open( + opt['model_file'] + log_suffix + '.' + datatype + ".json", 'a' + ) as f: + json.dump(dict_report(report), f) + else: + with PathManager.open( + opt['model_file'] + log_suffix + '.' + datatype, 'a' + ) as f: + f.write(f'{metrics}\n') return report @@ -919,6 +941,34 @@ def train(self): print_announcements(opt) + if opt['final_extra_valid_opt_filepath'] is not '': + final_valid_opt = copy.deepcopy(opt) + final_valid_opt_raw = Opt.load_init(opt['final_extra_valid_opt_filepath']) + final_datatype = final_valid_opt_raw["datatype"] + for k, v in final_valid_opt_raw.items(): + final_valid_opt[k] = v + final_max_exs = ( + final_valid_opt['validation_max_exs'] + if final_valid_opt.get('short_final_eval') + else -1 + ) + final_valid_world = load_eval_worlds( + self.agent, final_valid_opt, final_datatype + ) + final_valid_report = self._run_eval( + final_valid_world, + final_valid_opt, + final_datatype, + final_max_exs, + write_log=True, + log_suffix="_final", + ) + if opt['wandb_log'] and is_primary_worker(): + self.wb_logger.log_final(final_datatype, final_valid_report) + + if opt['wandb_log'] and is_primary_worker(): + self.wb_logger.finish() + return v_report, t_report diff --git a/tests/test_train_model.py b/tests/test_train_model.py index b15c3b1b188..be11307c4a7 100644 --- a/tests/test_train_model.py +++ b/tests/test_train_model.py @@ -13,12 +13,69 @@ import parlai.utils.testing as testing_utils from parlai.core.metrics import AverageMetric from parlai.core.worlds import create_task +from parlai.core.opt import Opt from parlai.core.params import ParlaiParser from parlai.core.agents import register_agent, Agent from parlai.utils.data import DatatypeHelper class TestTrainModel(unittest.TestCase): + def test_final_extra_eval_and_save_json(self): + """ + Test "final_extra_valid_opt_filepath". Happens to test that saving reports as + json works too. + + We copy train_model from testing_utils to directly access train loop. + """ + import parlai.scripts.train_model as tms + + def get_tl(tmpdir): + final_opt = Opt( + { + 'task': 'integration_tests', + 'datatype': 'valid', + 'validation_max_exs': 30, + 'short_final_eval': True, + 'write_log_as_json': True, + } + ) + final_opt.save(os.path.join(tmpdir, "final_opt.opt")) + + opt = Opt( + { + 'task': 'integration_tests', + 'validation_max_exs': 10, + 'model': 'repeat_label', + 'model_file': os.path.join(tmpdir, 'model'), + 'short_final_eval': True, + 'num_epochs': 1.0, + 'write_log_as_json': True, + 'final_extra_valid_opt_filepath': str( + os.path.join(tmpdir, "final_opt.opt") + ), + } + ) + parser = tms.setup_args() + parser.set_params(**opt) + popt = parser.parse_args([]) + for k, v in opt.items(): + popt[k] = v + return tms.TrainLoop(popt) + + with testing_utils.capture_output(), testing_utils.tempdir() as tmpdir: + tl = get_tl(tmpdir) + _, _ = tl.train() + + with open(os.path.join(tmpdir, 'model.valid.json')) as f: + self.assertEqual( + json.load(f)['exs'], 10, "Validation exs saved incorrectly" + ) + + with open(os.path.join(tmpdir, 'model_final.valid.json')) as f: + self.assertEqual( + json.load(f)['exs'], 30, "Final validation exs saved incorrectly" + ) + def test_fast_final_eval(self): valid, test = testing_utils.train_model( { From 0589e687737741a3ba170096a9f554cca10463cc Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Tue, 3 Aug 2021 17:17:29 -0700 Subject: [PATCH 2/6] Make requested changes --- parlai/scripts/train_model.py | 78 +++++++++++++++++------------------ tests/test_train_model.py | 6 +-- 2 files changed, 38 insertions(+), 46 deletions(-) diff --git a/parlai/scripts/train_model.py b/parlai/scripts/train_model.py index 117050eb65a..ab90df1ab90 100644 --- a/parlai/scripts/train_model.py +++ b/parlai/scripts/train_model.py @@ -88,17 +88,11 @@ def setup_args(parser=None) -> ParlaiParser: help='task to use for valid/test (defaults to the one used for training)', ) train.add_argument( - '--final-extra-valid-opt-filepath', + '--final-extra-opt', type=str, default='', help="A '.opt' file that is used for final eval. Useful for setting skip-generation to false. 'datatype' must be included as part of the opt.", ) - train.add_argument( - '--write-log-as-json', - type=bool, - default=False, - help="Write metrics log as json rather than a pretty print of the report.", - ) train.add_argument( '--eval-batchsize', type=int, @@ -656,19 +650,43 @@ def _run_eval( # write to file if write_log and opt.get('model_file') and is_primary_worker(): # Write out metrics - if opt["write_log_as_json"]: - with PathManager.open( - opt['model_file'] + log_suffix + '.' + datatype + ".json", 'a' - ) as f: - json.dump(dict_report(report), f) - else: - with PathManager.open( - opt['model_file'] + log_suffix + '.' + datatype, 'a' - ) as f: - f.write(f'{metrics}\n') + with PathManager.open( + opt['model_file'] + log_suffix + '.' + datatype + ".json", 'a' + ) as f: + json.dump(dict_report(report), f) + + with PathManager.open( + opt['model_file'] + log_suffix + '.' + datatype, 'a' + ) as f: + f.write(f'{metrics}\n') return report + def _run_final_extra_eval(self, opt): + final_valid_opt = copy.deepcopy(opt) + final_valid_opt_raw = Opt.load_init(opt['final_extra_opt']) + final_datatype = final_valid_opt_raw["datatype"] + for k, v in final_valid_opt_raw.items(): + final_valid_opt[k] = v + final_max_exs = ( + final_valid_opt['validation_max_exs'] + if final_valid_opt.get('short_final_eval') + else -1 + ) + final_valid_world = load_eval_worlds( + self.agent, final_valid_opt, final_datatype + ) + final_valid_report = self._run_eval( + final_valid_world, + final_valid_opt, + final_datatype, + final_max_exs, + write_log=True, + log_suffix="_final", + ) + if opt['wandb_log'] and is_primary_worker(): + self.wb_logger.log_final(final_datatype, final_valid_report) + def _sync_metrics(self, metrics): """ Sync training metrics across workers. @@ -941,30 +959,8 @@ def train(self): print_announcements(opt) - if opt['final_extra_valid_opt_filepath'] is not '': - final_valid_opt = copy.deepcopy(opt) - final_valid_opt_raw = Opt.load_init(opt['final_extra_valid_opt_filepath']) - final_datatype = final_valid_opt_raw["datatype"] - for k, v in final_valid_opt_raw.items(): - final_valid_opt[k] = v - final_max_exs = ( - final_valid_opt['validation_max_exs'] - if final_valid_opt.get('short_final_eval') - else -1 - ) - final_valid_world = load_eval_worlds( - self.agent, final_valid_opt, final_datatype - ) - final_valid_report = self._run_eval( - final_valid_world, - final_valid_opt, - final_datatype, - final_max_exs, - write_log=True, - log_suffix="_final", - ) - if opt['wandb_log'] and is_primary_worker(): - self.wb_logger.log_final(final_datatype, final_valid_report) + if opt['final_extra_opt'] is not '': + self._run_final_extra_eval(opt) if opt['wandb_log'] and is_primary_worker(): self.wb_logger.finish() diff --git a/tests/test_train_model.py b/tests/test_train_model.py index be11307c4a7..699856ee265 100644 --- a/tests/test_train_model.py +++ b/tests/test_train_model.py @@ -36,7 +36,6 @@ def get_tl(tmpdir): 'datatype': 'valid', 'validation_max_exs': 30, 'short_final_eval': True, - 'write_log_as_json': True, } ) final_opt.save(os.path.join(tmpdir, "final_opt.opt")) @@ -49,10 +48,7 @@ def get_tl(tmpdir): 'model_file': os.path.join(tmpdir, 'model'), 'short_final_eval': True, 'num_epochs': 1.0, - 'write_log_as_json': True, - 'final_extra_valid_opt_filepath': str( - os.path.join(tmpdir, "final_opt.opt") - ), + 'final_extra_opt': str(os.path.join(tmpdir, "final_opt.opt")), } ) parser = tms.setup_args() From bd27b3ac572f34cfea5659b9dba19287df7bd46d Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Thu, 5 Aug 2021 10:18:10 -0700 Subject: [PATCH 3/6] save final reports to .trainstats --- parlai/scripts/train_model.py | 26 ++++++++++++++++---------- tests/test_train_model.py | 12 ++++++++---- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/parlai/scripts/train_model.py b/parlai/scripts/train_model.py index ab90df1ab90..4ff473c9493 100644 --- a/parlai/scripts/train_model.py +++ b/parlai/scripts/train_model.py @@ -397,6 +397,9 @@ def __init__(self, opt): self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1 self.train_reports = [] self.valid_reports = [] + self.final_valid_report = {} + self.final_test_report = {} + self.final_extra_valid_report = {} self.best_valid = None self.impatience = 0 @@ -487,6 +490,9 @@ def _save_train_stats(self, suffix=None): 'valid_reports': self.valid_reports, 'best_valid': self.best_valid, 'impatience': self.impatience, + 'final_valid_report': self.final_valid_report, + 'final_test_report': self.final_test_report, + 'final_extra_valid_report': self.final_extra_valid_report, }, f, indent=4, @@ -650,11 +656,6 @@ def _run_eval( # write to file if write_log and opt.get('model_file') and is_primary_worker(): # Write out metrics - with PathManager.open( - opt['model_file'] + log_suffix + '.' + datatype + ".json", 'a' - ) as f: - json.dump(dict_report(report), f) - with PathManager.open( opt['model_file'] + log_suffix + '.' + datatype, 'a' ) as f: @@ -682,11 +683,12 @@ def _run_final_extra_eval(self, opt): final_datatype, final_max_exs, write_log=True, - log_suffix="_final", ) if opt['wandb_log'] and is_primary_worker(): self.wb_logger.log_final(final_datatype, final_valid_report) + return final_valid_report + def _sync_metrics(self, metrics): """ Sync training metrics across workers. @@ -941,9 +943,13 @@ def train(self): # perform final validation/testing valid_worlds = load_eval_worlds(self.agent, opt, 'valid') max_exs = opt['validation_max_exs'] if opt.get('short_final_eval') else -1 - v_report = self._run_eval(valid_worlds, opt, 'valid', max_exs, write_log=True) + self.final_valid_report = self._run_eval( + valid_worlds, opt, 'valid', max_exs, write_log=True + ) test_worlds = load_eval_worlds(self.agent, opt, 'test') - t_report = self._run_eval(test_worlds, opt, 'test', max_exs, write_log=True) + self.final_test_report = self._run_eval( + test_worlds, opt, 'test', max_exs, write_log=True + ) if opt['wandb_log'] and is_primary_worker(): self.wb_logger.log_final('valid', v_report) @@ -960,12 +966,12 @@ def train(self): print_announcements(opt) if opt['final_extra_opt'] is not '': - self._run_final_extra_eval(opt) + self.final_extra_valid_report = self._run_final_extra_eval(opt) if opt['wandb_log'] and is_primary_worker(): self.wb_logger.finish() - return v_report, t_report + return self.final_valid_report, self.final_test_report @register_script('train_model', aliases=['tm', 'train']) diff --git a/tests/test_train_model.py b/tests/test_train_model.py index 699856ee265..c37182b4acb 100644 --- a/tests/test_train_model.py +++ b/tests/test_train_model.py @@ -62,14 +62,18 @@ def get_tl(tmpdir): tl = get_tl(tmpdir) _, _ = tl.train() - with open(os.path.join(tmpdir, 'model.valid.json')) as f: + with open(os.path.join(tmpdir, 'model.trainstats')) as f: + data = json.load(f) self.assertEqual( - json.load(f)['exs'], 10, "Validation exs saved incorrectly" + data["final_valid_report"]["exs"], + 10, + "Validation exs saved incorrectly", ) - with open(os.path.join(tmpdir, 'model_final.valid.json')) as f: self.assertEqual( - json.load(f)['exs'], 30, "Final validation exs saved incorrectly" + data["final_extra_valid_report"]["exs"], + 30, + "Final validation exs saved incorrectly", ) def test_fast_final_eval(self): From b15624056b53347a36eaa6ad4ab249c013e5b432 Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Thu, 5 Aug 2021 10:22:23 -0700 Subject: [PATCH 4/6] delete unnecessary lines --- parlai/scripts/train_model.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/parlai/scripts/train_model.py b/parlai/scripts/train_model.py index 4ff473c9493..4092e72edbb 100644 --- a/parlai/scripts/train_model.py +++ b/parlai/scripts/train_model.py @@ -614,9 +614,7 @@ def _run_single_eval(self, opt, valid_world, max_exs): return valid_report - def _run_eval( - self, valid_worlds, opt, datatype, max_exs=-1, write_log=False, log_suffix="" - ): + def _run_eval(self, valid_worlds, opt, datatype, max_exs=-1, write_log=False): """ Eval on validation/test data. @@ -656,9 +654,7 @@ def _run_eval( # write to file if write_log and opt.get('model_file') and is_primary_worker(): # Write out metrics - with PathManager.open( - opt['model_file'] + log_suffix + '.' + datatype, 'a' - ) as f: + with PathManager.open(opt['model_file'] + '.' + datatype, 'a') as f: f.write(f'{metrics}\n') return report From a0368e960d59efffeacf00b5bd743bf48f75a57c Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Thu, 5 Aug 2021 10:53:37 -0700 Subject: [PATCH 5/6] run tests + linter --- parlai/scripts/train_model.py | 16 ++++++++++------ tests/test_train_model.py | 2 +- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/parlai/scripts/train_model.py b/parlai/scripts/train_model.py index 4092e72edbb..71f7c492e2b 100644 --- a/parlai/scripts/train_model.py +++ b/parlai/scripts/train_model.py @@ -490,9 +490,11 @@ def _save_train_stats(self, suffix=None): 'valid_reports': self.valid_reports, 'best_valid': self.best_valid, 'impatience': self.impatience, - 'final_valid_report': self.final_valid_report, - 'final_test_report': self.final_test_report, - 'final_extra_valid_report': self.final_extra_valid_report, + 'final_valid_report': dict_report(self.final_valid_report), + 'final_test_report': dict_report(self.final_test_report), + 'final_extra_valid_report': dict_report( + self.final_extra_valid_report + ), }, f, indent=4, @@ -948,8 +950,8 @@ def train(self): ) if opt['wandb_log'] and is_primary_worker(): - self.wb_logger.log_final('valid', v_report) - self.wb_logger.log_final('test', t_report) + self.wb_logger.log_final('valid', self.final_valid_report) + self.wb_logger.log_final('test', self.final_test_report) self.wb_logger.finish() if valid_worlds: @@ -961,12 +963,14 @@ def train(self): print_announcements(opt) - if opt['final_extra_opt'] is not '': + if opt['final_extra_opt'] != '': self.final_extra_valid_report = self._run_final_extra_eval(opt) if opt['wandb_log'] and is_primary_worker(): self.wb_logger.finish() + self._save_train_stats() + return self.final_valid_report, self.final_test_report diff --git a/tests/test_train_model.py b/tests/test_train_model.py index c37182b4acb..00fc1e77600 100644 --- a/tests/test_train_model.py +++ b/tests/test_train_model.py @@ -16,7 +16,6 @@ from parlai.core.opt import Opt from parlai.core.params import ParlaiParser from parlai.core.agents import register_agent, Agent -from parlai.utils.data import DatatypeHelper class TestTrainModel(unittest.TestCase): @@ -64,6 +63,7 @@ def get_tl(tmpdir): with open(os.path.join(tmpdir, 'model.trainstats')) as f: data = json.load(f) + print(data) self.assertEqual( data["final_valid_report"]["exs"], 10, From 8ce7b933accbe2745c4fbe6b243f69356556ba5a Mon Sep 17 00:00:00 2001 From: Moya Chen Date: Thu, 5 Aug 2021 13:25:01 -0700 Subject: [PATCH 6/6] Fix that test since it is actually relevant --- parlai/scripts/train_model.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/parlai/scripts/train_model.py b/parlai/scripts/train_model.py index 71f7c492e2b..ddcde9c8857 100644 --- a/parlai/scripts/train_model.py +++ b/parlai/scripts/train_model.py @@ -475,7 +475,9 @@ def save_model(self, suffix=None): pass def _save_train_stats(self, suffix=None): - fn = self.opt['model_file'] + fn = self.opt.get('model_file', None) + if not fn: + return if suffix: fn += suffix fn += '.trainstats' @@ -616,7 +618,15 @@ def _run_single_eval(self, opt, valid_world, max_exs): return valid_report - def _run_eval(self, valid_worlds, opt, datatype, max_exs=-1, write_log=False): + def _run_eval( + self, + valid_worlds, + opt, + datatype, + max_exs=-1, + write_log=False, + extra_log_suffix="", + ): """ Eval on validation/test data. @@ -656,7 +666,9 @@ def _run_eval(self, valid_worlds, opt, datatype, max_exs=-1, write_log=False): # write to file if write_log and opt.get('model_file') and is_primary_worker(): # Write out metrics - with PathManager.open(opt['model_file'] + '.' + datatype, 'a') as f: + with PathManager.open( + opt['model_file'] + extra_log_suffix + '.' + datatype, 'a' + ) as f: f.write(f'{metrics}\n') return report @@ -681,6 +693,7 @@ def _run_final_extra_eval(self, opt): final_datatype, final_max_exs, write_log=True, + extra_log_suffix="_extra", ) if opt['wandb_log'] and is_primary_worker(): self.wb_logger.log_final(final_datatype, final_valid_report)