Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Add WorldLogger to train_model script. #4369

Merged
merged 2 commits into from
Feb 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions parlai/scripts/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# * More logging (e.g. to files), make things prettier.
import copy
import json
import os
import numpy as np
import signal
from typing import Tuple
Expand All @@ -42,14 +43,17 @@
from parlai.core.params import ParlaiParser, print_announcements
from parlai.core.worlds import create_task, World
from parlai.scripts.build_dict import build_dict, setup_args as setup_dict_args
from parlai.scripts.eval_model import get_task_world_logs
from parlai.utils.distributed import (
sync_object,
is_primary_worker,
all_gather_list,
is_distributed,
get_rank,
num_workers,
)
from parlai.utils.misc import Timer, nice_report
from parlai.utils.world_logging import WorldLogger
from parlai.core.script import ParlaiScript, register_script
import parlai.utils.logging as logging
from parlai.utils.io import PathManager
Expand Down Expand Up @@ -257,6 +261,20 @@ def setup_args(parser=None) -> ParlaiParser:
help='Report micro-averaged metrics instead of macro averaged metrics.',
recommended=False,
)
train.add_argument(
'--world-logs',
type=str,
default='',
help='Saves a jsonl file of the world logs.'
'Set to the empty string to not save at all.',
)
train.add_argument(
'--save-format',
type=str,
default='conversations',
choices=['conversations', 'parlai'],
)
WorldLogger.add_cmdline_args(parser, partial_opt=None)
TensorboardLogger.add_cmdline_args(parser, partial_opt=None)
WandbLogger.add_cmdline_args(parser, partial_opt=None)

Expand Down Expand Up @@ -598,20 +616,42 @@ def validate(self):
return True
return False

def _run_single_eval(self, opt, valid_world, max_exs):
def _run_single_eval(self, opt, valid_world, max_exs, datatype, is_multitask):

# run evaluation on a single world
valid_world.reset()

world_logger = None
task_opt = opt.copy()
# set up world logger for the "test" fold
if opt['world_logs'] and datatype == 'test':
task_opt['world_logs'] = get_task_world_logs(
valid_world.getID(), opt['world_logs'], is_multitask
)
world_logger = WorldLogger(task_opt)

cnt = 0
max_cnt = max_exs if max_exs > 0 else float('inf')
while not valid_world.epoch_done() and cnt < max_cnt:
valid_world.parley()
if world_logger is not None:
world_logger.log(valid_world)
if cnt == 0 and opt['display_examples']:
print(valid_world.display() + '\n~~')
print(valid_world.report())
cnt = valid_world.report().get('exs') or 0

if world_logger is not None:
# dump world acts to file
world_logger.reset() # add final acts to logs
if is_distributed():
rank = get_rank()
base_outfile, extension = os.path.splitext(task_opt['world_logs'])
outfile = base_outfile + f'_{rank}' + extension
else:
outfile = task_opt['world_logs']
world_logger.write(outfile, valid_world, file_format=opt['save_format'])

valid_report = valid_world.report()
if opt.get('validation_share_agent', False):
valid_world.reset() # make sure world doesn't remember valid data
Expand Down Expand Up @@ -647,8 +687,11 @@ def _run_eval(
reports = []

max_exs_per_worker = max_exs / (len(valid_worlds) * num_workers())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jxmsML This is why we have only 5 lines instead of 10, cuz during multitask len(valid_worlds) > 1

is_multitask = len(valid_worlds) > 1
for v_world in valid_worlds:
task_report = self._run_single_eval(opt, v_world, max_exs_per_worker)
task_report = self._run_single_eval(
opt, v_world, max_exs_per_worker, datatype, is_multitask
)
reports.append(task_report)

tasks = [world.getID() for world in valid_worlds]
Expand Down
48 changes: 48 additions & 0 deletions tests/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@
import unittest
import json
import parlai.utils.testing as testing_utils
from parlai.utils.io import PathManager
from parlai.core.metrics import AverageMetric
from parlai.core.worlds import create_task
from parlai.core.opt import Opt
from parlai.core.params import ParlaiParser
from parlai.core.agents import register_agent, Agent
from parlai.scripts.eval_model import get_task_world_logs


class TestTrainModel(unittest.TestCase):
Expand Down Expand Up @@ -223,6 +225,52 @@ def test_opt_step(self):
def test_opt_step_update_freq_2(self):
self._test_opt_step_opts(2)

def test_save_world_logs(self):
"""
Test that we can save world logs from train model.
"""
with testing_utils.tempdir() as tmpdir:
log_report = os.path.join(tmpdir, 'world_logs.jsonl')
valid, test = testing_utils.train_model(
{
'task': 'integration_tests',
'validation_max_exs': 10,
'model': 'repeat_label',
'short_final_eval': True,
'num_epochs': 1.0,
'world_logs': log_report,
}
)
with PathManager.open(log_report) as f:
json_lines = f.readlines()
assert len(json_lines) == 10

def test_save_multiple_world_logs(self):
"""
Test that we can save multiple world_logs from train model on multiple tasks.
"""
with testing_utils.tempdir() as tmpdir:
log_report = os.path.join(tmpdir, 'world_logs.jsonl')
multitask = 'integration_tests,integration_tests:ReverseTeacher'
valid, test = testing_utils.train_model(
{
'task': multitask,
'validation_max_exs': 10,
'model': 'repeat_label',
'short_final_eval': True,
'num_epochs': 1.0,
'world_logs': log_report,
}
)

for task in multitask.split(','):
task_log_report = get_task_world_logs(
task, log_report, is_multitask=True
)
with PathManager.open(task_log_report) as f:
json_lines = f.readlines()
assert len(json_lines) == 5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is there only 5 lines instead of 10?



@register_agent("fake_report")
class FakeReportAgent(Agent):
Expand Down