Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Fixes train_model worldlogging for multitask with mutators. (#4414)
Browse files Browse the repository at this point in the history
* Fixes train_model worldlogging for multitask with mutators.

* Fix bug in train_model when evaltask doesn't match task.
  • Loading branch information
kauterry authored Mar 17, 2022
1 parent 573c76c commit d6773a0
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 4 deletions.
12 changes: 8 additions & 4 deletions parlai/scripts/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ def validate(self):
return True
return False

def _run_single_eval(self, opt, valid_world, max_exs, datatype, is_multitask):
def _run_single_eval(self, opt, valid_world, max_exs, datatype, is_multitask, task):

# run evaluation on a single world
valid_world.reset()
Expand All @@ -629,7 +629,7 @@ def _run_single_eval(self, opt, valid_world, max_exs, datatype, is_multitask):
# set up world logger for the "test" fold
if opt['world_logs'] and datatype == 'test':
task_opt['world_logs'] = get_task_world_logs(
valid_world.getID(), opt['world_logs'], is_multitask
task, opt['world_logs'], is_multitask
)
world_logger = WorldLogger(task_opt)

Expand Down Expand Up @@ -691,9 +691,13 @@ def _run_eval(

max_exs_per_worker = max_exs / (len(valid_worlds) * num_workers())
is_multitask = len(valid_worlds) > 1
for v_world in valid_worlds:
for index, v_world in enumerate(valid_worlds):
if opt.get('evaltask'):
task = opt['evaltask'].split(',')[index]
else:
task = opt['task'].split(',')[index]
task_report = self._run_single_eval(
opt, v_world, max_exs_per_worker, datatype, is_multitask
opt, v_world, max_exs_per_worker, datatype, is_multitask, task
)
reports.append(task_report)

Expand Down
56 changes: 56 additions & 0 deletions tests/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,62 @@ def test_save_multiple_world_logs(self):
json_lines = f.readlines()
assert len(json_lines) == 5

def test_save_multiple_world_logs_evaltask(self):
"""
Test that we can save multiple world_logs from train model on multiple tasks
where there are more evaltasks than tasks.
"""
with testing_utils.tempdir() as tmpdir:
log_report = os.path.join(tmpdir, 'world_logs.jsonl')
multitask = 'integration_tests,integration_tests:ReverseTeacher'
evaltask = 'integration_tests,integration_tests:mutators=flatten,integration_tests:ReverseTeacher:mutator=reverse'
valid, test = testing_utils.train_model(
{
'task': multitask,
'evaltask': evaltask,
'validation_max_exs': 10,
'model': 'repeat_label',
'short_final_eval': True,
'num_epochs': 1.0,
'world_logs': log_report,
}
)

for task in evaltask.split(','):
task_log_report = get_task_world_logs(
task, log_report, is_multitask=True
)
with PathManager.open(task_log_report) as f:
json_lines = f.readlines()
assert len(json_lines) == 4

def test_save_multiple_world_logs_mutator(self):
"""
Test that we can save multiple world_logs from train model on multiple tasks
with mutators present.
"""
with testing_utils.tempdir() as tmpdir:
log_report = os.path.join(tmpdir, 'world_logs.jsonl')
multitask = 'integration_tests:mutators=flatten,integration_tests:ReverseTeacher:mutator=reverse'
valid, test = testing_utils.train_model(
{
'task': multitask,
'validation_max_exs': 10,
'model': 'repeat_label',
'short_final_eval': True,
'num_epochs': 1.0,
'world_logs': log_report,
}
)

for task in multitask.split(','):
task_log_report = get_task_world_logs(
task, log_report, is_multitask=True
)
with PathManager.open(task_log_report) as f:
json_lines = f.readlines()
assert len(json_lines) == 5


@register_agent("fake_report")
class FakeReportAgent(Agent):
Expand Down

0 comments on commit d6773a0

Please sign in to comment.