Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Add "final_extra_valid_opt_filepath" to train_model.py #3883

Merged
merged 7 commits into from
Aug 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 74 additions & 9 deletions parlai/scripts/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

# TODO List:
# * More logging (e.g. to files), make things prettier.

import copy
import json
import numpy as np
import signal
Expand Down Expand Up @@ -87,6 +87,12 @@ def setup_args(parser=None) -> ParlaiParser:
'--evaltask',
help='task to use for valid/test (defaults to the one used for training)',
)
train.add_argument(
'--final-extra-opt',
type=str,
default='',
help="A '.opt' file that is used for final eval. Useful for setting skip-generation to false. 'datatype' must be included as part of the opt.",
)
train.add_argument(
'--eval-batchsize',
type=int,
Expand Down Expand Up @@ -391,6 +397,9 @@ def __init__(self, opt):
self.valid_optim = 1 if opt['validation_metric_mode'] == 'max' else -1
self.train_reports = []
self.valid_reports = []
self.final_valid_report = {}
self.final_test_report = {}
self.final_extra_valid_report = {}
self.best_valid = None

self.impatience = 0
Expand Down Expand Up @@ -466,7 +475,9 @@ def save_model(self, suffix=None):
pass

def _save_train_stats(self, suffix=None):
fn = self.opt['model_file']
fn = self.opt.get('model_file', None)
if not fn:
return
if suffix:
fn += suffix
fn += '.trainstats'
Expand All @@ -481,6 +492,11 @@ def _save_train_stats(self, suffix=None):
'valid_reports': self.valid_reports,
'best_valid': self.best_valid,
'impatience': self.impatience,
'final_valid_report': dict_report(self.final_valid_report),
'final_test_report': dict_report(self.final_test_report),
'final_extra_valid_report': dict_report(
self.final_extra_valid_report
),
},
f,
indent=4,
Expand Down Expand Up @@ -602,7 +618,15 @@ def _run_single_eval(self, opt, valid_world, max_exs):

return valid_report

def _run_eval(self, valid_worlds, opt, datatype, max_exs=-1, write_log=False):
def _run_eval(
self,
valid_worlds,
opt,
datatype,
max_exs=-1,
write_log=False,
extra_log_suffix="",
):
"""
Eval on validation/test data.

Expand Down Expand Up @@ -642,11 +666,40 @@ def _run_eval(self, valid_worlds, opt, datatype, max_exs=-1, write_log=False):
# write to file
if write_log and opt.get('model_file') and is_primary_worker():
# Write out metrics
with PathManager.open(opt['model_file'] + '.' + datatype, 'a') as f:
with PathManager.open(
opt['model_file'] + extra_log_suffix + '.' + datatype, 'a'
) as f:
f.write(f'{metrics}\n')

return report

def _run_final_extra_eval(self, opt):
final_valid_opt = copy.deepcopy(opt)
final_valid_opt_raw = Opt.load_init(opt['final_extra_opt'])
final_datatype = final_valid_opt_raw["datatype"]
for k, v in final_valid_opt_raw.items():
final_valid_opt[k] = v
final_max_exs = (
final_valid_opt['validation_max_exs']
if final_valid_opt.get('short_final_eval')
else -1
)
final_valid_world = load_eval_worlds(
self.agent, final_valid_opt, final_datatype
)
final_valid_report = self._run_eval(
final_valid_world,
final_valid_opt,
final_datatype,
final_max_exs,
write_log=True,
extra_log_suffix="_extra",
)
if opt['wandb_log'] and is_primary_worker():
self.wb_logger.log_final(final_datatype, final_valid_report)

return final_valid_report

def _sync_metrics(self, metrics):
"""
Sync training metrics across workers.
Expand Down Expand Up @@ -901,13 +954,17 @@ def train(self):
# perform final validation/testing
valid_worlds = load_eval_worlds(self.agent, opt, 'valid')
max_exs = opt['validation_max_exs'] if opt.get('short_final_eval') else -1
v_report = self._run_eval(valid_worlds, opt, 'valid', max_exs, write_log=True)
self.final_valid_report = self._run_eval(
valid_worlds, opt, 'valid', max_exs, write_log=True
)
test_worlds = load_eval_worlds(self.agent, opt, 'test')
t_report = self._run_eval(test_worlds, opt, 'test', max_exs, write_log=True)
self.final_test_report = self._run_eval(
test_worlds, opt, 'test', max_exs, write_log=True
)

if opt['wandb_log'] and is_primary_worker():
self.wb_logger.log_final('valid', v_report)
self.wb_logger.log_final('test', t_report)
self.wb_logger.log_final('valid', self.final_valid_report)
self.wb_logger.log_final('test', self.final_test_report)
self.wb_logger.finish()

if valid_worlds:
Expand All @@ -919,7 +976,15 @@ def train(self):

print_announcements(opt)

return v_report, t_report
if opt['final_extra_opt'] != '':
self.final_extra_valid_report = self._run_final_extra_eval(opt)

if opt['wandb_log'] and is_primary_worker():
self.wb_logger.finish()

self._save_train_stats()

return self.final_valid_report, self.final_test_report


@register_script('train_model', aliases=['tm', 'train'])
Expand Down
59 changes: 58 additions & 1 deletion tests/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,69 @@
import parlai.utils.testing as testing_utils
from parlai.core.metrics import AverageMetric
from parlai.core.worlds import create_task
from parlai.core.opt import Opt
from parlai.core.params import ParlaiParser
from parlai.core.agents import register_agent, Agent
from parlai.utils.data import DatatypeHelper


class TestTrainModel(unittest.TestCase):
def test_final_extra_eval_and_save_json(self):
moyapchen marked this conversation as resolved.
Show resolved Hide resolved
"""
Test "final_extra_valid_opt_filepath". Happens to test that saving reports as
json works too.

We copy train_model from testing_utils to directly access train loop.
"""
import parlai.scripts.train_model as tms

def get_tl(tmpdir):
final_opt = Opt(
{
'task': 'integration_tests',
'datatype': 'valid',
'validation_max_exs': 30,
'short_final_eval': True,
}
)
final_opt.save(os.path.join(tmpdir, "final_opt.opt"))

opt = Opt(
{
'task': 'integration_tests',
'validation_max_exs': 10,
'model': 'repeat_label',
'model_file': os.path.join(tmpdir, 'model'),
'short_final_eval': True,
'num_epochs': 1.0,
'final_extra_opt': str(os.path.join(tmpdir, "final_opt.opt")),
}
)
parser = tms.setup_args()
parser.set_params(**opt)
popt = parser.parse_args([])
for k, v in opt.items():
popt[k] = v
return tms.TrainLoop(popt)

with testing_utils.capture_output(), testing_utils.tempdir() as tmpdir:
tl = get_tl(tmpdir)
_, _ = tl.train()

with open(os.path.join(tmpdir, 'model.trainstats')) as f:
data = json.load(f)
print(data)
self.assertEqual(
data["final_valid_report"]["exs"],
10,
"Validation exs saved incorrectly",
)

self.assertEqual(
data["final_extra_valid_report"]["exs"],
30,
"Final validation exs saved incorrectly",
)

def test_fast_final_eval(self):
valid, test = testing_utils.train_model(
{
Expand Down