From 0ffd3a665f81567674fee03db2790b65bf9b7d44 Mon Sep 17 00:00:00 2001 From: Stephen Roller Date: Tue, 22 Sep 2020 15:26:38 -0400 Subject: [PATCH] [dist] Allow non-tcp based distributed setup. (#3095) * [dist] Allow non-tcp based distributed setup. * Also catch mp_eval. * Lint. * Docstring correction. * Importing chunks from other patch. --- parlai/core/metrics.py | 4 ++++ parlai/scripts/multiprocessing_eval.py | 3 ++- parlai/scripts/multiprocessing_train.py | 3 ++- parlai/scripts/train_model.py | 14 +++++++------- parlai/utils/distributed.py | 12 +++++------- parlai/utils/logging.py | 12 ++++++++++-- parlai/utils/misc.py | 23 ++++++++++++----------- 7 files changed, 42 insertions(+), 29 deletions(-) diff --git a/parlai/core/metrics.py b/parlai/core/metrics.py index 56354470071..ec9f9fb11aa 100644 --- a/parlai/core/metrics.py +++ b/parlai/core/metrics.py @@ -592,6 +592,10 @@ def aggregate_unnamed_reports(reports: List[Dict[str, Metric]]) -> Dict[str, Met return m +def dict_report(report: Dict[str, Metric]): + return {k: v.value() if isinstance(v, Metric) else v for k, v in report.items()} + + class Metrics(object): """ Metrics aggregator. diff --git a/parlai/scripts/multiprocessing_eval.py b/parlai/scripts/multiprocessing_eval.py index a32c1e75810..c31627bcd23 100644 --- a/parlai/scripts/multiprocessing_eval.py +++ b/parlai/scripts/multiprocessing_eval.py @@ -39,8 +39,9 @@ def multiprocess_eval( Invoked by launch_and_eval, not instantiated directly. """ + init_method = f'tcp://{hostname}:{port}' with distributed_utils.distributed_context( - rank, opt, port, rank_offset, gpu, hostname + rank, opt, rank_offset, gpu, init_method=init_method ) as opt: opt['multiprocessing'] = True return eval_model.eval_model(opt) diff --git a/parlai/scripts/multiprocessing_train.py b/parlai/scripts/multiprocessing_train.py index b27fcbf4d76..85adefd6ff6 100644 --- a/parlai/scripts/multiprocessing_train.py +++ b/parlai/scripts/multiprocessing_train.py @@ -35,8 +35,9 @@ def multiprocess_train( rank, opt, port=61337, rank_offset=0, gpu=None, hostname='localhost' ): + init_method = f"tcp://{hostname}:{port}" with distributed_utils.distributed_context( - rank, opt, port, rank_offset, gpu, hostname + rank, opt, rank_offset, gpu, init_method=init_method ) as opt: # Run the actual training opt['multiprocessing'] = True diff --git a/parlai/scripts/train_model.py b/parlai/scripts/train_model.py index 861682cafa1..1ccd8347611 100644 --- a/parlai/scripts/train_model.py +++ b/parlai/scripts/train_model.py @@ -28,13 +28,16 @@ import json import numpy as np import signal -from typing import Dict from parlai.core.metrics import Metric from parlai.core.agents import create_agent, create_agent_from_shared from parlai.core.exceptions import StopTrainException from parlai.core.logs import TensorboardLogger -from parlai.core.metrics import aggregate_named_reports, aggregate_unnamed_reports +from parlai.core.metrics import ( + aggregate_named_reports, + aggregate_unnamed_reports, + dict_report, +) from parlai.core.params import ParlaiParser, print_announcements from parlai.core.worlds import create_task from parlai.scripts.build_dict import build_dict, setup_args as setup_dict_args @@ -384,9 +387,6 @@ def save_model(self, suffix=None): except KeyboardInterrupt: pass - def _safe_report(self, report: Dict[str, Metric]): - return {k: v.value() if isinstance(v, Metric) else v for k, v in report.items()} - def _save_train_stats(self, suffix=None): fn = self.opt['model_file'] if suffix: @@ -423,7 +423,7 @@ def validate(self): valid_report = self._run_eval( self.valid_worlds, opt, 'valid', opt['validation_max_exs'] ) - v = self._safe_report(valid_report.copy()) + v = dict_report(valid_report) v['train_time'] = self.train_time.time() v['parleys'] = self.parleys v['total_exs'] = self._total_exs @@ -616,7 +616,7 @@ def log(self): train_report = self._sync_metrics(train_report) self.world.reset_metrics() - train_report_trainstats = self._safe_report(train_report) + train_report_trainstats = dict_report(train_report) train_report_trainstats['total_epochs'] = self._total_epochs train_report_trainstats['total_exs'] = self._total_exs train_report_trainstats['parleys'] = self.parleys diff --git a/parlai/utils/distributed.py b/parlai/utils/distributed.py index 81e372d6139..97004110046 100644 --- a/parlai/utils/distributed.py +++ b/parlai/utils/distributed.py @@ -269,7 +269,7 @@ def sync_parameters(model: torch.nn.Module) -> bool: @contextlib.contextmanager def distributed_context( - rank, opt, port=61337, rank_offset=0, gpu=None, hostname='localhost' + rank, opt, rank_offset=0, gpu=None, init_method="tcp://localhost:61337" ): """ A context which wraps initialization of a distributed/multiprocessing run. @@ -285,14 +285,12 @@ def distributed_context( non-primary workers. :param opt: command line options - :param int port: - A TCP port to use. This will need to be changed to run multiple distributed training setups on the same machine. :param int gpu: Which GPU to use. Defaults to using rank and local devices, but must be manually specified when using many-hosts. - :param str hostname: - Hostname of the main server. + :param str init method: + Init method, such as ``tcp://localhost:61337``. See torch.distributed docs. """ # Set per-host options opt = copy.deepcopy(opt) @@ -322,7 +320,7 @@ def distributed_context( torch.cuda.set_device(opt['gpu']) dist.init_process_group( backend="nccl", - init_method="tcp://{}:{}".format(hostname, port), + init_method=init_method, world_size=opt['distributed_world_size'], rank=rank, ) @@ -379,7 +377,7 @@ def slurm_distributed_context(opt): ) # Begin distributed training with distributed_context( - distributed_rank, opt, port, 0, device_id, main_host + distributed_rank, opt, 0, device_id, init_method=f"tcp://{main_host}:{port}" ) as opt: yield opt except subprocess.CalledProcessError as e: diff --git a/parlai/utils/logging.py b/parlai/utils/logging.py index 53e2af698ed..9ec511b154d 100644 --- a/parlai/utils/logging.py +++ b/parlai/utils/logging.py @@ -4,6 +4,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import os import sys import logging @@ -48,6 +49,8 @@ def _is_interactive(): + if os.environ.get('PARLAI_FORCE_COLOR'): + return True try: __IPYTHON__ return True @@ -72,19 +75,20 @@ def __init__(self, name, console_level=INFO): self.streamHandler = logging.StreamHandler(sys.stdout) # Log to stdout levels: console_level and above self.prefix = None + self.interactive = _is_interactive() self.streamHandler.setFormatter(self._build_formatter()) super().addHandler(self.streamHandler) def _build_formatter(self): prefix_format = f'{self.prefix} ' if self.prefix else '' - if COLORED_LOGS and _is_interactive(): + if COLORED_LOGS and self.interactive: return coloredlogs.ColoredFormatter( prefix_format + COLORED_FORMAT, datefmt=CONSOLE_DATE_FORMAT, level_styles=COLORED_LEVEL_STYLES, field_styles={}, ) - elif _is_interactive(): + elif self.interactive: return logging.Formatter( prefix_format + CONSOLE_FORMAT, datefmt=CONSOLE_DATE_FORMAT ) @@ -93,6 +97,10 @@ def _build_formatter(self): prefix_format + LOGFILE_FORMAT, datefmt=LOGFILE_DATE_FORMAT ) + def force_interactive(self): + self.interactive = True + self.streamHandler.setFormatter(self._build_formatter()) + def log(self, msg, level=INFO): """ Default Logging function. diff --git a/parlai/utils/misc.py b/parlai/utils/misc.py index f0b33b05f6c..80b40f6f29d 100644 --- a/parlai/utils/misc.py +++ b/parlai/utils/misc.py @@ -16,6 +16,7 @@ import re import shutil import json +import os from parlai.core.message import Message from parlai.utils.strings import colorize @@ -357,6 +358,11 @@ def float_formatter(f: Union[float, int]) -> str: def _line_width(): + if os.environ.get('PARLAI_FORCE_WIDTH'): + try: + return int(os.environ['PARLAI_FORCE_WIDTH']) + except ValueError: + pass try: # if we're in an interactive ipython notebook, hardcode a longer width __IPYTHON__ @@ -410,17 +416,12 @@ def nice_report(report) -> str: df = pd.DataFrame([output]) df.columns = pd.MultiIndex.from_tuples(df.columns) df = df.stack().transpose().droplevel(0, axis=1) - result = ( - " " - + df.to_string( - na_rep="", - line_width=line_width - 3, # -3 for the extra spaces we add - float_format=float_formatter, - index=df.shape[0] > 1, - ) - .replace("\n\n", "\n") - .replace("\n", "\n ") - ) + result = " " + df.to_string( + na_rep="", + line_width=line_width - 3, # -3 for the extra spaces we add + float_format=float_formatter, + index=df.shape[0] > 1, + ).replace("\n\n", "\n").replace("\n", "\n ") result = re.sub(r"\s+$", "", result) return result else: