Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[dist] Allow non-tcp based distributed setup. #3095

Merged
merged 5 commits into from
Sep 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions parlai/core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,10 @@ def aggregate_unnamed_reports(reports: List[Dict[str, Metric]]) -> Dict[str, Met
return m


def dict_report(report: Dict[str, Metric]):
return {k: v.value() if isinstance(v, Metric) else v for k, v in report.items()}


class Metrics(object):
"""
Metrics aggregator.
Expand Down
3 changes: 2 additions & 1 deletion parlai/scripts/multiprocessing_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ def multiprocess_eval(

Invoked by launch_and_eval, not instantiated directly.
"""
init_method = f'tcp://{hostname}:{port}'
with distributed_utils.distributed_context(
rank, opt, port, rank_offset, gpu, hostname
rank, opt, rank_offset, gpu, init_method=init_method
) as opt:
opt['multiprocessing'] = True
return eval_model.eval_model(opt)
Expand Down
3 changes: 2 additions & 1 deletion parlai/scripts/multiprocessing_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@
def multiprocess_train(
rank, opt, port=61337, rank_offset=0, gpu=None, hostname='localhost'
):
init_method = f"tcp://{hostname}:{port}"
with distributed_utils.distributed_context(
rank, opt, port, rank_offset, gpu, hostname
rank, opt, rank_offset, gpu, init_method=init_method
) as opt:
# Run the actual training
opt['multiprocessing'] = True
Expand Down
14 changes: 7 additions & 7 deletions parlai/scripts/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,16 @@
import json
import numpy as np
import signal
from typing import Dict

from parlai.core.metrics import Metric
from parlai.core.agents import create_agent, create_agent_from_shared
from parlai.core.exceptions import StopTrainException
from parlai.core.logs import TensorboardLogger
from parlai.core.metrics import aggregate_named_reports, aggregate_unnamed_reports
from parlai.core.metrics import (
aggregate_named_reports,
aggregate_unnamed_reports,
dict_report,
)
from parlai.core.params import ParlaiParser, print_announcements
from parlai.core.worlds import create_task
from parlai.scripts.build_dict import build_dict, setup_args as setup_dict_args
Expand Down Expand Up @@ -384,9 +387,6 @@ def save_model(self, suffix=None):
except KeyboardInterrupt:
pass

def _safe_report(self, report: Dict[str, Metric]):
return {k: v.value() if isinstance(v, Metric) else v for k, v in report.items()}

def _save_train_stats(self, suffix=None):
fn = self.opt['model_file']
if suffix:
Expand Down Expand Up @@ -423,7 +423,7 @@ def validate(self):
valid_report = self._run_eval(
self.valid_worlds, opt, 'valid', opt['validation_max_exs']
)
v = self._safe_report(valid_report.copy())
v = dict_report(valid_report)
v['train_time'] = self.train_time.time()
v['parleys'] = self.parleys
v['total_exs'] = self._total_exs
Expand Down Expand Up @@ -613,7 +613,7 @@ def log(self):
train_report = self._sync_metrics(train_report)
self.world.reset_metrics()

train_report_trainstats = self._safe_report(train_report)
train_report_trainstats = dict_report(train_report)
train_report_trainstats['total_epochs'] = self._total_epochs
train_report_trainstats['total_exs'] = self._total_exs
train_report_trainstats['parleys'] = self.parleys
Expand Down
12 changes: 5 additions & 7 deletions parlai/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def sync_parameters(model: torch.nn.Module) -> bool:

@contextlib.contextmanager
def distributed_context(
rank, opt, port=61337, rank_offset=0, gpu=None, hostname='localhost'
rank, opt, rank_offset=0, gpu=None, init_method="tcp://localhost:61337"
):
"""
A context which wraps initialization of a distributed/multiprocessing run.
Expand All @@ -285,14 +285,12 @@ def distributed_context(
non-primary workers.
:param opt:
command line options
:param int port:
A TCP port to use. This will need to be changed to run multiple
distributed training setups on the same machine.
:param int gpu:
Which GPU to use. Defaults to using rank and local devices, but must be
manually specified when using many-hosts.
:param str hostname:
Hostname of the main server.
:param str init method:
Init method, such as ``tcp://localhost:61337``. See torch.distributed docs.
"""
# Set per-host options
opt = copy.deepcopy(opt)
Expand Down Expand Up @@ -322,7 +320,7 @@ def distributed_context(
torch.cuda.set_device(opt['gpu'])
dist.init_process_group(
backend="nccl",
init_method="tcp://{}:{}".format(hostname, port),
init_method=init_method,
world_size=opt['distributed_world_size'],
rank=rank,
)
Expand Down Expand Up @@ -379,7 +377,7 @@ def slurm_distributed_context(opt):
)
# Begin distributed training
with distributed_context(
distributed_rank, opt, port, 0, device_id, main_host
distributed_rank, opt, 0, device_id, init_method=f"tcp://{main_host}:{port}"
) as opt:
yield opt
except subprocess.CalledProcessError as e:
Expand Down
12 changes: 10 additions & 2 deletions parlai/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os
import sys
import logging

Expand Down Expand Up @@ -48,6 +49,8 @@


def _is_interactive():
if os.environ.get('PARLAI_FORCE_COLOR'):
return True
try:
__IPYTHON__
return True
Expand All @@ -72,19 +75,20 @@ def __init__(self, name, console_level=INFO):
self.streamHandler = logging.StreamHandler(sys.stdout)
# Log to stdout levels: console_level and above
self.prefix = None
self.interactive = _is_interactive()
self.streamHandler.setFormatter(self._build_formatter())
super().addHandler(self.streamHandler)

def _build_formatter(self):
prefix_format = f'{self.prefix} ' if self.prefix else ''
if COLORED_LOGS and _is_interactive():
if COLORED_LOGS and self.interactive:
return coloredlogs.ColoredFormatter(
prefix_format + COLORED_FORMAT,
datefmt=CONSOLE_DATE_FORMAT,
level_styles=COLORED_LEVEL_STYLES,
field_styles={},
)
elif _is_interactive():
elif self.interactive:
return logging.Formatter(
prefix_format + CONSOLE_FORMAT, datefmt=CONSOLE_DATE_FORMAT
)
Expand All @@ -93,6 +97,10 @@ def _build_formatter(self):
prefix_format + LOGFILE_FORMAT, datefmt=LOGFILE_DATE_FORMAT
)

