Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[eval model] store world logs per task #3718

Merged
merged 3 commits into from
Jul 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions parlai/scripts/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,27 @@ def _save_eval_stats(opt, report):
f.write("\n") # for jq


def get_task_world_logs(task, world_logs, is_multitask=False):
if not is_multitask:
return world_logs
else:
base_outfile, extension = os.path.splitext(world_logs)
return f'{base_outfile}_{task}{extension}'
Copy link
Contributor

@kauterry kauterry Jul 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't it be this?

return f'{base_outfile}_{task}.{extension}'

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, splitext includes the . In the ext

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, sorry!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the extension contains "."



def _eval_single_world(opt, agent, task):
logging.info(f'Evaluating task {task} using datatype {opt.get("datatype")}.')
# set up world logger
world_logger = WorldLogger(opt) if opt['world_logs'] else None

task_opt = opt.copy() # copy opt since we're editing the task
task_opt['task'] = task
# add task suffix in case of multi-tasking
if opt['world_logs']:
task_opt['world_logs'] = get_task_world_logs(
task, task_opt['world_logs'], is_multitask=len(opt['task'].split(',')) > 1
)

world_logger = WorldLogger(task_opt) if task_opt['world_logs'] else None

world = create_task(task_opt, agent) # create worlds for tasks

# set up logging
Expand Down Expand Up @@ -161,10 +175,10 @@ def _eval_single_world(opt, agent, task):
world_logger.reset() # add final acts to logs
if is_distributed():
rank = get_rank()
base_outfile, extension = os.path.splitext(opt['world_logs'])
base_outfile, extension = os.path.splitext(task_opt['world_logs'])
outfile = base_outfile + f'_{rank}' + extension
else:
outfile = opt['world_logs']
outfile = task_opt['world_logs']
world_logger.write(outfile, world, file_format=opt['save_format'])

report = aggregate_unnamed_reports(all_gather_list(world.report()))
Expand Down
29 changes: 29 additions & 0 deletions tests/test_eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
import unittest
import parlai.utils.testing as testing_utils
from parlai.scripts.eval_model import get_task_world_logs


class TestEvalModel(unittest.TestCase):
Expand Down Expand Up @@ -227,6 +228,34 @@ def test_save_report(self):
json_lines = f.readlines()
assert len(json_lines) == 100

def test_save_multiple_logs(self):
"""
Test that we can save multiple world_logs from eval model on multiple tasks.
"""
with testing_utils.tempdir() as tmpdir:
log_report = os.path.join(tmpdir, 'world_logs.jsonl')
save_report = os.path.join(tmpdir, 'report')
multitask = 'integration_tests,blended_skill_talk'
opt = dict(
task=multitask,
model='repeat_label',
datatype='valid',
batchsize=97,
num_examples=100,
display_examples=False,
world_logs=log_report,
report_filename=save_report,
)
valid, test = testing_utils.eval_model(opt)

for task in multitask.split(','):
task_log_report = get_task_world_logs(
task, log_report, is_multitask=True
)
with PathManager.open(task_log_report) as f:
json_lines = f.readlines()
assert len(json_lines) == 100


if __name__ == '__main__':
unittest.main()