Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Add TB logging to parlai eval_model script (#4497)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kirill Syomin authored Apr 29, 2022
1 parent 4291c8a commit a48c0db
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions parlai/scripts/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,32 @@ def get_task_world_logs(task, world_logs, is_multitask=False):
return f'{base_outfile}_{task}{extension}'


def prepare_tb_logger(opt):
if opt['tensorboard_log'] and is_primary_worker():
tb_logger = TensorboardLogger(opt)
else:
tb_logger = None

if 'train' in opt['datatype']:
setting = 'train'
elif 'valid' in opt['datatype']:
setting = 'valid'
else:
setting = 'test'
return tb_logger, setting


def get_n_parleys(opt):
trainstats_suffix = '.trainstats'
if opt.get('model_file') and PathManager.exists(opt['model_file'] + trainstats_suffix):
with PathManager.open(opt['model_file'] + trainstats_suffix) as ts:
obj = json.load(ts)
parleys = obj.get('parleys', 0)
else:
parleys = 0
return parleys


def _eval_single_world(opt, agent, task):
logging.info(f'Evaluating task {task} using datatype {opt.get("datatype")}.')
# set up world logger
Expand Down Expand Up @@ -233,6 +259,11 @@ def eval_model(opt):
agent = create_agent(opt, requireModelExists=True)
agent.opt.log()

tb_logger, setting = prepare_tb_logger(opt)

if tb_logger:
n_parleys = get_n_parleys(opt)

tasks = opt['task'].split(',')
reports = []
for task in tasks:
Expand All @@ -252,6 +283,9 @@ def eval_model(opt):

print(nice_report(report))
_save_eval_stats(opt, report)
if tb_logger:
tb_logger.log_metrics(setting, n_parleys, report)
tb_logger.flush()
return report


Expand Down

0 comments on commit a48c0db

Please sign in to comment.