Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unnecessary endpoint logic, rename collaborative to hivemind #13392

Merged
merged 7 commits into from
Jun 28, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Hivemind Strategy
* Added `CollaborativeStrategy` ([#12842](https://github.com/PyTorchLightning/pytorch-lightning/pull/12842))
* Renamed `CollaborativeStrategy` to `HivemindStrategy` ([#13388](https://github.com/PyTorchLightning/pytorch-lightning/pull/13388))
* Remove unnecessary endpoint logic, rename `collaborative` to `hivemind` ([#13392](https://github.com/PyTorchLightning/pytorch-lightning/pull/13392))
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved

- Include a version suffix for new "last" checkpoints of later runs in the same directory ([#12902](https://github.com/PyTorchLightning/pytorch-lightning/pull/12902))

Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.strategies.bagua import BaguaStrategy # noqa: F401
from pytorch_lightning.strategies.collaborative import HivemindStrategy # noqa: F401
from pytorch_lightning.strategies.ddp import DDPStrategy # noqa: F401
from pytorch_lightning.strategies.ddp2 import DDP2Strategy # noqa: F401
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy # noqa: F401
from pytorch_lightning.strategies.deepspeed import DeepSpeedStrategy # noqa: F401
from pytorch_lightning.strategies.dp import DataParallelStrategy # noqa: F401
from pytorch_lightning.strategies.fully_sharded import DDPFullyShardedStrategy # noqa: F401
from pytorch_lightning.strategies.fully_sharded_native import DDPFullyShardedNativeStrategy # noqa: F401
from pytorch_lightning.strategies.hivemind import HivemindStrategy # noqa: F401
from pytorch_lightning.strategies.horovod import HorovodStrategy # noqa: F401
from pytorch_lightning.strategies.hpu_parallel import HPUParallelStrategy # noqa: F401
from pytorch_lightning.strategies.ipu import IPUStrategy # noqa: F401
Expand Down
Original file line number Diff line number Diff line change
@@ -1,22 +1,15 @@
import http
import ipaddress
import logging
import os
import platform
import re
import threading
import time
import warnings
from http.server import BaseHTTPRequestHandler
from typing import Any, Callable, Dict, List, Optional, Union

import requests
import torch
from torch import Tensor

import pytorch_lightning as pl
from pytorch_lightning.strategies.strategy import Strategy, TBroadcast
from pytorch_lightning.utilities import rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.data import extract_batch_size
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand All @@ -33,6 +26,8 @@


class HivemindStrategy(Strategy):
INITIAL_PEERS_ENV: str = "PL_INITIAL_PEERS"

def __init__(
self,
target_batch_size: int,
Expand All @@ -50,13 +45,6 @@ def __init__(
averager_opts: Optional[Dict] = None,
host_maddrs: Optional[List] = None,
initial_peers: Optional[Union[str, List]] = None,
endpoint: Optional[bool] = None,
peer_endpoint: Optional[str] = None,
persistent: bool = True,
host: Optional[str] = None,
port: Optional[int] = None,
retry_endpoint_attempts: int = 5,
retry_endpoint_sleep_duration: int = 5,
**optimizer_kwargs: Any,
):
"""Provides capabilities to train using the Hivemind Library, training collaboratively across the internet
Expand All @@ -81,11 +69,11 @@ def __init__(
corresponding :meth:`hivemind.Optimizer.step` call.

delay_optimizer_step: Run optimizer in background, apply results in future .step. requires
:paramref:`~pytorch_lightning.strategies.collaborative.HivemindStrategy.offload_optimizer`.
:paramref:`~pytorch_lightning.strategies.hivemind.HivemindStrategy.offload_optimizer`.

delay_grad_averaging: Average gradients in background; requires
:paramref:`~pytorch_lightning.strategies.collaborative.HivemindStrategy.offload_optimizer` and
:paramref:`~pytorch_lightning.strategies.collaborative.HivemindStrategy.delay_optimizer_step`.
:paramref:`~pytorch_lightning.strategies.hivemind.HivemindStrategy.offload_optimizer` and
:paramref:`~pytorch_lightning.strategies.hivemind.HivemindStrategy.delay_optimizer_step`.

offload_optimizer: Offload the optimizer to host memory, saving GPU memory for parameters and gradients.

Expand Down Expand Up @@ -118,26 +106,6 @@ def __init__(
initial_peers: If connecting to a running process, a list of initial peers needs to be passed in.
This can also be set via the env variable ``INITIAL_PEERS``.

endpoint: Enable if a side-car endpoint server is required on the process to server initial peers.
This is useful when using some form of orchestration such as torchelastic.

peer_endpoint: The endpoint to request initial peers from.

persistent: When using an endpoint, this controls whether other processes that are not the endpoint
server log/checkpoint. If ``persistent`` is True, we do not log/checkpoint from other processes.

host: When creating the endpoint, the host IP to use.

port: When creating the endpoint, the host port to use.

retry_endpoint_attempts: When connecting to the
:paramref:`~pytorch_lightning.strategies.collaborative.HivemindStrategy.peer_endpoint`,
how many time to retry before raising an exception.

retry_endpoint_sleep_duration: When connecting to the
:paramref:`~pytorch_lightning.strategies.collaborative.HivemindStrategy.peer_endpoint`,
how long to wait between retries.

**optimizer_kwargs: kwargs are passed to the :class:`hivemind.Optimizer` class.
"""
if not _HIVEMIND_AVAILABLE or platform.system() != "Linux":
Expand All @@ -147,17 +115,7 @@ def __init__(
)

super().__init__()
self.dht_manager = DHTManager(
persistent=persistent,
endpoint=endpoint,
peer_endpoint=peer_endpoint,
host=host,
port=port,
host_maddrs=host_maddrs,
initial_peers=initial_peers,
retry_endpoint_attempts=retry_endpoint_attempts,
retry_endpoint_sleep_duration=retry_endpoint_sleep_duration,
)
self._initial_peers = initial_peers
self._target_batch_size = target_batch_size
self._batch_size = batch_size
self._scheduler_fn = scheduler_fn
Expand All @@ -179,28 +137,38 @@ def __init__(
**optimizer_kwargs,
)

# a bit of a hack to only log from the stable server
if self.dht_manager.disable_logging_checkpointing:
warnings.warn(
"This machine is not a persistent machine. Checkpointing/Logging has been disabled.", UserWarning
self._parse_env_initial_peers()

self.dht = hivemind.DHT(
start=True,
initial_peers=initial_peers,
host_maddrs=host_maddrs if host_maddrs is not None else ["/ip4/0.0.0.0/tcp/0", "/ip4/0.0.0.0/udp/0/quic"],
)

visible_addresses = [
str(a) for a in self.dht.get_visible_maddrs() if not ipaddress.ip_address(a.values()[0]).is_loopback
]

if initial_peers is None:
log.info(
"\nOther machines can connect running the same command:\n"
f"INITIAL_PEERS={','.join(visible_addresses)} python ...\n"
"or passing the peers to the strategy:\n"
f"HivemindStrategy(initial_peers='{','.join(visible_addresses)}')"
)
rank_zero_only.rank = 1 if self.dht_manager.disable_logging_checkpointing else 0

self._hivemind_initialized = False

def _parse_env_initial_peers(self) -> None:
initial_peers = os.environ.get(self.INITIAL_PEERS_ENV, self._initial_peers)
self._initial_peers = initial_peers.split(",") if isinstance(initial_peers, str) else self._initial_peers

@property
def num_peers(self) -> int:
if self._opt:
return self._opt.tracker.global_progress.num_peers
return 1

@property
def dht(self) -> "hivemind.DHT":
"""Hivemind Distributed Hash Table which stores values across all peers.

See documentation for more details: `https://learning-at-home.readthedocs.io/en/latest/modules/dht.html`
"""
return self.dht_manager.dht

@property
def root_device(self) -> torch.device:
from pytorch_lightning.accelerators.cpu import CPUAccelerator
Expand Down Expand Up @@ -361,167 +329,3 @@ def load_state_dict(self, state_dict: Dict) -> None:

def state_dict(self) -> Dict:
return self.scheduler.state_dict()


class DHTManager:
ENDPOINT_ENV: str = "PL_ENDPOINT"
PEER_ENDPOINT_ENV: str = "PL_PEER_ENDPOINT"
INITIAL_PEERS_ENV: str = "PL_INITIAL_PEERS"
HOST_ENV: str = "PL_HOST"
PORT_ENV: str = "PL_PORT"
DEFAULT_HOST: str = "0.0.0.0"
DEFAULT_PORT: int = 1440

def __init__(
self,
host_maddrs: Optional[List],
initial_peers: Optional[Union[str, List]],
persistent: bool,
endpoint: Optional[bool],
peer_endpoint: Optional[str],
host: Optional[str],
port: Optional[int],
retry_endpoint_attempts: int = 5,
retry_endpoint_sleep_duration: int = 5,
) -> None:
"""Manages the `hivemind.DHT` connection and provides a side-car endpoint server for initial peer access.

Arguments:

host_maddrs: :paramref:`~pytorch_lightning.strategies.collaborative.HivemindStrategy.host_maddrs`

initial_peers: :paramref:`~pytorch_lightning.strategies.collaborative.HivemindStrategy.initial_peers`

persistent: :paramref:`~pytorch_lightning.strategies.collaborative.HivemindStrategy.persistent`

endpoint: :paramref:`~pytorch_lightning.strategies.collaborative.HivemindStrategy.endpoint`

peer_endpoint: :paramref:`~pytorch_lightning.strategies.collaborative.HivemindStrategy.peer_endpoint`

host: :paramref:`~pytorch_lightning.strategies.collaborative.HivemindStrategy.host`

port: :paramref:`~pytorch_lightning.strategies.collaborative.HivemindStrategy.port`

retry_endpoint_attempts:
:paramref:`~pytorch_lightning.strategies.collaborative.HivemindStrategy.retry_endpoint_attempts`

retry_endpoint_sleep_duration:
:paramref:
`~pytorch_lightning.strategies.collaborative.HivemindStrategy.retry_endpoint_sleep_duration`
"""
self._persistent = persistent
self._endpoint = endpoint
self._initial_peers = initial_peers
self._peer_endpoint = peer_endpoint
self._host = host
self._port = port

self._parse_env_vars()

if self._peer_endpoint and self._initial_peers is None:
self._initial_peers = self._get_initial_peers_from_endpoint(
retry_initial_peers=retry_endpoint_attempts, retry_peer_sleep_duration=retry_endpoint_sleep_duration
)

self.dht = hivemind.DHT(
start=True,
initial_peers=self._initial_peers,
host_maddrs=host_maddrs if host_maddrs is not None else ["/ip4/0.0.0.0/tcp/0", "/ip4/0.0.0.0/udp/0/quic"],
)

visible_addresses = [
str(a) for a in self.dht.get_visible_maddrs() if not ipaddress.ip_address(a.values()[0]).is_loopback
]

if self._endpoint:
self._host = self._host if self._host is not None else self.DEFAULT_HOST
self._port = self._port if self._port is not None else self.DEFAULT_PORT
self._start_server_process(self._host, self._port)
self._log_endpoint_helper_message(visible_addresses)
elif self._peer_endpoint:
log.info("Machine received initial peers from endpoint.")
elif self._initial_peers is None:
log.info(
"\nOther machines can connect running the same command:\n"
f"INITIAL_PEERS={','.join(visible_addresses)} python ...\n"
"or passing the peers to the strategy:\n"
f"HivemindStrategy(initial_peers='{','.join(visible_addresses)}')"
)

def _log_endpoint_helper_message(self, visible_addresses: List[str]) -> None:
assert self._host is not None
resolved_host = self._host
if "0.0.0.0" in self._host:
# use the visible multi-addresses to figure out the IP that has been exposed
# todo (sean): this is pretty hacky, worth investigating.
p = re.compile(r"[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+")
# todo (sean): we select one address from here, could we have multiple?
resolved_host = {p.findall(maddr)[0] for maddr in visible_addresses}.pop()
log.info(
"\nSidecar endpoint enabled to serve peers.\n"
"Other peers can connect via:\n"
f"PEER_ENDPOINT={resolved_host}:{self._port} python ...\n"
"or pass the peer endpoint address to the strategy:\n"
f"HivemindStrategy(peer_endpoint='{resolved_host}:{self._port}')"
)

def _start_server_process(self, host: str, port: int) -> None:
dht = self.dht

class DHTHandler(BaseHTTPRequestHandler):
def do_GET(self) -> None:
"""Respond to a GET request."""
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()

visible_peers = [
str(a) for a in dht.get_visible_maddrs() if not ipaddress.ip_address(a.values()[0]).is_loopback
]

self.wfile.write("\n".join(visible_peers).encode())

server = http.server.ThreadingHTTPServer((host, int(port)), DHTHandler)
thread = threading.Thread(target=server.serve_forever)
thread.daemon = True
thread.start()

def _get_initial_peers_from_endpoint(self, retry_initial_peers: int, retry_peer_sleep_duration: int) -> List:
peers = None
for _ in range(retry_initial_peers):
try:
peers = self._get_peers()
break
except requests.exceptions.RequestException:
log.info(f"Failed to get peers, retrying in {retry_peer_sleep_duration} seconds...")
time.sleep(retry_peer_sleep_duration)
if peers is None:
raise MisconfigurationException(
f"Unable to get peers. Tried {retry_initial_peers} times waiting {retry_peer_sleep_duration}s."
f"These parameters can be extended by passing "
"to the strategy (HivemindStrategy(retry_connection=x, retry_sleep_duration=y))."
)
log.info(f"Received initial peers from collaborative server: {peers}")
return peers

def _get_peers(self) -> List[str]:
assert self._peer_endpoint is not None
url = f"http://{self._peer_endpoint}" if not self._peer_endpoint.startswith("http://") else self._peer_endpoint
r = requests.get(url)
return r.text.split(",")

def _parse_env_vars(self) -> None:
endpoint = os.environ.get(self.ENDPOINT_ENV, self._endpoint)
self._endpoint = endpoint == "1" if isinstance(endpoint, str) else endpoint
self._peer_endpoint = os.environ.get(self.PEER_ENDPOINT_ENV, self._peer_endpoint)
initial_peers = os.environ.get(self.INITIAL_PEERS_ENV, self._initial_peers)
self._initial_peers = initial_peers.split(",") if isinstance(initial_peers, str) else initial_peers

port = os.environ.get(self.PORT_ENV, self._port)
self._port = int(port) if isinstance(port, str) else port
self._host = os.environ.get(self.HOST_ENV, self._host)

@property
def disable_logging_checkpointing(self) -> bool:
# if this node is a peer, we do not log/checkpoint in persistent mode.
return self._persistent and (self._initial_peers is not None or self._peer_endpoint is not None)
Loading