Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
[dist] Allow non-tcp based distributed setup. (#3095)
Browse files Browse the repository at this point in the history
* [dist] Allow non-tcp based distributed setup.

* Also catch mp_eval.

* Lint.

* Docstring correction.

* Importing chunks from other patch.
  • Loading branch information
stephenroller authored Sep 22, 2020
1 parent ba7f3da commit 0ffd3a6
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 29 deletions.
4 changes: 4 additions & 0 deletions parlai/core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,10 @@ def aggregate_unnamed_reports(reports: List[Dict[str, Metric]]) -> Dict[str, Met
return m


def dict_report(report: Dict[str, Metric]):
return {k: v.value() if isinstance(v, Metric) else v for k, v in report.items()}


class Metrics(object):
"""
Metrics aggregator.
Expand Down
3 changes: 2 additions & 1 deletion parlai/scripts/multiprocessing_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ def multiprocess_eval(
Invoked by launch_and_eval, not instantiated directly.
"""
init_method = f'tcp://{hostname}:{port}'
with distributed_utils.distributed_context(
rank, opt, port, rank_offset, gpu, hostname
rank, opt, rank_offset, gpu, init_method=init_method
) as opt:
opt['multiprocessing'] = True
return eval_model.eval_model(opt)
Expand Down
3 changes: 2 additions & 1 deletion parlai/scripts/multiprocessing_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@
def multiprocess_train(
rank, opt, port=61337, rank_offset=0, gpu=None, hostname='localhost'
):
init_method = f"tcp://{hostname}:{port}"
with distributed_utils.distributed_context(
rank, opt, port, rank_offset, gpu, hostname
rank, opt, rank_offset, gpu, init_method=init_method
) as opt:
# Run the actual training
opt['multiprocessing'] = True
Expand Down
14 changes: 7 additions & 7 deletions parlai/scripts/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@
import json
import numpy as np
import signal
from typing import Dict

from parlai.core.metrics import Metric
from parlai.core.agents import create_agent, create_agent_from_shared
from parlai.core.exceptions import StopTrainException
from parlai.core.logs import TensorboardLogger
from parlai.core.metrics import aggregate_named_reports, aggregate_unnamed_reports
from parlai.core.metrics import (
aggregate_named_reports,
aggregate_unnamed_reports,
dict_report,
)
from parlai.core.params import ParlaiParser, print_announcements
from parlai.core.worlds import create_task
from parlai.scripts.build_dict import build_dict, setup_args as setup_dict_args
Expand Down Expand Up @@ -384,9 +387,6 @@ def save_model(self, suffix=None):
except KeyboardInterrupt:
pass

def _safe_report(self, report: Dict[str, Metric]):
return {k: v.value() if isinstance(v, Metric) else v for k, v in report.items()}

def _save_train_stats(self, suffix=None):
fn = self.opt['model_file']
if suffix:
Expand Down Expand Up @@ -423,7 +423,7 @@ def validate(self):
valid_report = self._run_eval(
self.valid_worlds, opt, 'valid', opt['validation_max_exs']
)
v = self._safe_report(valid_report.copy())
v = dict_report(valid_report)
v['train_time'] = self.train_time.time()
v['parleys'] = self.parleys
v['total_exs'] = self._total_exs
Expand Down Expand Up @@ -616,7 +616,7 @@ def log(self):
train_report = self._sync_metrics(train_report)
self.world.reset_metrics()

train_report_trainstats = self._safe_report(train_report)
train_report_trainstats = dict_report(train_report)
train_report_trainstats['total_epochs'] = self._total_epochs
train_report_trainstats['total_exs'] = self._total_exs
train_report_trainstats['parleys'] = self.parleys
Expand Down
12 changes: 5 additions & 7 deletions parlai/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def sync_parameters(model: torch.nn.Module) -> bool:

@contextlib.contextmanager
def distributed_context(
rank, opt, port=61337, rank_offset=0, gpu=None, hostname='localhost'
rank, opt, rank_offset=0, gpu=None, init_method="tcp://localhost:61337"
):
"""
A context which wraps initialization of a distributed/multiprocessing run.
Expand All @@ -285,14 +285,12 @@ def distributed_context(
non-primary workers.
:param opt:
command line options
:param int port:
A TCP port to use. This will need to be changed to run multiple
distributed training setups on the same machine.
:param int gpu:
Which GPU to use. Defaults to using rank and local devices, but must be
manually specified when using many-hosts.
:param str hostname:
Hostname of the main server.
:param str init method:
Init method, such as ``tcp://localhost:61337``. See torch.distributed docs.
"""
# Set per-host options
opt = copy.deepcopy(opt)
Expand Down Expand Up @@ -322,7 +320,7 @@ def distributed_context(
torch.cuda.set_device(opt['gpu'])
dist.init_process_group(
backend="nccl",
init_method="tcp://{}:{}".format(hostname, port),
init_method=init_method,
world_size=opt['distributed_world_size'],
rank=rank,
)
Expand Down Expand Up @@ -379,7 +377,7 @@ def slurm_distributed_context(opt):
)
# Begin distributed training
with distributed_context(
distributed_rank, opt, port, 0, device_id, main_host
distributed_rank, opt, 0, device_id, init_method=f"tcp://{main_host}:{port}"
) as opt:
yield opt
except subprocess.CalledProcessError as e:
Expand Down
12 changes: 10 additions & 2 deletions parlai/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import sys
import logging

Expand Down Expand Up @@ -48,6 +49,8 @@


def _is_interactive():
if os.environ.get('PARLAI_FORCE_COLOR'):
return True
try:
__IPYTHON__
return True
Expand All @@ -72,19 +75,20 @@ def __init__(self, name, console_level=INFO):
self.streamHandler = logging.StreamHandler(sys.stdout)
# Log to stdout levels: console_level and above
self.prefix = None
self.interactive = _is_interactive()
self.streamHandler.setFormatter(self._build_formatter())
super().addHandler(self.streamHandler)

def _build_formatter(self):
prefix_format = f'{self.prefix} ' if self.prefix else ''
if COLORED_LOGS and _is_interactive():
if COLORED_LOGS and self.interactive:
return coloredlogs.ColoredFormatter(
prefix_format + COLORED_FORMAT,
datefmt=CONSOLE_DATE_FORMAT,
level_styles=COLORED_LEVEL_STYLES,
field_styles={},
)
elif _is_interactive():
elif self.interactive:
return logging.Formatter(
prefix_format + CONSOLE_FORMAT, datefmt=CONSOLE_DATE_FORMAT
)
Expand All @@ -93,6 +97,10 @@ def _build_formatter(self):
prefix_format + LOGFILE_FORMAT, datefmt=LOGFILE_DATE_FORMAT
)

def force_interactive(self):
self.interactive = True
self.streamHandler.setFormatter(self._build_formatter())

def log(self, msg, level=INFO):
"""
Default Logging function.
Expand Down
23 changes: 12 additions & 11 deletions parlai/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import re
import shutil
import json
import os

from parlai.core.message import Message
from parlai.utils.strings import colorize
Expand Down Expand Up @@ -357,6 +358,11 @@ def float_formatter(f: Union[float, int]) -> str:


def _line_width():
if os.environ.get('PARLAI_FORCE_WIDTH'):
try:
return int(os.environ['PARLAI_FORCE_WIDTH'])
except ValueError:
pass
try:
# if we're in an interactive ipython notebook, hardcode a longer width
__IPYTHON__
Expand Down Expand Up @@ -410,17 +416,12 @@ def nice_report(report) -> str:
df = pd.DataFrame([output])
df.columns = pd.MultiIndex.from_tuples(df.columns)
df = df.stack().transpose().droplevel(0, axis=1)
result = (
" "
+ df.to_string(
na_rep="",
line_width=line_width - 3, # -3 for the extra spaces we add
float_format=float_formatter,
index=df.shape[0] > 1,
)
.replace("\n\n", "\n")
.replace("\n", "\n ")
)
result = " " + df.to_string(
na_rep="",
line_width=line_width - 3, # -3 for the extra spaces we add
float_format=float_formatter,
index=df.shape[0] > 1,
).replace("\n\n", "\n").replace("\n", "\n ")
result = re.sub(r"\s+$", "", result)
return result
else:
Expand Down

0 comments on commit 0ffd3a6

Please sign in to comment.