Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
More work on distributed_eval.
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenroller committed Jul 15, 2020
1 parent 98153e5 commit 1aef411
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 2 deletions.
51 changes: 51 additions & 0 deletions parlai/scripts/distributed_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Distributed evaluation script. NOT MEANT TO BE CALLED DIRECTLY BY USER.
This script is meant to be in conjunction with
`SLURM <https://slurm.schedmd.com/>`, which provides environmental variables
describing the environment.
An example sbatch script is below, for a 2-host, 8-GPU setup (16 total gpus):
.. code-block:: bash\n\n
#!/bin/sh
#SBATCH --job-name=distributed_example
#SBATCH --output=/path/to/savepoint/stdout.%j
#SBATCH --error=/path/to/savepoint/stderr.%j
#SBATCH --partition=priority
#SBATCH --nodes=2
#SBATCH --time=0:10:00
#SBATCH --signal=SIGINT
#SBATCH --gres=gpu:8
#SBATCH --ntasks-per-node=8
#SBATCH --mem=64G
#SBATCH --cpus-per-task=10
srun python -u -m parlai.scripts.distributed_eval \
-m seq2seq -t convai2 --dict-file /path/to/dict-file
"""

import os

import parlai.scripts.eval_model as eval_model
import parlai.utils.distributed as distributed_utils


def main():
parser = eval_model.setup_args()
parser.add_distributed_training_args()
parser.add_argument('--port', type=int, default=61337, help='TCP port number')
opt = parser.parse_args(print_args=(os.environ['SLURM_PROCID'] == '0'))

with distributed_utils.slurm_distributed_context(opt) as opt:
return eval_model.eval_model(opt)


if __name__ == '__main__':
main()
4 changes: 2 additions & 2 deletions parlai/scripts/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,10 @@ def _eval_single_world(opt, agent, task):
# max number of examples to evaluate
max_cnt = opt['num_examples'] if opt['num_examples'] > 0 else float('inf')
cnt = 0
total_cnt = sum(all_gather_list(world.num_examples()))
total_cnt = world.num_examples()

if is_distributed():
logging.warning('Progress bar is approximate in distributed mode.')
logging.warn('Progress bar is approximate in distributed mode.')

while not world.epoch_done() and cnt < max_cnt:
cnt += opt.get('batchsize', 1)
Expand Down
76 changes: 76 additions & 0 deletions parlai/scripts/multiprocessing_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


"""
Main launch script for single-host, multi-GPU training.
This is a drop-in replacement for train_model.py. This script will launch N
subprocess, each which runs the full training loop independently.
Uses torch.nn.parallel.DistributedDataParallel for its main uses. Agents must
specifically implement the wrapper of DistributedDatParallel, but all
TorchRankerAgents and TorchGeneratorAgents support this.
"""

import torch
import random
import os
import signal
import parlai.utils.distributed as distributed_utils
import parlai.scripts.eval_model as eval_model


def multiprocess_eval(
rank, opt, port=61337, rank_offset=0, gpu=None, hostname='localhost'
):
with distributed_utils.distributed_context(
rank, opt, port, rank_offset, gpu, hostname
) as opt:
return eval_model.eval_model(opt)


def launch_and_eval(opt, port):
"""
Perform a fork() to many processes.
"""
# Launch multiple subprocesses
spawncontext = torch.multiprocessing.spawn(
multiprocess_eval,
# need to give rank offset as 1 to cover the fact that the main
# process is rank 0, but that spawn() doesn't let you control rank
(opt, port, 1),
nprocs=opt['distributed_world_size'] - 1, # main proc will also run loop
join=False,
)

try:
retval = multiprocess_eval(0, opt, port)
spawncontext.join()
return retval
except KeyboardInterrupt:
# tell the subprocesses to stop too
for p in spawncontext.processes:
if p.is_alive():
os.kill(p.pid, signal.SIGINT)
raise


def setup_args():
parser = eval_model.setup_args()
parser.add_distributed_training_args()
parser.set_defaults(distributed_world_size=torch.cuda.device_count())
return parser


def main():
opt = setup_args().parse_args()
port = random.randint(32000, 48000)
return launch_and_eval(opt, port)


if __name__ == '__main__':
main()

0 comments on commit 1aef411

Please sign in to comment.