diff --git a/parlai/scripts/distributed_train.py b/parlai/scripts/distributed_train.py index 314862045bf..d8868be5cb6 100644 --- a/parlai/scripts/distributed_train.py +++ b/parlai/scripts/distributed_train.py @@ -31,10 +31,7 @@ -m seq2seq -t convai2 --dict-file /path/to/dict-file """ -import os - import parlai.scripts.train_model as single_train -import parlai.utils.logging as logging from parlai.scripts.script import ParlaiScript import parlai.utils.distributed as distributed_utils @@ -52,7 +49,7 @@ def setup_args(cls): return setup_args() def run(self): - with distributed_utils.slurm_distributed_context(opt) as opt: + with distributed_utils.slurm_distributed_context(self.opt) as opt: return single_train.TrainLoop(opt).train_model() diff --git a/parlai/scripts/eval_model.py b/parlai/scripts/eval_model.py index 0bdbd1902cc..36733d0354c 100644 --- a/parlai/scripts/eval_model.py +++ b/parlai/scripts/eval_model.py @@ -38,7 +38,7 @@ is_primary_worker, all_gather_list, is_distributed, - sync_object, + get_rank, ) @@ -162,7 +162,11 @@ def _eval_single_world(opt, agent, task): # dump world acts to file world_logger.reset() # add final acts to logs base_outfile = opt['report_filename'].split('.')[0] - outfile = base_outfile + f'_{task}_replies.jsonl' + if is_distributed(): + rank = get_rank() + outfile = base_outfile + f'_{task}_{rank}_replies.jsonl' + else: + outfile = base_outfile + f'_{task}_replies.jsonl' world_logger.write(outfile, world, file_format=opt['save_format']) return report