def force_interactive(self):
self.interactive = True
self.streamHandler.setFormatter(self._build_formatter())

def log(self, msg, level=INFO):
"""
Default Logging function.
Expand Down
23 changes: 12 additions & 11 deletions parlai/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import re
import shutil
import json
import os

from parlai.core.message import Message
from parlai.utils.strings import colorize
Expand Down Expand Up @@ -357,6 +358,11 @@ def float_formatter(f: Union[float, int]) -> str:


def _line_width():
if os.environ.get('PARLAI_FORCE_WIDTH'):
try:
return int(os.environ['PARLAI_FORCE_WIDTH'])
except ValueError:
pass
try:
# if we're in an interactive ipython notebook, hardcode a longer width
__IPYTHON__
Expand Down Expand Up @@ -410,17 +416,12 @@ def nice_report(report) -> str:
df = pd.DataFrame([output])
df.columns = pd.MultiIndex.from_tuples(df.columns)
df = df.stack().transpose().droplevel(0, axis=1)
result = (
" "
+ df.to_string(
na_rep="",
line_width=line_width - 3, # -3 for the extra spaces we add
float_format=float_formatter,
index=df.shape[0] > 1,
)
.replace("\n\n", "\n")
.replace("\n", "\n ")
)
result = " " + df.to_string(
na_rep="",
line_width=line_width - 3, # -3 for the extra spaces we add
float_format=float_formatter,
index=df.shape[0] > 1,
).replace("\n\n", "\n").replace("\n", "\n ")
result = re.sub(r"\s+$", "", result)
return result
else:
Expand Down