diff --git a/benchmarks/pipe.py b/benchmarks/pipe.py index b67cb8169..4dcf00868 100644 --- a/benchmarks/pipe.py +++ b/benchmarks/pipe.py @@ -1,6 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. import argparse +import logging import math import os import time @@ -11,14 +12,17 @@ from torch.distributed import rpc import torch.multiprocessing as mp import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader import torchtext from torchtext.data.utils import get_tokenizer from fairscale.nn import Pipe from fairscale.nn.model_parallel import initialize_model_parallel -from fairscale.nn.pipe import pipe +from fairscale.nn.model_parallel.initialize import get_data_parallel_group, get_pipeline_parallel_group +from fairscale.nn.pipe import LazyModule, pipe from fairscale.optim import GradScaler +from fairscale.optim.oss import OSS from tests.nn.model_parallel.commons import dist_init, get_worker_map try: @@ -164,13 +168,13 @@ def make_model(args, device, ntokens): if args.lazy_construction: layers = [ - lambda: EmbeddingLayer(ntokens, ninp, initrange), - lambda: PositionalEncodingLayer(ninp, dropout), + LazyModule(lambda: EmbeddingLayer(ntokens, ninp, initrange)), + LazyModule(lambda: PositionalEncodingLayer(ninp, dropout)), ] for _ in range(ndecoder): - layers.append(lambda: TransformerDecoderLayer(ninp, nhead, nhid, dropout)) + layers.append(LazyModule(lambda: TransformerDecoderLayer(ninp, nhead, nhid, dropout))) - layers.append(lambda: LinearLayer(ninp, ntokens, initrange)) + layers.append(LazyModule(lambda: LinearLayer(ninp, ntokens, initrange))) model = layers else: model = TransformerLMSequntial(ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device) @@ -179,7 +183,10 @@ def make_model(args, device, ntokens): lr = 0.01 # learning rate def make_adam(model): - return Adam(model.parameters(), lr=lr) + if args.ddp_zero: + return OSS(params=model.parameters(), optim=Adam, group=get_data_parallel_group(), lr=lr) + else: + return Adam(model.parameters(), lr=lr) optimizer = make_adam scaler = GradScaler() @@ -276,9 +283,17 @@ def train(lm_dataloader, model, criterion, optimizer, vocab_size, args): num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters())) if model.group: - print(f"training model, #prams = {num_params}, group: {model.group.rank()}, sizes {model.group.size()}") + total = torch.Tensor([num_params]).cuda() + torch.distributed.all_reduce(total, group=model.group) + logging.info( + f"training model, #prams = {num_params}, group: {model.group.rank()}, grank:" + f" {torch.distributed.get_rank()}, sizes {model.group.size()}" + ) + torch.distributed.barrier() + if model.group.rank() == 0: + logging.info(f"total #prams = {total.item()}") else: - print(f"training model, #prams = {num_params}") + logging.info(f"training model, #prams = {num_params}") vocab_size = 10000 # FIXME total_loss = 0.0 start_time = time.time() @@ -287,37 +302,81 @@ def train(lm_dataloader, model, criterion, optimizer, vocab_size, args): optimizer = optimizer(model) def get_first_device(model): + if isinstance(model, DDP): + model = model.module + if model.devices: return model.devices[0] else: return torch.cuda.current_device() def get_last_device(model): + if isinstance(model, DDP): + model = model.module if model.devices: return model.devices[-1] else: return torch.cuda.current_device() + pipe_group = model.group + + if args.ddp_zero: + model = DDP( + model, + device_ids=[torch.cuda.current_device()], + process_group=get_data_parallel_group(), + find_unused_parameters=False, + ) + + if pipe_group and pipe_group.rank() != 0 and pipe_group.rank() != (pipe_group.size() - 1): + thing = {"input": torch.zeros(args.batch_size)} + + class FakeDataset: + def __getitem__(self, index): + return thing + + def __len__(self): + return len(lm_dataloader) + + lm_dataloader = FakeDataset() + for i, batch in enumerate(lm_dataloader): + bi = batch["input"] if args.max_batch and i > args.max_batch: break optimizer.zero_grad() - output = model(batch["input"].to(get_first_device(model))) - - if model.group is None or model.group.rank() == model.group.size() - 1: + try: + if (pipe_group is None or pipe_group.rank() == 0) and not args.ddp_zero: + tmp = batch["input"].to(get_first_device(model)) + output = model(tmp) + else: + output = model(batch["input"]) + except Exception as e: + raise RuntimeError(f"training failed on {torch.distributed.get_rank()}") from e + + if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1: target = batch["target"].to(get_last_device(model)) output = output.to(target.device) + loss = criterion(output.view(-1, vocab_size), target.view(-1)) + if args.ddp_zero: + ddp_group = get_data_parallel_group() + torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM, group=ddp_group) + loss /= ddp_group.size() loss.backward() + del target else: - model.back_helper(output) + if args.ddp_zero: + model.module.back_helper(output) + else: + model.back_helper(output) del output torch.nn.utils.clip_grad_value_(model.parameters(), 0.05) optimizer.step() - if model.group is None or model.group.rank() == model.group.size() - 1: + if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1: total_loss += loss.item() log_interval = 1 word_counter += batch["ntokens"] @@ -406,6 +465,17 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion, print("No regression detected") +def generate_balance_weighted(num_devices, num_layers, fraction=0.5): + balance = [] + layers_assigned = 0 + average_count = num_layers / num_devices + last_layers = int(average_count * fraction) + + balance = generate_balance(num_devices - 1, num_layers - last_layers) + balance.append(last_layers) + return balance + + def generate_balance(num_devices, num_layers): balance = [] layers_assigned = 0 @@ -460,7 +530,7 @@ def bench_single_process(args): blob = make_model_and_data(args, None, new_data=new_data) model = blob["model"] - balance = generate_balance(min(num_devices, 8), len(model)) + balance = generate_balance(min(num_devices, 4), len(model)) p = pipe.Pipe( model, balance, chunks=args.chunks, pipelined_backward=args.pipelined_backward, checkpoint=args.checkpoint ) @@ -480,16 +550,17 @@ def run_mp_worker(args, available_workers): blob = make_model_and_data(args, None, new_data=new_data) model = blob["model"] - balance = generate_balance(min(available_workers, 8), len(model)) + balance = generate_balance_weighted(get_pipeline_parallel_group().size(), len(model), 0.8) p = pipe.Pipe( model, balance, - style=Pipe.MultiProcess, + style=Pipe.AsyncSchedule, chunks=args.chunks, worker_map=get_worker_map(), input_device=torch.cuda.current_device(), pipelined_backward=args.pipelined_backward, checkpoint=args.checkpoint, + # loss_fn=blob["criterion"], ).cuda() if args.all_at_once and p.pipeline: @@ -537,18 +608,24 @@ def bench_multi_process(args, all_at_once=False): def bench_mpi(args): guess_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) - os.environ["UCX_NET_DEVICES"] = best_device_map[guess_rank] + world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) + local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) + os.environ["UCX_NET_DEVICES"] = best_device_map[local_rank] - torch.distributed.init_process_group(backend="mpi") os.environ["MASTER_ADDR"] = args.host - os.environ["MASTER_PORT"] = "10639" + os.environ["MASTER_PORT"] = "10638" if args.socket_name: os.environ["GLOO_SOCKET_IFNAME"] = args.socket_name os.environ["TP_SOCKET_IFNAME"] = args.socket_name + + torch.distributed.init_process_group(backend="gloo", rank=guess_rank, world_size=world_size) + + os.environ["MASTER_ADDR"] = args.host + os.environ["MASTER_PORT"] = "10639" init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}" rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() - torch.cuda.set_device(rank % torch.cuda.device_count()) + torch.cuda.set_device(local_rank % torch.cuda.device_count()) rpc.init_rpc( f"Test{rank}", @@ -558,7 +635,12 @@ def bench_mpi(args): rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(rpc_timeout=20, init_method=init_method), ) - initialize_model_parallel(1, world_size) + backends = {"model_parallel_backend": "nccl", "pipeline_backend": "mpi", "ddp_backend": "nccl"} + + if args.ddp_zero: + initialize_model_parallel(1, 4, **backends) + else: + initialize_model_parallel(1, world_size, **backends) init_random_seed(0) run_mp_worker(args, world_size) @@ -579,6 +661,7 @@ def bench_mpi(args): parser.add_argument("--max-batch", type=int, default=4, help="Max number of batches") parser.add_argument("--socket-name", type=str, default=None, help="socket ifname for gloo/tp") parser.add_argument("--num-decoder-layers", type=int, default=10, help="Number of decoder layers in the model") +parser.add_argument("--ddp-zero", action="store_true", default=False, help="enable ddp") parser.add_argument( "--lazy-construction", action="store_true", default=False, help="Number of decoder layers in the model" ) diff --git a/docs/Makefile b/docs/Makefile index d0c3cbf10..008578dde 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -12,7 +12,10 @@ BUILDDIR = build help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -.PHONY: help Makefile +setup: + pip install -r requirements.txt + +.PHONY: help Makefile setup # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). diff --git a/examples/tutorial_pipe_multiprocess.py b/examples/tutorial_pipe_multiprocess.py new file mode 100644 index 000000000..f57f1ec92 --- /dev/null +++ b/examples/tutorial_pipe_multiprocess.py @@ -0,0 +1,62 @@ +import os + +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim + +import fairscale +from fairscale.nn.model_parallel import initialize_model_parallel + + +def run(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "10638" + torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size) + os.environ["MASTER_PORT"] = "10639" + torch.distributed.rpc.init_rpc(f"worker{rank}", rank=rank, world_size=world_size) + initialize_model_parallel(1, world_size) + + model = nn.Sequential(torch.nn.Linear(10, 10), torch.nn.ReLU(), torch.nn.Linear(10, 5)) + target = torch.randint(0, 2, size=(20, 1)).squeeze() + data = torch.randn(20, 10) + loss_fn = F.nll_loss + + device = torch.device("cuda", rank) + + model = fairscale.nn.Pipe( + model, + balance=[2, 1], + style=fairscale.nn.Pipe.MultiProcess, + worker_map={0: "worker0", 1: "worker1"}, # Needed to convert ranks to RPC worker names + input_device=device, + ).to(device) + + # define optimizer and loss function + optimizer = optim.SGD(model.parameters(), lr=0.001) + + # zero the parameter gradients + optimizer.zero_grad() + + # outputs and target need to be on the same device + # forward step + outputs = model(data.to(device)) + # compute loss + if rank == 1: + loss = loss_fn(outputs.to(device), target.to(device)) + + # backward + optimize + loss.backward() + optimizer.step() + else: + model.back_helper(outputs) + + print(f"Finished Training Step on {rank}") + + del model + + +if __name__ == "__main__": + world_size = 2 + mp.spawn(run, args=(world_size,), nprocs=world_size, join=True) diff --git a/examples/tutorial_pipe_rpc.py b/examples/tutorial_pipe_rpc.py new file mode 100644 index 000000000..fc22725cb --- /dev/null +++ b/examples/tutorial_pipe_rpc.py @@ -0,0 +1,76 @@ +# run with: +# mpirun -np 2 --host localhost:2 -x PYTHONPATH=$PWD python # examples/tutorial_pipe_rpc.py + +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torch_pg + +import fairscale +from fairscale.nn.model_parallel import initialize_model_parallel + + +def register_optimizer(ctx, model): + # Set the optimizer as an attribute on the model so we can access it later + model.optimizer = optim.SGD(model.parameters(), **ctx) + # zero the parameter gradients + model.optimizer.zero_grad() + + +def run_optimizer(ctx, model): + model.optimizer.step() + + +def run(rank, world_size): + torch_pg.init_mpi() + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "10638" + torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size) + os.environ["MASTER_PORT"] = "10639" + torch.distributed.rpc.init_rpc(f"worker{rank}", rank=rank, world_size=world_size) + initialize_model_parallel(1, world_size, pipeline_backend="mpi") + + if rank == 1: + # For RPC, all ranks other than 0 just need to call rpc.shutdown() + torch.distributed.rpc.shutdown() + return + + model = nn.Sequential(torch.nn.Linear(10, 10), torch.nn.ReLU(), torch.nn.Linear(10, 5)) + target = torch.randint(0, 2, size=(20, 1)).squeeze() + data = torch.randn(20, 10) + loss_fn = F.nll_loss + + device = torch.device("cuda", rank) + + model = fairscale.nn.PipeRPCWrapper( + model, + balance=[2, 1], + worker_map={0: "worker0", 1: "worker1"}, # Needed to convert ranks to RPC worker names + input_device=device, + ).to(device) + + # We can't directly access the model on each worker, so we need to call + # foreach_worker with a callback to setup the optimizer + model.foreach_worker(register_optimizer, {"lr": 0.001}, include_self=True) + + outputs = model(data.to(device)) + loss = loss_fn(outputs.to(device), target.to(device)) + loss.backward() + + # Same as earlier, use foreach_worker to step the optimizer on each rank + model.foreach_worker(run_optimizer, include_self=True) + + print(f"Finished Training Step on {rank}") + + torch.distributed.rpc.shutdown() + + del model + + +if __name__ == "__main__": + rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) + world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) + run(rank, world_size) diff --git a/fairscale/nn/__init__.py b/fairscale/nn/__init__.py index 39e347057..45690576f 100644 --- a/fairscale/nn/__init__.py +++ b/fairscale/nn/__init__.py @@ -4,6 +4,6 @@ # LICENSE file in the root directory of this source tree. from .moe import MOELayer, Top2Gate -from .pipe import Pipe +from .pipe import LazyModule, Pipe, PipeRPCWrapper -__all__ = ["Pipe", "Top2Gate"] +__all__ = ["Pipe", "PipeRPCWrapper", "Top2Gate", "LazyModule"] diff --git a/fairscale/nn/model_parallel/initialize.py b/fairscale/nn/model_parallel/initialize.py index 3828b2425..c42a61dcc 100644 --- a/fairscale/nn/model_parallel/initialize.py +++ b/fairscale/nn/model_parallel/initialize.py @@ -22,7 +22,7 @@ """Model and data parallel groups.""" -from typing import List +from typing import List, Optional import torch @@ -38,7 +38,14 @@ _PIPELINE_PARALLEL_RANKS = None -def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int = 1) -> None: +def initialize_model_parallel( + model_parallel_size_: int, + pipeline_length: int = 1, + *, + model_parallel_backend: Optional[str] = None, + pipeline_backend: Optional[str] = None, + ddp_backend: Optional[str] = None +) -> None: """ Initialize model data parallel groups. @@ -57,8 +64,6 @@ def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int = with a total of 16 GPUs, rank 0 to 7 belong to the first box and ranks 8 to 15 belong to the second box. """ - if torch.distributed.get_rank() == 0: - print("> initializing model parallel with size {}".format(model_parallel_size_)) # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size = torch.distributed.get_world_size() @@ -69,6 +74,11 @@ def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int = data_parallel_size = int(world_size / (model_parallel_size * pipeline_length)) + if torch.distributed.get_rank() == 0: + print("> initializing model parallel with size {}".format(model_parallel_size_)) + print("> initializing ddp with size {}".format(data_parallel_size)) + print("> initializing pipeline with size {}".format(pipeline_length)) + groups = torch.LongTensor(range(world_size)).reshape(data_parallel_size, pipeline_length, model_parallel_size) found = torch.where(groups == rank) @@ -80,7 +90,7 @@ def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int = assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized" for j in range(pipeline_length): for k in range(model_parallel_size): - group = torch.distributed.new_group(groups[:, j, k].tolist()) + group = torch.distributed.new_group(groups[:, j, k].tolist(), backend=ddp_backend) if j == found[1] and k == found[2]: _DATA_PARALLEL_GROUP = group @@ -89,7 +99,7 @@ def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int = assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized" for i in range(data_parallel_size): for j in range(pipeline_length): - group = torch.distributed.new_group(groups[i, j, :].tolist()) + group = torch.distributed.new_group(groups[i, j, :].tolist(), backend=model_parallel_backend) if i == found[0] and j == found[1]: _MODEL_PARALLEL_GROUP = group @@ -100,7 +110,7 @@ def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int = for i in range(data_parallel_size): for k in range(model_parallel_size): ranks = groups[i, :, k].tolist() - group = torch.distributed.new_group(ranks) + group = torch.distributed.new_group(ranks, backend=pipeline_backend) if i == found[0] and k == found[2]: _PIPELINE_PARALLEL_GROUP = group _PIPELINE_PARALLEL_RANKS = ranks diff --git a/fairscale/nn/model_parallel/mappings.py b/fairscale/nn/model_parallel/mappings.py index 4afabe686..78d0961c5 100644 --- a/fairscale/nn/model_parallel/mappings.py +++ b/fairscale/nn/model_parallel/mappings.py @@ -39,7 +39,6 @@ def _reduce(ctx: Any, input_: torch.Tensor) -> torch.Tensor: return input_ # All-reduce. - print(f"doing all_reduce on {torch.distributed.get_rank()}") torch.distributed.all_reduce(input_, group=group) return input_ @@ -93,12 +92,10 @@ class _CopyToModelParallelRegion(torch.autograd.Function): @staticmethod def forward(ctx, input_): # type: ignore - print(f"{torch.distributed.get_rank()}: _CopyToModelParallelRegion Forward") return input_ @staticmethod def backward(ctx, grad_output): # type: ignore - print(f"{torch.distributed.get_rank()}: _CopyToModelParallelRegion Backward") return _reduce(None, grad_output) @@ -107,12 +104,10 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function): @staticmethod def forward(ctx, input_): # type: ignore - print(f"{torch.distributed.get_rank()}: _ReduceFromModelParallelRegion Forward") return _reduce(ctx, input_) @staticmethod def backward(ctx, grad_output): # type: ignore - print(f"{torch.distributed.get_rank()}: _ReduceFromModelParallelRegion Backward") return grad_output @@ -121,12 +116,10 @@ class _ScatterToModelParallelRegion(torch.autograd.Function): @staticmethod def forward(ctx, input_): # type: ignore - print(f"{torch.distributed.get_rank()}: _ScatterToModelParallelRegion Forward") return _split(input_) @staticmethod def backward(ctx, grad_output): # type: ignore - print(f"{torch.distributed.get_rank()}: _ScatterToModelParallelRegion Backward") return _gather(grad_output) @@ -135,12 +128,10 @@ class _GatherFromModelParallelRegion(torch.autograd.Function): @staticmethod def forward(ctx, input_): # type: ignore - print(f"{torch.distributed.get_rank()}: _GatherFromModelParallelRegion Forward") return _gather(input_) @staticmethod def backward(ctx, grad_output): # type: ignore - print(f"{torch.distributed.get_rank()}: _GatherFromModelParallelRegion Backward") return _split(grad_output) diff --git a/fairscale/nn/model_parallel/random.py b/fairscale/nn/model_parallel/random.py index 6cc2ca1b2..e0775ed8a 100644 --- a/fairscale/nn/model_parallel/random.py +++ b/fairscale/nn/model_parallel/random.py @@ -182,11 +182,12 @@ def model_parallel_cuda_manual_seed(seed: int) -> None: ), flush=True, ) - _CUDA_RNG_STATE_TRACKER.reset() - # Set the default state. - torch.cuda.manual_seed(data_parallel_seed) - # and model parallel state. - _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, model_parallel_seed) + if torch.cuda.is_available(): + _CUDA_RNG_STATE_TRACKER.reset() + # Set the default state. + torch.cuda.manual_seed(data_parallel_seed) + # and model parallel state. + _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, model_parallel_seed) class CheckpointFunction(torch.autograd.Function): diff --git a/fairscale/nn/pipe/__init__.py b/fairscale/nn/pipe/__init__.py index 03e779303..75591bf09 100644 --- a/fairscale/nn/pipe/__init__.py +++ b/fairscale/nn/pipe/__init__.py @@ -19,6 +19,7 @@ """A Pipe implementation in PyTorch.""" from .checkpoint import is_checkpointing, is_recomputing -from .pipe import Pipe +from .pipe import LazyModule, Pipe +from .rpc import PipeRPCWrapper -__all__ = ["Pipe", "is_checkpointing", "is_recomputing"] +__all__ = ["Pipe", "is_checkpointing", "is_recomputing", "LazyModule"] diff --git a/fairscale/nn/pipe/async_schedule.py b/fairscale/nn/pipe/async_schedule.py new file mode 100644 index 000000000..52448f6f2 --- /dev/null +++ b/fairscale/nn/pipe/async_schedule.py @@ -0,0 +1,461 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from collections import OrderedDict +from enum import Enum, auto +from threading import Event +from typing import Dict, Iterable, List, Optional, Tuple + +from dataclasses import dataclass +import torch +from torch import Tensor, nn +from torch.distributed import ProcessGroup + +from fairscale.nn.model_parallel import get_pipeline_parallel_ranks + +from .messages import Transport +from .microbatch import Batch +from .skip.tracker import SkipTrackerThroughPotals +from .types import EVENT_LOOP_QUEUE, PipelineStyle, PipeMessage, Tensors + + +@dataclass(frozen=True) +class Location: + stage: int + index: int + + def __repr__(self) -> str: + return f"{self.stage}@{self.index}" + + +@dataclass(frozen=True) +class Invocation: + order: int + this: Location + source: Optional[Location] + dest: Optional[Location] + + +Activations = Dict[int, Dict[int, Dict[int, Batch]]] +Invocations = Dict[int, Invocation] + + +@dataclass(frozen=True) +class TailBackwardContext: + activations: Activations + invocations: Invocations + count_per_order: Dict[int, int] + expected_gradients: int + + +class ModuleWrapper: + def __init__(self, module: nn.Sequential, location: Location, invocations: Optional[List[Invocation]] = None): + self.module: nn.Sequential = module + self.location: Location = location + self.invocations: List[Invocation] = invocations or [] + + def __repr__(self) -> str: + return f"{self.location}:\n" + "\n".join(map(str, self.invocations)) + "\n\t" + str(self.module) + + def __len__(self) -> int: + return len(self.module) + + def __iter__(self) -> Iterable: + yield from self.module + + +class AsyncMessageType(Enum): + Activations = auto() + Gradients = auto() + + +@dataclass(frozen=True) +class AsyncMessageBody: + message_type: AsyncMessageType + microbatch_index: int + source: Location + dest: Location + order: int + + +class AutogradWithoutActivations(torch.autograd.Function): + """A helper class to add another edge in the autograd graph which allows us + to delete the potentially large activations and still perform a backward + pass. Returns return a phony tensor which is connected to the graph.""" + + @staticmethod + # type: ignore + def forward(ctx, *x): + return torch.tensor(1.0) + + @staticmethod + # type: ignore + def backward(ctx, grad): + assert ctx.grad_from_pipeline is not None + return ctx.grad_from_pipeline + + +class AsyncRecvOperator(torch.autograd.Function): + """Receive activations to the previous pipeline stage""" + + @staticmethod + # type: ignore + def forward(ctx, phony: Tensor, transport: Transport, message: PipeMessage) -> Tensors: + ctx.transport = transport + ctx.index = message.args.microbatch_index + + result = transport.recv_message_tensors(message) + + ctx.args = result.args + + def maybe_requires_grad(t: Tensor) -> Tensor: + if t.dtype.is_floating_point: + return t.requires_grad_() + return t + + return tuple(maybe_requires_grad(r) for r in result.tensors) + + @staticmethod + # type: ignore + def backward(ctx, *grad: Tensor,) -> Tuple[Optional[Tensor], ...]: + ranks = get_pipeline_parallel_ranks() + this_rank = torch.distributed.get_rank() + body = AsyncMessageBody( + AsyncMessageType.Gradients, ctx.index, source=ctx.args.dest, dest=ctx.args.source, order=ctx.args.order - 1 + ) + ctx.transport.send_message( + PipeMessage( + this_rank, ranks[ctx.args.source.stage], queue_name=EVENT_LOOP_QUEUE, args=body, tensors=tuple(grad), + ), + sync=True, + ) + + tail_ctx = getattr(ctx, "tail_ctx", None) + if tail_ctx: + expected_gradients = tail_ctx.expected_gradients + while expected_gradients > 0: + message = ctx.transport.recv_message_header(EVENT_LOOP_QUEUE) + + args: AsyncMessageBody = message.args + assert args.message_type is AsyncMessageType.Gradients + + invocation = tail_ctx.invocations[args.order] + expected_gradients -= tail_ctx.count_per_order[invocation.order] + AsyncEventLoop.perform_backward_for_invocation(ctx.transport, message, tail_ctx.activations, invocation) + + return (None, None, None, None, None) + + +class AsyncEventLoop: + def __init__( + self, + partitions: List[ModuleWrapper], + group: ProcessGroup, + transport: Transport, + training: bool, + checkpoint_stop: int, + ): + self.training = training + self.checkpoint_stop = checkpoint_stop + self.transport = transport + self.group = group + self.partitions: List[ModuleWrapper] = partitions + + def send_async_message(self, dst_rank: int, result: Batch, invocation: Invocation) -> Batch: + """Send batch to dst_rank, and use AutogradWithoutActivations to delete + the activations since we no longer need them""" + + assert invocation.dest + src_rank = torch.distributed.get_rank() + + body = AsyncMessageBody( + AsyncMessageType.Activations, result.index, invocation.this, invocation.dest, invocation.order + 1 + ) + self.transport.send_message( + PipeMessage(src_rank, dst_rank, queue_name=EVENT_LOOP_QUEUE, args=body, tensors=tuple([*result])), + sync=True, + ) + + phony = AutogradWithoutActivations.apply(*result) + return Batch(phony, result.index) + + def run_invocation( + self, + batch: Batch, + partition: ModuleWrapper, + skip_trackers: List[SkipTrackerThroughPotals], + invocation: Invocation, + ) -> Batch: + """Actually run the forward pass for a given module, and send the result + to the next stage in the pipeline if needed.""" + assert self.group + from .pipeline import create_task + + task = create_task( + PipelineStyle.AsyncSchedule, + self.checkpoint_stop, + batch.index, + self.group.rank(), + batch, + partition.module, + skip_trackers, + [], + ) + result = task.compute() + task.finalize(result) + + if invocation.dest and invocation.dest.stage != invocation.this.stage: + ranks = get_pipeline_parallel_ranks() + dst_rank = ranks[invocation.dest.stage] + result = self.send_async_message(dst_rank, result, invocation) + return result + + @staticmethod + def perform_backward_for_invocation( + transport: Transport, message: PipeMessage, activations: Activations, invocation: Invocation + ) -> None: + """Perform the backward pass by looking up the appropriate `Batch` and + then calling `backward` on the tensor""" + + recvd_grads = transport.recv_message_tensors(message) + + batch: Batch = activations[invocation.this.index][invocation.order][message.args.microbatch_index] + + # All batches saved in `activations` are generated by AutogradWithoutActivations, + # so we store the gradients in `grad_from_pipeline` so it will be used + # during the backward pass + batch.tensor.grad_fn.grad_from_pipeline = tuple(recvd_grads.tensors) # type: ignore + batch.tensor.backward(retain_graph=True) + + def run_invocations_on_batch( + self, + batch: Batch, + invocations: Invocations, + order: int, + skip_trackers: List[SkipTrackerThroughPotals], + activations: Activations, + ) -> Tuple[int, int]: + """Run invocations on the batch until we hit one that receives its input + from a different stage (i.e. another process)""" + + invocations_handled = 0 + last_order = 0 + for invocation in invocations.values(): + if invocation.order < order: + continue + pi = invocation.this.index + partition = self.partitions[pi] + + if invocation.order == order: + invocations_handled += 1 + last_order = invocation.order + activations[pi][invocation.order][batch.index] = self.run_invocation( + batch, partition, skip_trackers, invocation + ) + elif invocation.source and invocation.source.stage == self.group.rank(): + invocations_handled += 1 + last_order = invocation.order + batch = activations[invocation.source.index][invocation.order - 1][batch.index] + activations[pi][invocation.order][batch.index] = self.run_invocation( + batch, partition, skip_trackers, invocation + ) + del activations[invocation.source.index][invocation.order - 1][batch.index] + + elif invocation.source and invocation.source.stage != self.group.rank(): + break + + return (invocations_handled, last_order) + + def event_loop_head( + self, batches: List[Batch], skip_trackers: List[SkipTrackerThroughPotals], event: Optional[Event] + ) -> None: + """The event loop for the "head", which first performs the forward pass + on any applicable layers for this stage, and then enters the common + `event_loop_inner`""" + + invocations, activations = self.get_invocations_and_activations() + + expected_invocations = len(invocations) * len(batches) + actual_invocations = 0 + + count_per_order = dict() + + for batch in batches: + inv_count, last_order = self.run_invocations_on_batch(batch, invocations, 0, skip_trackers, activations) + actual_invocations += inv_count + count_per_order[last_order] = inv_count + + if actual_invocations < expected_invocations or self.training: + self.event_loop_inner( + expected_invocations, + skip_trackers, + activations, + invocations, + count_per_order, + already_received=actual_invocations, + event=event, + ) + + def get_batch_from_message(self, message: PipeMessage) -> Batch: + """Get the tensor(s) wrapped in a `Batch` from a `PipeMessage`, applying + AsyncRecvOperator so we can intercept the backward pass""" + + microbatch_index = message.args.microbatch_index + phony = torch.empty(0, device=self.transport.input_device, requires_grad=True) + result = AsyncRecvOperator.apply(phony, self.transport, message) + if len(result) == 1: + batch = Batch(result[0], microbatch_index) + else: + batch = Batch(result, microbatch_index) + return batch + + def event_loop_tail(self, batches: List[Batch], skip_trackers: List[SkipTrackerThroughPotals]) -> None: + """The event loop for the "tail", or final stage which only processes + activations and then returns to the caller so that the loss can be + calculated. This also handles the first/only stage for the special + case of a 1-stage pipeline.""" + + assert self.group + + invocations, activations = self.get_invocations_and_activations() + expected_invocations = len(invocations) * len(batches) + actual_invocations = 0 + + rank = self.group.rank() + count_per_order = dict() + + for batch in batches: + if rank == 0: + order = 0 + else: + message = self.transport.recv_message_header(EVENT_LOOP_QUEUE) + args: AsyncMessageBody = message.args + + batch = self.get_batch_from_message(message) + order = args.order + + inv_count, last_order = self.run_invocations_on_batch(batch, invocations, order, skip_trackers, activations) + actual_invocations += inv_count + count_per_order[last_order] = inv_count + + if invocations[last_order].dest is None: + self.prepare_tail_backward( + batch, activations, invocations, count_per_order, len(invocations) - inv_count + ) + + if actual_invocations < expected_invocations: + expected_gradients = 0 # (len(invocations) - 1) * len(batches) + + self.event_loop_inner( + expected_invocations, + skip_trackers, + activations, + invocations, + count_per_order, + already_received=actual_invocations, + ignore_gradients=True, + tail=True, + ) + + _, last_invocation = invocations.popitem() + + for index, batch in activations[len(self.partitions) - 1][last_invocation.order].items(): + batches[index] = batch + + def get_invocations_and_activations(self) -> Tuple[Invocations, Activations]: + activations: Activations = dict() + invocations: Invocations = OrderedDict() + + for pi, partition in enumerate(self.partitions): + activations[pi] = dict() + for invocation in partition.invocations: + activations[pi][invocation.order] = dict() + invocations[invocation.order] = invocation + + invocations = OrderedDict(sorted(invocations.items(), key=lambda entry: entry[0])) + + return (invocations, activations) + + def event_loop(self, num_microbatch: int, skip_trackers: List[SkipTrackerThroughPotals]) -> None: + """The event loop for the "middle", i.e. neither the head nor the tail""" + assert self.group + + invocations, activations = self.get_invocations_and_activations() + + expected_invocations = len(invocations) * num_microbatch + + self.event_loop_inner(expected_invocations, skip_trackers, activations, invocations, dict()) + + def event_loop_inner( + self, + expected_invocations: int, + skip_trackers: List[SkipTrackerThroughPotals], + activations: Activations, + invocations: Invocations, + count_per_order: Dict[int, int], + *, + already_received: int = 0, + ignore_gradients: bool = False, + event: Optional[Event] = None, + tail: bool = False, + ) -> None: + """The common event loop shared by all stages. This processses + activations for the forward pass, and if `self.training` is true, + processes gradients for the backward pass.""" + + num_activations = already_received + if self.training and not ignore_gradients: + num_gradients = 0 + else: + num_gradients = expected_invocations + + while num_activations < expected_invocations or num_gradients < expected_invocations: + if num_activations == expected_invocations and num_gradients == 0 and event is not None: + # We are ready to do the backward pass, but must wait for + # PipeRPCWrapper to signal that it is safe to proceed, otherwise + # deadlock + event.wait() + + message = self.transport.recv_message_header(EVENT_LOOP_QUEUE) + args: AsyncMessageBody = message.args + + invocation = invocations[args.order] + + # FIXME(tom) for combining pipeline with megatron, I currently don't + # control the order of received activations or gradients, so it is + # possible for a reused ColumnParallelLinear for example to receive + # a different order of activations w.r.t. the sending stage, which + # would result in incorrect values being used for the all_gather + if args.message_type is AsyncMessageType.Activations: + batch = self.get_batch_from_message(message) + + inv_count, last_order = self.run_invocations_on_batch( + batch, invocations, args.order, skip_trackers, activations + ) + count_per_order[last_order] = inv_count + num_activations += inv_count + if tail and invocations[last_order].dest is None: + self.prepare_tail_backward( + batch, activations, invocations, count_per_order, len(invocations) - inv_count + ) + + assert num_activations <= expected_invocations + + elif args.message_type is AsyncMessageType.Gradients: + num_gradients += count_per_order[invocation.order] + self.perform_backward_for_invocation(self.transport, message, activations, invocation) + + @staticmethod + def prepare_tail_backward( + batch: Batch, + activations: Activations, + invocations: Invocations, + count_per_order: Dict[int, int], + expected_gradients: int, + ) -> None: + if expected_gradients > 0: + grad_fn = next(b.grad_fn for b in batch if b.requires_grad) + assert grad_fn + grad_fn.tail_ctx = TailBackwardContext(activations, invocations, count_per_order, expected_gradients) diff --git a/fairscale/nn/pipe/messages.py b/fairscale/nn/pipe/messages.py new file mode 100644 index 000000000..0613274d7 --- /dev/null +++ b/fairscale/nn/pipe/messages.py @@ -0,0 +1,159 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC +from queue import Empty as QueueEmpty +from queue import Queue +from typing import Dict, List, Optional + +from dataclasses import dataclass +import torch + +from fairscale.nn.model_parallel import get_pipeline_parallel_group +from fairscale.utils.object import pyobject_to_tensor, tensor_to_pyobject + +from .types import MESSAGE_GENERATION_START, InputDevice, PipeMessage, Tensors + +MESSAGE_TENSOR_SIZE = 1024 + +MessageQueues: List[Queue] = [Queue() for _ in range(MESSAGE_GENERATION_START)] + + +def to_input_device(tensors: Tensors, input_device: InputDevice) -> Tensors: + if input_device is None: + return tensors + else: + return tuple(t.to(input_device) for t in tensors) + + +def rpc_push_queue(message: PipeMessage) -> None: + globals()["MessageQueues"][message.queue_name].put(message) + + +@dataclass(frozen=True) +class Transport(ABC): + worker_map: Optional[Dict[int, str]] + input_device: InputDevice + + def recv_message(self, queue_name: int, *, nowait: bool = False) -> PipeMessage: + message = self.recv_message_header(queue_name, nowait) + return self.recv_message_tensors(message) + + def recv_message_header(self, queue_name: int, nowait: bool = False) -> PipeMessage: + ... + + def recv_message_tensors(self, message: PipeMessage) -> PipeMessage: + ... + + def send_message(self, message: PipeMessage, sync: bool = False, skip_header: bool = False) -> None: + ... + + def get_out_of_order(self, queue_name: int, index: int) -> Tensors: + ... + + +def MakeTransport(use_rpc: bool, worker_map: Optional[Dict[int, str]], input_device: InputDevice) -> Transport: + if use_rpc: + if worker_map is None: + raise ValueError("'RpcTransport' requires 'worker_map' to be set") + return RpcTransport(worker_map, input_device) + else: + return SendRecvTransport(worker_map, input_device) + + +class RpcTransport(Transport): + def send_message(self, message: PipeMessage, sync: bool = False, skip_header: bool = False) -> None: + message.tensors = tuple(t.cpu() for t in message.tensors) + assert self.worker_map + name = self.worker_map[message.dest] + if sync: + torch.distributed.rpc.rpc_sync(name, rpc_push_queue, args=(message,)) + else: + torch.distributed.rpc.rpc_async(name, rpc_push_queue, args=(message,)) + + def recv_message_header(self, queue_name: int, nowait: bool = False) -> PipeMessage: + queue = MessageQueues[queue_name] + if nowait: + result = queue.get_nowait() + else: + result = queue.get() + result.tensors = to_input_device(result.tensors, self.input_device) + return result + + def recv_message_tensors(self, message: PipeMessage) -> PipeMessage: + # Tensors already contained within message + message.tensors = to_input_device(message.tensors, self.input_device) + return message + + def get_out_of_order(self, queue_name: int, index: int) -> Tensors: + """Receive a message with a known microbatch index, and handle out-of-order + messages by placing them back on the queue""" + + queue = globals()["MessageQueues"][queue_name] + out_of_order: List[PipeMessage] = [] + while True: + message = self.recv_message(queue_name) + got_index = message.args + value = message.tensors + if got_index == index: + for b in out_of_order: + queue.put(b) + return value + else: + out_of_order.append(message) + + +class SendRecvTransport(Transport): + def send_message(self, message: PipeMessage, sync: bool = False, skip_header: bool = False) -> None: + tensors = message.tensors + message.tensors = tuple() + torch.cuda.current_stream().synchronize() + if not skip_header: + message.tensor_shapes = [t.size() for t in tensors] + message.tensor_dtypes = [t.dtype for t in tensors] + torch.distributed.send( + pyobject_to_tensor(message, MESSAGE_TENSOR_SIZE).cuda(), + message.dest, + tag=message.queue_name, + group=get_pipeline_parallel_group(), + ) + for index, t in enumerate(tensors): + if t.device.type == "cpu": + t = t.cuda() + torch.distributed.send( + t.contiguous(), message.dest, tag=message.tag + index, group=get_pipeline_parallel_group() + ) + + def recv_message_header(self, queue_name: int, nowait: bool = False) -> PipeMessage: + # FIXME(handle nowait) + if nowait: + raise QueueEmpty + tensor = torch.empty(MESSAGE_TENSOR_SIZE, dtype=torch.uint8, device=self.input_device) + torch.cuda.current_stream().synchronize() + torch.distributed.recv(tensor, src=None, tag=queue_name, group=get_pipeline_parallel_group()) + torch.cuda.current_stream().synchronize() + return tensor_to_pyobject(tensor) + + def recv_message_tensors(self, message: PipeMessage) -> PipeMessage: + torch.cuda.current_stream().synchronize() + + message_tensors = [] + for index, (shape, dtype) in enumerate(zip(message.tensor_shapes, message.tensor_dtypes)): + t = torch.empty(*shape, dtype=dtype, device=self.input_device) + torch.distributed.recv(t, message.src, tag=message.tag + index, group=get_pipeline_parallel_group()) + message_tensors.append(t) + + message.tensors = tuple(message_tensors) + + torch.cuda.current_stream().synchronize() + return message + + def get_out_of_order(self, queue_name: int, index: int) -> Tensors: + """Receive a message with a known microbatch index, and handle out-of-order + messages by placing them back on the queue""" + + message = self.recv_message(queue_name) + assert message.args == index + return message.tensors diff --git a/fairscale/nn/pipe/pipe.py b/fairscale/nn/pipe/pipe.py index 8a6e6fdcf..d543c64fc 100644 --- a/fairscale/nn/pipe/pipe.py +++ b/fairscale/nn/pipe/pipe.py @@ -19,24 +19,29 @@ """The Pipe interface.""" from collections import OrderedDict -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, cast +import itertools +import threading +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union, cast import warnings +from dataclasses import dataclass, field import torch from torch import Tensor, nn import torch.autograd import torch.cuda -from fairscale.nn.model_parallel import get_model_parallel_group, get_pipeline_parallel_group +from fairscale.nn.model_parallel import get_model_parallel_world_size, get_pipeline_parallel_group from . import microbatch +from .async_schedule import Invocation, Location, ModuleWrapper from .batchnorm import DeferredBatchNorm -from .pipeline import Pipeline, PipelineStyle +from .pipeline import Pipeline from .skip.layout import SkipLayout, inspect_skip_layout from .skip.skippable import Skippable, verify_skippables from .stream import AbstractStream, new_stream +from .types import LazyModule, PipelineStyle -__all__ = ["Pipe"] +__all__ = ["Pipe", "LazyModule"] Device = Union[torch.device, int, str] @@ -45,7 +50,7 @@ Tensors = Tuple[Tensor, ...] TensorOrTensors = Union[Tensor, Tensors] -ListOfLazyModules = List[Callable[[], nn.Module]] +ListOfLazyModules = List[LazyModule] if TYPE_CHECKING: Module = nn.Module[TensorOrTensors] @@ -79,10 +84,10 @@ def verify_list_of_callable(module: Union[nn.Sequential, list]) -> None: for layer in module: if isinstance(layer, nn.Module): pass - elif callable(layer): + elif isinstance(layer, LazyModule): pass else: - raise TypeError(f"layer {type(layer)} must be nn.Module or callable to be partitioned") + raise TypeError(f"layer {type(layer)} must be nn.Module or LazyModule to be partitioned") def verify_module(module: Union[nn.Sequential, ListOfLazyModules]) -> None: @@ -124,8 +129,14 @@ class BalanceError(ValueError): pass -def check_balance(module: Any, balance: Iterable[int]) -> None: - if len(module) != sum(balance): +def check_balance(module: Any, balance: Iterable[int], filter_unique: bool = False) -> None: + + if filter_unique: + module_len = len(set(map(id, module))) + else: + module_len = len(module) + + if module_len != sum(balance): raise BalanceError( f"module and sum of balance have different length (module: {len(module)}, sum of balance: {sum(balance)})" ) @@ -134,16 +145,27 @@ def check_balance(module: Any, balance: Iterable[int]) -> None: raise BalanceError(f"all balance numbers must be positive integer (balance: {balance})") +@dataclass +class PartitionInfo: + location: Location + modules: "OrderedDict[str, nn.Module]" + invocations: List[Invocation] = field(default_factory=list) + + def __len__(self) -> int: + return len(self.modules) + + def instantiate_partition( - module: Union[nn.Sequential, ListOfLazyModules], balance: Iterable[int], group: torch.distributed.ProcessGroup -) -> nn.Sequential: + module: Union[nn.Sequential, ListOfLazyModules], + balance: Iterable[int], + group: torch.distributed.ProcessGroup, + style: PipelineStyle, +) -> List[ModuleWrapper]: balance = list(balance) - check_balance(module, balance) + check_balance(module, balance, True) layers: NamedModules = OrderedDict() - j = 0 - def maybe_realize(layer: Any) -> nn.Module: if isinstance(layer, nn.Module): return layer @@ -156,7 +178,85 @@ def iterate_module(module: Union[nn.Sequential, list]) -> Iterable[Tuple[Any, nn if isinstance(module, nn.Sequential): yield from module.named_children() else: - yield from enumerate(module) + yield from ((str(k), v) for k, v in enumerate(module)) + + if style == PipelineStyle.AsyncSchedule: + module_ids = list(map(id, module)) + index_of_first_use = [module_ids.index(x) for x in module_ids] + locations: List[Location] = [] + module_iter = enumerate(iterate_module(module)) + + partitions: List[List[PartitionInfo]] = [] + for bi, b in enumerate(balance): + modules_for_rank: List[PartitionInfo] = [] + current_module: OrderedDict[str, nn.Module] = OrderedDict() + + def current_location() -> Location: + return Location(bi, len(modules_for_rank)) + + def append_module(mod: "OrderedDict[str, nn.Module]") -> None: + modules_for_rank.append(PartitionInfo(current_location(), mod)) + + while sum(map(len, modules_for_rank)) + len(current_module) < b: + module_index, (name, layer) = next(module_iter) + + if index_of_first_use[module_index] != module_index: + # Subsequent reuse of a module + locations.append(locations[index_of_first_use[module_index]]) + continue + + is_reused = index_of_first_use.count(index_of_first_use[module_index]) > 1 + + if is_reused and len(current_module) > 0: + append_module(current_module) + current_module = OrderedDict() + + current_module[str(name)] = layer + locations.append(current_location()) + + if is_reused: + append_module(current_module) + current_module = OrderedDict() + + if len(current_module) > 0: + append_module(current_module) + + partitions.append(modules_for_rank) + + filtered_locations: List[Optional[Location]] = [loc for loc, _ in itertools.groupby(locations)] + filtered_locations.append(None) + + for i in range(len(filtered_locations) - 1): + loc = filtered_locations[i] + assert loc + if i == 0: + inv = Invocation(i, loc, None, filtered_locations[i + 1]) + else: + inv = Invocation(i, loc, filtered_locations[i - 1], filtered_locations[i + 1]) + + partitions[loc.stage][loc.index].invocations.append(inv) + + invocations = enumerate(iterate_module(module)) + + partition = partitions[group.rank()] + result: List[ModuleWrapper] = [] + for partition_info in partition: + wrapper = ModuleWrapper( + nn.Sequential(OrderedDict((k, maybe_realize(m)) for k, m in partition_info.modules.items())), + partition_info.location, + partition_info.invocations, + ) + + if not isinstance(module, nn.Sequential): + for layer in wrapper.module: + if isinstance(layer, Skippable): + raise ValueError("Can't use Skippable layers with multi-process pipe and lazy construction") + + result.append(wrapper) + + return result + + j = 0 for name, layer in iterate_module(module): layers[name] = layer @@ -170,8 +270,7 @@ def iterate_module(module: Union[nn.Sequential, list]) -> Iterable[Tuple[Any, nn if isinstance(layer, Skippable): raise ValueError("Can't use Skippable layers with multi-process pipe and lazy construction") - partition = nn.Sequential(*layers.values()) - return partition + return [ModuleWrapper(nn.Sequential(layers), Location(j, 0))] # Prepare for the next partition. layers.clear() @@ -297,7 +396,7 @@ class Pipe(Module): backward pass (instead of once for the whole batch). This works around a potential deadlock in pytorch when using tensor parallelism at the same time. Defaults to `True` if - `get_model_parallel_group.size() > 1` + `get_model_parallel_world_size() > 1` (default: `None`) retain_graph (bool): The value passed to `torch.autograd.backwards(..., retain_graph=) @@ -315,6 +414,7 @@ class Pipe(Module): SingleProcess: PipelineStyle = PipelineStyle.SingleProcess MultiProcess: PipelineStyle = PipelineStyle.MultiProcess + AsyncSchedule: PipelineStyle = PipelineStyle.AsyncSchedule #: The number of layers in each partition. balance: List[int] = [] @@ -359,6 +459,7 @@ def __init__( deferred_batch_norm: bool = False, pipelined_backward: bool = None, retain_graph: bool = False, + loss_fn: Optional[nn.Module] = None, ) -> None: super().__init__() @@ -384,6 +485,17 @@ def __init__( self.pipelined_backward = pipelined_backward self.retain_graph = retain_graph self.pipeline: Optional[Pipeline] + self.loss_fn = loss_fn + self.lock = threading.Lock() + + self.group = group + self.worker_map = worker_map + self.input_device = input_device + + self._copy_streams: List[List[AbstractStream]] = [] + + # The micro-batch index where the checkpointing stops. + checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint] if style is PipelineStyle.SingleProcess: module = cast(nn.Sequential, module) @@ -407,29 +519,42 @@ def __init__( self._skip_layout = inspect_skip_layout(self.partitions) - elif style is PipelineStyle.MultiProcess: - if group is None: - group = get_pipeline_parallel_group() + # Separate CUDA streams for copy. + copy_streams = self._ensure_copy_streams() + if self.pipelined_backward is None: + self.pipelined_backward = False + self.pipeline = Pipeline( + self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop, style=style, + ) + + elif style in [PipelineStyle.MultiProcess, PipelineStyle.AsyncSchedule]: + + if self.group is None: + self.group = get_pipeline_parallel_group() + assert self.group if devices is not None: raise ValueError("'devices' argument only applies to 'PipelineStyle.SingleProcess'") self.balance = list(balance) - if group.size() < len(self.balance): + if self.group.size() < len(self.balance): raise IndexError( - f"too few ranks to hold given partitions (ranks: {group.size()}, partitions: {len(self.balance)})" + f"too few ranks to hold given partitions (ranks: {self.group.size()}, partitions:" + f" {len(self.balance)})" ) try: - rank = torch.distributed.get_rank(group) + rank = self.group.rank() if rank >= len(self.balance): warnings.warn("More ranks than partitions, some ranks unused") - self.partitions = cast(List[nn.Sequential], nn.ModuleList([nn.Sequential()])) + self.mp_partitions: List[ModuleWrapper] = [] else: - partition = instantiate_partition(module, balance, group) + self.mp_partitions = instantiate_partition(module, balance, self.group, style) if deferred_batch_norm: - partition = DeferredBatchNorm.convert_deferred_batch_norm(partition, chunks) - self.partitions = cast(List[nn.Sequential], nn.ModuleList([partition])) + for part in self.mp_partitions: + part.module = DeferredBatchNorm.convert_deferred_batch_norm(part.module, chunks) + for name, part in enumerate(self.mp_partitions): + self.add_module(str(name), part.module) self.devices = None if isinstance(module, nn.Sequential): local_partitions, _, _ = split_module(module, balance, None) @@ -440,31 +565,16 @@ def __init__( except BalanceError as exc: raise ValueError(recommend_auto_balance(str(exc))) - self.group = group - self.worker_map = worker_map - self.input_device = input_device - - self._copy_streams: List[List[AbstractStream]] = [] - - # The micro-batch index where the checkpointing stops. - checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint] - - if style is PipelineStyle.SingleProcess: - # Separate CUDA streams for copy. - copy_streams = self._ensure_copy_streams() - if self.pipelined_backward is None: - self.pipelined_backward = False - self.pipeline = Pipeline( - self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop, style=style - ) - elif style is PipelineStyle.MultiProcess: - rank = torch.distributed.get_rank(group) + rank = self.group.rank() if rank >= len(self.balance): self.pipeline = None + self.final_stage = False else: self.final_stage = rank == len(self.balance) - 1 + assert loss_fn is None or self.final_stage + self.pipeline = Pipeline( - self.partitions, + cast(List[nn.Sequential], self.mp_partitions), None, None, self._skip_layout, @@ -473,27 +583,39 @@ def __init__( group=self.group, worker_map=self.worker_map, input_device=self.input_device, + final_stage=self.final_stage, ) del module if self.pipelined_backward is None: - if get_model_parallel_group().size() > 1: + if get_model_parallel_world_size() > 1: self.pipelined_backward = True else: self.pipelined_backward = False def __len__(self) -> int: """Counts the length of the underlying sequential module.""" - return sum(len(p) for p in self.partitions) + if hasattr(self, "partitions"): + return sum(len(p) for p in self.partitions) + else: + return sum(len(p) for p in self.mp_partitions) def __getitem__(self, index: int) -> nn.Module: """Gets a layer in the underlying sequential module.""" - partitions = self.partitions + partitions: List[Any] + if hasattr(self, "partitions"): + partitions = self.partitions + else: + partitions = self.mp_partitions + if index < 0: partitions = partitions[::-1] for partition in partitions: try: - return partition[index] + if isinstance(partition, ModuleWrapper): + return partition.module[index] + else: + return partition[index] except IndexError: pass @@ -508,8 +630,12 @@ def __getitem__(self, index: int) -> nn.Module: def __iter__(self) -> Iterable[nn.Module]: """Iterates over children of the underlying sequential module.""" - for partition in self.partitions: - yield from partition + if hasattr(self, "partitions"): + for partition in self.partitions: + yield from partition + else: + for mp_partition in self.mp_partitions: + yield from mp_partition.module # Pipe should manage the device of each partition. # Deny cuda(), cpu(), and to() with device, by TypeError. @@ -527,7 +653,7 @@ def cpu(self) -> "Pipe": return super().cpu() def to(self, *args: Any, **kwargs: Any) -> "Pipe": - """ Restrict .to() options. + """Restrict .to() options. Deny these usages: - to(device[, dtype, non_blocking]) @@ -563,7 +689,7 @@ def _ensure_copy_streams(self) -> List[List[AbstractStream]]: return self._copy_streams - def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore + def forward(self, input: TensorOrTensors, *, event=None) -> TensorOrTensors: # type: ignore """:class:`Pipe` is a fairly transparent module wrapper. It doesn't modify the input and output signature of the underlying module. But there's type restriction. Input and output have to be a @@ -594,25 +720,26 @@ def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore batches = microbatch.scatter(input, self.chunks) # Run pipeline parallelism. - self.pipeline.run(batches) - - if self.group and not self.final_stage: - # Don't merge micro-batches to avoid unnecessary edges in autograd - # graph - # FIXME(tom) should figure out a proper type here - return batches # type: ignore - else: - # Merge the micro-batches into one mini-batch. - if self.pipelined_backward: - with torch.no_grad(): - output = microbatch.gather(batches) + with self.lock: + self.pipeline.run(self.training, batches, event) + + if self.group and not self.final_stage: + # Don't merge micro-batches to avoid unnecessary edges in autograd + # graph + # FIXME(tom) should figure out a proper type here + return batches # type: ignore + else: + # Merge the micro-batches into one mini-batch. + if self.pipelined_backward: + with torch.no_grad(): + output = microbatch.gather(batches) - from .phony import get_phony + from .phony import get_phony - phony = get_phony(torch.device(torch.cuda.current_device()), requires_grad=True) - output = PipelinedBackwardPass.apply(output, batches, phony, True) # self.retain_graph) - else: - output = microbatch.gather(batches) + phony = get_phony(torch.device(torch.cuda.current_device()), requires_grad=True) + output = PipelinedBackwardPass.apply(output, batches, phony, True) # self.retain_graph) + else: + output = microbatch.gather(batches) return output diff --git a/fairscale/nn/pipe/pipeline.py b/fairscale/nn/pipe/pipeline.py index 4f395a7b1..2c50de019 100644 --- a/fairscale/nn/pipe/pipeline.py +++ b/fairscale/nn/pipe/pipeline.py @@ -17,214 +17,56 @@ # limitations under the License. """The pipeline parallelism of Pipe.""" -from enum import Enum, auto +import logging import os -import pickle from queue import Empty as QueueEmpty from queue import Queue +from threading import Event from types import TracebackType -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Type, Union, cast +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Type, Union, cast -from dataclasses import dataclass -import numpy as np import torch from torch import Tensor, nn from torch.autograd.profiler import record_function from fairscale.nn.model_parallel import get_pipeline_parallel_ranks +from .async_schedule import AsyncEventLoop, ModuleWrapper from .checkpoint import Checkpointing from .copy import Copy, Wait from .dependency import fork, join +from .messages import MakeTransport, Transport from .microbatch import Batch from .skip import Namespace from .skip.layout import SkipLayout from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker from .stream import AbstractStream, current_stream, use_device +from .types import ( + ACTIVATIONS_GRADS_QUEUE, + PORTAL_QUEUE, + SKIP_TENSOR_QUEUE, + PipelineStyle, + PipeMessage, + Schedule, + TensorOrTensors, + Tensors, +) from .worker import Task, create_workers, join_workers __all__: List[str] = [] - -Tensors = Tuple[Tensor, ...] -TensorOrTensors = Union[Tensor, Tensors] - -InputDevice = Union[None, int, str, torch.device] -Schedule = List[Tuple[int, int]] - ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] -MessageQueues: List[Queue] = [Queue(), Queue(), Queue()] -ACTIVATIONS_GRADS_QUEUE = 0 -SKIP_TENSOR_QUEUE = 1 -PORTAL_QUEUE = 2 -MESSAGE_GENERATION_START = 3 - - -# FIXME Why is 256 ok for training but not for tests? -MESSAGE_TENSOR_SIZE = 512 # 256 - -MessageGeneration = MESSAGE_GENERATION_START - - -class PipelineStyle(Enum): - SingleProcess = auto() - MultiProcess = auto() - - -@dataclass(frozen=True) -class TransportConfig: - use_rpc: bool - worker_map: Optional[Dict[int, str]] - - -@dataclass -class PipeMessage: - src: int - dest: int - queue_name: int - args: Any - tensors: Tensors - tensor_shapes: List[torch.Size] - tensor_dtypes: List[torch.dtype] - tag: int = 0 - - def __init__(self, src: int, dest: int, queue_name: int, args: Any, tensors: Tensors): - self.src = src - self.dest = dest - self.queue_name = queue_name - self.args = args - self.tensors = tensors - - global MessageGeneration - self.tag = MessageGeneration - MessageGeneration += len(tensors) - - -def rpc_push_queue(message: PipeMessage) -> None: - globals()["MessageQueues"][message.queue_name].put(message) - - -def pyobject_to_tensor(obj: Any) -> Tensor: - pickled = pickle.dumps(obj) - nparray = np.frombuffer(pickled, dtype=np.uint8).copy() - nparray.setflags(write=True) - result = torch.from_numpy(nparray) - delta = MESSAGE_TENSOR_SIZE - len(result) - if delta < 0: - raise ValueError( - f"message too big to send, increase MESSAGE_TENSOR_SIZE? - {len(result)} > {MESSAGE_TENSOR_SIZE}" - ) - elif delta > 0: - result = torch.cat((result, torch.zeros(delta, dtype=torch.uint8))) - - return result.cuda() - - -def tensor_to_pyobject(tensor: Tensor) -> Any: - nparray = tensor.numpy() - return pickle.loads(nparray.tobytes()) - - -def send_message(config: TransportConfig, message: PipeMessage, sync: bool = False) -> None: - if config.use_rpc: - message.tensors = tuple(t.cpu() for t in message.tensors) - assert config.worker_map - name = config.worker_map[message.dest] - if sync: - torch.distributed.rpc.rpc_sync(name, rpc_push_queue, args=(message,)) - else: - torch.distributed.rpc.rpc_async(name, rpc_push_queue, args=(message,)) - else: - tensors = message.tensors - message.tensors = tuple() - message.tensor_shapes = [t.size() for t in tensors] - message.tensor_dtypes = [t.dtype for t in tensors] - torch.cuda.current_stream().synchronize() - torch.distributed.send(pyobject_to_tensor(message), message.dest, tag=0) - for index, t in enumerate(tensors): - if t.device.type == "cpu": - t = t.cuda() - torch.distributed.send(t, message.dest, tag=message.tag + index) - - -def recv_message( - config: TransportConfig, queue_name: int, *, nowait: bool = False, input_device: InputDevice = None -) -> PipeMessage: - if config.use_rpc: - queue = globals()["MessageQueues"][queue_name] - if nowait: - result = queue.get_nowait() - else: - result = queue.get() - result.tensors = to_input_device(result.tensors, input_device) - return result - else: - # FIXME(handle nowait) - if nowait: - raise QueueEmpty - - tensor = torch.empty(MESSAGE_TENSOR_SIZE, dtype=torch.uint8, device=input_device) - torch.distributed.recv(tensor, src=-1, tag=queue_name) - message = tensor_to_pyobject(tensor.cpu()) - - torch.cuda.current_stream().synchronize() - - message_tensors = [] - for index, (shape, dtype) in enumerate(zip(message.tensor_shapes, message.tensor_dtypes)): - t = torch.empty(*shape, dtype=dtype, device=input_device) - torch.distributed.recv(t, message.src, tag=message.tag + index) - message_tensors.append(t) - - message.tensors = tuple(message_tensors) - - torch.cuda.current_stream().synchronize() - return message - - -def get_out_of_order(config: TransportConfig, queue_name: int, index: int, *, input_device: InputDevice) -> Tensors: - """Receive a message with a known microbatch index, and handle out-of-order - messages by placing them back on the queue""" - - if config.use_rpc: - queue = globals()["MessageQueues"][queue_name] - out_of_order: List[PipeMessage] = [] - while True: - message = recv_message(config, queue_name, input_device=input_device) - got_index = message.args - value = message.tensors - if got_index == index: - for b in out_of_order: - queue.put(b) - return value - else: - out_of_order.append(message) - else: - message = recv_message(config, queue_name, input_device=input_device) - assert message.args == index - return message.tensors - - -def to_input_device(tensors: TensorOrTensors, input_device: InputDevice) -> TensorOrTensors: - if input_device is None: - return tensors - else: - if isinstance(tensors, Tensor): - return tensors.to(input_device) - else: - return tuple(t.to(input_device) for t in tensors) - class SendOperator(torch.autograd.Function): """Send activations to the next pipeline stage""" @staticmethod # type: ignore - def forward(ctx, src_rank, dst_rank, config: TransportConfig, input: List[Tensor], index: int) -> Tensors: + def forward(ctx, src_rank, dst_rank, transport: Transport, input: List[Tensor], index: int) -> Tensors: assert src_rank == torch.distributed.get_rank() - send_message( - config, + transport.send_message( PipeMessage(src_rank, dst_rank, queue_name=ACTIVATIONS_GRADS_QUEUE, args=index, tensors=tuple(input)), ) return () @@ -240,12 +82,12 @@ class RecvOperator(torch.autograd.Function): @staticmethod # type: ignore - def forward(ctx, dst_rank: int, tensor: Tensor, input_device, config: TransportConfig, index: int) -> Tensors: + def forward(ctx, dst_rank: int, tensor: Tensor, input_device, transport: Transport, index: int) -> Tensors: assert dst_rank == torch.distributed.get_rank() - ctx.config = config + ctx.transport = transport ctx.index = index - result = get_out_of_order(config, ACTIVATIONS_GRADS_QUEUE, index, input_device=input_device) + result = transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, index) def maybe_requires_grad(t: Tensor) -> Tensor: if t.dtype.is_floating_point: @@ -259,8 +101,7 @@ def maybe_requires_grad(t: Tensor) -> Tensor: def backward(ctx, *grad: Tensor,) -> Tuple[Optional[Tensor], ...]: ranks = get_pipeline_parallel_ranks() this_rank = torch.distributed.get_rank() - send_message( - ctx.config, + ctx.transport.send_message( PipeMessage( this_rank, ranks[ranks.index(this_rank) - 1], @@ -318,6 +159,57 @@ def clock_cycles(m: int, n: int) -> Iterable[Schedule]: yield [(k - j, j) for j in range(max(1 + k - m, 0), min(1 + k, n))] +def create_task( + style: PipelineStyle, + checkpoint_stop: int, + i: int, + j: int, + batch: Batch, + partition: nn.Sequential, + skip_trackers: List[SkipTrackerThroughPotals], + streams: List[AbstractStream], +) -> Task: + # Determine whether checkpointing or not. + if i < checkpoint_stop: + + def function( + input: TensorOrTensors, + partition: nn.Sequential = partition, + skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], + chunk_id: int = i, + part_id: int = j, + ) -> TensorOrTensors: + with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): + return partition(input) + + chk = Checkpointing(function, batch) + if style is PipelineStyle.SingleProcess: + task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute) + elif style in [PipelineStyle.MultiProcess, PipelineStyle.AsyncSchedule]: + task = Task(None, compute=chk.checkpoint, finalize=chk.recompute) + del function, chk # TODO(tom) maybe remove + + else: + + def compute( + batch: Batch = batch, + partition: nn.Sequential = partition, + skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], + chunk_id: int = i, + part_id: int = j, + ) -> Batch: + with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): + return batch.call(partition) + + if style is PipelineStyle.SingleProcess: + task = Task(streams[j], compute=compute, finalize=None) + elif style in [PipelineStyle.MultiProcess, PipelineStyle.AsyncSchedule]: + task = Task(None, compute=compute, finalize=None) + del compute # TODO(tom) maybe remove + + return task + + class Pipeline: """The pipeline parallelism for Pipe.""" @@ -332,52 +224,63 @@ def __init__( group: Optional[torch.distributed.ProcessGroup] = None, worker_map: Optional[Dict[int, str]] = None, input_device: Union[None, int, str, torch.device] = None, + final_stage: bool = False, ) -> None: - self.partitions = partitions + if style == PipelineStyle.SingleProcess: + self.partitions = partitions + else: + self.mp_partitions: List[ModuleWrapper] = cast(List[ModuleWrapper], partitions) self.devices = devices self.copy_streams = copy_streams self.skip_layout = skip_layout - self.checkpoint_stop = checkpoint_stop + self.__checkpoint_stop = checkpoint_stop self.style = style self.group = group - self.transport_config = TransportConfig( - use_rpc=("OMPI_COMM_WORLD_RANK" not in os.environ), worker_map=worker_map - ) - + self.training: bool + if style in [PipelineStyle.MultiProcess, PipelineStyle.AsyncSchedule]: + self.transport = MakeTransport( + use_rpc=("OMPI_COMM_WORLD_RANK" not in os.environ) or ("FORCE_RPC" in os.environ), + worker_map=worker_map, + input_device=input_device, + ) self.input_device = input_device self.all_at_once = False self.callcount = 0 + self.final_stage = final_stage if self.style is PipelineStyle.SingleProcess: assert self.devices is not None (self.in_queues, self.out_queues) = create_workers(self.devices) - if ( - self.style is PipelineStyle.MultiProcess - and self.transport_config.worker_map is None - and self.transport_config.use_rpc is True - ): - raise ValueError("'PipelineStyle.MultiProcess' requires 'worker_map' to be set") + @property + def checkpoint_stop(self) -> int: + # Disable checkpointing if in eval mode. + if self.style == PipelineStyle.SingleProcess: + training = self.partitions[0].training + else: + training = self.mp_partitions[0].module.training + if not training: + return 0 + return self.__checkpoint_stop def __del__(self) -> None: if self.style is PipelineStyle.SingleProcess: join_workers(self.in_queues, self.out_queues) - def run(self, batches: List[Batch]) -> None: + def run(self, training: bool, batches: List[Batch], event: Optional[Event]) -> None: """Runs pipeline parallelism. It modifies the given batches in place. """ - partitions = self.partitions - devices = self.devices + self.training = training m = len(batches) - n = len(partitions) skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(len(batches))] if self.style is PipelineStyle.SingleProcess: + n = len(self.partitions) for schedule in clock_cycles(m, n): self.fence(batches, schedule, skip_trackers) self.compute(batches, schedule, skip_trackers) @@ -385,6 +288,24 @@ def run(self, batches: List[Batch]) -> None: assert self.group schedule = [(i, self.group.rank()) for i in range(m)] self.compute(batches, schedule, skip_trackers) + elif self.style is PipelineStyle.AsyncSchedule: + assert self.group + rank = self.group.rank() + event_loop = AsyncEventLoop( + self.mp_partitions, self.group, self.transport, self.training, self.checkpoint_stop, + ) + if rank == 0 and not self.final_stage: + logging.debug(f"{torch.distributed.get_rank()}: entered event head") + event_loop.event_loop_head(batches, skip_trackers, event) + logging.debug(f"{torch.distributed.get_rank()}: exited event head") + elif self.final_stage: + logging.debug(f"{torch.distributed.get_rank()}: entered event tail") + event_loop.event_loop_tail(batches, skip_trackers) + logging.debug(f"{torch.distributed.get_rank()}: exited event tail") + else: + logging.debug(f"{torch.distributed.get_rank()}: entered event loop") + event_loop.event_loop(len(batches), skip_trackers) + logging.debug(f"{torch.distributed.get_rank()}: exited event loop") self.callcount += 1 @@ -421,7 +342,7 @@ def get_batch_from_previous_stage( ) -> Batch: phony = torch.empty(0, device=self.input_device, requires_grad=True) - result = RecvOperator.apply(torch.distributed.get_rank(), phony, self.input_device, self.transport_config, i) + result = RecvOperator.apply(torch.distributed.get_rank(), phony, self.input_device, self.transport, i) if len(result) == 1: batch = Batch(result[0], i) else: @@ -443,8 +364,7 @@ def send_skip_tensors( else: tensors = tuple() - send_message( - self.transport_config, + self.transport.send_message( PipeMessage( this_rank, ranks[next_j], queue_name=SKIP_TENSOR_QUEUE, args=(i, ns, name, life), tensors=tensors, ), @@ -454,9 +374,7 @@ def send_skip_tensors( def recv_skip_tensors(self, skip_trackers: List[SkipTrackerThroughPotals], batches: List[Batch]) -> None: while True: try: - message = recv_message( - self.transport_config, SKIP_TENSOR_QUEUE, nowait=True, input_device=self.input_device - ) + message = self.transport.recv_message(SKIP_TENSOR_QUEUE, nowait=True) (si, ns, name, life) = message.args value: Optional[TensorOrTensors] = message.tensors @@ -481,12 +399,12 @@ def execute_task(self, task: Task, i: int, skip_trackers: List[SkipTrackerThroug assert self.group rank = self.group.rank() - if rank != self.group.size() - 1: + if self.style is PipelineStyle.MultiProcess and not self.final_stage: ranks = get_pipeline_parallel_ranks() this_rank = torch.distributed.get_rank() self.send_skip_tensors(this_rank, ranks, batch, i, skip_trackers) - SendOperator.apply(this_rank, ranks[ranks.index(this_rank) + 1], self.transport_config, [*batch], i) + SendOperator.apply(this_rank, ranks[ranks.index(this_rank) + 1], self.transport, [*batch], i) for portal in skip_trackers[i].portals.values(): portal.pipeline = self @@ -534,72 +452,16 @@ def finalize_tasks( if exc_info is not None: raise exc_info[0].with_traceback(exc_info[1], exc_info[2]) - def create_task( - self, - i: int, - j: int, - batch: Batch, - checkpoint_stop: int, - partition: nn.Sequential, - skip_trackers: List[SkipTrackerThroughPotals], - streams: List[AbstractStream], - ) -> Task: - # Determine whether checkpointing or not. - if i < checkpoint_stop: - - def function( - input: TensorOrTensors, - partition: nn.Sequential = partition, - skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], - chunk_id: int = i, - part_id: int = j, - ) -> TensorOrTensors: - with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): - return partition(input) - - chk = Checkpointing(function, batch) - if self.style is PipelineStyle.SingleProcess: - task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute) - elif self.style is PipelineStyle.MultiProcess: - task = Task(None, compute=chk.checkpoint, finalize=chk.recompute) - del function, chk # TODO(tom) maybe remove - - else: - - def compute( - batch: Batch = batch, - partition: nn.Sequential = partition, - skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], - chunk_id: int = i, - part_id: int = j, - ) -> Batch: - with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): - return batch.call(partition) - - if self.style is PipelineStyle.SingleProcess: - task = Task(streams[j], compute=compute, finalize=None) - elif self.style is PipelineStyle.MultiProcess: - task = Task(None, compute=compute, finalize=None) - del compute # TODO(tom) maybe remove - - return task - def compute( - self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], + self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals] ) -> None: """Runs tasks with synchronization to copy streams.""" - partitions = self.partitions devices = self.devices copy_streams = self.copy_streams - checkpoint_stop = self.checkpoint_stop - - # Disable checkpointing if in eval mode. - if not self.partitions[0].training: - checkpoint_stop = 0 if self.style is PipelineStyle.SingleProcess: assert devices is not None - n = len(partitions) + n = len(self.partitions) streams = [current_stream(d) for d in devices] elif self.style is PipelineStyle.MultiProcess: assert self.group @@ -635,25 +497,28 @@ def compute( batch = batches[i] if self.style is PipelineStyle.SingleProcess: - partition = partitions[j] + partition = self.partitions[j] # Synchronize with the copied input. ([1] in the diagram) assert copy_streams if j != 0: wait(batch, copy_streams[j][i], streams[j]) + + task = create_task(self.style, self.checkpoint_stop, i, j, batch, partition, skip_trackers, streams) + + # Compute tasks in parallel. ([2] in the diagram) + self.in_queues[j].put(task) elif self.style is PipelineStyle.MultiProcess: - assert len(self.partitions) == 1 - partition = self.partitions[0] + assert len(self.mp_partitions) == 1 + mp_partition = self.mp_partitions[0] assert self.group if self.group.rank() != 0: batch = self.get_batch_from_previous_stage(i, skip_trackers, batches) - task = self.create_task(i, j, batch, checkpoint_stop, partition, skip_trackers, streams) + task = create_task( + self.style, self.checkpoint_stop, i, j, batch, mp_partition.module, skip_trackers, streams + ) - if self.style is PipelineStyle.SingleProcess: - # Compute tasks in parallel. ([2] in the diagram) - self.in_queues[j].put(task) - elif self.style is PipelineStyle.MultiProcess: batches[i] = self.execute_task(task, i, skip_trackers) if self.style is PipelineStyle.SingleProcess: @@ -671,14 +536,12 @@ def send_portal_grad(self, ns_name: Tuple[Namespace, str], index: int, grad: Ten if isinstance(grad, Tensor): grad = tuple([grad]) - send_message( - self.transport_config, - PipeMessage(ranks[src], dst_rank, queue_name=PORTAL_QUEUE, args=(ns_name, index), tensors=grad), - sync=True, + self.transport.send_message( + PipeMessage(ranks[src], dst_rank, queue_name=PORTAL_QUEUE, args=(ns_name, index), tensors=grad), sync=True, ) def recv_portal_grad(self, expected_ns_name: Tuple[Namespace, str], expected_index: int) -> Tensor: - message = recv_message(self.transport_config, PORTAL_QUEUE, input_device=self.input_device) + message = self.transport.recv_message(PORTAL_QUEUE) (ns_name, index) = message.args grad = message.tensors @@ -689,6 +552,9 @@ def recv_portal_grad(self, expected_ns_name: Tuple[Namespace, str], expected_ind return result def back_helper(self, output: List[Batch]) -> None: + if self.style == PipelineStyle.AsyncSchedule: + return + o = list(output) tensors: Tensors @@ -698,9 +564,7 @@ def back_helper(self, output: List[Batch]) -> None: grads = [] for i, batch in enumerate(o): rank = torch.distributed.get_rank() - found = get_out_of_order( - self.transport_config, ACTIVATIONS_GRADS_QUEUE, i, input_device=self.input_device - ) + found = self.transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, i) assert len(found) == 1 grads.append(found[0]) tensors = tuple(x.tensor_or_tensors for x in o) # type: ignore @@ -711,9 +575,7 @@ def back_helper(self, output: List[Batch]) -> None: else: rank = torch.distributed.get_rank() for batch in o: - found = get_out_of_order( - self.transport_config, ACTIVATIONS_GRADS_QUEUE, batch.index, input_device=self.input_device - ) + found = self.transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, batch.index) if batch.atomic: tensors = tuple([batch.tensor]) else: @@ -732,4 +594,4 @@ def back_helper(self, output: List[Batch]) -> None: try: torch.autograd.backward(final_tensors, grad_tensors=grads, retain_graph=True) except Exception as e: - raise RuntimeError("Autograd failed") from e + raise RuntimeError(f"Autograd failed on {torch.distributed.get_rank()}") from e diff --git a/fairscale/nn/pipe/rpc.py b/fairscale/nn/pipe/rpc.py new file mode 100644 index 000000000..dbea851ab --- /dev/null +++ b/fairscale/nn/pipe/rpc.py @@ -0,0 +1,284 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from threading import Event, Lock, Thread +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast + +import torch +from torch import nn +from torch.distributed import ProcessGroup, rpc +from torch.distributed.distributed_c10d import _get_global_rank + +from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group + +from . import Pipe +from .types import EVENT_LOOP_QUEUE, PipeMessage, TensorOrTensors + +DEFAULT_MAX_SOURCE_POSITIONS = 1024 +DEFAULT_MAX_TARGET_POSITIONS = 1024 + +PipeModel: Pipe +PipeResult: TensorOrTensors + + +SizeOrSizes = Union[torch.Size, List[torch.Size]] +DtypeOrDtypes = Union[torch.dtype, List[torch.dtype]] + + +def set_device_based_on_group(group: ProcessGroup) -> None: + # torch.cuda.set_device(group.rank() % torch.cuda.device_count()) + torch.cuda.set_device(torch.distributed.get_rank() % torch.cuda.device_count()) + + +def get_shapes(tensor: TensorOrTensors) -> SizeOrSizes: + if isinstance(tensor, torch.Tensor): + return tensor.shape + else: + return [t.shape for t in tensor] + + +def get_dtype(tensor: TensorOrTensors) -> DtypeOrDtypes: + if isinstance(tensor, torch.Tensor): + return tensor.dtype + else: + return [t.dtype for t in tensor] + + +def get_global_ranks_from_group(group: ProcessGroup) -> List[int]: + return [_get_global_rank(group, r) for r in range(group.size())] + + +class PipeBackRedirect(torch.autograd.Function): + @staticmethod + # type: ignore + def forward(ctx, inputs, dest, event, message, transport, futures): + ctx.dest = dest + ctx.event = event + ctx.message = message + ctx.transport = transport + ctx.futures = futures + return inputs + + @staticmethod + # type: ignore + def backward(ctx, *grad): + ctx.message.tensors = tuple(grad) + ctx.transport.send_message(ctx.message, sync=False, skip_header=True) + ctx.event.set() + # torch.futures.wait_all(ctx.futures) + return (None, None, None, None, None, None) + + +def callback_with_model(callback: Callable[[Any, Pipe], None], ctx: Any) -> None: + try: + group = get_pipeline_parallel_group() # FIXME(tom) handle dynamic group + set_device_based_on_group(group) + + with PipeModel.lock: + callback(ctx, PipeModel) + except Exception as e: + print(f"callback_with_model got {e}") + + +class PipeRPCWrapper(nn.Module): + """A wrapper for Pipe to control the entire pipeline from a single process. + Typical usecase would have rank 0 construct `PipeRPCWrapper` and run the + training loop as normal, and all other ranks would call + `torch.distributed.rpc.shutdown()` + + To run code on each worker, e.g. to run the optimizer, use `foreach_worker` + """ + + def __init__(self, *args: Any, **kwargs: Any): + super().__init__() + self.group = cast(ProcessGroup, kwargs.get("group")) or get_pipeline_parallel_group() + assert self.group.rank() == 0 + self.lock = Lock() + + if True: + assert ( + self.group == get_pipeline_parallel_group() + ), "Can't pickle groups, so group must be `get_pipeline_parallel_group()`" + kwargs["group"] = None + else: + kwargs["group"] = self.group + + kwargs["style"] = Pipe.AsyncSchedule + kwargs["input_device"] = torch.device("cuda", torch.cuda.current_device()) + + self.model = Pipe(*args, **kwargs) + self.worker_map = kwargs["worker_map"] + self._foreach_worker(self._register_remote_model, args=(args, kwargs)) + self.model.cuda() + + def _get_rpc_name(self, rank: int) -> str: + return self.worker_map[_get_global_rank(self.group, rank)] + + def _foreach_worker(self, callback: Callable, args: Any = None) -> None: + futures = [rpc.rpc_async(self._get_rpc_name(rank), callback, args=args) for rank in range(1, self.group.size())] + futures = [f.wait() for f in futures] + + def foreach_worker( + self, callback: Callable[[Any, Pipe], None], ctx: Any = None, *, include_self: bool = False + ) -> None: + """Call `callback` on each worker with the `ctx` and model local to that + worker. e.g. + def register_optimizer(ctx, model): + args, kwargs = ctx + model.optimizer = torch.optim.SGD(model.parameters(), *args, **kwargs) + + pipe_model = PipeRPCWrapper( ... ) + + pipe_model.foreach_worker( + register_optimizer, + ([], {"lr" : 0.01, "momentum" : 0.9}) + ) + """ + + self._foreach_worker(callback_with_model, args=(callback, ctx)) + + if include_self: + with self.model.lock: + callback(ctx, self.model) + + def forward(self, tensor: TensorOrTensors) -> TensorOrTensors: # type: ignore + shape = get_shapes(tensor) + dtype = get_dtype(tensor) + + if isinstance(tensor, torch.Tensor): + num_tensors = 1 + else: + num_tensors = len(tensor) + + futures = [ + rpc.rpc_async(self._get_rpc_name(rank), self._model_forward, args=(self.model.training, shape, dtype)) + for rank in range(1, self.group.size()) + ] + + if self.model.final_stage: + return self.model(tensor) + else: + event = Event() + t = Thread(target=self._model_forward_first_stage, args=(tensor, event)) + t.start() + + shape, dtype = futures.pop().wait() + dest_rank = self.group.size() - 1 + dest = self._get_rpc_name(dest_rank) + dest_global_rank = _get_global_rank(self.group, dest_rank) + src_global_rank = torch.distributed.get_rank() + queue = EVENT_LOOP_QUEUE + + activations = PipeMessage(dest_global_rank, src_global_rank, queue_name=queue, tensor_count=num_tensors) + grads = PipeMessage(src_global_rank, dest_global_rank, queue_name=queue, tensor_count=num_tensors) + + back_fut = rpc.rpc_async( + dest, self._send_result_and_do_backwards, args=(self.model.training, activations, grads) + ) + futures.append(back_fut) + + result = self._recv_result(self.model, shape, dtype, activations) + if isinstance(result, torch.Tensor): + result.requires_grad_() + else: + for r in result: + r.requires_grad_() + + assert self.model.pipeline + return PipeBackRedirect.apply( + result, dest_global_rank, event, grads, self.model.pipeline.transport, futures + ) + + @property + def final_stage(self) -> bool: + return self.model.final_stage + + @staticmethod + def _recv_result(model: Pipe, shapes: SizeOrSizes, dtypes: DtypeOrDtypes, message: PipeMessage) -> TensorOrTensors: + group = get_pipeline_parallel_group() + set_device_based_on_group(group) + + assert model.pipeline + transport = model.pipeline.transport + + if isinstance(shapes, torch.Size): + message.tensor_shapes = [cast(torch.Size, shapes)] + message.tensor_dtypes = [cast(torch.dtype, dtypes)] + message = transport.recv_message_tensors(message) + return message.tensors[0] + else: + message.tensor_shapes = cast(List[torch.Size], shapes) + message.tensor_dtypes = cast(List[torch.dtype], dtypes) + message = transport.recv_message_tensors(message) + return message.tensors + + @staticmethod + def _send_result_and_do_backwards(training: bool, message: PipeMessage, grads_message: PipeMessage) -> None: + group = get_pipeline_parallel_group() + set_device_based_on_group(group) + result = PipeResult + model = PipeModel + + if isinstance(result, torch.Tensor): + result = tuple([result]) + + message.tensors = tuple(result) + assert model.pipeline + transport = model.pipeline.transport + transport.send_message(message, sync=False, skip_header=True) + + if training: + grads_message.tensor_shapes = [r.shape for r in result] + grads_message.tensor_dtypes = [r.dtype for r in result] + grads_message = transport.recv_message_tensors(grads_message) + + with model.lock: + torch.autograd.backward(result, grads_message.tensors, retain_graph=True) + + @staticmethod + def _register_remote_model(args: List[Any], kwargs: Dict[str, Any]) -> None: + group = get_pipeline_parallel_group() # FIXME(tom) handle dynamic group + set_device_based_on_group(group) + kwargs["group"] = group + kwargs["input_device"] = torch.device("cuda", torch.cuda.current_device()) + model = Pipe(*args, **kwargs) + model.cuda() + global PipeModel + PipeModel = model + + @staticmethod + def _model_forward( + training: bool, shape: torch.Size, dtype: torch.dtype + ) -> Optional[Tuple[SizeOrSizes, DtypeOrDtypes]]: + try: + if isinstance(shape, torch.Size): + tensor = torch.empty(shape, dtype=dtype) + else: + tensor = tuple([torch.empty(s, dtype=d) for s, d in zip(shape, dtype)]) + + model = PipeModel + assert model.group + set_device_based_on_group(model.group) + + model.train(training) + result = model(tensor) + if model.final_stage: + global PipeResult + PipeResult = result + return (get_shapes(result), get_dtype(result)) + + return None + except Exception as e: + print(f"_model_forward got {e}") + raise e + + def _model_forward_first_stage(self, tensor: TensorOrTensors, event: Event) -> None: + try: + assert self.model.group + set_device_based_on_group(self.model.group) + self.model(tensor, event=event) + except Exception as e: + print(f"_model_forward got {e}") + raise e diff --git a/fairscale/nn/pipe/types.py b/fairscale/nn/pipe/types.py new file mode 100644 index 000000000..eec479748 --- /dev/null +++ b/fairscale/nn/pipe/types.py @@ -0,0 +1,75 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum, auto +from typing import Any, Callable, List, Optional, Tuple, Union + +from dataclasses import dataclass +import torch +from torch import Tensor, nn + +ACTIVATIONS_GRADS_QUEUE = 0 +SKIP_TENSOR_QUEUE = 1 +PORTAL_QUEUE = 2 +EVENT_LOOP_QUEUE = 3 +MESSAGE_GENERATION_START = 4 + +MessageGeneration = MESSAGE_GENERATION_START + +Tensors = Tuple[Tensor, ...] +TensorOrTensors = Union[Tensor, Tensors] + +InputDevice = Union[None, int, str, torch.device] +Schedule = List[Tuple[int, int]] + + +class LazyModule: + def __init__(self, function: Callable[[], nn.Module]): + self.function = function + + def __call__(self) -> nn.Module: + return self.function() + + +class PipelineStyle(Enum): + SingleProcess = auto() + MultiProcess = auto() + AsyncSchedule = auto() + + +@dataclass(init=False) +class PipeMessage: + src: int + dest: int + queue_name: int + args: Any + tensors: Tensors + tensor_shapes: List[torch.Size] + tensor_dtypes: List[torch.dtype] + tag: int = 0 + + def __init__( + self, + src: int, + dest: int, + queue_name: int, + args: Any = None, + tensors: Optional[Tensors] = None, + tensor_count: int = 0, + ): + self.src = src + self.dest = dest + self.queue_name = queue_name + self.args = args + self.tensors = tensors or tuple() + self.tensor_shapes = [] + self.tensor_dtypes = [] + + global MessageGeneration + self.tag = MessageGeneration + if tensors is None: + MessageGeneration += tensor_count + else: + MessageGeneration += len(self.tensors) diff --git a/fairscale/optim/oss.py b/fairscale/optim/oss.py index d1bdff97a..57d94d740 100644 --- a/fairscale/optim/oss.py +++ b/fairscale/optim/oss.py @@ -422,7 +422,7 @@ def get_global_rank(group: Any, rank: int) -> int: if group is dist.group.WORLD: return rank else: - global_rank = dist.distributed_c10d._get_global_rank(group, rank) # type: ignore + global_rank = dist.distributed_c10d._get_global_rank(group, rank) return global_rank def _broadcast_params(self, buffers: List[torch.Tensor], per_rank_params: List[List[Parameter]]) -> None: diff --git a/fairscale/utils/object.py b/fairscale/utils/object.py new file mode 100644 index 000000000..fbde70e1c --- /dev/null +++ b/fairscale/utils/object.py @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import pickle +from typing import Any + +import torch + + +def pyobject_to_tensor(obj: Any, fixed_buffer_size: int = 0) -> torch.Tensor: + pickled = pickle.dumps(obj) + result: torch.Tensor = torch.ByteTensor(bytearray(pickled)) + if fixed_buffer_size: + delta = fixed_buffer_size - len(result) + if delta < 0: + raise ValueError( + f"message too big to send, increase `fixed_buffer_size`? - {len(result)} > {fixed_buffer_size}" + ) + elif delta > 0: + result = torch.cat((result, torch.zeros(delta, dtype=torch.uint8))) + + return result + + +def tensor_to_pyobject(tensor: torch.Tensor) -> Any: + nparray = tensor.cpu().numpy() + return pickle.loads(nparray.tobytes()) diff --git a/pyproject.toml b/pyproject.toml index 8e9afb654..8010c95e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,4 +28,4 @@ use_parentheses = true skip_glob = ["build/*", "stubs/*"] # Don't split "import" and "from". force_sort_within_sections = true -known_third_party = ["benchmark_dataset", "dataclasses", "numpy", "packaging", "pytest", "recommonmark", "setuptools", "torch", "torchtext", "torchvision"] +known_third_party = ["benchmark_dataset", "dataclasses", "numpy", "packaging", "pytest", "recommonmark", "setuptools", "torch", "torch_pg", "torchtext", "torchvision"] diff --git a/run_mpi_tests.sh b/run_mpi_tests.sh index d3672996c..a5abdbaed 100755 --- a/run_mpi_tests.sh +++ b/run_mpi_tests.sh @@ -1,6 +1,10 @@ #!/bin/bash - set -e -for WORKERS in {1..5}; do - mpirun -n $WORKERS python -m pytest tests/nn/pipe_process +rpc_tests=$(pytest --collect-only | grep 'Function.*rpc' | cut -d' ' -f 6 | tr -d '>') + +for WORKERS in {1..6}; do + mpirun -n $WORKERS -mca orte_base_help_aggregate 0 python -m pytest tests/nn/pipe_process -k "not rpc" + for test_name in $rpc_tests; do + mpirun -n $WORKERS -mca orte_base_help_aggregate 0 python -m pytest tests/nn/pipe_process -k $test_name + done done diff --git a/stubs/torch/__init__.pyi b/stubs/torch/__init__.pyi index 904780526..b999dccbf 100644 --- a/stubs/torch/__init__.pyi +++ b/stubs/torch/__init__.pyi @@ -35,7 +35,7 @@ from . import version #END class dtype: - is_floating_point: bool + is_floating_point: builtins.bool class layout: ... @@ -277,7 +277,7 @@ class Tensor: def atan2(self, other: Tensor) -> Tensor: ... def atan2_(self, other: Tensor) -> Tensor: ... def atan_(self) -> Tensor: ... - def backward(self, gradient: Optional[Tensor]=None, keep_graph: _bool=False, create_graph: _bool=False) -> None: ... + def backward(self, gradient: Optional[Tensor]=None, retain_graph: _bool=False, create_graph: _bool=False) -> None: ... def baddbmm(self, batch1: Tensor, batch2: Tensor, *, beta: Number=1, alpha: Number=1) -> Tensor: ... def baddbmm_(self, batch1: Tensor, batch2: Tensor, *, beta: Number=1, alpha: Number=1) -> Tensor: ... @overload diff --git a/stubs/torch/cuda/__init__.pyi b/stubs/torch/cuda/__init__.pyi index 940eeced4..ed437d685 100644 --- a/stubs/torch/cuda/__init__.pyi +++ b/stubs/torch/cuda/__init__.pyi @@ -29,7 +29,7 @@ _device_t = Union[_device, int, str] def check_error(res: int) -> None: ... def device_count() -> int: ... def empty_cache() -> None: ... -def synchronize(device: _device_t) -> None: ... +def synchronize(device: Optional[_device_t]=None) -> None: ... def set_device(device: _device_t) -> None: ... def get_device_capability(device: Optional[_device_t]=...) -> Tuple[int, int]: ... def get_device_name(device: Optional[_device_t]=...) -> str: ... diff --git a/stubs/torch/distributed/__init__.pyi b/stubs/torch/distributed/__init__.pyi index 160ee999c..71a23ea01 100644 --- a/stubs/torch/distributed/__init__.pyi +++ b/stubs/torch/distributed/__init__.pyi @@ -5,6 +5,7 @@ from torch import Tensor import datetime from . import rpc as rpc +from . import distributed_c10d as distributed_c10d class Backend: GLOO: str diff --git a/stubs/torch/distributed/distributed_c10d.pyi b/stubs/torch/distributed/distributed_c10d.pyi new file mode 100644 index 000000000..b8543cbb8 --- /dev/null +++ b/stubs/torch/distributed/distributed_c10d.pyi @@ -0,0 +1,7 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +from typing import Any, List, Union, Optional + +from . import ProcessGroup + +def _get_global_rank(group: ProcessGroup, rank: int) -> int: ... diff --git a/stubs/torch/distributed/rpc/__init__.pyi b/stubs/torch/distributed/rpc/__init__.pyi index 5267fedd2..d278c5520 100644 --- a/stubs/torch/distributed/rpc/__init__.pyi +++ b/stubs/torch/distributed/rpc/__init__.pyi @@ -1,6 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. from typing import Union, Callable, Optional +from torch.futures import Future class RRef: @@ -17,7 +18,7 @@ def rpc_async( args: Optional[tuple] = None, kwargs: Optional[dict] = None, timeout=-1.0, -) -> None: +) -> Future: ... diff --git a/stubs/torch/futures.pyi b/stubs/torch/futures.pyi new file mode 100644 index 000000000..86b47606c --- /dev/null +++ b/stubs/torch/futures.pyi @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +from typing import Any + +class Future: + def wait(self) -> Any: ... diff --git a/stubs/torch/nn/__init__.pyi b/stubs/torch/nn/__init__.pyi index e186119ef..9dc050501 100644 --- a/stubs/torch/nn/__init__.pyi +++ b/stubs/torch/nn/__init__.pyi @@ -1,6 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from .modules import * +from .modules import * from .parameter import Parameter as Parameter from .parallel import DataParallel as DataParallel from . import functional as functional diff --git a/tests/nn/model_parallel/commons.py b/tests/nn/model_parallel/commons.py index 23c66de5d..26da89753 100644 --- a/tests/nn/model_parallel/commons.py +++ b/tests/nn/model_parallel/commons.py @@ -21,6 +21,7 @@ import functools import inspect +import multiprocessing import os import random @@ -100,17 +101,32 @@ def spawn_for_all_world_sizes(test_func, world_sizes=get_world_sizes(), args=[]) mp.spawn(test_func, args=(world_size, *args), nprocs=world_size, join=True) -def helper(rank, world_size, func, args): +def worker_process(rank, world_size, func, args, error_queue): + """Main function for unit tests launced with torch_spawn""" + dist_init(rank, world_size) - initialize_model_parallel(1, world_size) - func(*args) + kwargs = {} + if "OMPI_COMM_WORLD_RANK" not in os.environ: + kwargs["pipeline_backend"] = "gloo" + initialize_model_parallel(1, world_size, **kwargs) + try: + func(*args) + except BaseException as e: + # If the function raises 'Skipped', this indicates pytest.skip(), so + # forward it to parent so we can call pytest.skip() there + if e.__class__.__name__ == "Skipped": + error_queue.put(str(e)) + return + raise e def torch_spawn(world_sizes=None): if world_sizes is None: world_sizes = get_world_sizes() - def fixer(func): + def prepare_test(func): + """Function called with the test function as the argument. Generates a + replacement which serves as the actual test function.""" name = func.__name__ parameters = inspect.signature(func).parameters @@ -128,21 +144,39 @@ def replacement(*args, **kwargs): kwargs[p] for p in parameters if p != "rank" ) # converting named parameters to positional parameters to pass to `spawn` + error_queue = multiprocessing.get_context("spawn").SimpleQueue() if "OMPI_COMM_WORLD_RANK" in os.environ: + os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"] + os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"] + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "10638" torch.distributed.init_process_group("mpi") world_size = torch.distributed.get_world_size() initialize_model_parallel(1, world_size) torch.cuda.set_device(torch.distributed.get_rank() % torch.cuda.device_count()) if world_size in world_sizes: - func(*args) + try: + func(*args) + except BaseException as e: + print(f"got exception {e} from test") + import traceback + + print(f"{traceback.format_exc()}") + raise e else: pytest.skip(f"requested world size doesn't match current world size") else: - spawn_for_all_world_sizes(helper, world_sizes, (func, args)) + spawn_for_all_world_sizes(worker_process, world_sizes, (func, args, error_queue)) + + if not error_queue.empty(): + msg = error_queue.get() + pytest.skip(msg) + # Register a function with the same name, prefixed with "test_" in the + # calling module, so it will be picked up by pytest caller_module = inspect.getmodule(inspect.currentframe().f_back) setattr(caller_module, f"test_{name}", replacement) return func - return fixer + return prepare_test diff --git a/tests/nn/model_parallel/test_initialize.py b/tests/nn/model_parallel/test_initialize.py index f1df1b598..bc7477771 100644 --- a/tests/nn/model_parallel/test_initialize.py +++ b/tests/nn/model_parallel/test_initialize.py @@ -110,7 +110,7 @@ def is_initialized(self): def get_world_size(self): return data_parallel_size * pipeline_length * model_parallel_size - def new_group(self, args): + def new_group(self, args, backend=None): new_groups.append(args.copy()) return () diff --git a/tests/nn/model_parallel/test_layers.py b/tests/nn/model_parallel/test_layers.py index 2c97c5a7c..835d731ee 100644 --- a/tests/nn/model_parallel/test_layers.py +++ b/tests/nn/model_parallel/test_layers.py @@ -436,6 +436,7 @@ def forward_model(model_, target, step=False): model[2].weight.data = saved_weight_2 worker_map = {i: f"Test{i}" for i in range(torch.distributed.get_world_size())} + style = Pipe.MultiProcess # Pipe.AsyncSchedule if pipe_world_size == 2: print(f"actually doing pipe stuff now") @@ -444,7 +445,7 @@ def forward_model(model_, target, step=False): pipe_model = Pipe( model, [2, 1], - style=Pipe.MultiProcess, + style=style, group=pipeline_devices, worker_map=worker_map, input_device=torch.cuda.current_device(), @@ -511,7 +512,8 @@ def forward_model(model_, target, step=False): failed = False with torch.autograd.profiler.profile() as prof: try: - pipe_model.back_helper(pipe_output) + if style == Pipe.MultiProcess: + pipe_model.back_helper(pipe_output) except Exception as e: failed = True print(f"got {e} while doing backward, deadlock?") @@ -527,6 +529,7 @@ def forward_model(model_, target, step=False): pipe_model.zero_grad() torch.distributed.barrier() + pipe_model.eval() pipe_output = pipe_model(identity()) updated_ref_output = forward_model(reference, target) if torch.distributed.get_rank(mpu.get_pipeline_parallel_group()) == 1: diff --git a/tests/nn/moe/test_moe_layer.py b/tests/nn/moe/test_moe_layer.py index bc569eebb..4dc95b4cd 100644 --- a/tests/nn/moe/test_moe_layer.py +++ b/tests/nn/moe/test_moe_layer.py @@ -23,17 +23,18 @@ os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29501" if "OMPI_COMM_WORLD_SIZE" in os.environ: - dist.init_process_group(backend=dist.Backend.MPI) + pass # dist.init_process_group(backend=dist.Backend.MPI) def setup_module(module): if "OMPI_COMM_WORLD_SIZE" not in os.environ: dist.init_process_group(backend=BACKEND, rank=0, world_size=1) + else: + dist.init_process_group(backend=dist.Backend.MPI) def teardown_module(module): - if "OMPI_COMM_WORLD_SIZE" not in os.environ: - torch.distributed.destroy_process_group() + torch.distributed.destroy_process_group() @pytest.mark.parametrize("device", devices) diff --git a/tests/nn/pipe_process/conftest.py b/tests/nn/pipe_process/conftest.py index ef91fbaf0..82af2e3a0 100644 --- a/tests/nn/pipe_process/conftest.py +++ b/tests/nn/pipe_process/conftest.py @@ -65,3 +65,7 @@ def pytest_runtest_teardown(item): destroy_model_parallel() if torch.distributed.is_initialized(): torch.distributed.destroy_process_group() + try: + torch.distributed.rpc.shutdown() + except Exception: + pass diff --git a/tests/nn/pipe_process/skip/test_gpipe.py b/tests/nn/pipe_process/skip/test_gpipe.py index 4a4510749..dc10c0a19 100644 --- a/tests/nn/pipe_process/skip/test_gpipe.py +++ b/tests/nn/pipe_process/skip/test_gpipe.py @@ -23,7 +23,7 @@ import torch from torch import nn -from fairscale.nn.pipe import Pipe +from fairscale.nn.pipe import LazyModule, Pipe from fairscale.nn.pipe.skip import pop, skippable, stash from fairscale.nn.pipe.skip.portal import PortalBlue, PortalCopy, PortalOrange from tests.nn.model_parallel.commons import get_worker_map, torch_spawn @@ -33,10 +33,15 @@ @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.parametrize("balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"]) @pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi") -def x1to3(balance, checkpoint): +def x1to3(balance, checkpoint, pipeline_style): torch.manual_seed(0) + if pipeline_style == Pipe.AsyncSchedule and len(balance) > 1: + print(f"skipping yarg") + pytest.skip("Skip tensors NYI for AsyncSchedule") + @skippable(stash=["1to3"]) class Layer1(nn.Module): def __init__(self): @@ -75,7 +80,7 @@ def forward(self, input): chunks=3, checkpoint=checkpoint, input_device=torch.cuda.current_device(), - style=Pipe.MultiProcess, + style=pipeline_style, worker_map=get_worker_map(), pipelined_backward=False, ).cuda() @@ -101,7 +106,11 @@ def forward(self, input): @torch_spawn([2]) @pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -def none_skip(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def none_skip(pipeline_style): + if pipeline_style == Pipe.AsyncSchedule: + pytest.skip("Skip tensors NYI for AsyncSchedule") + @skippable(stash=["none"]) class Stash(nn.Module): def forward(self, input): @@ -119,7 +128,7 @@ def forward(self, input): model = Pipe( model, [1, 1], - style=Pipe.MultiProcess, + style=pipeline_style, worker_map=get_worker_map(), input_device=torch.cuda.current_device(), chunks=5, @@ -151,7 +160,8 @@ def assert_grad_fn_is_not_portal(grad_fn, visited=set()): @torch_spawn([2]) -def lazy_skippable_error(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def lazy_skippable_error(pipeline_style): """Using skippable layers in combination with lazy construction is currently not supported, check that it raises an Exception""" @@ -163,9 +173,13 @@ class Layer1(nn.Linear): class Layer3(nn.Linear): pass - model = [lambda: Layer1(10, 10), lambda: nn.Linear(10, 10), lambda: Layer3(10, 10)] + model = [ + LazyModule(lambda: Layer1(10, 10)), + LazyModule(lambda: nn.Linear(10, 10)), + LazyModule(lambda: Layer3(10, 10)), + ] with pytest.raises(ValueError, match="Can't use Skippable layers with multi-process pipe and lazy construction"): Pipe( - model, [2, 1], style=Pipe.MultiProcess, worker_map=get_worker_map(), + model, [2, 1], style=pipeline_style, worker_map=get_worker_map(), ) diff --git a/tests/nn/pipe_process/skip/test_leak.py b/tests/nn/pipe_process/skip/test_leak.py index 78501b531..67fccb009 100644 --- a/tests/nn/pipe_process/skip/test_leak.py +++ b/tests/nn/pipe_process/skip/test_leak.py @@ -46,9 +46,10 @@ def forward(self, input): @torch_spawn([2]) @pytest.mark.parametrize("train", [True, False], ids=["train", "eval"]) @pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"]) +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -def delete_portal_tensor(train, checkpoint): +def delete_portal_tensor(train, checkpoint, pipeline_style): # Without checkpointing: # +- Stash --+ +--- Pop ----+ - - - layers # | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function @@ -59,6 +60,9 @@ def delete_portal_tensor(train, checkpoint): # | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 | # +----------+ +------------+ +------------+ +----------+ + if pipeline_style == Pipe.AsyncSchedule: + pytest.skip("Skip tensors NYI for AsyncSchedule") + def portal_tensor_life_is(tensor_life, skip_tracker=None): if skip_tracker is None: skip_tracker = current_skip_tracker() @@ -111,7 +115,7 @@ def forward(self, input): model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_) model = Pipe( - model, balance=[2, 1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=2, checkpoint=checkpoint, + model, balance=[2, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, checkpoint=checkpoint, ) input = torch.rand(10, requires_grad=True) diff --git a/tests/nn/pipe_process/test_bugs.py b/tests/nn/pipe_process/test_bugs.py index a79e4bd6a..b192704dc 100644 --- a/tests/nn/pipe_process/test_bugs.py +++ b/tests/nn/pipe_process/test_bugs.py @@ -28,7 +28,9 @@ @torch_spawn([2]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -def python_autograd_function(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def python_autograd_function(pipeline_style): + # FIXME deadlock with Pipe.AsyncSchedule? # A Python autograd function might fail with this error: # # RuntimeError: Returning Variables sharing storage with other Variables @@ -55,7 +57,8 @@ def forward(self, input): return Identity.apply(input) model = nn.Sequential(M(), M()) - model = Pipe(model, [1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map(), checkpoint="always").cuda() + model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always").cuda() + model.eval() x = torch.rand(42) y = model(x) @@ -67,7 +70,8 @@ def forward(self, input): @torch_spawn([3]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -def exception_no_hang(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def exception_no_hang(pipeline_style): # In v0.0.2, once a failed partition receives a normal message # (non-closing) for the next micro-batch, a hang occured. The reason was # that a failed partition didn't call in_queue.task_done() on a normal @@ -85,7 +89,8 @@ def forward(self, x): raise ExpectedException() model = nn.Sequential(Pass(), Pass(), Raise()) - model = Pipe(model, [1, 1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=3) + model = Pipe(model, [1, 1, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=3) + model.eval() if model.group.rank() == 2: with pytest.raises(ExpectedException): @@ -98,7 +103,8 @@ def forward(self, x): @torch_spawn([2]) @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 cuda devices required") -def tuple_wait(cuda_sleep): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def tuple_wait(cuda_sleep, pipeline_style): # In v0.0.3, Wait is applied to only the first tensor on a micro-batch. # Under this behavior, if checkpointing was disabled, there's a possibility # that gradient accumulations on other tensors are not synchronized @@ -129,7 +135,7 @@ def forward(self, triple): model = Pipe( model, [1, 1], - style=Pipe.MultiProcess, + style=pipeline_style, worker_map=get_worker_map(), input_device=torch.cuda.current_device(), chunks=32, @@ -151,7 +157,8 @@ def forward(self, triple): @torch_spawn([2]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -def parallel_randoms(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def parallel_randoms(pipeline_style): class Dropouts(nn.Module): def forward(self, x): for _ in range(100): @@ -165,7 +172,7 @@ def forward(self, x): model = Pipe( model, [1, 1], - style=Pipe.MultiProcess, + style=pipeline_style, input_device=torch.cuda.current_device(), worker_map=get_worker_map(), chunks=10, diff --git a/tests/nn/pipe_process/test_inplace.py b/tests/nn/pipe_process/test_inplace.py index 881c6cefb..7cf24558e 100644 --- a/tests/nn/pipe_process/test_inplace.py +++ b/tests/nn/pipe_process/test_inplace.py @@ -27,11 +27,17 @@ @torch_spawn([2]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -def inplace_on_requires_grad(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def inplace_on_requires_grad(pipeline_style): model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True)) - model = Pipe(model, [1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map(), checkpoint="always") + model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always") x = torch.rand(1) + + if pipeline_style == Pipe.AsyncSchedule and model.group.rank() == 0: + # With AsyncSchedule, model will wait forever for gradients if not eval + model.eval() + y = model(x) message = r"a leaf Variable that requires grad .* used in an in-place operation." @@ -44,11 +50,12 @@ def inplace_on_requires_grad(): @torch_spawn([1]) @pytest.mark.xfail(strict=True) -def inplace_on_not_requires_grad(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def inplace_on_not_requires_grad(pipeline_style): # In-place operation on a tensor not requiring grad doesn't cause a # RuntimeError. Currently, we cannot detect this case. model = nn.Sequential(nn.ReLU(inplace=True)) - model = Pipe(model, [1], style=Pipe.MultiProcess, worker_map=get_worker_map(), checkpoint="always") + model = Pipe(model, [1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always") x = torch.rand(1) y = model(x) @@ -63,7 +70,8 @@ def inplace_on_not_requires_grad(): @torch_spawn([1]) @pytest.mark.xfail(strict=True) -def inplace_incorrect_grad(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def inplace_incorrect_grad(pipeline_style): class M(nn.Module): def forward(self, foo_bar): # 'foo' requires grad but 'bar' does not. In-place operation on @@ -80,7 +88,7 @@ def forward(self, foo_bar): return foo * bar model = nn.Sequential(M()) - model = Pipe(model, [1], style=Pipe.MultiProcess, worker_map=get_worker_map(), checkpoint="always") + model = Pipe(model, [1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always") foo = torch.tensor([1.0], requires_grad=True) bar = torch.tensor([1.0]) diff --git a/tests/nn/pipe_process/test_pipe.py b/tests/nn/pipe_process/test_pipe.py index 59ce5d534..e42da9fef 100644 --- a/tests/nn/pipe_process/test_pipe.py +++ b/tests/nn/pipe_process/test_pipe.py @@ -21,21 +21,27 @@ from copy import deepcopy import os import time +from typing import Tuple from packaging import version import pytest import torch from torch import nn -from fairscale.nn.model_parallel.initialize import destroy_model_parallel, initialize_model_parallel -from fairscale.nn.pipe import Pipe -from tests.nn.model_parallel.commons import get_worker_map, torch_spawn +from fairscale.nn.model_parallel.initialize import ( + destroy_model_parallel, + get_pipeline_parallel_group, + initialize_model_parallel, +) +from fairscale.nn.pipe import LazyModule, Pipe +from tests.nn.model_parallel.commons import get_worker_map, set_random_seed, torch_spawn @torch_spawn([2]) -def parameters(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def parameters(pipeline_style): model = nn.Sequential(nn.Linear(1, 1)) - pipe = Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=1) + pipe = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1) if torch.distributed.get_rank() == 0: assert list(pipe.parameters()) != [] else: @@ -62,10 +68,10 @@ def infiniband(): def infiniband2(): if torch.distributed.get_rank() == 0: t = torch.Tensor(range(100)).cuda() - torch.distributed.send(t, 1) + torch.distributed.send(t, 1, group=get_pipeline_parallel_group()) else: t = torch.empty(100).cuda() - torch.distributed.recv(t, 0) + torch.distributed.recv(t, 0, group=get_pipeline_parallel_group()) assert torch.equal(t, torch.Tensor(range(100)).cuda()) print(f"t on {torch.distributed.get_rank()} is {t}") @@ -87,7 +93,6 @@ def mpi(): torch.cuda.manual_seed(seed) torch.distributed.barrier() - group = torch.distributed.new_group([0, 1]) tensor_size = (1024, 1024, 10) torch.cuda.set_device(torch.distributed.get_rank()) # need to pin device or ucx gets unhappy @@ -104,7 +109,8 @@ def mpi(): @torch_spawn([1]) -def public_attrs(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def public_attrs(pipeline_style): class MyString: def __init__(self, value): self.value = value @@ -117,7 +123,7 @@ def __str__(self): pipe = Pipe( model, balance=(1,), - style=Pipe.MultiProcess, + style=pipeline_style, worker_map=get_worker_map(), chunks=42.000, checkpoint=MyString("always"), @@ -134,12 +140,13 @@ def __str__(self): @torch_spawn([2]) @pytest.mark.parametrize("balance", [[2], [1, 1]]) -def sequential_like(balance): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def sequential_like(balance, pipeline_style): a = nn.Linear(1, 1) b = nn.Linear(1, 1) model = nn.Sequential(a, b) - model = Pipe(model, balance, style=Pipe.MultiProcess, worker_map=get_worker_map()) + model = Pipe(model, balance, style=pipeline_style, worker_map=get_worker_map()) if balance == [2]: if torch.distributed.get_rank() == 0: @@ -172,57 +179,62 @@ def sequential_like(balance): @torch_spawn([1]) -def balance_wrong_length(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def balance_wrong_length(pipeline_style): a = nn.Linear(1, 1) b = nn.Linear(1, 1) model = nn.Sequential(a, b) with pytest.raises(ValueError): - Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map()) + Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map()) with pytest.raises(ValueError): - Pipe(model, balance=[3], style=Pipe.MultiProcess, worker_map=get_worker_map()) + Pipe(model, balance=[3], style=pipeline_style, worker_map=get_worker_map()) @torch_spawn([2]) -def balance_less_than_1(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def balance_less_than_1(pipeline_style): a = nn.Linear(1, 1) b = nn.Linear(1, 1) model = nn.Sequential(a, b) with pytest.raises(ValueError): - Pipe(model, balance=[0, 2], style=Pipe.MultiProcess, worker_map=get_worker_map()) + Pipe(model, balance=[0, 2], style=pipeline_style, worker_map=get_worker_map()) with pytest.raises(ValueError): - Pipe(model, balance=[-1, 3], style=Pipe.MultiProcess, worker_map=get_worker_map()) + Pipe(model, balance=[-1, 3], style=pipeline_style, worker_map=get_worker_map()) @torch_spawn([1]) -def chunks_less_than_1(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def chunks_less_than_1(pipeline_style): model = nn.Sequential(nn.Linear(1, 1)) with pytest.raises(ValueError): - Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=0) + Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=0) with pytest.raises(ValueError): - Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=-1) + Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=-1) @torch_spawn([1]) -def too_few_devices(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def too_few_devices(pipeline_style): model = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1)) with pytest.raises(IndexError): # len(balance) > len(group.size()) - model = Pipe(model, balance=[1, 1, 1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map()) + model = Pipe(model, balance=[1, 1, 1, 1], style=pipeline_style, worker_map=get_worker_map()) @torch_spawn([1]) -def batch_size_indivisible(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def batch_size_indivisible(pipeline_style): model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=4) + model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=4) with pytest.warns(None) as record: model(torch.rand(7, 1)) @@ -232,9 +244,10 @@ def batch_size_indivisible(): @torch_spawn([1]) -def batch_size_small(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def batch_size_small(pipeline_style): model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=4) + model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=4) with pytest.warns(None) as record: model(torch.rand(2, 1)) @@ -244,7 +257,8 @@ def batch_size_small(): @torch_spawn([1]) -def checkpoint_mode(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def checkpoint_mode(pipeline_style): def count_grad_fn(grad_fn, name, visited=set()): if grad_fn in visited: return 0 @@ -266,7 +280,7 @@ def count_grad_fn(grad_fn, name, visited=set()): always = Pipe( model, balance=[1], - style=Pipe.MultiProcess, + style=pipeline_style, worker_map=get_worker_map(), chunks=2, checkpoint="always", @@ -275,7 +289,7 @@ def count_grad_fn(grad_fn, name, visited=set()): except_last = Pipe( model, balance=[1], - style=Pipe.MultiProcess, + style=pipeline_style, worker_map=get_worker_map(), chunks=2, checkpoint="except_last", @@ -284,7 +298,7 @@ def count_grad_fn(grad_fn, name, visited=set()): never = Pipe( model, balance=[1], - style=Pipe.MultiProcess, + style=pipeline_style, worker_map=get_worker_map(), chunks=2, checkpoint="never", @@ -301,14 +315,15 @@ def count_grad_fn(grad_fn, name, visited=set()): @torch_spawn([1]) -def checkpoint_mode_invalid(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def checkpoint_mode_invalid(pipeline_style): model = nn.Sequential(nn.Linear(1, 1)) with pytest.raises(ValueError, match="checkpoint is not one of 'always', 'except_last', or 'never'"): Pipe( model, balance=[1], - style=Pipe.MultiProcess, + style=pipeline_style, worker_map=get_worker_map(), chunks=2, checkpoint="INVALID_CHECKPOINT", @@ -316,22 +331,24 @@ def checkpoint_mode_invalid(): @torch_spawn([1]) -def checkpoint_mode_when_chunks_1(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def checkpoint_mode_when_chunks_1(pipeline_style): model = nn.Sequential(nn.Linear(1, 1)) # All checkpoint modes are fine. Pipe( - model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=1, checkpoint="except_last", + model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint="except_last", ) - Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=1, checkpoint="always") - Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=1, checkpoint="never") + Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint="always") + Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint="never") @torch_spawn([1]) -def checkpoint_eval(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def checkpoint_eval(pipeline_style): model = nn.Sequential(nn.Linear(1, 1)) model = Pipe( - model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=2, pipelined_backward=False, + model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, pipelined_backward=False, ) input = torch.rand(2, 1) @@ -356,11 +373,16 @@ def find_grad_fn(grad_fn, name): assert not find_grad_fn(eval_output.grad_fn, "RecomputeBackward") +def torch_version() -> Tuple[int, ...]: + result = version.parse(torch.__version__).release + assert result + return result + + @torch_spawn([2]) -@pytest.mark.xfail( - version.parse(torch.__version__) < version.parse("1.6.0"), reason="Doesn't work on torch < 1.6.0", strict=True -) -def checkpoint_non_float_input(): +@pytest.mark.xfail(torch_version() < (1, 6, 0), reason="Doesn't work on torch < 1.6.0", strict=True) +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def checkpoint_non_float_input(pipeline_style): class ForkNonFloat(nn.Module): def forward(self, input): return (input * 2, torch.tensor([False])) @@ -373,7 +395,7 @@ def forward(self, input): model = Pipe( model, balance=[1, 1], - style=Pipe.MultiProcess, + style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint="always", @@ -385,14 +407,17 @@ def forward(self, input): if model.group.rank() == 1: # with torch.autograd.detect_anomaly(): output.backward() - else: + elif pipeline_style == Pipe.MultiProcess: model.back_helper(output) + torch.distributed.barrier() + @torch_spawn([1]) -def no_grad(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def no_grad(pipeline_style): model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=2) + model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2) input = torch.rand(2, 1) latent = None @@ -404,8 +429,8 @@ def hook(module, input, output): nonlocal latent latent = output - partition = model.partitions[0] - partition.register_forward_hook(hook) + partition = model.mp_partitions[0] + partition.module.register_forward_hook(hook) with torch.no_grad(): model(input) @@ -414,7 +439,8 @@ def hook(module, input, output): @torch_spawn([1]) -def exception(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def exception(pipeline_style): class ExpectedException(Exception): pass @@ -423,7 +449,7 @@ def forward(self, *_): raise ExpectedException() model = nn.Sequential(Raise()) - model = Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=1) + model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1) with pytest.raises(ExpectedException): model(torch.rand(1)) @@ -432,7 +458,8 @@ def forward(self, *_): # FIXME(tom) should probably signal to all hosts in group to stop @torch_spawn([4]) @pytest.mark.xfail(strict=True) -def exception_early_stop_asap(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def exception_early_stop_asap(pipeline_style): """Even the first partitions have finished to process, the partition before the failed partition hould be killed as soon as possible. """ @@ -460,7 +487,7 @@ def forward(self, x): raise ExpectedException() model = nn.Sequential(Pass(), Pass(), Counter(), Raise()) - model = Pipe(model, [1, 1, 1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=3) + model = Pipe(model, [1, 1, 1, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=3) with pytest.raises(ExpectedException): model(torch.rand(3)) @@ -470,7 +497,8 @@ def forward(self, x): @torch_spawn([1]) -def input_pair(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def input_pair(pipeline_style): class Two(nn.Module): def __init__(self): super().__init__() @@ -483,7 +511,7 @@ def forward(self, a_and_b): model = nn.Sequential(Two()) model = Pipe( - model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=2, pipelined_backward=False, + model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, pipelined_backward=False, ) a = torch.rand(10, 1, requires_grad=True) @@ -498,7 +526,8 @@ def forward(self, a_and_b): @torch_spawn([1]) -def input_singleton(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def input_singleton(pipeline_style): class One(nn.Module): def __init__(self): super().__init__() @@ -510,7 +539,7 @@ def forward(self, only_a): model = nn.Sequential(One()) model = Pipe( - model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=2, pipelined_backward=False, + model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, pipelined_backward=False, ) a = torch.rand(10, 1, requires_grad=True) @@ -524,9 +553,10 @@ def forward(self, only_a): @torch_spawn([1]) -def input_varargs(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def input_varargs(pipeline_style): model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map()) + model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map()) a = torch.rand(1) b = torch.rand(1) @@ -537,13 +567,14 @@ def input_varargs(): @torch_spawn([1]) -def non_tensor(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def non_tensor(pipeline_style): class NonTensor(nn.Module): def forward(self, _): return "hello" model = nn.Sequential(NonTensor()) - model = Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map()) + model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map()) x = torch.rand(1) # TypeError: expected Tensor as element 0 in argument 0, but got str @@ -556,13 +587,14 @@ def forward(self, _): @torch_spawn([1]) -def non_tensor_tuple(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def non_tensor_tuple(pipeline_style): class NonTensorTuple(nn.Module): def forward(self, x): return (x, "hello") model = nn.Sequential(NonTensorTuple()) - model = Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map()) + model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map()) x = torch.rand(1) # TypeError: CheckpointBackward.forward: expected Variable (got str) for return value 1 @@ -577,18 +609,19 @@ def forward(self, x): @torch_spawn([1]) @pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) @pytest.mark.parametrize("lazy", [True, False]) -def deferred_batch_norm(checkpoint, lazy): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def deferred_batch_norm(checkpoint, lazy, pipeline_style): bn = nn.BatchNorm2d(3) pipe_bn = deepcopy(bn) pipe_fn = lambda: pipe_bn # noqa: E731 if lazy: - model = [pipe_fn] + model = [LazyModule(pipe_fn)] else: model = nn.Sequential(pipe_bn) pipe = Pipe( model, balance=[1], - style=Pipe.MultiProcess, + style=pipeline_style, worker_map=get_worker_map(), chunks=2, checkpoint=checkpoint, @@ -606,18 +639,19 @@ def deferred_batch_norm(checkpoint, lazy): @torch_spawn([1]) @pytest.mark.parametrize("checkpoint", ["never", "always"]) @pytest.mark.parametrize("lazy", [True, False]) -def deferred_batch_norm_params(checkpoint, lazy): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def deferred_batch_norm_params(checkpoint, lazy, pipeline_style): bn = nn.BatchNorm2d(3) pipe_bn = deepcopy(bn) pipe_fn = lambda: pipe_bn # noqa: E731 if lazy: - model = [pipe_fn] + model = [LazyModule(pipe_fn)] else: model = nn.Sequential(pipe_bn) pipe = Pipe( model, balance=[1], - style=Pipe.MultiProcess, + style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint=checkpoint, @@ -636,14 +670,15 @@ def deferred_batch_norm_params(checkpoint, lazy): @torch_spawn([4]) -def devices(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def devices(pipeline_style): a = nn.Linear(1, 1) b = nn.Linear(1, 1) c = nn.Linear(1, 1) # There are extra two ranks. model = nn.Sequential(a, b, c) - model = Pipe(model, [1, 1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map()) + model = Pipe(model, [1, 1, 1], style=pipeline_style, worker_map=get_worker_map()) # Extra devices must be discarded. if model.group.rank() == 3: @@ -651,28 +686,33 @@ def devices(): @torch_spawn([2]) -def partitions(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def partitions(pipeline_style): a = nn.Linear(1, 1) b = nn.Linear(1, 1) model = nn.Sequential(a, b) - model = Pipe(model, [1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map()) + model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map()) - assert isinstance(model.partitions, nn.ModuleList) + assert isinstance(model.mp_partitions, list) assert len(model) == 1 - assert isinstance(model.partitions[0], nn.Sequential) + assert isinstance(model.mp_partitions[0].module, nn.Sequential) - assert "partitions.0.0.weight" in model.state_dict() + if model.group.rank() == 0: + assert "0.0.weight" in model.state_dict() + else: + assert "0.1.weight" in model.state_dict() @torch_spawn([2]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -def deny_moving(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def deny_moving(pipeline_style): a = nn.Linear(1, 1) b = nn.Linear(1, 1) model = nn.Sequential(a, b) - model = Pipe(model, [1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map()) + model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map()) model.cuda() model.cpu() @@ -690,10 +730,11 @@ def deny_moving(): @torch_spawn([1]) -def empty_module(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def empty_module(pipeline_style): # Empty sequential module is not illegal. model = nn.Sequential() - model = Pipe(model, [], style=Pipe.MultiProcess, worker_map=get_worker_map()) + model = Pipe(model, [], style=pipeline_style, worker_map=get_worker_map()) assert model(torch.tensor([42])) == torch.tensor([42]) assert model((torch.tensor([42]),)) == (torch.tensor([42]),) @@ -705,16 +746,19 @@ def empty_module(): @torch_spawn([2]) -def named_children(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def named_children(pipeline_style): a = nn.Linear(1, 1) b = nn.Linear(1, 1) model = nn.Sequential(OrderedDict([("a", a), ("b", b)])) - model = Pipe(model, [1, 1], devices=["cpu", "cpu"]) + model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map()) names = set(n for n, _ in model.named_modules()) - assert "partitions.0.a" in names - assert "partitions.1.b" in names + if model.group.rank() == 0: + assert "0.a" in names + else: + assert "0.b" in names # Pipe doesn't support __getattr__. Unlike nn.Sequential, Pipe requires # several methods in its namespace. @@ -723,7 +767,8 @@ def named_children(): @torch_spawn([1]) -def recommend_auto_balance(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def recommend_auto_balance(pipeline_style): with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"): # balance is required Pipe(nn.Sequential()) @@ -737,23 +782,9 @@ def recommend_auto_balance(): Pipe(nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1)), [1]) -@torch_spawn([1]) -def verify_module_non_sequential(): - with pytest.raises(TypeError, match="module must be nn.Sequential to be partitioned"): - Pipe(nn.Module(), [1]) - - -@torch_spawn([1]) -def verify_module_duplicate_children(): - conv = nn.Conv2d(3, 3, 1) - model = nn.Sequential(conv, conv) - - with pytest.raises(ValueError, match="module with duplicate children is not supported"): - Pipe(model, [1, 1]) - - @torch_spawn([2]) -def lazy_construction(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def lazy_construction(pipeline_style): init_count = 0 class Custom(nn.Module): @@ -766,13 +797,13 @@ def forward(self, x): return x model = [ - lambda: Custom(), - lambda: Custom(), - lambda: Custom(), - lambda: Custom(), + LazyModule(lambda: Custom()), + LazyModule(lambda: Custom()), + LazyModule(lambda: Custom()), + LazyModule(lambda: Custom()), ] - pipe = Pipe(model, balance=[2, 2], style=Pipe.MultiProcess, worker_map=get_worker_map()) + pipe = Pipe(model, balance=[2, 2], style=pipeline_style, worker_map=get_worker_map()) assert isinstance(pipe[0], Custom) assert isinstance(pipe[1], Custom) @@ -780,18 +811,20 @@ def forward(self, x): assert init_count == 2 -@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="doesn't apply to mpi") @torch_spawn([2]) -def missing_worker_map(): +@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="doesn't apply to mpi") +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def missing_worker_map(pipeline_style): model = nn.Sequential(nn.ReLU(), nn.ReLU()) - with pytest.raises(ValueError, match="'PipelineStyle.MultiProcess' requires 'worker_map' to be set"): - Pipe(model, [1, 1], style=Pipe.MultiProcess) + with pytest.raises(ValueError, match="'RpcTransport' requires 'worker_map' to be set"): + Pipe(model, [1, 1], style=pipeline_style) @torch_spawn([2]) @pytest.mark.skip(reason="currently broken") -def verify_module_duplicate_parameters_on_distinct_partitions(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def verify_module_duplicate_parameters_on_distinct_partitions(pipeline_style): class Surrogate(nn.Module): def __init__(self, module): super().__init__() @@ -802,21 +835,205 @@ def __init__(self, module): # FIXME(tom) can't have duplicate params with separate processes with pytest.raises(ValueError, match="module with duplicate parameters on distinct devices is not supported"): - Pipe(model, [1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map()) + Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map()) @torch_spawn([4]) -def pipelined_backward(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def pipelined_backward(pipeline_style): model = nn.Sequential(nn.ReLU(), nn.ReLU()) destroy_model_parallel() initialize_model_parallel(1, 4) - pipe = Pipe(model, [1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map()) + pipe = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map()) assert pipe.pipelined_backward is False destroy_model_parallel() initialize_model_parallel(2, 2) - pipe = Pipe(model, [1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map()) + pipe = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map()) assert pipe.pipelined_backward is True + + +@torch_spawn([4]) +def async_event_loop(): + + model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10), nn.ReLU()) + pipe = Pipe(model, [1, 1, 1, 1], style=Pipe.AsyncSchedule, worker_map=get_worker_map(), chunks=10) + + inputs = torch.rand(100, 10) + + output = pipe(inputs) + if pipe.final_stage: + loss = output.mean() + loss.backward() + + +@torch_spawn([4]) +def reuse_lazy(): + if False: # speed + reused = LazyModule(lambda: nn.Linear(10, 10)) + model = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()] + # model = [reused, reused, nn.Linear(10, 10), nn.ReLU(), reused, reused, nn.ReLU(), reused, reused, nn.ReLU()] + pipe = Pipe(model, [3, 1, 1], style=Pipe.AsyncSchedule, worker_map=get_worker_map()) + pipe.eval() + output = pipe(torch.rand(10)) + + print(f"output on {pipe.group.rank()}, {output}") + torch.distributed.barrier() + + set_random_seed(1234) + # test both foward + reused = nn.Linear(10, 10) + layers = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()] + model = nn.Sequential(*layers) + model.eval() + + set_random_seed(1234) + # ensure identical weights but no sharing between model and pipe + reused = nn.Linear(10, 10) + layers = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()] + pipe = Pipe(layers, [3, 1, 1], style=Pipe.AsyncSchedule, worker_map=get_worker_map()) + pipe.eval() + model_optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + pipe_optimizer = torch.optim.SGD(pipe.parameters(), lr=0.01, momentum=0.9) if len(list(pipe.parameters())) else None + inputs = torch.rand(10) + if False: # speed + model_out = model(inputs) + pipe_out = pipe(inputs) + + torch.distributed.barrier() + + if pipe.final_stage: + assert torch.equal(model_out, pipe_out) + + model.train() + pipe.train() + model_out = model(inputs) + pipe_out = pipe(inputs) + if pipe.final_stage: + pipe_loss = pipe_out.mean() + pipe_loss.backward() + + model_loss = model_out.mean() + model_loss.backward() + + model_optimizer.step() + if pipe_optimizer: + pipe_optimizer.step() + + model.eval() + pipe.eval() + model_out = model(inputs) + pipe_out = pipe(inputs) + + print(f"before barrier on {torch.distributed.get_rank()}") + torch.distributed.barrier() + print(f"after barrier on {torch.distributed.get_rank()}") + + if pipe.final_stage: + assert torch.equal(model_out, pipe_out) + + +def test_instantiate_partition(): + from fairscale.nn.pipe.async_schedule import Location + from fairscale.nn.pipe.pipe import instantiate_partition + + class FakeGroup: + def __init__(self, rank, size): + self._rank = rank + self._size = size + + def rank(self): + return self._rank + + def size(self): + return self._size + + def check_partitions(model, balance, expected_order, expected_ranks): + """Check the instantiated model matches expectation of order and rank + + model: a list of modules or an nn.Sequential + balance: the balance argument to Pipe + expected_order: the index of modules in `model` in the order they will + be executed, grouped by nn.Sequential + expected_rank: the rank that each module will be executed on + """ + + invocations = [] + invocation_wrapper = dict() + + # Collect `Invocation` and `Invocation` -> `ModuleWrapper` mapping from + # instantiated model + for rank in range(len(balance)): + instantiated = instantiate_partition(model, balance, FakeGroup(rank, len(balance)), Pipe.AsyncSchedule) + for part in instantiated: + assert isinstance(part.module, nn.Sequential) + for inv in part.invocations: + invocations.append(inv) + invocation_wrapper[inv] = part + + modules = [] + prev = None + current = Location(0, 0) + ranks = [] + + for order, inv in enumerate(sorted(invocations, key=lambda x: x.order)): + # Check integrity of Location chain + assert inv.order == order + assert inv.source == prev + assert inv.this == current + prev = inv.this + current = inv.dest + modules.append(list(invocation_wrapper[inv].module.children())) + ranks.append(inv.this.stage) + + # assert len(modules) == len(expected_order) + for left, right in zip(modules, expected_order): + assert len(left) == len(right), f"{right}" + assert list(map(id, left)) == list(map(id, (model[e] for e in right))), f"{right}" + + assert ranks == expected_ranks + + reused = nn.Linear(20, 20) + model = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()] + balance = [3, 1, 1] + + check_partitions( + model, balance, expected_order=[[0], [1, 2], [0], [4], [0], [6]], expected_ranks=[0, 0, 0, 1, 0, 2] + ) + + reused2 = nn.Linear(5, 5) + model = [reused, reused2, nn.Linear(10, 10), nn.ReLU(), reused, reused2, nn.ReLU(), reused, reused2, nn.ReLU()] + balance = [4, 1, 1] + + check_partitions( + model, + balance, + expected_order=[[0], [1], [2, 3], [0], [1], [6], [0], [1], [9]], + expected_ranks=[0, 0, 0, 0, 0, 1, 0, 0, 2], + ) + + reused2 = nn.Linear(5, 5) + model = [ + nn.Linear(10, 10), + reused, + nn.Linear(10, 10), + nn.ReLU(), + reused, + reused2, + nn.ReLU(), + reused, + reused2, + nn.ReLU(), + ] + # 0 1 2 3 1 5 6 1 5 9 + balance = [4, 2, 1] + + check_partitions( + model, + balance, + expected_order=[[0], [1], [2, 3], [1], [5], [6], [1], [5], [9]], + expected_ranks=[0, 0, 0, 0, 1, 1, 0, 1, 2], + ) diff --git a/tests/nn/pipe_process/test_rpc.py b/tests/nn/pipe_process/test_rpc.py new file mode 100644 index 000000000..0f135dd7c --- /dev/null +++ b/tests/nn/pipe_process/test_rpc.py @@ -0,0 +1,254 @@ +import copy +import os + +import pytest +import torch +from torch import nn +from torch.distributed import rpc + +from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group +from fairscale.nn.pipe import PipeRPCWrapper +from tests.nn.model_parallel.commons import get_worker_map, torch_spawn + + +def init_rpc(): + os.environ["MASTER_PORT"] = "10639" + init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}" + rpc.init_rpc( + f"Test{torch.distributed.get_rank()}", + rank=torch.distributed.get_rank(), + world_size=torch.distributed.get_world_size(), + backend=rpc.BackendType.TENSORPIPE, + rpc_backend_options=rpc.TensorPipeRpcBackendOptions(init_method=init_method), + ) + + +@torch_spawn([2]) +@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" not in os.environ, reason="mpi required") +def basic_rpc(): + init_rpc() + if torch.distributed.get_rank() != 0: + rpc.shutdown() + torch.distributed.barrier() + return + + model = [nn.Linear(10, 10), nn.ReLU()] + pipe = PipeRPCWrapper(model, [1, 1], input_device=torch.cuda.current_device(), worker_map=get_worker_map()) + + pipe.foreach_worker(register_optimizer, include_self=True) + + inputs = torch.rand(10).cuda() + output = pipe(inputs) + loss = output.mean() + loss.backward() + + pipe.foreach_worker(step_optimizer, include_self=True) + + pipe.eval() + + rpc.shutdown() + torch.distributed.barrier() + + +def register_optimizer(ctx, model): + if len(list(model.parameters())) > 0: + model.optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + else: + model.optimizer = None + + +def step_optimizer(ctx, model): + if model.optimizer: + model.optimizer.step() + + +def check_pipe_against_reference(balance, model_constructor, checkpoint="except_last", custom_inputs=None): + model = model_constructor() + reference_model = model_constructor() + for src, dst in zip(model, reference_model): + dst.load_state_dict(copy.deepcopy(src.state_dict())) + + reference_model = nn.Sequential(*reference_model).cuda() + + pipe = PipeRPCWrapper( + model, balance, input_device=torch.cuda.current_device(), worker_map=get_worker_map(), checkpoint=checkpoint, + ) + + pipe.foreach_worker(register_optimizer, include_self=True) + register_optimizer(None, reference_model) + + inputs = torch.rand(10).cuda() + target = torch.rand(10).cuda() + cloned = inputs.clone() + output = pipe(inputs) + ref_out = reference_model(inputs) + + assert torch.equal(ref_out.cpu(), output.cpu()) + + for out in output, ref_out: + target = target.to(out.device) + loss = nn.MSELoss()(out, target) + loss.backward() + + pipe.foreach_worker(step_optimizer, include_self=True) + step_optimizer(None, reference_model.cuda()) + + pipe.eval() + reference_model.eval() + + final_output = pipe(inputs) + final_ref = reference_model(inputs.cuda()) + + assert torch.equal(final_output.cpu(), final_ref.cpu()) + + +@torch_spawn([3]) +@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" not in os.environ, reason="mpi required") +def rpc_optimizer(): + + init_rpc() + if torch.distributed.get_rank() != 0: + rpc.shutdown() + torch.distributed.barrier() + return + + def model_with_reuse(): + reused_1 = nn.Linear(10, 10) + return [reused_1, nn.ReLU(), reused_1, nn.ReLU(), reused_1, nn.ReLU()] + + check_pipe_against_reference( + [2, 2, 2], lambda: [nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10), nn.ReLU()], + ) + check_pipe_against_reference([2, 1, 1], model_with_reuse) + + rpc.shutdown() + torch.distributed.barrier() + + +@torch_spawn([6]) +@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" not in os.environ, reason="mpi required") +def rpc_megatron_reuse(): + + from fairscale.nn.model_parallel import layers + from fairscale.nn.model_parallel.initialize import destroy_model_parallel, initialize_model_parallel + + def make_model_simple(): + return [ + layers.ColumnParallelLinear(10, 10), + nn.ReLU(), + layers.RowParallelLinear(10, 10), + nn.ReLU(), + layers.ColumnParallelLinear(10, 10), + nn.ReLU(), + layers.RowParallelLinear(10, 10), + nn.ReLU(), + nn.Linear(10, 10), + nn.ReLU(), + ] + + def make_model_with_reuse(): + column = layers.ColumnParallelLinear(10, 10) + row = layers.RowParallelLinear(10, 10) + return [ + column, + nn.ReLU(), + row, + nn.ReLU(), + column, + nn.ReLU(), + row, + nn.ReLU(), + nn.Linear(10, 10), + nn.ReLU(), + ] + + destroy_model_parallel() + torch.distributed.destroy_process_group() + torch.distributed.init_process_group("gloo", rank=int(os.environ["RANK"]), world_size=int(os.environ["WORLD_SIZE"])) + initialize_model_parallel(2, 3, model_parallel_backend="nccl", pipeline_backend="mpi") + + init_rpc() + if get_pipeline_parallel_group().rank() != 0: + rpc.shutdown() + torch.distributed.barrier() + return + + check_pipe_against_reference([4, 4, 2], make_model_simple, "always") + check_pipe_against_reference([4, 2, 2], make_model_with_reuse) + + rpc.shutdown() + torch.distributed.barrier() + + +@torch_spawn([3]) +@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" not in os.environ, reason="mpi required") +def rpc_reuse_in_final_stage(): + + # 'reused' and 'reused2' are located on stage 2, so the backward pass for + # the final stage will need to first send gradients to stage 2, then receive + # gradients from stage 2. This tests custom logic to handle reuse of layers + # in the final stage of the pipeline. + + reused = nn.Linear(10, 10) + reused2 = nn.Linear(10, 10) + model = [ + nn.Linear(10, 10), + nn.ReLU(), + nn.Linear(10, 10), + reused2, + nn.ReLU(), + reused, + nn.ReLU(), + reused, + reused2, + nn.ReLU(), + reused, + nn.ReLU(), + ] + balance = [2, 3, 4] + + init_rpc() + + if torch.distributed.get_rank() != 0: + rpc.shutdown() + torch.distributed.barrier() + return + + pipe = PipeRPCWrapper(model, balance, worker_map=get_worker_map()) + + inputs = torch.rand(10).cuda() + target = torch.rand(10).cuda() + output = pipe(inputs) + nn.MSELoss()(output, target).backward() + output = pipe(inputs) + nn.MSELoss()(output, target).backward() + rpc.shutdown() + torch.distributed.barrier() + + +@torch_spawn([3]) +@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" not in os.environ, reason="mpi required") +def rpc_multiple_tensors(): + class FuseTwo(nn.Module): + def forward(self, left, right): + return left + right + + class SplitTwo(nn.Module): + def forward(self, inputs): + return (inputs, 2 * inputs) + + +@torch_spawn([2]) +@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="no mpi") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") +def construct_only_rank_zero(): + model = [nn.Linear(10, 10), nn.ReLU()] + if torch.distributed.get_rank() == 0: + PipeRPCWrapper(model, [1, 1], worker_map=get_worker_map()) + rpc.shutdown() + else: + # Must enter rpc loop to complte PipeRPCWrapper constructor above + rpc.shutdown() + + with pytest.raises(AssertionError): + PipeRPCWrapper(model, [1, 1], worker_map=get_worker_map()) diff --git a/tests/nn/pipe_process/test_transparency.py b/tests/nn/pipe_process/test_transparency.py index ad11ac4bb..a1e509924 100644 --- a/tests/nn/pipe_process/test_transparency.py +++ b/tests/nn/pipe_process/test_transparency.py @@ -27,7 +27,8 @@ @torch_spawn([2]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -def simple_linears(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def simple_linears(pipeline_style): def sum_grad(parameters): return sum([p.grad.sum() for p in parameters if p.grad is not None]) @@ -54,19 +55,19 @@ def zero_grad(parameters): zero_grad(model.parameters()) # With Pipe - model = Pipe(model, [2, 2], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=4) + model = Pipe(model, [2, 2], style=pipeline_style, worker_map=get_worker_map(), chunks=4) outputs = model(inputs) if model.group.rank() == 1: loss = outputs.mean() loss.backward() - grad_with_pipe = sum_grad(model.pipeline.partitions[0].parameters()) + grad_with_pipe = sum_grad(model.pipeline.mp_partitions[0].module.parameters()) # Both grads should be identical. assert torch.allclose(grad_with_pipe, grad_without_pipe[1]) else: model.back_helper(outputs) - grad_with_pipe = sum_grad(model.pipeline.partitions[0].parameters()) + grad_with_pipe = sum_grad(model.pipeline.mp_partitions[0].module.parameters()) # Both grads should be identical. assert torch.allclose(grad_with_pipe, grad_without_pipe[0])