From bccb2f7a5681867ada97c89a47ece1b57e766cb2 Mon Sep 17 00:00:00 2001 From: Tom Birch Date: Mon, 5 Oct 2020 11:47:54 -0700 Subject: [PATCH 1/8] initialize-fixes --- fairscale/nn/model_parallel/initialize.py | 26 +++++++++++++++++------ 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/fairscale/nn/model_parallel/initialize.py b/fairscale/nn/model_parallel/initialize.py index 3828b2425..fb488fa1f 100644 --- a/fairscale/nn/model_parallel/initialize.py +++ b/fairscale/nn/model_parallel/initialize.py @@ -22,7 +22,7 @@ """Model and data parallel groups.""" -from typing import List +from typing import List, Optional import torch @@ -38,7 +38,14 @@ _PIPELINE_PARALLEL_RANKS = None -def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int = 1) -> None: +def initialize_model_parallel( + model_parallel_size_: int, + pipeline_length: int = 1, + *, + model_parallel_backend: Optional[str] = None, + pipeline_backend: Optional[str] = None, + ddp_backend: Optional[str] = None +) -> None: """ Initialize model data parallel groups. @@ -57,8 +64,8 @@ def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int = with a total of 16 GPUs, rank 0 to 7 belong to the first box and ranks 8 to 15 belong to the second box. """ - if torch.distributed.get_rank() == 0: - print("> initializing model parallel with size {}".format(model_parallel_size_)) + # if torch.distributed.get_rank() == 0: + # print("> initializing model parallel with size {}".format(model_parallel_size_)) # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size = torch.distributed.get_world_size() @@ -69,6 +76,11 @@ def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int = data_parallel_size = int(world_size / (model_parallel_size * pipeline_length)) + if torch.distributed.get_rank() == 0: + print("> initializing model parallel with size {}".format(model_parallel_size_)) + print("> initializing ddp with size {}".format(data_parallel_size)) + print("> initializing pipeline with size {}".format(pipeline_length)) + groups = torch.LongTensor(range(world_size)).reshape(data_parallel_size, pipeline_length, model_parallel_size) found = torch.where(groups == rank) @@ -80,7 +92,7 @@ def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int = assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized" for j in range(pipeline_length): for k in range(model_parallel_size): - group = torch.distributed.new_group(groups[:, j, k].tolist()) + group = torch.distributed.new_group(groups[:, j, k].tolist(), backend=ddp_backend) if j == found[1] and k == found[2]: _DATA_PARALLEL_GROUP = group @@ -89,7 +101,7 @@ def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int = assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized" for i in range(data_parallel_size): for j in range(pipeline_length): - group = torch.distributed.new_group(groups[i, j, :].tolist()) + group = torch.distributed.new_group(groups[i, j, :].tolist(), backend=model_parallel_backend) if i == found[0] and j == found[1]: _MODEL_PARALLEL_GROUP = group @@ -100,7 +112,7 @@ def initialize_model_parallel(model_parallel_size_: int, pipeline_length: int = for i in range(data_parallel_size): for k in range(model_parallel_size): ranks = groups[i, :, k].tolist() - group = torch.distributed.new_group(ranks) + group = torch.distributed.new_group(ranks, backend=pipeline_backend) if i == found[0] and k == found[2]: _PIPELINE_PARALLEL_GROUP = group _PIPELINE_PARALLEL_RANKS = ranks From e9bbbf9e89d5c120b74a856aadb37a907d57169b Mon Sep 17 00:00:00 2001 From: Tom Birch Date: Fri, 18 Sep 2020 09:29:23 -0700 Subject: [PATCH 2/8] Add AsyncSchedule method --- benchmarks/pipe.py | 156 +++++- docs/Makefile | 5 +- fairscale/nn/__init__.py | 4 +- fairscale/nn/model_parallel/mappings.py | 3 +- fairscale/nn/model_parallel/random.py | 11 +- fairscale/nn/pipe/__init__.py | 5 +- fairscale/nn/pipe/async_schedule.py | 492 +++++++++++++++++++ fairscale/nn/pipe/messages.py | 137 ++++++ fairscale/nn/pipe/pipe.py | 286 ++++++++--- fairscale/nn/pipe/pipeline.py | 383 ++++++--------- fairscale/nn/pipe/rpc.py | 313 ++++++++++++ fairscale/nn/pipe/types.py | 65 +++ fairscale/optim/oss.py | 2 +- run_mpi_tests.sh | 10 +- stubs/torch/__init__.pyi | 4 +- stubs/torch/cuda/__init__.pyi | 2 +- stubs/torch/distributed/__init__.pyi | 1 + stubs/torch/distributed/distributed_c10d.pyi | 7 + stubs/torch/distributed/rpc/__init__.pyi | 3 +- stubs/torch/futures.pyi | 6 + stubs/torch/nn/__init__.pyi | 2 +- tests/nn/model_parallel/commons.py | 34 +- tests/nn/model_parallel/test_initialize.py | 2 +- tests/nn/model_parallel/test_layers.py | 27 +- tests/nn/moe/test_moe_layer.py | 4 +- tests/nn/pipe_process/conftest.py | 4 + tests/nn/pipe_process/skip/test_gpipe.py | 30 +- tests/nn/pipe_process/skip/test_leak.py | 8 +- tests/nn/pipe_process/test_bugs.py | 23 +- tests/nn/pipe_process/test_inplace.py | 20 +- tests/nn/pipe_process/test_pipe.py | 445 ++++++++++++----- tests/nn/pipe_process/test_rpc.py | 278 +++++++++++ tests/nn/pipe_process/test_transparency.py | 9 +- 33 files changed, 2263 insertions(+), 518 deletions(-) create mode 100644 fairscale/nn/pipe/async_schedule.py create mode 100644 fairscale/nn/pipe/messages.py create mode 100644 fairscale/nn/pipe/rpc.py create mode 100644 fairscale/nn/pipe/types.py create mode 100644 stubs/torch/distributed/distributed_c10d.pyi create mode 100644 stubs/torch/futures.pyi create mode 100644 tests/nn/pipe_process/test_rpc.py diff --git a/benchmarks/pipe.py b/benchmarks/pipe.py index b67cb8169..21b9b1439 100644 --- a/benchmarks/pipe.py +++ b/benchmarks/pipe.py @@ -11,16 +11,26 @@ from torch.distributed import rpc import torch.multiprocessing as mp import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data import DataLoader import torchtext from torchtext.data.utils import get_tokenizer +# from deepspeed.pipe import PipelineModule from fairscale.nn import Pipe from fairscale.nn.model_parallel import initialize_model_parallel -from fairscale.nn.pipe import pipe +from fairscale.nn.model_parallel.initialize import get_data_parallel_group, get_pipeline_parallel_group +from fairscale.nn.pipe import LazyModule, pipe from fairscale.optim import GradScaler +from fairscale.optim.oss import OSS from tests.nn.model_parallel.commons import dist_init, get_worker_map +try: + import torch_ucc # noqa: F401 +except ImportError as e: + print(f"can't import torch_ucc: {e}") + pass + try: from fairscale.optim import Adam # type: ignore @@ -164,13 +174,13 @@ def make_model(args, device, ntokens): if args.lazy_construction: layers = [ - lambda: EmbeddingLayer(ntokens, ninp, initrange), - lambda: PositionalEncodingLayer(ninp, dropout), + LazyModule(lambda: EmbeddingLayer(ntokens, ninp, initrange)), + LazyModule(lambda: PositionalEncodingLayer(ninp, dropout)), ] for _ in range(ndecoder): - layers.append(lambda: TransformerDecoderLayer(ninp, nhead, nhid, dropout)) + layers.append(LazyModule(lambda: TransformerDecoderLayer(ninp, nhead, nhid, dropout))) - layers.append(lambda: LinearLayer(ninp, ntokens, initrange)) + layers.append(LazyModule(lambda: LinearLayer(ninp, ntokens, initrange))) model = layers else: model = TransformerLMSequntial(ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device) @@ -179,7 +189,10 @@ def make_model(args, device, ntokens): lr = 0.01 # learning rate def make_adam(model): - return Adam(model.parameters(), lr=lr) + if args.ddp_zero: + return OSS(params=model.parameters(), optim=Adam, group=get_data_parallel_group(), lr=lr) + else: + return Adam(model.parameters(), lr=lr) optimizer = make_adam scaler = GradScaler() @@ -276,7 +289,15 @@ def train(lm_dataloader, model, criterion, optimizer, vocab_size, args): num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters())) if model.group: - print(f"training model, #prams = {num_params}, group: {model.group.rank()}, sizes {model.group.size()}") + total = torch.Tensor([num_params]).cuda() + torch.distributed.all_reduce(total, group=model.group) + print( + f"training model, #prams = {num_params}, group: {model.group.rank()}, grank:" + f" {torch.distributed.get_rank()}, sizes {model.group.size()}" + ) + torch.distributed.barrier() + if model.group.rank() == 0: + print(f"total #prams = {total.item()}") else: print(f"training model, #prams = {num_params}") vocab_size = 10000 # FIXME @@ -287,37 +308,94 @@ def train(lm_dataloader, model, criterion, optimizer, vocab_size, args): optimizer = optimizer(model) def get_first_device(model): + if isinstance(model, DDP): + model = model.module + if model.devices: return model.devices[0] else: return torch.cuda.current_device() def get_last_device(model): + if isinstance(model, DDP): + model = model.module if model.devices: return model.devices[-1] else: return torch.cuda.current_device() + pipe_group = model.group + + if pipe_group is None or pipe_group.rank() == 0: + print(f">> Init DDP") + + if args.ddp_zero: + model = DDP( + model, + device_ids=[torch.cuda.current_device()], + process_group=get_data_parallel_group(), + find_unused_parameters=False, + ) + + if pipe_group is None or pipe_group.rank() == 0: + print(f"<< Init DDP") + + if pipe_group and pipe_group.rank() != 0 and pipe_group.rank() != (pipe_group.size() - 1): + thing = {"input": torch.zeros(args.batch_size)} + + class FakeDataset: + def __getitem__(self, index): + return thing + + def __len__(self): + return len(lm_dataloader) + + lm_dataloader = FakeDataset() + for i, batch in enumerate(lm_dataloader): + bi = batch["input"] + # print(f"batch size: {torch.numel(bi)}, {bi.size()}, {bi.device}") if args.max_batch and i > args.max_batch: break optimizer.zero_grad() - output = model(batch["input"].to(get_first_device(model))) - - if model.group is None or model.group.rank() == model.group.size() - 1: - target = batch["target"].to(get_last_device(model)) - output = output.to(target.device) + try: + if (pipe_group is None or pipe_group.rank() == 0) and not args.ddp_zero: + tmp = batch["input"].to(get_first_device(model)) + output = model(tmp) + else: + output = model(batch["input"]) + except Exception as e: + print(f"exception while training on rank {torch.distributed.get_rank()}") + raise RuntimeError(f"training failed on {torch.distributed.get_rank()}") from e + + if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1: + if True: + target = batch["target"].to(get_last_device(model)) + output = output.to(target.device) + else: + target = batch["target"].cpu() + output = output.cpu() + + print(f"output size is {output.size()}") loss = criterion(output.view(-1, vocab_size), target.view(-1)) + if args.ddp_zero: + ddp_group = get_data_parallel_group() + torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM, group=ddp_group) + loss /= ddp_group.size() loss.backward() + del target else: - model.back_helper(output) + if args.ddp_zero: + model.module.back_helper(output) + else: + model.back_helper(output) del output torch.nn.utils.clip_grad_value_(model.parameters(), 0.05) optimizer.step() - if model.group is None or model.group.rank() == model.group.size() - 1: + if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1: total_loss += loss.item() log_interval = 1 word_counter += batch["ntokens"] @@ -406,6 +484,17 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion, print("No regression detected") +def generate_balance_weighted(num_devices, num_layers, fraction=0.5): + balance = [] + layers_assigned = 0 + average_count = num_layers / num_devices + last_layers = int(average_count * fraction) + + balance = generate_balance(num_devices - 1, num_layers - last_layers) + balance.append(last_layers) + return balance + + def generate_balance(num_devices, num_layers): balance = [] layers_assigned = 0 @@ -460,10 +549,13 @@ def bench_single_process(args): blob = make_model_and_data(args, None, new_data=new_data) model = blob["model"] - balance = generate_balance(min(num_devices, 8), len(model)) - p = pipe.Pipe( - model, balance, chunks=args.chunks, pipelined_backward=args.pipelined_backward, checkpoint=args.checkpoint - ) + if args.deepspeed: + p = PipelineModule(layers=model, num_stages=min(num_devices, 4)) + else: + balance = generate_balance(min(num_devices, 4), len(model)) + p = pipe.Pipe( + model, balance, chunks=args.chunks, pipelined_backward=args.pipelined_backward, checkpoint=args.checkpoint + ) del model del blob["model"] @@ -480,16 +572,17 @@ def run_mp_worker(args, available_workers): blob = make_model_and_data(args, None, new_data=new_data) model = blob["model"] - balance = generate_balance(min(available_workers, 8), len(model)) + balance = generate_balance_weighted(get_pipeline_parallel_group().size(), len(model), 0.8) p = pipe.Pipe( model, balance, - style=Pipe.MultiProcess, + style=Pipe.AsyncSchedule, chunks=args.chunks, worker_map=get_worker_map(), input_device=torch.cuda.current_device(), pipelined_backward=args.pipelined_backward, checkpoint=args.checkpoint, + # loss_fn=blob["criterion"], ).cuda() if args.all_at_once and p.pipeline: @@ -537,18 +630,24 @@ def bench_multi_process(args, all_at_once=False): def bench_mpi(args): guess_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) - os.environ["UCX_NET_DEVICES"] = best_device_map[guess_rank] + world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) + local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) + os.environ["UCX_NET_DEVICES"] = best_device_map[local_rank] - torch.distributed.init_process_group(backend="mpi") os.environ["MASTER_ADDR"] = args.host - os.environ["MASTER_PORT"] = "10639" + os.environ["MASTER_PORT"] = "10638" if args.socket_name: os.environ["GLOO_SOCKET_IFNAME"] = args.socket_name os.environ["TP_SOCKET_IFNAME"] = args.socket_name + + torch.distributed.init_process_group(backend="gloo", rank=guess_rank, world_size=world_size) + + os.environ["MASTER_ADDR"] = args.host + os.environ["MASTER_PORT"] = "10639" init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}" rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() - torch.cuda.set_device(rank % torch.cuda.device_count()) + torch.cuda.set_device(local_rank % torch.cuda.device_count()) rpc.init_rpc( f"Test{rank}", @@ -558,7 +657,12 @@ def bench_mpi(args): rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(rpc_timeout=20, init_method=init_method), ) - initialize_model_parallel(1, world_size) + backends = {"model_parallel_backend": "nccl", "pipeline_backend": "mpi", "ddp_backend": "nccl"} + + if args.ddp_zero: + initialize_model_parallel(1, 4, **backends) + else: + initialize_model_parallel(1, world_size, **backends) init_random_seed(0) run_mp_worker(args, world_size) @@ -579,6 +683,8 @@ def bench_mpi(args): parser.add_argument("--max-batch", type=int, default=4, help="Max number of batches") parser.add_argument("--socket-name", type=str, default=None, help="socket ifname for gloo/tp") parser.add_argument("--num-decoder-layers", type=int, default=10, help="Number of decoder layers in the model") +parser.add_argument("--ddp-zero", action="store_true", default=False, help="enable ddp") +parser.add_argument("--deepspeed", action="store_true", default=False, help="use eepspeed instead of fairscale pipe") parser.add_argument( "--lazy-construction", action="store_true", default=False, help="Number of decoder layers in the model" ) diff --git a/docs/Makefile b/docs/Makefile index d0c3cbf10..008578dde 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -12,7 +12,10 @@ BUILDDIR = build help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -.PHONY: help Makefile +setup: + pip install -r requirements.txt + +.PHONY: help Makefile setup # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). diff --git a/fairscale/nn/__init__.py b/fairscale/nn/__init__.py index 39e347057..690e26fc5 100644 --- a/fairscale/nn/__init__.py +++ b/fairscale/nn/__init__.py @@ -4,6 +4,6 @@ # LICENSE file in the root directory of this source tree. from .moe import MOELayer, Top2Gate -from .pipe import Pipe +from .pipe import LazyModule, Pipe -__all__ = ["Pipe", "Top2Gate"] +__all__ = ["Pipe", "Top2Gate", "LazyModule"] diff --git a/fairscale/nn/model_parallel/mappings.py b/fairscale/nn/model_parallel/mappings.py index 4afabe686..c91c01cf5 100644 --- a/fairscale/nn/model_parallel/mappings.py +++ b/fairscale/nn/model_parallel/mappings.py @@ -39,8 +39,9 @@ def _reduce(ctx: Any, input_: torch.Tensor) -> torch.Tensor: return input_ # All-reduce. - print(f"doing all_reduce on {torch.distributed.get_rank()}") + print(f">> doing all_reduce on {torch.distributed.get_rank()}") torch.distributed.all_reduce(input_, group=group) + print(f"<< doing all_reduce on {torch.distributed.get_rank()}") return input_ diff --git a/fairscale/nn/model_parallel/random.py b/fairscale/nn/model_parallel/random.py index 6cc2ca1b2..e0775ed8a 100644 --- a/fairscale/nn/model_parallel/random.py +++ b/fairscale/nn/model_parallel/random.py @@ -182,11 +182,12 @@ def model_parallel_cuda_manual_seed(seed: int) -> None: ), flush=True, ) - _CUDA_RNG_STATE_TRACKER.reset() - # Set the default state. - torch.cuda.manual_seed(data_parallel_seed) - # and model parallel state. - _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, model_parallel_seed) + if torch.cuda.is_available(): + _CUDA_RNG_STATE_TRACKER.reset() + # Set the default state. + torch.cuda.manual_seed(data_parallel_seed) + # and model parallel state. + _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, model_parallel_seed) class CheckpointFunction(torch.autograd.Function): diff --git a/fairscale/nn/pipe/__init__.py b/fairscale/nn/pipe/__init__.py index 03e779303..75591bf09 100644 --- a/fairscale/nn/pipe/__init__.py +++ b/fairscale/nn/pipe/__init__.py @@ -19,6 +19,7 @@ """A Pipe implementation in PyTorch.""" from .checkpoint import is_checkpointing, is_recomputing -from .pipe import Pipe +from .pipe import LazyModule, Pipe +from .rpc import PipeRPCWrapper -__all__ = ["Pipe", "is_checkpointing", "is_recomputing"] +__all__ = ["Pipe", "is_checkpointing", "is_recomputing", "LazyModule"] diff --git a/fairscale/nn/pipe/async_schedule.py b/fairscale/nn/pipe/async_schedule.py new file mode 100644 index 000000000..5d3e72869 --- /dev/null +++ b/fairscale/nn/pipe/async_schedule.py @@ -0,0 +1,492 @@ +from enum import Enum, auto +from threading import Event +from typing import Dict, Iterable, List, Optional, Tuple + +from dataclasses import dataclass +import torch +from torch import Tensor, nn +from torch.distributed import ProcessGroup + +from fairscale.nn.model_parallel import get_pipeline_parallel_group, get_pipeline_parallel_ranks + +from .messages import MESSAGE_TENSOR_SIZE, MessageQueues, send_message, tensor_to_pyobject, to_input_device +from .microbatch import Batch +from .skip.tracker import SkipTrackerThroughPotals +from .types import EVENT_LOOP_QUEUE, InputDevice, PipelineStyle, PipeMessage, Tensors, TransportConfig + +Activations = Dict[int, Dict[int, Dict[int, Batch]]] + + +def dprint(x: str) -> None: + pass + + +@dataclass(frozen=True) +class Location: + stage: int + index: int + + def __repr__(self) -> str: + return f"{self.stage}@{self.index}" + + +@dataclass(frozen=True) +class Invocation: + order: int + this: Location + source: Optional[Location] + dest: Optional[Location] + + +class ModuleWrapper: + def __init__(self, module: nn.Sequential, location: Location, invocations: Optional[List[Invocation]] = None): + self.module: nn.Sequential = module + self.location: Location = location + self.invocations: List[Invocation] = invocations or [] + + def __repr__(self) -> str: + return f"{self.location}:\n" + "\n".join(map(str, self.invocations)) + "\n\t" + str(self.module) + + def __len__(self) -> int: + return len(self.module) + + def __iter__(self) -> Iterable: + yield from self.module + + +class AsyncMessageType(Enum): + Activations = auto() + Gradients = auto() + + +@dataclass(frozen=True) +class AsyncMessageBody: + message_type: AsyncMessageType + microbatch_index: int + source: Location + dest: Location + order: int + + +class Hackity(torch.autograd.Function): + @staticmethod + # type: ignore + def forward(ctx, *x): + return torch.tensor(1.0) + + @staticmethod + # type: ignore + def backward(ctx, grad): + assert ctx.grad_from_pipeline is not None + return ctx.grad_from_pipeline + + +def recv_async_tensors( + rank: int, input_device: InputDevice, config: TransportConfig, message: PipeMessage +) -> PipeMessage: + if config.use_rpc: + # Tensors already contained within message + message.tensors = to_input_device(message.tensors, input_device) + dprint(f"recv_async_tensors {torch.distributed.get_rank()}, {len(message.tensors)}") + return message + else: + torch.cuda.current_stream().synchronize() + + message_tensors = [] + for index, (shape, dtype) in enumerate(zip(message.tensor_shapes, message.tensor_dtypes)): + t = torch.empty(*shape, dtype=dtype, device=input_device) + torch.distributed.recv(t, message.src, tag=message.tag + index, group=get_pipeline_parallel_group()) + message_tensors.append(t) + + message.tensors = tuple(message_tensors) + + torch.cuda.current_stream().synchronize() + return message + + +class AsyncRecvOperator(torch.autograd.Function): + """Receive activations to the previous pipeline stage""" + + @staticmethod + # type: ignore + def forward( + ctx, dst_rank: int, phony: Tensor, input_device, config: TransportConfig, message: PipeMessage + ) -> Tensors: + assert dst_rank == torch.distributed.get_rank() + ctx.config = config + ctx.index = message.args.microbatch_index + + result = recv_async_tensors(dst_rank, input_device, config, message) + + ctx.args = result.args + + def maybe_requires_grad(t: Tensor) -> Tensor: + if t.dtype.is_floating_point: + return t.requires_grad_() + return t + + return tuple(maybe_requires_grad(r) for r in result.tensors) + + @staticmethod + # type: ignore + def backward(ctx, *grad: Tensor,) -> Tuple[Optional[Tensor], ...]: + ranks = get_pipeline_parallel_ranks() + this_rank = torch.distributed.get_rank() + dprint(f"AsyncRecvOperator back {this_rank} {len(grad)}, {ctx.args}") + # Note that dst/source are swaped coz in backward pass, maybe abstract + # this out? + body = AsyncMessageBody( + AsyncMessageType.Gradients, ctx.index, source=ctx.args.dest, dest=ctx.args.source, order=ctx.args.order - 1 + ) + dprint(f"AsyncRecvOperator 2 back {this_rank} {len(grad)}") + send_message( + ctx.config, + PipeMessage( + this_rank, ranks[ctx.args.source.stage], queue_name=EVENT_LOOP_QUEUE, args=body, tensors=tuple(grad), + ), + sync=True, + ) + dprint(f"AsyncRecvOperator 3 back {this_rank} {len(grad)}") + return (None, None, None, None, None) + + +def recv_async_header(transport_config: TransportConfig, input_device: InputDevice) -> PipeMessage: + if transport_config.use_rpc: + queue = MessageQueues[EVENT_LOOP_QUEUE] + result = queue.get() + result.tensors = to_input_device(result.tensors, input_device) + return result + else: + dprint(f"cactus") + tensor = torch.empty(MESSAGE_TENSOR_SIZE, dtype=torch.uint8, device=input_device) + torch.cuda.current_stream().synchronize() + torch.distributed.recv(tensor, src=None, tag=EVENT_LOOP_QUEUE, group=get_pipeline_parallel_group()) + torch.cuda.current_stream().synchronize() + dprint(f"cactus2") + return tensor_to_pyobject(tensor.cpu()) + + +class AsyncEventLoop: + def __init__( + self, + partitions: List[ModuleWrapper], + group: ProcessGroup, + transport_config: TransportConfig, + training: bool, + input_device: InputDevice, + checkpoint_stop: int, + ): + self.training = training + self.input_device = input_device + self.checkpoint_stop = checkpoint_stop + self.transport_config = transport_config + self.group = group + self.partitions: List[ModuleWrapper] = partitions + + def send_async_message( + self, src_rank: int, dst_rank: int, input: List[Tensor], index: int, invocation: Invocation + ) -> None: + assert src_rank == torch.distributed.get_rank() + assert invocation.dest + + body = AsyncMessageBody( + AsyncMessageType.Activations, index, invocation.this, invocation.dest, invocation.order + 1 + ) + dprint(f">>> send batch {src_rank} {dst_rank} {len(input)} {invocation.order}") + send_message( + self.transport_config, + PipeMessage(src_rank, dst_rank, queue_name=EVENT_LOOP_QUEUE, args=body, tensors=tuple(input)), + sync=True, + ) + dprint(f"<<< send batch {src_rank} {dst_rank} {len(input)} {invocation.order}") + + def async_send_inner( + self, + batch: Batch, + partition: ModuleWrapper, + index: int, + skip_trackers: List[SkipTrackerThroughPotals], + invocation: Invocation, + ) -> Batch: + assert self.group + from .pipeline import create_task + + task = create_task( + PipelineStyle.AsyncSchedule, + self.checkpoint_stop, + index, + self.group.rank(), + batch, + partition.module, + skip_trackers, + [], + ) + result = task.compute() + task.finalize(result) + + if invocation.dest and invocation.dest.stage != invocation.this.stage: + ranks = get_pipeline_parallel_ranks() + this_rank = torch.distributed.get_rank() + + # self.send_skip_tensors(this_rank, ranks, batch, i, skip_trackers) + dprint(f"sending to next stage from {this_rank}...{invocation}, {index}") + self.send_async_message(this_rank, ranks[invocation.dest.stage], [*result], index, invocation) + z = Hackity.apply(*result) + result = Batch(z, result.index) + dprint(f"empty yay!") + else: + dprint(f"not sending to next stage...{invocation.this}, {invocation.dest}") + return result + + def async_grad_inner(self, message: PipeMessage, activations: Activations, invocation: Invocation) -> None: + args: AsyncMessageBody = message.args + if self.transport_config.use_rpc: + recvd_grads = message + else: + recvd_grads = recv_async_tensors( + torch.distributed.get_rank(), self.input_device, self.transport_config, message + ) + + # FIXME tom + + batch: Batch = activations[invocation.this.index][invocation.order][args.microbatch_index] + + try: + batch.tensor.grad_fn.grad_from_pipeline = tuple(recvd_grads.tensors) # type: ignore + batch.tensor.backward(retain_graph=True) + return + except Exception as e: + print(f"hackity fail {e}") + raise e + + def process_batch_forward( + self, + batch: Batch, + i: int, + invocations: List[Invocation], + order: int, + skip_trackers: List[SkipTrackerThroughPotals], + activations: Activations, + ) -> Tuple[int, int]: + invocations_handled = 0 + last_order = 0 + for invocation in invocations: + if invocation.order < order: + continue + pi = invocation.this.index + partition = self.partitions[pi] + + dprint(f"{self.group.rank()}: pbb {invocation}, {order}, {self.group.rank()}") + if invocation.order == order: + dprint(f"{self.group.rank()}: assigning {pi}, {invocation.order}, {i}") + invocations_handled += 1 + last_order = invocation.order + activations[pi][invocation.order][i] = self.async_send_inner( + batch, partition, i, skip_trackers, invocation + ) + elif invocation.source and invocation.source.stage == self.group.rank(): + dprint( + f"{self.group.rank()}: reading {invocation}, {invocation.source.index}, {invocation.order-1}, {i}" + ) + invocations_handled += 1 + last_order = invocation.order + batch = activations[invocation.source.index][invocation.order - 1][i] + dprint(f"{self.group.rank()}: assigning {pi}, {invocation.order}, {i}") + activations[pi][invocation.order][i] = self.async_send_inner( + batch, partition, i, skip_trackers, invocation + ) + del activations[invocation.source.index][invocation.order - 1][i] + + elif invocation.source and invocation.source.stage != self.group.rank(): + break + + dprint(f"pbb {self.group.rank()} {invocations_handled}") + return (invocations_handled, last_order) + + def event_loop_head( + self, batches: List[Batch], skip_trackers: List[SkipTrackerThroughPotals], event: Optional[Event] + ) -> None: + + invocations, activations = self.get_sorted_invocations_and_activations() + + expected_invocations = len(invocations) * len(batches) + actual_invocations = 0 + + count_per_order = dict() + + dprint(f"head loop start {torch.distributed.get_rank()}") + for i, batch in enumerate(batches): + dprint(f"head loop iter {torch.distributed.get_rank()}, {i}") + inv_count, last_order = self.process_batch_forward(batch, i, invocations, 0, skip_trackers, activations) + actual_invocations += inv_count + count_per_order[last_order] = inv_count + + dprint(f"head wat {actual_invocations}, {expected_invocations}") + if actual_invocations < expected_invocations or self.training: + dprint(f"head extra {actual_invocations}, {expected_invocations}") + self.event_loop_inner( + expected_invocations, + skip_trackers, + activations, + invocations, + count_per_order, + already_received=actual_invocations, + event=event, + ) + + # if self.pipeline.training: + # for _ in range(len(batches)): + # message = self.recv_async_header() + # args: AsyncMessageBody = message.args + # assert args.message_type is AsyncMessageType.Gradients + # self.async_grad_inner(message, activations) + + def event_loop_tail(self, batches: List[Batch], skip_trackers: List[SkipTrackerThroughPotals]) -> None: + assert self.group + + invocations, activations = self.get_sorted_invocations_and_activations() + expected_invocations = len(invocations) * len(batches) + actual_invocations = 0 + + rank = self.group.rank() + count_per_order = dict() + + for i, batch in enumerate(batches): + if rank == 0: + batch_index = i + order = 0 + else: + message = recv_async_header(self.transport_config, self.input_device) + args: AsyncMessageBody = message.args + + phony = torch.empty(0, device=self.input_device, requires_grad=True) + result = AsyncRecvOperator.apply( + torch.distributed.get_rank(), phony, self.input_device, self.transport_config, message, + ) + if len(result) == 1: + batch = Batch(result[0], args.microbatch_index) + else: + batch = Batch(result, args.microbatch_index) + batch_index = args.microbatch_index + order = args.order + + inv_count, last_order = self.process_batch_forward( + batch, batch_index, invocations, order, skip_trackers, activations + ) + actual_invocations += inv_count + count_per_order[last_order] = inv_count + + if actual_invocations < expected_invocations: + expected_gradients = 0 # (len(invocations) - 1) * len(batches) + dprint(f"tail expect {expected_invocations}, {len(invocations)}, {len(batches)}") + + self.event_loop_inner( + expected_invocations, + skip_trackers, + activations, + invocations, + count_per_order, + already_received=actual_invocations, + ignore_gradients=True, + ) + + for index, batch in activations[len(self.partitions) - 1][invocations[-1].order].items(): + batches[index] = batch + + def get_sorted_invocations_and_activations(self) -> Tuple[List[Invocation], Activations]: + activations: Activations = dict() + invocations: List[Invocation] = [] + + for pi, partition in enumerate(self.partitions): + activations[pi] = dict() + for invocation in partition.invocations: + activations[pi][invocation.order] = dict() + invocations.append(invocation) + + invocations.sort(key=lambda inv: inv.order) + + return (invocations, activations) + + def event_loop(self, num_microbatch: int, skip_trackers: List[SkipTrackerThroughPotals]) -> None: + assert self.group + + invocations, activations = self.get_sorted_invocations_and_activations() + + expected_invocations = len(invocations) * num_microbatch + + dprint(f"event_loop {expected_invocations}, {num_microbatch}, {len(invocations)}") + self.event_loop_inner(expected_invocations, skip_trackers, activations, invocations, dict()) + + def event_loop_inner( + self, + expected_invocations: int, + skip_trackers: List[SkipTrackerThroughPotals], + activations: Activations, + invocations: List[Invocation], + count_per_order: Dict[int, int], + *, + already_received: int = 0, + ignore_gradients: bool = False, + event: Optional[Event] = None, + ) -> None: + + num_activations = already_received + if self.training and not ignore_gradients: + num_gradients = 0 + else: + num_gradients = expected_invocations + + while num_activations < expected_invocations or num_gradients < expected_invocations: + dprint( + f">> recv_async_header {self.group.rank()}, {torch.distributed.get_rank()} {expected_invocations}," + f" {num_activations}, {num_gradients}, {ignore_gradients}" + ) + if num_activations == expected_invocations and num_gradients == 0 and event is not None: + print(f">>> wait on event") + event.wait() + print(f"<<< wait on event") + + message = recv_async_header(self.transport_config, self.input_device) + dprint(f"<< recv_async_header {torch.distributed.get_rank()}") + args: AsyncMessageBody = message.args + + filtered = [inv for inv in invocations if inv.order == args.order] + if len(filtered) == 0: + dprint(f"no invocation on {self.group.rank()} for {args.order}, {invocations}") + invocation = filtered[0] + + # FIXME(tom) for combining pipeline with megatron, I currently don't + # control the order of received activations or gradients, so it is + # possible for a reused ColumnParallelLinear for example to receive + # a different order of activations w.r.t. the sending stage, which + # would result in incorrect values being used for the all_gather + if args.message_type is AsyncMessageType.Activations: + + phony = torch.empty(0, device=self.input_device, requires_grad=True) + result = AsyncRecvOperator.apply( + torch.distributed.get_rank(), phony, self.input_device, self.transport_config, message, + ) + + dprint( + f"got batch {torch.distributed.get_rank()}|{self.group.rank()} i:{args.microbatch_index}" + f" len:{len(result)}, {invocation}" + ) + + if len(result) == 1: + batch = Batch(result[0], args.microbatch_index) + else: + batch = Batch(result, args.microbatch_index) + + dprint(f"calling pbb? {self.group.rank()}, {expected_invocations}, {num_activations}, {num_gradients}") + inv_count, last_order = self.process_batch_forward( + batch, args.microbatch_index, invocations, args.order, skip_trackers, activations + ) + count_per_order[last_order] = inv_count + num_activations += inv_count + assert num_activations <= expected_invocations + + elif args.message_type is AsyncMessageType.Gradients: + dprint(f">> try {self.group.rank()}, {invocation.order}, {count_per_order}, {num_gradients}") + num_gradients += count_per_order[invocation.order] + self.async_grad_inner(message, activations, invocation) + dprint(f"<< try {self.group.rank()}, {invocation.order}, {count_per_order}, {num_gradients}") diff --git a/fairscale/nn/pipe/messages.py b/fairscale/nn/pipe/messages.py new file mode 100644 index 000000000..b9fa63969 --- /dev/null +++ b/fairscale/nn/pipe/messages.py @@ -0,0 +1,137 @@ +import pickle +from queue import Empty as QueueEmpty +from queue import Queue +from typing import Any, List + +import numpy as np +import torch +from torch import Tensor + +from fairscale.nn.model_parallel import get_pipeline_parallel_group + +from .types import MESSAGE_GENERATION_START, InputDevice, PipeMessage, Tensors, TransportConfig + +# FIXME Why is 256 ok for training but not for tests? +MESSAGE_TENSOR_SIZE = 1024 # 256 + +MessageQueues: List[Queue] = [Queue() for _ in range(MESSAGE_GENERATION_START)] + + +def pyobject_to_tensor(obj: Any) -> Tensor: + pickled = pickle.dumps(obj) + nparray = np.frombuffer(pickled, dtype=np.uint8).copy() + nparray.setflags(write=True) + result = torch.from_numpy(nparray) + delta = MESSAGE_TENSOR_SIZE - len(result) + if delta < 0: + raise ValueError( + f"message too big to send, increase MESSAGE_TENSOR_SIZE? - {len(result)} > {MESSAGE_TENSOR_SIZE}" + ) + elif delta > 0: + result = torch.cat((result, torch.zeros(delta, dtype=torch.uint8))) + + return result.cuda() + + +def tensor_to_pyobject(tensor: Tensor) -> Any: + try: + nparray = tensor.numpy() + return pickle.loads(nparray.tobytes()) + except Exception as e: + print(f"pickle fail {e}") + raise e + + +def to_input_device(tensors: Tensors, input_device: InputDevice) -> Tensors: + if input_device is None: + return tensors + else: + return tuple(t.to(input_device) for t in tensors) + + +def rpc_push_queue(message: PipeMessage) -> None: + globals()["MessageQueues"][message.queue_name].put(message) + + +def send_message(config: TransportConfig, message: PipeMessage, sync: bool = False) -> None: + if config.use_rpc: + message.tensors = tuple(t.cpu() for t in message.tensors) + assert config.worker_map + name = config.worker_map[message.dest] + if sync: + torch.distributed.rpc.rpc_sync(name, rpc_push_queue, args=(message,)) + else: + torch.distributed.rpc.rpc_async(name, rpc_push_queue, args=(message,)) + else: + tensors = message.tensors + message.tensors = tuple() + message.tensor_shapes = [t.size() for t in tensors] + message.tensor_dtypes = [t.dtype for t in tensors] + torch.cuda.current_stream().synchronize() + torch.distributed.send( + pyobject_to_tensor(message), message.dest, tag=message.queue_name, group=get_pipeline_parallel_group() + ) + for index, t in enumerate(tensors): + if t.device.type == "cpu": + t = t.cuda() + torch.distributed.send( + t.contiguous(), message.dest, tag=message.tag + index, group=get_pipeline_parallel_group() + ) + + +def recv_message( + config: TransportConfig, queue_name: int, *, nowait: bool = False, input_device: InputDevice = None +) -> PipeMessage: + if config.use_rpc: + queue = globals()["MessageQueues"][queue_name] + if nowait: + result = queue.get_nowait() + else: + result = queue.get() + result.tensors = to_input_device(result.tensors, input_device) + return result + else: + # FIXME(handle nowait) + if nowait: + raise QueueEmpty + + torch.cuda.current_stream().synchronize() + tensor = torch.empty(MESSAGE_TENSOR_SIZE, dtype=torch.uint8, device=input_device) + torch.distributed.recv(tensor, src=None, tag=queue_name, group=get_pipeline_parallel_group()) + torch.cuda.current_stream().synchronize() + message = tensor_to_pyobject(tensor.cpu()) + + message_tensors = [] + for index, (shape, dtype) in enumerate(zip(message.tensor_shapes, message.tensor_dtypes)): + t = torch.empty(*shape, dtype=dtype, device=input_device) + torch.distributed.recv(t, message.src, tag=message.tag + index, group=get_pipeline_parallel_group()) + message_tensors.append(t) + + message.tensors = tuple(message_tensors) + # print(f"<<< recv:{torch.distributed.get_rank()}") + + torch.cuda.current_stream().synchronize() + return message + + +def get_out_of_order(config: TransportConfig, queue_name: int, index: int, *, input_device: InputDevice) -> Tensors: + """Receive a message with a known microbatch index, and handle out-of-order + messages by placing them back on the queue""" + + if config.use_rpc: + queue = globals()["MessageQueues"][queue_name] + out_of_order: List[PipeMessage] = [] + while True: + message = recv_message(config, queue_name, input_device=input_device) + got_index = message.args + value = message.tensors + if got_index == index: + for b in out_of_order: + queue.put(b) + return value + else: + out_of_order.append(message) + else: + message = recv_message(config, queue_name, input_device=input_device) + assert message.args == index + return message.tensors diff --git a/fairscale/nn/pipe/pipe.py b/fairscale/nn/pipe/pipe.py index 8a6e6fdcf..8d9034c12 100644 --- a/fairscale/nn/pipe/pipe.py +++ b/fairscale/nn/pipe/pipe.py @@ -19,24 +19,29 @@ """The Pipe interface.""" from collections import OrderedDict -from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, cast +import itertools +import threading +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union, cast import warnings +from dataclasses import dataclass, field import torch from torch import Tensor, nn import torch.autograd import torch.cuda -from fairscale.nn.model_parallel import get_model_parallel_group, get_pipeline_parallel_group +from fairscale.nn.model_parallel import get_model_parallel_world_size, get_pipeline_parallel_group from . import microbatch +from .async_schedule import Invocation, Location, ModuleWrapper from .batchnorm import DeferredBatchNorm -from .pipeline import Pipeline, PipelineStyle +from .pipeline import Pipeline from .skip.layout import SkipLayout, inspect_skip_layout from .skip.skippable import Skippable, verify_skippables from .stream import AbstractStream, new_stream +from .types import LazyModule, PipelineStyle -__all__ = ["Pipe"] +__all__ = ["Pipe", "LazyModule"] Device = Union[torch.device, int, str] @@ -45,7 +50,7 @@ Tensors = Tuple[Tensor, ...] TensorOrTensors = Union[Tensor, Tensors] -ListOfLazyModules = List[Callable[[], nn.Module]] +ListOfLazyModules = List[LazyModule] if TYPE_CHECKING: Module = nn.Module[TensorOrTensors] @@ -79,10 +84,10 @@ def verify_list_of_callable(module: Union[nn.Sequential, list]) -> None: for layer in module: if isinstance(layer, nn.Module): pass - elif callable(layer): + elif isinstance(layer, LazyModule): pass else: - raise TypeError(f"layer {type(layer)} must be nn.Module or callable to be partitioned") + raise TypeError(f"layer {type(layer)} must be nn.Module or LazyModule to be partitioned") def verify_module(module: Union[nn.Sequential, ListOfLazyModules]) -> None: @@ -124,8 +129,16 @@ class BalanceError(ValueError): pass -def check_balance(module: Any, balance: Iterable[int]) -> None: - if len(module) != sum(balance): +def check_balance(module: Any, balance: Iterable[int], filter_unique: bool = False) -> None: + + if filter_unique: + module_len = len(set(map(id, module))) + print(f"unique layer {module_len}, {balance}") + else: + module_len = len(module) + print(f"non-unique layer {module_len}, {balance}") + + if module_len != sum(balance): raise BalanceError( f"module and sum of balance have different length (module: {len(module)}, sum of balance: {sum(balance)})" ) @@ -134,16 +147,27 @@ def check_balance(module: Any, balance: Iterable[int]) -> None: raise BalanceError(f"all balance numbers must be positive integer (balance: {balance})") +@dataclass +class PartitionInfo: + location: Location + modules: "OrderedDict[str, nn.Module]" + invocations: List[Invocation] = field(default_factory=list) + + def __len__(self) -> int: + return len(self.modules) + + def instantiate_partition( - module: Union[nn.Sequential, ListOfLazyModules], balance: Iterable[int], group: torch.distributed.ProcessGroup -) -> nn.Sequential: + module: Union[nn.Sequential, ListOfLazyModules], + balance: Iterable[int], + group: torch.distributed.ProcessGroup, + style: PipelineStyle, +) -> List[ModuleWrapper]: balance = list(balance) - check_balance(module, balance) + check_balance(module, balance, True) layers: NamedModules = OrderedDict() - j = 0 - def maybe_realize(layer: Any) -> nn.Module: if isinstance(layer, nn.Module): return layer @@ -156,22 +180,110 @@ def iterate_module(module: Union[nn.Sequential, list]) -> Iterable[Tuple[Any, nn if isinstance(module, nn.Sequential): yield from module.named_children() else: - yield from enumerate(module) + yield from ((str(k), v) for k, v in enumerate(module)) + + if style == PipelineStyle.AsyncSchedule: + module_ids = list(map(id, module)) + index_of_first_use = [module_ids.index(x) for x in module_ids] + locations: List[Location] = [] + module_iter = enumerate(iterate_module(module)) + + partitions: List[List[PartitionInfo]] = [] + for bi, b in enumerate(balance): + modules_for_rank: List[PartitionInfo] = [] + current_module: OrderedDict[str, nn.Module] = OrderedDict() + + def current_location() -> Location: + return Location(bi, len(modules_for_rank)) + + def append_module(mod: "OrderedDict[str, nn.Module]") -> None: + modules_for_rank.append(PartitionInfo(current_location(), mod)) + + while sum(map(len, modules_for_rank)) + len(current_module) < b: + module_index, (name, layer) = next(module_iter) + + if index_of_first_use[module_index] != module_index: + # Subsequent reuse of a module + locations.append(locations[index_of_first_use[module_index]]) + continue + + is_reused = index_of_first_use.count(index_of_first_use[module_index]) > 1 + + if is_reused and len(current_module) > 0: + append_module(current_module) + current_module = OrderedDict() + + current_module[str(name)] = layer + locations.append(current_location()) + if is_reused: + append_module(current_module) + current_module = OrderedDict() + + if len(current_module) > 0: + append_module(current_module) + + partitions.append(modules_for_rank) + + filtered_locations: List[Optional[Location]] = [loc for loc, _ in itertools.groupby(locations)] + filtered_locations.append(None) + + for i in range(len(filtered_locations) - 1): + loc = filtered_locations[i] + assert loc + if i == 0: + inv = Invocation(i, loc, None, filtered_locations[i + 1]) + else: + inv = Invocation(i, loc, filtered_locations[i - 1], filtered_locations[i + 1]) + + partitions[loc.stage][loc.index].invocations.append(inv) + + invocations = enumerate(iterate_module(module)) + + partition = partitions[group.rank()] + result: List[ModuleWrapper] = [] + for partition_info in partition: + wrapper = ModuleWrapper( + nn.Sequential(OrderedDict((k, maybe_realize(m)) for k, m in partition_info.modules.items())), + partition_info.location, + partition_info.invocations, + ) + + if not isinstance(module, nn.Sequential): + for layer in wrapper.module: + if isinstance(layer, Skippable): + raise ValueError("Can't use Skippable layers with multi-process pipe and lazy construction") + + result.append(wrapper) + + # print(f"partitions = {partitions}") + + return result + + j = 0 + + # print(f"{torch.distributed.get_rank()}: assigned = {assigned}") + + # print(f"yay {list(map(id, module))}") + # print(f"first_index = {first_index}") + # print(f"duplicates = {duplicates}") + # print(f"duplicates2 = {duplicates2}") + # print(f"locations = {locations}") + # for name, layer in iterate_module(module): layers[name] = layer if len(layers) == balance[j]: if j == group.rank(): for key in layers: + print(f"key = {type(key)}-{key}") layers[key] = maybe_realize(layers[key]) if not isinstance(module, nn.Sequential): for layer in layers.values(): if isinstance(layer, Skippable): raise ValueError("Can't use Skippable layers with multi-process pipe and lazy construction") - partition = nn.Sequential(*layers.values()) - return partition + return [ModuleWrapper(nn.Sequential(layers), Location(j, 0))] # Prepare for the next partition. layers.clear() @@ -297,7 +409,7 @@ class Pipe(Module): backward pass (instead of once for the whole batch). This works around a potential deadlock in pytorch when using tensor parallelism at the same time. Defaults to `True` if - `get_model_parallel_group.size() > 1` + `get_model_parallel_world_size() > 1` (default: `None`) retain_graph (bool): The value passed to `torch.autograd.backwards(..., retain_graph=) @@ -315,6 +427,7 @@ class Pipe(Module): SingleProcess: PipelineStyle = PipelineStyle.SingleProcess MultiProcess: PipelineStyle = PipelineStyle.MultiProcess + AsyncSchedule: PipelineStyle = PipelineStyle.AsyncSchedule #: The number of layers in each partition. balance: List[int] = [] @@ -359,6 +472,7 @@ def __init__( deferred_batch_norm: bool = False, pipelined_backward: bool = None, retain_graph: bool = False, + loss_fn: Optional[nn.Module] = None, ) -> None: super().__init__() @@ -384,6 +498,17 @@ def __init__( self.pipelined_backward = pipelined_backward self.retain_graph = retain_graph self.pipeline: Optional[Pipeline] + self.loss_fn = loss_fn + self.lock = threading.Lock() + + self.group = group + self.worker_map = worker_map + self.input_device = input_device + + self._copy_streams: List[List[AbstractStream]] = [] + + # The micro-batch index where the checkpointing stops. + checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint] if style is PipelineStyle.SingleProcess: module = cast(nn.Sequential, module) @@ -407,29 +532,42 @@ def __init__( self._skip_layout = inspect_skip_layout(self.partitions) - elif style is PipelineStyle.MultiProcess: - if group is None: - group = get_pipeline_parallel_group() + # Separate CUDA streams for copy. + copy_streams = self._ensure_copy_streams() + if self.pipelined_backward is None: + self.pipelined_backward = False + self.pipeline = Pipeline( + self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop, style=style, + ) + + elif style in [PipelineStyle.MultiProcess, PipelineStyle.AsyncSchedule]: + + if self.group is None: + self.group = get_pipeline_parallel_group() + assert self.group if devices is not None: raise ValueError("'devices' argument only applies to 'PipelineStyle.SingleProcess'") self.balance = list(balance) - if group.size() < len(self.balance): + if self.group.size() < len(self.balance): raise IndexError( - f"too few ranks to hold given partitions (ranks: {group.size()}, partitions: {len(self.balance)})" + f"too few ranks to hold given partitions (ranks: {self.group.size()}, partitions:" + f" {len(self.balance)})" ) try: - rank = torch.distributed.get_rank(group) + rank = self.group.rank() if rank >= len(self.balance): warnings.warn("More ranks than partitions, some ranks unused") - self.partitions = cast(List[nn.Sequential], nn.ModuleList([nn.Sequential()])) + self.mp_partitions: List[ModuleWrapper] = [] else: - partition = instantiate_partition(module, balance, group) + self.mp_partitions = instantiate_partition(module, balance, self.group, style) if deferred_batch_norm: - partition = DeferredBatchNorm.convert_deferred_batch_norm(partition, chunks) - self.partitions = cast(List[nn.Sequential], nn.ModuleList([partition])) + for part in self.mp_partitions: + part.module = DeferredBatchNorm.convert_deferred_batch_norm(part.module, chunks) + for name, part in enumerate(self.mp_partitions): + self.add_module(str(name), part.module) self.devices = None if isinstance(module, nn.Sequential): local_partitions, _, _ = split_module(module, balance, None) @@ -440,31 +578,16 @@ def __init__( except BalanceError as exc: raise ValueError(recommend_auto_balance(str(exc))) - self.group = group - self.worker_map = worker_map - self.input_device = input_device - - self._copy_streams: List[List[AbstractStream]] = [] - - # The micro-batch index where the checkpointing stops. - checkpoint_stop = {"always": self.chunks, "except_last": self.chunks - 1, "never": 0}[self.checkpoint] - - if style is PipelineStyle.SingleProcess: - # Separate CUDA streams for copy. - copy_streams = self._ensure_copy_streams() - if self.pipelined_backward is None: - self.pipelined_backward = False - self.pipeline = Pipeline( - self.partitions, self.devices, copy_streams, self._skip_layout, checkpoint_stop, style=style - ) - elif style is PipelineStyle.MultiProcess: - rank = torch.distributed.get_rank(group) + rank = self.group.rank() if rank >= len(self.balance): self.pipeline = None + self.final_stage = False else: self.final_stage = rank == len(self.balance) - 1 + assert loss_fn is None or self.final_stage + self.pipeline = Pipeline( - self.partitions, + cast(List[nn.Sequential], self.mp_partitions), None, None, self._skip_layout, @@ -473,27 +596,39 @@ def __init__( group=self.group, worker_map=self.worker_map, input_device=self.input_device, + final_stage=self.final_stage, ) del module if self.pipelined_backward is None: - if get_model_parallel_group().size() > 1: + if get_model_parallel_world_size() > 1: self.pipelined_backward = True else: self.pipelined_backward = False def __len__(self) -> int: """Counts the length of the underlying sequential module.""" - return sum(len(p) for p in self.partitions) + if hasattr(self, "partitions"): + return sum(len(p) for p in self.partitions) + else: + return sum(len(p) for p in self.mp_partitions) def __getitem__(self, index: int) -> nn.Module: """Gets a layer in the underlying sequential module.""" - partitions = self.partitions + partitions: List[Any] + if hasattr(self, "partitions"): + partitions = self.partitions + else: + partitions = self.mp_partitions + if index < 0: partitions = partitions[::-1] for partition in partitions: try: - return partition[index] + if isinstance(partition, ModuleWrapper): + return partition.module[index] + else: + return partition[index] except IndexError: pass @@ -508,8 +643,12 @@ def __getitem__(self, index: int) -> nn.Module: def __iter__(self) -> Iterable[nn.Module]: """Iterates over children of the underlying sequential module.""" - for partition in self.partitions: - yield from partition + if hasattr(self, "partitions"): + for partition in self.partitions: + yield from partition + else: + for mp_partition in self.mp_partitions: + yield from mp_partition.module # Pipe should manage the device of each partition. # Deny cuda(), cpu(), and to() with device, by TypeError. @@ -527,7 +666,7 @@ def cpu(self) -> "Pipe": return super().cpu() def to(self, *args: Any, **kwargs: Any) -> "Pipe": - """ Restrict .to() options. + """Restrict .to() options. Deny these usages: - to(device[, dtype, non_blocking]) @@ -563,7 +702,7 @@ def _ensure_copy_streams(self) -> List[List[AbstractStream]]: return self._copy_streams - def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore + def forward(self, input: TensorOrTensors, *, event=None) -> TensorOrTensors: # type: ignore """:class:`Pipe` is a fairly transparent module wrapper. It doesn't modify the input and output signature of the underlying module. But there's type restriction. Input and output have to be a @@ -594,25 +733,26 @@ def forward(self, input: TensorOrTensors) -> TensorOrTensors: # type: ignore batches = microbatch.scatter(input, self.chunks) # Run pipeline parallelism. - self.pipeline.run(batches) - - if self.group and not self.final_stage: - # Don't merge micro-batches to avoid unnecessary edges in autograd - # graph - # FIXME(tom) should figure out a proper type here - return batches # type: ignore - else: - # Merge the micro-batches into one mini-batch. - if self.pipelined_backward: - with torch.no_grad(): - output = microbatch.gather(batches) + with self.lock: + self.pipeline.run(self.training, batches, event) + + if self.group and not self.final_stage: + # Don't merge micro-batches to avoid unnecessary edges in autograd + # graph + # FIXME(tom) should figure out a proper type here + return batches # type: ignore + else: + # Merge the micro-batches into one mini-batch. + if self.pipelined_backward: + with torch.no_grad(): + output = microbatch.gather(batches) - from .phony import get_phony + from .phony import get_phony - phony = get_phony(torch.device(torch.cuda.current_device()), requires_grad=True) - output = PipelinedBackwardPass.apply(output, batches, phony, True) # self.retain_graph) - else: - output = microbatch.gather(batches) + phony = get_phony(torch.device(torch.cuda.current_device()), requires_grad=True) + output = PipelinedBackwardPass.apply(output, batches, phony, True) # self.retain_graph) + else: + output = microbatch.gather(batches) return output diff --git a/fairscale/nn/pipe/pipeline.py b/fairscale/nn/pipe/pipeline.py index 4f395a7b1..b53c84b6b 100644 --- a/fairscale/nn/pipe/pipeline.py +++ b/fairscale/nn/pipe/pipeline.py @@ -17,202 +17,56 @@ # limitations under the License. """The pipeline parallelism of Pipe.""" -from enum import Enum, auto import os -import pickle from queue import Empty as QueueEmpty from queue import Queue +from threading import Event +import time from types import TracebackType -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Type, Union, cast +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Type, Union, cast -from dataclasses import dataclass -import numpy as np import torch from torch import Tensor, nn from torch.autograd.profiler import record_function from fairscale.nn.model_parallel import get_pipeline_parallel_ranks +from .async_schedule import AsyncEventLoop, ModuleWrapper from .checkpoint import Checkpointing from .copy import Copy, Wait from .dependency import fork, join +from .messages import get_out_of_order, recv_message, send_message from .microbatch import Batch from .skip import Namespace from .skip.layout import SkipLayout from .skip.tracker import SkipTrackerThroughPotals, use_skip_tracker from .stream import AbstractStream, current_stream, use_device +from .types import ( + ACTIVATIONS_GRADS_QUEUE, + PORTAL_QUEUE, + SKIP_TENSOR_QUEUE, + PipelineStyle, + PipeMessage, + Schedule, + TensorOrTensors, + Tensors, + TransportConfig, +) from .worker import Task, create_workers, join_workers __all__: List[str] = [] -Tensors = Tuple[Tensor, ...] -TensorOrTensors = Union[Tensor, Tensors] - -InputDevice = Union[None, int, str, torch.device] -Schedule = List[Tuple[int, int]] - ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] -MessageQueues: List[Queue] = [Queue(), Queue(), Queue()] -ACTIVATIONS_GRADS_QUEUE = 0 -SKIP_TENSOR_QUEUE = 1 -PORTAL_QUEUE = 2 -MESSAGE_GENERATION_START = 3 - - -# FIXME Why is 256 ok for training but not for tests? -MESSAGE_TENSOR_SIZE = 512 # 256 - -MessageGeneration = MESSAGE_GENERATION_START - - -class PipelineStyle(Enum): - SingleProcess = auto() - MultiProcess = auto() - - -@dataclass(frozen=True) -class TransportConfig: - use_rpc: bool - worker_map: Optional[Dict[int, str]] - - -@dataclass -class PipeMessage: - src: int - dest: int - queue_name: int - args: Any - tensors: Tensors - tensor_shapes: List[torch.Size] - tensor_dtypes: List[torch.dtype] - tag: int = 0 - - def __init__(self, src: int, dest: int, queue_name: int, args: Any, tensors: Tensors): - self.src = src - self.dest = dest - self.queue_name = queue_name - self.args = args - self.tensors = tensors - - global MessageGeneration - self.tag = MessageGeneration - MessageGeneration += len(tensors) - - -def rpc_push_queue(message: PipeMessage) -> None: - globals()["MessageQueues"][message.queue_name].put(message) - - -def pyobject_to_tensor(obj: Any) -> Tensor: - pickled = pickle.dumps(obj) - nparray = np.frombuffer(pickled, dtype=np.uint8).copy() - nparray.setflags(write=True) - result = torch.from_numpy(nparray) - delta = MESSAGE_TENSOR_SIZE - len(result) - if delta < 0: - raise ValueError( - f"message too big to send, increase MESSAGE_TENSOR_SIZE? - {len(result)} > {MESSAGE_TENSOR_SIZE}" - ) - elif delta > 0: - result = torch.cat((result, torch.zeros(delta, dtype=torch.uint8))) - - return result.cuda() - - -def tensor_to_pyobject(tensor: Tensor) -> Any: - nparray = tensor.numpy() - return pickle.loads(nparray.tobytes()) +debug_fd = None -def send_message(config: TransportConfig, message: PipeMessage, sync: bool = False) -> None: - if config.use_rpc: - message.tensors = tuple(t.cpu() for t in message.tensors) - assert config.worker_map - name = config.worker_map[message.dest] - if sync: - torch.distributed.rpc.rpc_sync(name, rpc_push_queue, args=(message,)) - else: - torch.distributed.rpc.rpc_async(name, rpc_push_queue, args=(message,)) - else: - tensors = message.tensors - message.tensors = tuple() - message.tensor_shapes = [t.size() for t in tensors] - message.tensor_dtypes = [t.dtype for t in tensors] - torch.cuda.current_stream().synchronize() - torch.distributed.send(pyobject_to_tensor(message), message.dest, tag=0) - for index, t in enumerate(tensors): - if t.device.type == "cpu": - t = t.cuda() - torch.distributed.send(t, message.dest, tag=message.tag + index) - - -def recv_message( - config: TransportConfig, queue_name: int, *, nowait: bool = False, input_device: InputDevice = None -) -> PipeMessage: - if config.use_rpc: - queue = globals()["MessageQueues"][queue_name] - if nowait: - result = queue.get_nowait() - else: - result = queue.get() - result.tensors = to_input_device(result.tensors, input_device) - return result - else: - # FIXME(handle nowait) - if nowait: - raise QueueEmpty - - tensor = torch.empty(MESSAGE_TENSOR_SIZE, dtype=torch.uint8, device=input_device) - torch.distributed.recv(tensor, src=-1, tag=queue_name) - message = tensor_to_pyobject(tensor.cpu()) - torch.cuda.current_stream().synchronize() - - message_tensors = [] - for index, (shape, dtype) in enumerate(zip(message.tensor_shapes, message.tensor_dtypes)): - t = torch.empty(*shape, dtype=dtype, device=input_device) - torch.distributed.recv(t, message.src, tag=message.tag + index) - message_tensors.append(t) - - message.tensors = tuple(message_tensors) - - torch.cuda.current_stream().synchronize() - return message - - -def get_out_of_order(config: TransportConfig, queue_name: int, index: int, *, input_device: InputDevice) -> Tensors: - """Receive a message with a known microbatch index, and handle out-of-order - messages by placing them back on the queue""" - - if config.use_rpc: - queue = globals()["MessageQueues"][queue_name] - out_of_order: List[PipeMessage] = [] - while True: - message = recv_message(config, queue_name, input_device=input_device) - got_index = message.args - value = message.tensors - if got_index == index: - for b in out_of_order: - queue.put(b) - return value - else: - out_of_order.append(message) - else: - message = recv_message(config, queue_name, input_device=input_device) - assert message.args == index - return message.tensors - - -def to_input_device(tensors: TensorOrTensors, input_device: InputDevice) -> TensorOrTensors: - if input_device is None: - return tensors - else: - if isinstance(tensors, Tensor): - return tensors.to(input_device) - else: - return tuple(t.to(input_device) for t in tensors) +def dprint(s: str) -> None: + print(f"{time.monotonic()}: {s}") + # debug_fd.write(f"{time.monotonic()}: {s}\n") + # debug_fd.flush() class SendOperator(torch.autograd.Function): @@ -318,6 +172,57 @@ def clock_cycles(m: int, n: int) -> Iterable[Schedule]: yield [(k - j, j) for j in range(max(1 + k - m, 0), min(1 + k, n))] +def create_task( + style: PipelineStyle, + checkpoint_stop: int, + i: int, + j: int, + batch: Batch, + partition: nn.Sequential, + skip_trackers: List[SkipTrackerThroughPotals], + streams: List[AbstractStream], +) -> Task: + # Determine whether checkpointing or not. + if i < checkpoint_stop: + + def function( + input: TensorOrTensors, + partition: nn.Sequential = partition, + skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], + chunk_id: int = i, + part_id: int = j, + ) -> TensorOrTensors: + with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): + return partition(input) + + chk = Checkpointing(function, batch) + if style is PipelineStyle.SingleProcess: + task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute) + elif style in [PipelineStyle.MultiProcess, PipelineStyle.AsyncSchedule]: + task = Task(None, compute=chk.checkpoint, finalize=chk.recompute) + del function, chk # TODO(tom) maybe remove + + else: + + def compute( + batch: Batch = batch, + partition: nn.Sequential = partition, + skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], + chunk_id: int = i, + part_id: int = j, + ) -> Batch: + with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): + return batch.call(partition) + + if style is PipelineStyle.SingleProcess: + task = Task(streams[j], compute=compute, finalize=None) + elif style in [PipelineStyle.MultiProcess, PipelineStyle.AsyncSchedule]: + task = Task(None, compute=compute, finalize=None) + del compute # TODO(tom) maybe remove + + return task + + class Pipeline: """The pipeline parallelism for Pipe.""" @@ -332,52 +237,68 @@ def __init__( group: Optional[torch.distributed.ProcessGroup] = None, worker_map: Optional[Dict[int, str]] = None, input_device: Union[None, int, str, torch.device] = None, + final_stage: bool = False, ) -> None: - self.partitions = partitions + if style == PipelineStyle.SingleProcess: + self.partitions = partitions + else: + self.mp_partitions: List[ModuleWrapper] = cast(List[ModuleWrapper], partitions) self.devices = devices self.copy_streams = copy_streams self.skip_layout = skip_layout - self.checkpoint_stop = checkpoint_stop + self.__checkpoint_stop = checkpoint_stop self.style = style self.group = group + self.training: bool self.transport_config = TransportConfig( - use_rpc=("OMPI_COMM_WORLD_RANK" not in os.environ), worker_map=worker_map + use_rpc=("OMPI_COMM_WORLD_RANK" not in os.environ) or ("FORCE_RPC" in os.environ), worker_map=worker_map ) self.input_device = input_device self.all_at_once = False self.callcount = 0 + self.final_stage = final_stage if self.style is PipelineStyle.SingleProcess: assert self.devices is not None (self.in_queues, self.out_queues) = create_workers(self.devices) if ( - self.style is PipelineStyle.MultiProcess + self.style in [PipelineStyle.MultiProcess, PipelineStyle.AsyncSchedule] and self.transport_config.worker_map is None and self.transport_config.use_rpc is True ): raise ValueError("'PipelineStyle.MultiProcess' requires 'worker_map' to be set") + @property + def checkpoint_stop(self) -> int: + # Disable checkpointing if in eval mode. + if self.style == PipelineStyle.SingleProcess: + training = self.partitions[0].training + else: + training = self.mp_partitions[0].module.training + if not training: + return 0 + return self.__checkpoint_stop + def __del__(self) -> None: if self.style is PipelineStyle.SingleProcess: join_workers(self.in_queues, self.out_queues) - def run(self, batches: List[Batch]) -> None: + def run(self, training: bool, batches: List[Batch], event: Optional[Event]) -> None: """Runs pipeline parallelism. It modifies the given batches in place. """ - partitions = self.partitions - devices = self.devices + self.training = training m = len(batches) - n = len(partitions) skip_trackers = [SkipTrackerThroughPotals(self.skip_layout, i) for i in range(len(batches))] if self.style is PipelineStyle.SingleProcess: + n = len(self.partitions) for schedule in clock_cycles(m, n): self.fence(batches, schedule, skip_trackers) self.compute(batches, schedule, skip_trackers) @@ -385,6 +306,32 @@ def run(self, batches: List[Batch]) -> None: assert self.group schedule = [(i, self.group.rank()) for i in range(m)] self.compute(batches, schedule, skip_trackers) + elif self.style is PipelineStyle.AsyncSchedule: + assert self.group + rank = self.group.rank() + dprint(f">>> start barrier") + # torch.distributed.barrier(group=self.group) + event_loop = AsyncEventLoop( + self.mp_partitions, + self.group, + self.transport_config, + self.training, + self.input_device, + self.checkpoint_stop, + ) + dprint(f"<<< start barrier") + if rank == 0 and not self.final_stage: + dprint(f"{torch.distributed.get_rank()}: entered event head") + event_loop.event_loop_head(batches, skip_trackers, event) + dprint(f"{torch.distributed.get_rank()}: exited event head") + elif self.final_stage: + dprint(f"{torch.distributed.get_rank()}: entered event tail") + event_loop.event_loop_tail(batches, skip_trackers) + dprint(f"{torch.distributed.get_rank()}: exited event tail") + else: + dprint(f"{torch.distributed.get_rank()}: entered event loop") + event_loop.event_loop(len(batches), skip_trackers) + dprint(f"{torch.distributed.get_rank()}: exited event loop") self.callcount += 1 @@ -481,7 +428,7 @@ def execute_task(self, task: Task, i: int, skip_trackers: List[SkipTrackerThroug assert self.group rank = self.group.rank() - if rank != self.group.size() - 1: + if self.style is PipelineStyle.MultiProcess and not self.final_stage: ranks = get_pipeline_parallel_ranks() this_rank = torch.distributed.get_rank() @@ -534,72 +481,16 @@ def finalize_tasks( if exc_info is not None: raise exc_info[0].with_traceback(exc_info[1], exc_info[2]) - def create_task( - self, - i: int, - j: int, - batch: Batch, - checkpoint_stop: int, - partition: nn.Sequential, - skip_trackers: List[SkipTrackerThroughPotals], - streams: List[AbstractStream], - ) -> Task: - # Determine whether checkpointing or not. - if i < checkpoint_stop: - - def function( - input: TensorOrTensors, - partition: nn.Sequential = partition, - skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], - chunk_id: int = i, - part_id: int = j, - ) -> TensorOrTensors: - with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): - return partition(input) - - chk = Checkpointing(function, batch) - if self.style is PipelineStyle.SingleProcess: - task = Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute) - elif self.style is PipelineStyle.MultiProcess: - task = Task(None, compute=chk.checkpoint, finalize=chk.recompute) - del function, chk # TODO(tom) maybe remove - - else: - - def compute( - batch: Batch = batch, - partition: nn.Sequential = partition, - skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], - chunk_id: int = i, - part_id: int = j, - ) -> Batch: - with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): - return batch.call(partition) - - if self.style is PipelineStyle.SingleProcess: - task = Task(streams[j], compute=compute, finalize=None) - elif self.style is PipelineStyle.MultiProcess: - task = Task(None, compute=compute, finalize=None) - del compute # TODO(tom) maybe remove - - return task - def compute( - self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], + self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals] ) -> None: """Runs tasks with synchronization to copy streams.""" - partitions = self.partitions devices = self.devices copy_streams = self.copy_streams - checkpoint_stop = self.checkpoint_stop - - # Disable checkpointing if in eval mode. - if not self.partitions[0].training: - checkpoint_stop = 0 if self.style is PipelineStyle.SingleProcess: assert devices is not None - n = len(partitions) + n = len(self.partitions) streams = [current_stream(d) for d in devices] elif self.style is PipelineStyle.MultiProcess: assert self.group @@ -635,25 +526,28 @@ def compute( batch = batches[i] if self.style is PipelineStyle.SingleProcess: - partition = partitions[j] + partition = self.partitions[j] # Synchronize with the copied input. ([1] in the diagram) assert copy_streams if j != 0: wait(batch, copy_streams[j][i], streams[j]) + + task = create_task(self.style, self.checkpoint_stop, i, j, batch, partition, skip_trackers, streams) + + # Compute tasks in parallel. ([2] in the diagram) + self.in_queues[j].put(task) elif self.style is PipelineStyle.MultiProcess: - assert len(self.partitions) == 1 - partition = self.partitions[0] + assert len(self.mp_partitions) == 1 + mp_partition = self.mp_partitions[0] assert self.group if self.group.rank() != 0: batch = self.get_batch_from_previous_stage(i, skip_trackers, batches) - task = self.create_task(i, j, batch, checkpoint_stop, partition, skip_trackers, streams) + task = create_task( + self.style, self.checkpoint_stop, i, j, batch, mp_partition.module, skip_trackers, streams + ) - if self.style is PipelineStyle.SingleProcess: - # Compute tasks in parallel. ([2] in the diagram) - self.in_queues[j].put(task) - elif self.style is PipelineStyle.MultiProcess: batches[i] = self.execute_task(task, i, skip_trackers) if self.style is PipelineStyle.SingleProcess: @@ -689,6 +583,9 @@ def recv_portal_grad(self, expected_ns_name: Tuple[Namespace, str], expected_ind return result def back_helper(self, output: List[Batch]) -> None: + if self.style == PipelineStyle.AsyncSchedule: + return + o = list(output) tensors: Tensors @@ -732,4 +629,4 @@ def back_helper(self, output: List[Batch]) -> None: try: torch.autograd.backward(final_tensors, grad_tensors=grads, retain_graph=True) except Exception as e: - raise RuntimeError("Autograd failed") from e + raise RuntimeError(f"Autograd failed on {torch.distributed.get_rank()}") from e diff --git a/fairscale/nn/pipe/rpc.py b/fairscale/nn/pipe/rpc.py new file mode 100644 index 000000000..bfc16cfe5 --- /dev/null +++ b/fairscale/nn/pipe/rpc.py @@ -0,0 +1,313 @@ +import os +from threading import Event, Lock, Thread +from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast + +import torch +from torch import nn +from torch.distributed import ProcessGroup, rpc +from torch.distributed.distributed_c10d import _get_global_rank + +from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group + +from . import Pipe +from .types import TensorOrTensors + +DEFAULT_MAX_SOURCE_POSITIONS = 1024 +DEFAULT_MAX_TARGET_POSITIONS = 1024 + +PipeModel: Pipe + + +SizeOrSizes = Union[torch.Size, List[torch.Size]] +DtypeOrDtypes = Union[torch.dtype, List[torch.dtype]] + + +def dprint(s: str) -> None: + print(str(torch.distributed.get_rank()) + ": " + s) + + +def set_device_based_on_group(group: ProcessGroup) -> None: + # torch.cuda.set_device(group.rank() % torch.cuda.device_count()) + torch.cuda.set_device(torch.distributed.get_rank() % torch.cuda.device_count()) + + +def register_remote_model(args: List[Any], kwargs: Dict[str, Any]) -> None: + group = get_pipeline_parallel_group() # FIXME(tom) handle dynamic group + set_device_based_on_group(group) + dprint(f"model registered {torch.cuda.current_device()}") + kwargs["group"] = group + kwargs["input_device"] = torch.device("cuda", torch.cuda.current_device()) + model = Pipe(*args, **kwargs) + model.cuda() + globals()["PipeModel"] = model + + +def get_shapes(tensor: TensorOrTensors) -> SizeOrSizes: + if isinstance(tensor, torch.Tensor): + return tensor.shape + else: + return [t.shape for t in tensor] + + +def get_dtype(tensor: TensorOrTensors) -> DtypeOrDtypes: + if isinstance(tensor, torch.Tensor): + return tensor.dtype + else: + return [t.dtype for t in tensor] + + +def model_forward(training: bool, shape: torch.Size, dtype: torch.dtype) -> Optional[Tuple[SizeOrSizes, DtypeOrDtypes]]: + try: + dprint(f"mf: train stage {torch.distributed.get_rank()}") + if isinstance(shape, torch.Size): + tensor = torch.empty(shape, dtype=dtype) + else: + tensor = tuple([torch.empty(s, dtype=d) for s, d in zip(shape, dtype)]) + + model = globals()["PipeModel"] + set_device_based_on_group(model.group) + + dprint(f"mf: train stage {model.group.rank()}, {os.getpid()}") + model.train(training) + result = model(tensor) + torch.cuda.current_stream().synchronize() + if model.final_stage: + globals()["PipeResult"] = result + return (get_shapes(result), get_dtype(result)) + except Exception as e: + print(f"failboat {e} {type(e)}") + import traceback + + print(f"format {traceback.format_exc()}") + raise e + + return None + + +def send_result(training: bool) -> None: + dprint(f"send result {training}") + group = get_pipeline_parallel_group() + set_device_based_on_group(group) + try: + dprint(f"send result {torch.distributed.get_rank()}, {torch.cuda.current_device()}") + result = globals()["PipeResult"] + model = globals()["PipeModel"] + + if isinstance(result, torch.Tensor): + result = [result] + + dest = _get_global_rank(group, 0) + print( + f"ho har {torch.distributed.get_rank()} " + str([_get_global_rank(group, i) for i in range(group.size())]) + ) + torch.cuda.current_stream().synchronize() + for r in result: + dprint(f">>> send {torch.distributed.get_rank()}, {dest} {r.shape}, {r.dtype}, {r.device}") + if "Gloo" in group.__class__.__name__: + r = r.cpu() + torch.distributed.send(r.contiguous(), dest, group=group) + dprint(f"<<< send {torch.distributed.get_rank()}, {dest}") + torch.cuda.current_stream().synchronize() + + if training: + grads = [] + for r in result: + g = torch.empty(r.shape).cuda() + dprint(f">>> recv grads {g.shape}") + torch.cuda.current_stream().synchronize() + torch.distributed.recv(g, dest, group=group) + torch.cuda.current_stream().synchronize() + if "Gloo" in group.__class__.__name__: + g = g.cuda() + dprint(f"<<< recv grads {g.shape}") + grads.append(g) + + with model.lock: + print(f" >>> autograd-backward tail") + torch.autograd.backward(result, tuple(grads), retain_graph=True) + print(f" <<< autograd-backward tail") + torch.cuda.synchronize() + + except Exception as e: + print(f"got {e}") + + +def recv_result(shapes: SizeOrSizes, dtypes: DtypeOrDtypes) -> TensorOrTensors: + group = get_pipeline_parallel_group() + set_device_based_on_group(group) + src = torch.distributed.distributed_c10d._get_global_rank(group, group.size() - 1) + dprint(f"recv_result... {src}, {torch.cuda.current_device()}") + + if isinstance(shapes, torch.Size): + shape = cast(torch.Size, shapes) + dtype = cast(torch.dtype, dtypes) + t = torch.empty(shape, dtype=dtype).cuda() + dprint(f">>> recv {torch.distributed.get_rank()}, {src} {t.shape}, {t.dtype}") + torch.cuda.current_stream().synchronize() + torch.distributed.recv(t, src, group=group) + torch.cuda.current_stream().synchronize() + if "Gloo" in group.__class__.__name__: + t = t.cuda() + dprint(f"<<< recv {torch.distributed.get_rank()}, {src}") + dprint(f"recvd solo") + return t + else: + result = [] + torch.cuda.current_stream().synchronize() + shapes = cast(List[torch.Size], shapes) + dtypes = cast(List[torch.dtype], dtypes) + for s, d in zip(shapes, dtypes): + t = torch.empty(s, dtype=d).cuda() + dprint(f">>> recv {torch.distributed.get_rank()}, {src} {t.shape}, {t.dtype}") + torch.distributed.recv(t, src, group=group) + if "Gloo" in group.__class__.__name__: + t = t.cuda() + dprint(f"<<< recv {torch.distributed.get_rank()}, {src}") + dprint(f"recvd multi / {len(shapes)}") + result.append(t) + torch.cuda.current_stream().synchronize() + return tuple(result) + + +def get_global_ranks_from_group(group: ProcessGroup) -> List[int]: + return [torch.distributed.distributed_c10d._get_global_rank(group, r) for r in range(group.size())] + + +def run_model(model: Pipe, tensor: TensorOrTensors, event: Event, lock: Lock) -> None: + t = model.training + with lock: + print(f">> run_model thread {t}") + assert model.group + set_device_based_on_group(model.group) + torch.cuda.current_stream().synchronize() + torch.cuda.synchronize() + model(tensor, event=event) + torch.cuda.synchronize() + torch.cuda.current_stream().synchronize() + print(f"<< run_model thread {t}") + + +class PipeBackRedirect(torch.autograd.Function): + @staticmethod + # type: ignore + def forward(ctx, inputs, dest, event): + ctx.dest = dest + ctx.event = event + return inputs + + @staticmethod + # type: ignore + def backward(ctx, *grad): + dprint(f">>> back hook yay") + group = get_pipeline_parallel_group() + torch.cuda.current_stream().synchronize() + for g in grad: + dprint(f">>> back send {g.shape}") + if "Gloo" in group.__class__.__name__: + g = g.cpu() + torch.distributed.send(g, ctx.dest, group=group) + dprint(f"<<< back send") + torch.cuda.current_stream().synchronize() + ctx.event.set() + dprint(f"<<< back hook yay") + return (None, None, None, None) + + +def callback_with_model(callback: Callable, ctx: Any) -> None: + group = get_pipeline_parallel_group() # FIXME(tom) handle dynamic group + set_device_based_on_group(group) + + global PipeModel + + torch.cuda.current_stream().synchronize() + with PipeModel.lock: + callback(ctx, PipeModel) + torch.cuda.current_stream().synchronize() + + +class PipeRPCWrapper(nn.Module): + def __init__(self, *args: Any, **kwargs: Any): + super().__init__() + self.group = cast(ProcessGroup, kwargs.get("group")) or get_pipeline_parallel_group() + assert self.group.rank() == 0 + self.lock = Lock() + + if True: + assert ( + self.group == get_pipeline_parallel_group() + ), "Can't pickle groups, so group must be `get_pipeline_parallel_group()`" + kwargs["group"] = None + else: + kwargs["group"] = self.group + + kwargs["style"] = Pipe.AsyncSchedule + kwargs["input_device"] = torch.device("cuda", torch.cuda.current_device()) + + self.model = Pipe(*args, **kwargs) + self.worker_map = kwargs["worker_map"] + print(f"calling rpc {args}, {kwargs}") + futures = [ + # FIXME get global rank + rpc.rpc_async(self.get_rpc_name(rank), register_remote_model, args=(args, kwargs)) + for rank in range(1, self.group.size()) + ] + futures = [f.wait() for f in futures] + self.model.cuda() + + def get_rpc_name(self, rank: int) -> str: + return self.worker_map[_get_global_rank(self.group, rank)] + + def foreach_worker(self, callback: Callable, ctx: Any = None, *, include_self: bool = False) -> None: + futures = [ + rpc.rpc_async(self.get_rpc_name(rank), callback_with_model, args=(callback, ctx)) + for rank in range(1, self.group.size()) + ] + futures = [f.wait() for f in futures] + if include_self: + with self.model.lock: + callback(ctx, self.model) + + def forward(self, tensor: TensorOrTensors) -> TensorOrTensors: # type: ignore + shape = get_shapes(tensor) + dtype = get_dtype(tensor) + + futures = [ + rpc.rpc_async(self.get_rpc_name(rank), model_forward, args=(self.model.training, shape, dtype)) + for rank in range(1, self.group.size()) + ] + + if self.model.final_stage: + return self.model(tensor) + else: + event = Event() + t = Thread(target=run_model, args=(self.model, tensor, event, self.lock)) + t.start() + + dprint("forward before wait recv") + shape, dtype = futures[-1].wait() + dprint("forward after wait recv") + dest_rank = self.group.size() - 1 + dest = self.get_rpc_name(dest_rank) + dprint(f"async to {dest}") + rpc.rpc_async(dest, send_result, args=(self.model.training,)) + dprint(">>> recv_result") + result = recv_result(shape, dtype) + dprint("<<< recv_result") + # event.set() + dprint("not set event") + try: + if isinstance(result, torch.Tensor): + result.requires_grad_() + else: + for r in result: + r.requires_grad_() + + applied = PipeBackRedirect.apply(result, _get_global_rank(self.group, dest_rank), event) + except Exception as e: + dprint(f"failed got {e}") + dprint("return applied") + return applied + + @property + def final_stage(self) -> bool: + return self.model.final_stage diff --git a/fairscale/nn/pipe/types.py b/fairscale/nn/pipe/types.py new file mode 100644 index 000000000..ae5d09d92 --- /dev/null +++ b/fairscale/nn/pipe/types.py @@ -0,0 +1,65 @@ +from enum import Enum, auto +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +from dataclasses import dataclass +import torch +from torch import Tensor, nn + +ACTIVATIONS_GRADS_QUEUE = 0 +SKIP_TENSOR_QUEUE = 1 +PORTAL_QUEUE = 2 +EVENT_LOOP_QUEUE = 3 +MESSAGE_GENERATION_START = 4 + +MessageGeneration = MESSAGE_GENERATION_START + +Tensors = Tuple[Tensor, ...] +TensorOrTensors = Union[Tensor, Tensors] + +InputDevice = Union[None, int, str, torch.device] +Schedule = List[Tuple[int, int]] + + +class LazyModule: + def __init__(self, function: Callable[[], nn.Module]): + self.function = function + + def __call__(self) -> nn.Module: + return self.function() + + +class PipelineStyle(Enum): + SingleProcess = auto() + MultiProcess = auto() + AsyncSchedule = auto() + + +@dataclass(frozen=True) +class TransportConfig: + use_rpc: bool + worker_map: Optional[Dict[int, str]] + + +@dataclass(init=False) +class PipeMessage: + src: int + dest: int + queue_name: int + args: Any + tensors: Tensors + tensor_shapes: List[torch.Size] + tensor_dtypes: List[torch.dtype] + tag: int = 0 + + def __init__(self, src: int, dest: int, queue_name: int, args: Any, tensors: Tensors): + self.src = src + self.dest = dest + self.queue_name = queue_name + self.args = args + self.tensors = tensors + self.tensor_shapes = [] + self.tensor_dtypes = [] + + global MessageGeneration + self.tag = MessageGeneration + MessageGeneration += len(tensors) diff --git a/fairscale/optim/oss.py b/fairscale/optim/oss.py index d1bdff97a..57d94d740 100644 --- a/fairscale/optim/oss.py +++ b/fairscale/optim/oss.py @@ -422,7 +422,7 @@ def get_global_rank(group: Any, rank: int) -> int: if group is dist.group.WORLD: return rank else: - global_rank = dist.distributed_c10d._get_global_rank(group, rank) # type: ignore + global_rank = dist.distributed_c10d._get_global_rank(group, rank) return global_rank def _broadcast_params(self, buffers: List[torch.Tensor], per_rank_params: List[List[Parameter]]) -> None: diff --git a/run_mpi_tests.sh b/run_mpi_tests.sh index d3672996c..a5abdbaed 100755 --- a/run_mpi_tests.sh +++ b/run_mpi_tests.sh @@ -1,6 +1,10 @@ #!/bin/bash - set -e -for WORKERS in {1..5}; do - mpirun -n $WORKERS python -m pytest tests/nn/pipe_process +rpc_tests=$(pytest --collect-only | grep 'Function.*rpc' | cut -d' ' -f 6 | tr -d '>') + +for WORKERS in {1..6}; do + mpirun -n $WORKERS -mca orte_base_help_aggregate 0 python -m pytest tests/nn/pipe_process -k "not rpc" + for test_name in $rpc_tests; do + mpirun -n $WORKERS -mca orte_base_help_aggregate 0 python -m pytest tests/nn/pipe_process -k $test_name + done done diff --git a/stubs/torch/__init__.pyi b/stubs/torch/__init__.pyi index 904780526..b999dccbf 100644 --- a/stubs/torch/__init__.pyi +++ b/stubs/torch/__init__.pyi @@ -35,7 +35,7 @@ from . import version #END class dtype: - is_floating_point: bool + is_floating_point: builtins.bool class layout: ... @@ -277,7 +277,7 @@ class Tensor: def atan2(self, other: Tensor) -> Tensor: ... def atan2_(self, other: Tensor) -> Tensor: ... def atan_(self) -> Tensor: ... - def backward(self, gradient: Optional[Tensor]=None, keep_graph: _bool=False, create_graph: _bool=False) -> None: ... + def backward(self, gradient: Optional[Tensor]=None, retain_graph: _bool=False, create_graph: _bool=False) -> None: ... def baddbmm(self, batch1: Tensor, batch2: Tensor, *, beta: Number=1, alpha: Number=1) -> Tensor: ... def baddbmm_(self, batch1: Tensor, batch2: Tensor, *, beta: Number=1, alpha: Number=1) -> Tensor: ... @overload diff --git a/stubs/torch/cuda/__init__.pyi b/stubs/torch/cuda/__init__.pyi index 940eeced4..ed437d685 100644 --- a/stubs/torch/cuda/__init__.pyi +++ b/stubs/torch/cuda/__init__.pyi @@ -29,7 +29,7 @@ _device_t = Union[_device, int, str] def check_error(res: int) -> None: ... def device_count() -> int: ... def empty_cache() -> None: ... -def synchronize(device: _device_t) -> None: ... +def synchronize(device: Optional[_device_t]=None) -> None: ... def set_device(device: _device_t) -> None: ... def get_device_capability(device: Optional[_device_t]=...) -> Tuple[int, int]: ... def get_device_name(device: Optional[_device_t]=...) -> str: ... diff --git a/stubs/torch/distributed/__init__.pyi b/stubs/torch/distributed/__init__.pyi index 160ee999c..71a23ea01 100644 --- a/stubs/torch/distributed/__init__.pyi +++ b/stubs/torch/distributed/__init__.pyi @@ -5,6 +5,7 @@ from torch import Tensor import datetime from . import rpc as rpc +from . import distributed_c10d as distributed_c10d class Backend: GLOO: str diff --git a/stubs/torch/distributed/distributed_c10d.pyi b/stubs/torch/distributed/distributed_c10d.pyi new file mode 100644 index 000000000..b8543cbb8 --- /dev/null +++ b/stubs/torch/distributed/distributed_c10d.pyi @@ -0,0 +1,7 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +from typing import Any, List, Union, Optional + +from . import ProcessGroup + +def _get_global_rank(group: ProcessGroup, rank: int) -> int: ... diff --git a/stubs/torch/distributed/rpc/__init__.pyi b/stubs/torch/distributed/rpc/__init__.pyi index 5267fedd2..d278c5520 100644 --- a/stubs/torch/distributed/rpc/__init__.pyi +++ b/stubs/torch/distributed/rpc/__init__.pyi @@ -1,6 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. from typing import Union, Callable, Optional +from torch.futures import Future class RRef: @@ -17,7 +18,7 @@ def rpc_async( args: Optional[tuple] = None, kwargs: Optional[dict] = None, timeout=-1.0, -) -> None: +) -> Future: ... diff --git a/stubs/torch/futures.pyi b/stubs/torch/futures.pyi new file mode 100644 index 000000000..86b47606c --- /dev/null +++ b/stubs/torch/futures.pyi @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. + +from typing import Any + +class Future: + def wait(self) -> Any: ... diff --git a/stubs/torch/nn/__init__.pyi b/stubs/torch/nn/__init__.pyi index e186119ef..9dc050501 100644 --- a/stubs/torch/nn/__init__.pyi +++ b/stubs/torch/nn/__init__.pyi @@ -1,6 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. -from .modules import * +from .modules import * from .parameter import Parameter as Parameter from .parallel import DataParallel as DataParallel from . import functional as functional diff --git a/tests/nn/model_parallel/commons.py b/tests/nn/model_parallel/commons.py index 23c66de5d..f4bc34250 100644 --- a/tests/nn/model_parallel/commons.py +++ b/tests/nn/model_parallel/commons.py @@ -21,6 +21,7 @@ import functools import inspect +import multiprocessing import os import random @@ -35,6 +36,12 @@ from fairscale.nn.model_parallel import initialize_model_parallel from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed +try: + import torch_ucc # noqa: F401 +except ImportError as e: + print(f"can't import torch_ucc: {e}") + pass + class IdentityLayer(torch.nn.Module): def __init__(self, size, scale=1.0): @@ -100,10 +107,20 @@ def spawn_for_all_world_sizes(test_func, world_sizes=get_world_sizes(), args=[]) mp.spawn(test_func, args=(world_size, *args), nprocs=world_size, join=True) -def helper(rank, world_size, func, args): +def helper(rank, world_size, func, args, error_queue): dist_init(rank, world_size) - initialize_model_parallel(1, world_size) - func(*args) + kwargs = {} + if "OMPI_COMM_WORLD_RANK" not in os.environ: + kwargs["pipeline_backend"] = "gloo" + print(f"init glooo backend") + initialize_model_parallel(1, world_size, **kwargs) + try: + func(*args) + except BaseException as e: + if e.__class__.__name__ == "Skipped": + error_queue.put(str(e)) + return + raise e def torch_spawn(world_sizes=None): @@ -128,7 +145,12 @@ def replacement(*args, **kwargs): kwargs[p] for p in parameters if p != "rank" ) # converting named parameters to positional parameters to pass to `spawn` + error_queue = multiprocessing.get_context("spawn").SimpleQueue() if "OMPI_COMM_WORLD_RANK" in os.environ: + os.environ["RANK"] = os.environ["OMPI_COMM_WORLD_RANK"] + os.environ["WORLD_SIZE"] = os.environ["OMPI_COMM_WORLD_SIZE"] + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "10638" torch.distributed.init_process_group("mpi") world_size = torch.distributed.get_world_size() initialize_model_parallel(1, world_size) @@ -138,7 +160,11 @@ def replacement(*args, **kwargs): else: pytest.skip(f"requested world size doesn't match current world size") else: - spawn_for_all_world_sizes(helper, world_sizes, (func, args)) + spawn_for_all_world_sizes(helper, world_sizes, (func, args, error_queue)) + + if not error_queue.empty(): + msg = error_queue.get() + pytest.skip(msg) caller_module = inspect.getmodule(inspect.currentframe().f_back) setattr(caller_module, f"test_{name}", replacement) diff --git a/tests/nn/model_parallel/test_initialize.py b/tests/nn/model_parallel/test_initialize.py index f1df1b598..bc7477771 100644 --- a/tests/nn/model_parallel/test_initialize.py +++ b/tests/nn/model_parallel/test_initialize.py @@ -110,7 +110,7 @@ def is_initialized(self): def get_world_size(self): return data_parallel_size * pipeline_length * model_parallel_size - def new_group(self, args): + def new_group(self, args, backend=None): new_groups.append(args.copy()) return () diff --git a/tests/nn/model_parallel/test_layers.py b/tests/nn/model_parallel/test_layers.py index 2c97c5a7c..314321a7d 100644 --- a/tests/nn/model_parallel/test_layers.py +++ b/tests/nn/model_parallel/test_layers.py @@ -307,6 +307,10 @@ def run_test_row_parallel_linear(rank, model_parallel_size): print(" >> passed the test :-)") +def a_barrier(): + torch.distributed.barrier() + + def run_test_pipe(rank, world_size, skip_dist_init=False): pipe_world_size = 2 @@ -413,17 +417,17 @@ def forward_model(model_, target, step=False): reference_output = forward_model(reference, target) error = reference_output.sub(output).max() - torch.distributed.barrier() + a_barrier() assert error < 1.0e-6 output = forward_model(model, target) error = reference_output.sub(output).max() - torch.distributed.barrier() + a_barrier() assert error < 1.0e-6 output = forward_model(model, target) error = reference_output.sub(output).max() - torch.distributed.barrier() + a_barrier() assert error < 1.0e-6 check_weights(model, reference, "before") @@ -436,6 +440,7 @@ def forward_model(model_, target, step=False): model[2].weight.data = saved_weight_2 worker_map = {i: f"Test{i}" for i in range(torch.distributed.get_world_size())} + style = Pipe.MultiProcess # Pipe.AsyncSchedule if pipe_world_size == 2: print(f"actually doing pipe stuff now") @@ -444,14 +449,14 @@ def forward_model(model_, target, step=False): pipe_model = Pipe( model, [2, 1], - style=Pipe.MultiProcess, + style=style, group=pipeline_devices, worker_map=worker_map, input_device=torch.cuda.current_device(), chunks=chunk_size, pipelined_backward=True, ).cuda() - torch.distributed.barrier() + a_barrier() pipe_rank = torch.distributed.get_rank(group=mpu.get_pipeline_parallel_group()) print(f"pipe rank is {pipe_rank}") if pipe_rank == 0: @@ -474,6 +479,7 @@ def forward_model(model_, target, step=False): else: check_weights(model, reference, "pre-pipe", index=0) + # pipe_mode.eval() pipe_output = pipe_model(identity()) print(f"exited pipe for {rank}") forward_model(reference, target, step=True) @@ -481,7 +487,8 @@ def forward_model(model_, target, step=False): print(f"pipe_output {rank} = {pipe_output}") print(f"reference_output {rank} = {reference_output}") - torch.distributed.barrier() + if style == Pipe.MultiProcess: + a_barrier() if torch.distributed.get_rank(mpu.get_pipeline_parallel_group()) == 1: error = reference_output.sub(pipe_output.cuda()).max() @@ -511,7 +518,8 @@ def forward_model(model_, target, step=False): failed = False with torch.autograd.profiler.profile() as prof: try: - pipe_model.back_helper(pipe_output) + if style == Pipe.MultiProcess: + pipe_model.back_helper(pipe_output) except Exception as e: failed = True print(f"got {e} while doing backward, deadlock?") @@ -525,15 +533,16 @@ def forward_model(model_, target, step=False): print(f"waiting for barrier on slave") pipe_model.zero_grad() - torch.distributed.barrier() + a_barrier() + pipe_model.eval() pipe_output = pipe_model(identity()) updated_ref_output = forward_model(reference, target) if torch.distributed.get_rank(mpu.get_pipeline_parallel_group()) == 1: error = updated_ref_output.sub(pipe_output.cuda()).max() print(f"outputs are ref:\n{updated_ref_output}\npipe:\n{pipe_output}") assert error < 1.0e-6 - torch.distributed.barrier() + a_barrier() print(f"finished waiting for barrier on, pid={os.getpid()}") diff --git a/tests/nn/moe/test_moe_layer.py b/tests/nn/moe/test_moe_layer.py index bc569eebb..77b4b1c3d 100644 --- a/tests/nn/moe/test_moe_layer.py +++ b/tests/nn/moe/test_moe_layer.py @@ -23,12 +23,14 @@ os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "29501" if "OMPI_COMM_WORLD_SIZE" in os.environ: - dist.init_process_group(backend=dist.Backend.MPI) + pass # dist.init_process_group(backend=dist.Backend.MPI) def setup_module(module): if "OMPI_COMM_WORLD_SIZE" not in os.environ: dist.init_process_group(backend=BACKEND, rank=0, world_size=1) + else: + dist.init_process_group(backend=dist.Backend.MPI) def teardown_module(module): diff --git a/tests/nn/pipe_process/conftest.py b/tests/nn/pipe_process/conftest.py index ef91fbaf0..82af2e3a0 100644 --- a/tests/nn/pipe_process/conftest.py +++ b/tests/nn/pipe_process/conftest.py @@ -65,3 +65,7 @@ def pytest_runtest_teardown(item): destroy_model_parallel() if torch.distributed.is_initialized(): torch.distributed.destroy_process_group() + try: + torch.distributed.rpc.shutdown() + except Exception: + pass diff --git a/tests/nn/pipe_process/skip/test_gpipe.py b/tests/nn/pipe_process/skip/test_gpipe.py index 4a4510749..dc10c0a19 100644 --- a/tests/nn/pipe_process/skip/test_gpipe.py +++ b/tests/nn/pipe_process/skip/test_gpipe.py @@ -23,7 +23,7 @@ import torch from torch import nn -from fairscale.nn.pipe import Pipe +from fairscale.nn.pipe import LazyModule, Pipe from fairscale.nn.pipe.skip import pop, skippable, stash from fairscale.nn.pipe.skip.portal import PortalBlue, PortalCopy, PortalOrange from tests.nn.model_parallel.commons import get_worker_map, torch_spawn @@ -33,10 +33,15 @@ @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") @pytest.mark.parametrize("balance", [[3], [1, 2], [2, 1], [1, 1, 1]], ids=["3", "1:2", "2:1", "1:1:1"]) @pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi") -def x1to3(balance, checkpoint): +def x1to3(balance, checkpoint, pipeline_style): torch.manual_seed(0) + if pipeline_style == Pipe.AsyncSchedule and len(balance) > 1: + print(f"skipping yarg") + pytest.skip("Skip tensors NYI for AsyncSchedule") + @skippable(stash=["1to3"]) class Layer1(nn.Module): def __init__(self): @@ -75,7 +80,7 @@ def forward(self, input): chunks=3, checkpoint=checkpoint, input_device=torch.cuda.current_device(), - style=Pipe.MultiProcess, + style=pipeline_style, worker_map=get_worker_map(), pipelined_backward=False, ).cuda() @@ -101,7 +106,11 @@ def forward(self, input): @torch_spawn([2]) @pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -def none_skip(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def none_skip(pipeline_style): + if pipeline_style == Pipe.AsyncSchedule: + pytest.skip("Skip tensors NYI for AsyncSchedule") + @skippable(stash=["none"]) class Stash(nn.Module): def forward(self, input): @@ -119,7 +128,7 @@ def forward(self, input): model = Pipe( model, [1, 1], - style=Pipe.MultiProcess, + style=pipeline_style, worker_map=get_worker_map(), input_device=torch.cuda.current_device(), chunks=5, @@ -151,7 +160,8 @@ def assert_grad_fn_is_not_portal(grad_fn, visited=set()): @torch_spawn([2]) -def lazy_skippable_error(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def lazy_skippable_error(pipeline_style): """Using skippable layers in combination with lazy construction is currently not supported, check that it raises an Exception""" @@ -163,9 +173,13 @@ class Layer1(nn.Linear): class Layer3(nn.Linear): pass - model = [lambda: Layer1(10, 10), lambda: nn.Linear(10, 10), lambda: Layer3(10, 10)] + model = [ + LazyModule(lambda: Layer1(10, 10)), + LazyModule(lambda: nn.Linear(10, 10)), + LazyModule(lambda: Layer3(10, 10)), + ] with pytest.raises(ValueError, match="Can't use Skippable layers with multi-process pipe and lazy construction"): Pipe( - model, [2, 1], style=Pipe.MultiProcess, worker_map=get_worker_map(), + model, [2, 1], style=pipeline_style, worker_map=get_worker_map(), ) diff --git a/tests/nn/pipe_process/skip/test_leak.py b/tests/nn/pipe_process/skip/test_leak.py index 78501b531..67fccb009 100644 --- a/tests/nn/pipe_process/skip/test_leak.py +++ b/tests/nn/pipe_process/skip/test_leak.py @@ -46,9 +46,10 @@ def forward(self, input): @torch_spawn([2]) @pytest.mark.parametrize("train", [True, False], ids=["train", "eval"]) @pytest.mark.parametrize("checkpoint", ["always", "except_last", "never"]) +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) @pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="broken on mpi") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -def delete_portal_tensor(train, checkpoint): +def delete_portal_tensor(train, checkpoint, pipeline_style): # Without checkpointing: # +- Stash --+ +--- Pop ----+ - - - layers # | 2,blue,1 |--| 1,orange,0 | - - - tensor_life and portal function @@ -59,6 +60,9 @@ def delete_portal_tensor(train, checkpoint): # | 3,blue,2 |--| 2,orange,1 |--| 1,orange,0 |--| 1,blue,0 | # +----------+ +------------+ +------------+ +----------+ + if pipeline_style == Pipe.AsyncSchedule: + pytest.skip("Skip tensors NYI for AsyncSchedule") + def portal_tensor_life_is(tensor_life, skip_tracker=None): if skip_tracker is None: skip_tracker = current_skip_tracker() @@ -111,7 +115,7 @@ def forward(self, input): model = nn.Sequential(NoPortalTensorAtBackward(), stash_, pop_) model = Pipe( - model, balance=[2, 1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=2, checkpoint=checkpoint, + model, balance=[2, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, checkpoint=checkpoint, ) input = torch.rand(10, requires_grad=True) diff --git a/tests/nn/pipe_process/test_bugs.py b/tests/nn/pipe_process/test_bugs.py index a79e4bd6a..b192704dc 100644 --- a/tests/nn/pipe_process/test_bugs.py +++ b/tests/nn/pipe_process/test_bugs.py @@ -28,7 +28,9 @@ @torch_spawn([2]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -def python_autograd_function(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def python_autograd_function(pipeline_style): + # FIXME deadlock with Pipe.AsyncSchedule? # A Python autograd function might fail with this error: # # RuntimeError: Returning Variables sharing storage with other Variables @@ -55,7 +57,8 @@ def forward(self, input): return Identity.apply(input) model = nn.Sequential(M(), M()) - model = Pipe(model, [1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map(), checkpoint="always").cuda() + model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always").cuda() + model.eval() x = torch.rand(42) y = model(x) @@ -67,7 +70,8 @@ def forward(self, input): @torch_spawn([3]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -def exception_no_hang(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def exception_no_hang(pipeline_style): # In v0.0.2, once a failed partition receives a normal message # (non-closing) for the next micro-batch, a hang occured. The reason was # that a failed partition didn't call in_queue.task_done() on a normal @@ -85,7 +89,8 @@ def forward(self, x): raise ExpectedException() model = nn.Sequential(Pass(), Pass(), Raise()) - model = Pipe(model, [1, 1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=3) + model = Pipe(model, [1, 1, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=3) + model.eval() if model.group.rank() == 2: with pytest.raises(ExpectedException): @@ -98,7 +103,8 @@ def forward(self, x): @torch_spawn([2]) @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="2 cuda devices required") -def tuple_wait(cuda_sleep): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def tuple_wait(cuda_sleep, pipeline_style): # In v0.0.3, Wait is applied to only the first tensor on a micro-batch. # Under this behavior, if checkpointing was disabled, there's a possibility # that gradient accumulations on other tensors are not synchronized @@ -129,7 +135,7 @@ def forward(self, triple): model = Pipe( model, [1, 1], - style=Pipe.MultiProcess, + style=pipeline_style, worker_map=get_worker_map(), input_device=torch.cuda.current_device(), chunks=32, @@ -151,7 +157,8 @@ def forward(self, triple): @torch_spawn([2]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -def parallel_randoms(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def parallel_randoms(pipeline_style): class Dropouts(nn.Module): def forward(self, x): for _ in range(100): @@ -165,7 +172,7 @@ def forward(self, x): model = Pipe( model, [1, 1], - style=Pipe.MultiProcess, + style=pipeline_style, input_device=torch.cuda.current_device(), worker_map=get_worker_map(), chunks=10, diff --git a/tests/nn/pipe_process/test_inplace.py b/tests/nn/pipe_process/test_inplace.py index 881c6cefb..7cf24558e 100644 --- a/tests/nn/pipe_process/test_inplace.py +++ b/tests/nn/pipe_process/test_inplace.py @@ -27,11 +27,17 @@ @torch_spawn([2]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -def inplace_on_requires_grad(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def inplace_on_requires_grad(pipeline_style): model = nn.Sequential(nn.Linear(1, 1), nn.ReLU(inplace=True)) - model = Pipe(model, [1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map(), checkpoint="always") + model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always") x = torch.rand(1) + + if pipeline_style == Pipe.AsyncSchedule and model.group.rank() == 0: + # With AsyncSchedule, model will wait forever for gradients if not eval + model.eval() + y = model(x) message = r"a leaf Variable that requires grad .* used in an in-place operation." @@ -44,11 +50,12 @@ def inplace_on_requires_grad(): @torch_spawn([1]) @pytest.mark.xfail(strict=True) -def inplace_on_not_requires_grad(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def inplace_on_not_requires_grad(pipeline_style): # In-place operation on a tensor not requiring grad doesn't cause a # RuntimeError. Currently, we cannot detect this case. model = nn.Sequential(nn.ReLU(inplace=True)) - model = Pipe(model, [1], style=Pipe.MultiProcess, worker_map=get_worker_map(), checkpoint="always") + model = Pipe(model, [1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always") x = torch.rand(1) y = model(x) @@ -63,7 +70,8 @@ def inplace_on_not_requires_grad(): @torch_spawn([1]) @pytest.mark.xfail(strict=True) -def inplace_incorrect_grad(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def inplace_incorrect_grad(pipeline_style): class M(nn.Module): def forward(self, foo_bar): # 'foo' requires grad but 'bar' does not. In-place operation on @@ -80,7 +88,7 @@ def forward(self, foo_bar): return foo * bar model = nn.Sequential(M()) - model = Pipe(model, [1], style=Pipe.MultiProcess, worker_map=get_worker_map(), checkpoint="always") + model = Pipe(model, [1], style=pipeline_style, worker_map=get_worker_map(), checkpoint="always") foo = torch.tensor([1.0], requires_grad=True) bar = torch.tensor([1.0]) diff --git a/tests/nn/pipe_process/test_pipe.py b/tests/nn/pipe_process/test_pipe.py index 59ce5d534..24cfceb22 100644 --- a/tests/nn/pipe_process/test_pipe.py +++ b/tests/nn/pipe_process/test_pipe.py @@ -21,21 +21,27 @@ from copy import deepcopy import os import time +from typing import Tuple from packaging import version import pytest import torch from torch import nn -from fairscale.nn.model_parallel.initialize import destroy_model_parallel, initialize_model_parallel -from fairscale.nn.pipe import Pipe -from tests.nn.model_parallel.commons import get_worker_map, torch_spawn +from fairscale.nn.model_parallel.initialize import ( + destroy_model_parallel, + get_pipeline_parallel_group, + initialize_model_parallel, +) +from fairscale.nn.pipe import LazyModule, Pipe +from tests.nn.model_parallel.commons import get_worker_map, set_random_seed, torch_spawn @torch_spawn([2]) -def parameters(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def parameters(pipeline_style): model = nn.Sequential(nn.Linear(1, 1)) - pipe = Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=1) + pipe = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1) if torch.distributed.get_rank() == 0: assert list(pipe.parameters()) != [] else: @@ -62,10 +68,10 @@ def infiniband(): def infiniband2(): if torch.distributed.get_rank() == 0: t = torch.Tensor(range(100)).cuda() - torch.distributed.send(t, 1) + torch.distributed.send(t, 1, group=get_pipeline_parallel_group()) else: t = torch.empty(100).cuda() - torch.distributed.recv(t, 0) + torch.distributed.recv(t, 0, group=get_pipeline_parallel_group()) assert torch.equal(t, torch.Tensor(range(100)).cuda()) print(f"t on {torch.distributed.get_rank()} is {t}") @@ -87,7 +93,6 @@ def mpi(): torch.cuda.manual_seed(seed) torch.distributed.barrier() - group = torch.distributed.new_group([0, 1]) tensor_size = (1024, 1024, 10) torch.cuda.set_device(torch.distributed.get_rank()) # need to pin device or ucx gets unhappy @@ -104,7 +109,8 @@ def mpi(): @torch_spawn([1]) -def public_attrs(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def public_attrs(pipeline_style): class MyString: def __init__(self, value): self.value = value @@ -117,7 +123,7 @@ def __str__(self): pipe = Pipe( model, balance=(1,), - style=Pipe.MultiProcess, + style=pipeline_style, worker_map=get_worker_map(), chunks=42.000, checkpoint=MyString("always"), @@ -134,12 +140,13 @@ def __str__(self): @torch_spawn([2]) @pytest.mark.parametrize("balance", [[2], [1, 1]]) -def sequential_like(balance): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def sequential_like(balance, pipeline_style): a = nn.Linear(1, 1) b = nn.Linear(1, 1) model = nn.Sequential(a, b) - model = Pipe(model, balance, style=Pipe.MultiProcess, worker_map=get_worker_map()) + model = Pipe(model, balance, style=pipeline_style, worker_map=get_worker_map()) if balance == [2]: if torch.distributed.get_rank() == 0: @@ -172,57 +179,62 @@ def sequential_like(balance): @torch_spawn([1]) -def balance_wrong_length(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def balance_wrong_length(pipeline_style): a = nn.Linear(1, 1) b = nn.Linear(1, 1) model = nn.Sequential(a, b) with pytest.raises(ValueError): - Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map()) + Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map()) with pytest.raises(ValueError): - Pipe(model, balance=[3], style=Pipe.MultiProcess, worker_map=get_worker_map()) + Pipe(model, balance=[3], style=pipeline_style, worker_map=get_worker_map()) @torch_spawn([2]) -def balance_less_than_1(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def balance_less_than_1(pipeline_style): a = nn.Linear(1, 1) b = nn.Linear(1, 1) model = nn.Sequential(a, b) with pytest.raises(ValueError): - Pipe(model, balance=[0, 2], style=Pipe.MultiProcess, worker_map=get_worker_map()) + Pipe(model, balance=[0, 2], style=pipeline_style, worker_map=get_worker_map()) with pytest.raises(ValueError): - Pipe(model, balance=[-1, 3], style=Pipe.MultiProcess, worker_map=get_worker_map()) + Pipe(model, balance=[-1, 3], style=pipeline_style, worker_map=get_worker_map()) @torch_spawn([1]) -def chunks_less_than_1(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def chunks_less_than_1(pipeline_style): model = nn.Sequential(nn.Linear(1, 1)) with pytest.raises(ValueError): - Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=0) + Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=0) with pytest.raises(ValueError): - Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=-1) + Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=-1) @torch_spawn([1]) -def too_few_devices(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def too_few_devices(pipeline_style): model = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1), nn.Linear(1, 1)) with pytest.raises(IndexError): # len(balance) > len(group.size()) - model = Pipe(model, balance=[1, 1, 1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map()) + model = Pipe(model, balance=[1, 1, 1, 1], style=pipeline_style, worker_map=get_worker_map()) @torch_spawn([1]) -def batch_size_indivisible(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def batch_size_indivisible(pipeline_style): model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=4) + model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=4) with pytest.warns(None) as record: model(torch.rand(7, 1)) @@ -232,9 +244,10 @@ def batch_size_indivisible(): @torch_spawn([1]) -def batch_size_small(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def batch_size_small(pipeline_style): model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=4) + model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=4) with pytest.warns(None) as record: model(torch.rand(2, 1)) @@ -244,7 +257,8 @@ def batch_size_small(): @torch_spawn([1]) -def checkpoint_mode(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def checkpoint_mode(pipeline_style): def count_grad_fn(grad_fn, name, visited=set()): if grad_fn in visited: return 0 @@ -266,7 +280,7 @@ def count_grad_fn(grad_fn, name, visited=set()): always = Pipe( model, balance=[1], - style=Pipe.MultiProcess, + style=pipeline_style, worker_map=get_worker_map(), chunks=2, checkpoint="always", @@ -275,7 +289,7 @@ def count_grad_fn(grad_fn, name, visited=set()): except_last = Pipe( model, balance=[1], - style=Pipe.MultiProcess, + style=pipeline_style, worker_map=get_worker_map(), chunks=2, checkpoint="except_last", @@ -284,7 +298,7 @@ def count_grad_fn(grad_fn, name, visited=set()): never = Pipe( model, balance=[1], - style=Pipe.MultiProcess, + style=pipeline_style, worker_map=get_worker_map(), chunks=2, checkpoint="never", @@ -301,14 +315,15 @@ def count_grad_fn(grad_fn, name, visited=set()): @torch_spawn([1]) -def checkpoint_mode_invalid(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def checkpoint_mode_invalid(pipeline_style): model = nn.Sequential(nn.Linear(1, 1)) with pytest.raises(ValueError, match="checkpoint is not one of 'always', 'except_last', or 'never'"): Pipe( model, balance=[1], - style=Pipe.MultiProcess, + style=pipeline_style, worker_map=get_worker_map(), chunks=2, checkpoint="INVALID_CHECKPOINT", @@ -316,22 +331,24 @@ def checkpoint_mode_invalid(): @torch_spawn([1]) -def checkpoint_mode_when_chunks_1(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def checkpoint_mode_when_chunks_1(pipeline_style): model = nn.Sequential(nn.Linear(1, 1)) # All checkpoint modes are fine. Pipe( - model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=1, checkpoint="except_last", + model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint="except_last", ) - Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=1, checkpoint="always") - Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=1, checkpoint="never") + Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint="always") + Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint="never") @torch_spawn([1]) -def checkpoint_eval(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def checkpoint_eval(pipeline_style): model = nn.Sequential(nn.Linear(1, 1)) model = Pipe( - model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=2, pipelined_backward=False, + model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, pipelined_backward=False, ) input = torch.rand(2, 1) @@ -356,11 +373,16 @@ def find_grad_fn(grad_fn, name): assert not find_grad_fn(eval_output.grad_fn, "RecomputeBackward") +def torch_version() -> Tuple[int, ...]: + result = version.parse(torch.__version__).release + assert result + return result + + @torch_spawn([2]) -@pytest.mark.xfail( - version.parse(torch.__version__) < version.parse("1.6.0"), reason="Doesn't work on torch < 1.6.0", strict=True -) -def checkpoint_non_float_input(): +@pytest.mark.xfail(torch_version() < (1, 6, 0), reason="Doesn't work on torch < 1.6.0", strict=True) +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def checkpoint_non_float_input(pipeline_style): class ForkNonFloat(nn.Module): def forward(self, input): return (input * 2, torch.tensor([False])) @@ -373,7 +395,7 @@ def forward(self, input): model = Pipe( model, balance=[1, 1], - style=Pipe.MultiProcess, + style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint="always", @@ -385,14 +407,17 @@ def forward(self, input): if model.group.rank() == 1: # with torch.autograd.detect_anomaly(): output.backward() - else: + elif pipeline_style == Pipe.MultiProcess: model.back_helper(output) + torch.distributed.barrier() + @torch_spawn([1]) -def no_grad(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def no_grad(pipeline_style): model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=2) + model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2) input = torch.rand(2, 1) latent = None @@ -404,8 +429,8 @@ def hook(module, input, output): nonlocal latent latent = output - partition = model.partitions[0] - partition.register_forward_hook(hook) + partition = model.mp_partitions[0] + partition.module.register_forward_hook(hook) with torch.no_grad(): model(input) @@ -414,7 +439,8 @@ def hook(module, input, output): @torch_spawn([1]) -def exception(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def exception(pipeline_style): class ExpectedException(Exception): pass @@ -423,7 +449,7 @@ def forward(self, *_): raise ExpectedException() model = nn.Sequential(Raise()) - model = Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=1) + model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=1) with pytest.raises(ExpectedException): model(torch.rand(1)) @@ -432,7 +458,8 @@ def forward(self, *_): # FIXME(tom) should probably signal to all hosts in group to stop @torch_spawn([4]) @pytest.mark.xfail(strict=True) -def exception_early_stop_asap(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def exception_early_stop_asap(pipeline_style): """Even the first partitions have finished to process, the partition before the failed partition hould be killed as soon as possible. """ @@ -460,7 +487,7 @@ def forward(self, x): raise ExpectedException() model = nn.Sequential(Pass(), Pass(), Counter(), Raise()) - model = Pipe(model, [1, 1, 1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=3) + model = Pipe(model, [1, 1, 1, 1], style=pipeline_style, worker_map=get_worker_map(), chunks=3) with pytest.raises(ExpectedException): model(torch.rand(3)) @@ -470,7 +497,8 @@ def forward(self, x): @torch_spawn([1]) -def input_pair(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def input_pair(pipeline_style): class Two(nn.Module): def __init__(self): super().__init__() @@ -483,7 +511,7 @@ def forward(self, a_and_b): model = nn.Sequential(Two()) model = Pipe( - model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=2, pipelined_backward=False, + model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, pipelined_backward=False, ) a = torch.rand(10, 1, requires_grad=True) @@ -498,7 +526,8 @@ def forward(self, a_and_b): @torch_spawn([1]) -def input_singleton(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def input_singleton(pipeline_style): class One(nn.Module): def __init__(self): super().__init__() @@ -510,7 +539,7 @@ def forward(self, only_a): model = nn.Sequential(One()) model = Pipe( - model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=2, pipelined_backward=False, + model, balance=[1], style=pipeline_style, worker_map=get_worker_map(), chunks=2, pipelined_backward=False, ) a = torch.rand(10, 1, requires_grad=True) @@ -524,9 +553,10 @@ def forward(self, only_a): @torch_spawn([1]) -def input_varargs(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def input_varargs(pipeline_style): model = nn.Sequential(nn.Linear(1, 1)) - model = Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map()) + model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map()) a = torch.rand(1) b = torch.rand(1) @@ -537,13 +567,14 @@ def input_varargs(): @torch_spawn([1]) -def non_tensor(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def non_tensor(pipeline_style): class NonTensor(nn.Module): def forward(self, _): return "hello" model = nn.Sequential(NonTensor()) - model = Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map()) + model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map()) x = torch.rand(1) # TypeError: expected Tensor as element 0 in argument 0, but got str @@ -556,13 +587,14 @@ def forward(self, _): @torch_spawn([1]) -def non_tensor_tuple(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def non_tensor_tuple(pipeline_style): class NonTensorTuple(nn.Module): def forward(self, x): return (x, "hello") model = nn.Sequential(NonTensorTuple()) - model = Pipe(model, balance=[1], style=Pipe.MultiProcess, worker_map=get_worker_map()) + model = Pipe(model, balance=[1], style=pipeline_style, worker_map=get_worker_map()) x = torch.rand(1) # TypeError: CheckpointBackward.forward: expected Variable (got str) for return value 1 @@ -577,18 +609,19 @@ def forward(self, x): @torch_spawn([1]) @pytest.mark.parametrize("checkpoint", ["never", "always", "except_last"]) @pytest.mark.parametrize("lazy", [True, False]) -def deferred_batch_norm(checkpoint, lazy): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def deferred_batch_norm(checkpoint, lazy, pipeline_style): bn = nn.BatchNorm2d(3) pipe_bn = deepcopy(bn) pipe_fn = lambda: pipe_bn # noqa: E731 if lazy: - model = [pipe_fn] + model = [LazyModule(pipe_fn)] else: model = nn.Sequential(pipe_bn) pipe = Pipe( model, balance=[1], - style=Pipe.MultiProcess, + style=pipeline_style, worker_map=get_worker_map(), chunks=2, checkpoint=checkpoint, @@ -606,18 +639,19 @@ def deferred_batch_norm(checkpoint, lazy): @torch_spawn([1]) @pytest.mark.parametrize("checkpoint", ["never", "always"]) @pytest.mark.parametrize("lazy", [True, False]) -def deferred_batch_norm_params(checkpoint, lazy): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def deferred_batch_norm_params(checkpoint, lazy, pipeline_style): bn = nn.BatchNorm2d(3) pipe_bn = deepcopy(bn) pipe_fn = lambda: pipe_bn # noqa: E731 if lazy: - model = [pipe_fn] + model = [LazyModule(pipe_fn)] else: model = nn.Sequential(pipe_bn) pipe = Pipe( model, balance=[1], - style=Pipe.MultiProcess, + style=pipeline_style, worker_map=get_worker_map(), chunks=1, checkpoint=checkpoint, @@ -636,14 +670,15 @@ def deferred_batch_norm_params(checkpoint, lazy): @torch_spawn([4]) -def devices(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def devices(pipeline_style): a = nn.Linear(1, 1) b = nn.Linear(1, 1) c = nn.Linear(1, 1) # There are extra two ranks. model = nn.Sequential(a, b, c) - model = Pipe(model, [1, 1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map()) + model = Pipe(model, [1, 1, 1], style=pipeline_style, worker_map=get_worker_map()) # Extra devices must be discarded. if model.group.rank() == 3: @@ -651,28 +686,33 @@ def devices(): @torch_spawn([2]) -def partitions(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def partitions(pipeline_style): a = nn.Linear(1, 1) b = nn.Linear(1, 1) model = nn.Sequential(a, b) - model = Pipe(model, [1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map()) + model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map()) - assert isinstance(model.partitions, nn.ModuleList) + assert isinstance(model.mp_partitions, list) assert len(model) == 1 - assert isinstance(model.partitions[0], nn.Sequential) + assert isinstance(model.mp_partitions[0].module, nn.Sequential) - assert "partitions.0.0.weight" in model.state_dict() + if model.group.rank() == 0: + assert "0.0.weight" in model.state_dict() + else: + assert "0.1.weight" in model.state_dict() @torch_spawn([2]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -def deny_moving(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def deny_moving(pipeline_style): a = nn.Linear(1, 1) b = nn.Linear(1, 1) model = nn.Sequential(a, b) - model = Pipe(model, [1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map()) + model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map()) model.cuda() model.cpu() @@ -690,10 +730,11 @@ def deny_moving(): @torch_spawn([1]) -def empty_module(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def empty_module(pipeline_style): # Empty sequential module is not illegal. model = nn.Sequential() - model = Pipe(model, [], style=Pipe.MultiProcess, worker_map=get_worker_map()) + model = Pipe(model, [], style=pipeline_style, worker_map=get_worker_map()) assert model(torch.tensor([42])) == torch.tensor([42]) assert model((torch.tensor([42]),)) == (torch.tensor([42]),) @@ -705,16 +746,19 @@ def empty_module(): @torch_spawn([2]) -def named_children(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def named_children(pipeline_style): a = nn.Linear(1, 1) b = nn.Linear(1, 1) model = nn.Sequential(OrderedDict([("a", a), ("b", b)])) - model = Pipe(model, [1, 1], devices=["cpu", "cpu"]) + model = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map()) names = set(n for n, _ in model.named_modules()) - assert "partitions.0.a" in names - assert "partitions.1.b" in names + if model.group.rank() == 0: + assert "0.a" in names + else: + assert "0.b" in names # Pipe doesn't support __getattr__. Unlike nn.Sequential, Pipe requires # several methods in its namespace. @@ -723,7 +767,8 @@ def named_children(): @torch_spawn([1]) -def recommend_auto_balance(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def recommend_auto_balance(pipeline_style): with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"): # balance is required Pipe(nn.Sequential()) @@ -737,23 +782,9 @@ def recommend_auto_balance(): Pipe(nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1)), [1]) -@torch_spawn([1]) -def verify_module_non_sequential(): - with pytest.raises(TypeError, match="module must be nn.Sequential to be partitioned"): - Pipe(nn.Module(), [1]) - - -@torch_spawn([1]) -def verify_module_duplicate_children(): - conv = nn.Conv2d(3, 3, 1) - model = nn.Sequential(conv, conv) - - with pytest.raises(ValueError, match="module with duplicate children is not supported"): - Pipe(model, [1, 1]) - - @torch_spawn([2]) -def lazy_construction(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def lazy_construction(pipeline_style): init_count = 0 class Custom(nn.Module): @@ -766,13 +797,13 @@ def forward(self, x): return x model = [ - lambda: Custom(), - lambda: Custom(), - lambda: Custom(), - lambda: Custom(), + LazyModule(lambda: Custom()), + LazyModule(lambda: Custom()), + LazyModule(lambda: Custom()), + LazyModule(lambda: Custom()), ] - pipe = Pipe(model, balance=[2, 2], style=Pipe.MultiProcess, worker_map=get_worker_map()) + pipe = Pipe(model, balance=[2, 2], style=pipeline_style, worker_map=get_worker_map()) assert isinstance(pipe[0], Custom) assert isinstance(pipe[1], Custom) @@ -780,18 +811,20 @@ def forward(self, x): assert init_count == 2 -@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="doesn't apply to mpi") @torch_spawn([2]) -def missing_worker_map(): +@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="doesn't apply to mpi") +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def missing_worker_map(pipeline_style): model = nn.Sequential(nn.ReLU(), nn.ReLU()) with pytest.raises(ValueError, match="'PipelineStyle.MultiProcess' requires 'worker_map' to be set"): - Pipe(model, [1, 1], style=Pipe.MultiProcess) + Pipe(model, [1, 1], style=pipeline_style) @torch_spawn([2]) @pytest.mark.skip(reason="currently broken") -def verify_module_duplicate_parameters_on_distinct_partitions(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def verify_module_duplicate_parameters_on_distinct_partitions(pipeline_style): class Surrogate(nn.Module): def __init__(self, module): super().__init__() @@ -802,21 +835,205 @@ def __init__(self, module): # FIXME(tom) can't have duplicate params with separate processes with pytest.raises(ValueError, match="module with duplicate parameters on distinct devices is not supported"): - Pipe(model, [1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map()) + Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map()) @torch_spawn([4]) -def pipelined_backward(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def pipelined_backward(pipeline_style): model = nn.Sequential(nn.ReLU(), nn.ReLU()) destroy_model_parallel() initialize_model_parallel(1, 4) - pipe = Pipe(model, [1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map()) + pipe = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map()) assert pipe.pipelined_backward is False destroy_model_parallel() initialize_model_parallel(2, 2) - pipe = Pipe(model, [1, 1], style=Pipe.MultiProcess, worker_map=get_worker_map()) + pipe = Pipe(model, [1, 1], style=pipeline_style, worker_map=get_worker_map()) assert pipe.pipelined_backward is True + + +@torch_spawn([4]) +def async_event_loop(): + + model = nn.Sequential(nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10), nn.ReLU()) + pipe = Pipe(model, [1, 1, 1, 1], style=Pipe.AsyncSchedule, worker_map=get_worker_map(), chunks=10) + + inputs = torch.rand(100, 10) + + output = pipe(inputs) + if pipe.final_stage: + loss = output.mean() + loss.backward() + + +@torch_spawn([4]) +def reuse_lazy(): + if False: # speed + reused = LazyModule(lambda: nn.Linear(10, 10)) + model = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()] + # model = [reused, reused, nn.Linear(10, 10), nn.ReLU(), reused, reused, nn.ReLU(), reused, reused, nn.ReLU()] + pipe = Pipe(model, [3, 1, 1], style=Pipe.AsyncSchedule, worker_map=get_worker_map()) + pipe.eval() + output = pipe(torch.rand(10)) + + print(f"output on {pipe.group.rank()}, {output}") + torch.distributed.barrier() + + set_random_seed(1234) + # test both foward + reused = nn.Linear(10, 10) + layers = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()] + model = nn.Sequential(*layers) + model.eval() + + set_random_seed(1234) + # ensure identical weights but no sharing between model and pipe + reused = nn.Linear(10, 10) + layers = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()] + pipe = Pipe(layers, [3, 1, 1], style=Pipe.AsyncSchedule, worker_map=get_worker_map()) + pipe.eval() + model_optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + pipe_optimizer = torch.optim.SGD(pipe.parameters(), lr=0.01, momentum=0.9) if len(list(pipe.parameters())) else None + inputs = torch.rand(10) + if False: # speed + model_out = model(inputs) + pipe_out = pipe(inputs) + + torch.distributed.barrier() + + if pipe.final_stage: + assert torch.equal(model_out, pipe_out) + + model.train() + pipe.train() + model_out = model(inputs) + pipe_out = pipe(inputs) + if pipe.final_stage: + pipe_loss = pipe_out.mean() + pipe_loss.backward() + + model_loss = model_out.mean() + model_loss.backward() + + model_optimizer.step() + if pipe_optimizer: + pipe_optimizer.step() + + model.eval() + pipe.eval() + model_out = model(inputs) + pipe_out = pipe(inputs) + + print(f"before barrier on {torch.distributed.get_rank()}") + torch.distributed.barrier() + print(f"after barrier on {torch.distributed.get_rank()}") + + if pipe.final_stage: + assert torch.equal(model_out, pipe_out) + + +def test_instantiate_partition(): + from fairscale.nn.pipe.pipe import instantiate_partition + from fairscale.nn.pipe.async_schedule import Location + + class FakeGroup: + def __init__(self, rank, size): + self._rank = rank + self._size = size + + def rank(self): + return self._rank + + def size(self): + return self._size + + def check_partitions(model, balance, expected_order, expected_ranks): + """Check the instantiated model matches expectation of order and rank + + model: a list of modules or an nn.Sequential + balance: the balance argument to Pipe + expected_order: the index of modules in `model` in the order they will + be executed, grouped by nn.Sequential + expected_rank: the rank that each module will be executed on + """ + + invocations = [] + invocation_wrapper = dict() + + # Collect `Invocation` and `Invocation` -> `ModuleWrapper` mapping from + # instantiated model + for rank in range(len(balance)): + instantiated = instantiate_partition(model, balance, FakeGroup(rank, len(balance)), Pipe.AsyncSchedule) + for part in instantiated: + assert isinstance(part.module, nn.Sequential) + for inv in part.invocations: + invocations.append(inv) + invocation_wrapper[inv] = part + + modules = [] + prev = None + current = Location(0, 0) + ranks = [] + + for order, inv in enumerate(sorted(invocations, key=lambda x: x.order)): + # Check integrity of Location chain + assert inv.order == order + assert inv.source == prev + assert inv.this == current + prev = inv.this + current = inv.dest + modules.append(list(invocation_wrapper[inv].module.children())) + ranks.append(inv.this.stage) + + # assert len(modules) == len(expected_order) + for left, right in zip(modules, expected_order): + assert len(left) == len(right), f"{right}" + assert list(map(id, left)) == list(map(id, (model[e] for e in right))), f"{right}" + + assert ranks == expected_ranks + + reused = nn.Linear(20, 20) + model = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()] + balance = [3, 1, 1] + + check_partitions( + model, balance, expected_order=[[0], [1, 2], [0], [4], [0], [6]], expected_ranks=[0, 0, 0, 1, 0, 2] + ) + + reused2 = nn.Linear(5, 5) + model = [reused, reused2, nn.Linear(10, 10), nn.ReLU(), reused, reused2, nn.ReLU(), reused, reused2, nn.ReLU()] + balance = [4, 1, 1] + + check_partitions( + model, + balance, + expected_order=[[0], [1], [2, 3], [0], [1], [6], [0], [1], [9]], + expected_ranks=[0, 0, 0, 0, 0, 1, 0, 0, 2], + ) + + reused2 = nn.Linear(5, 5) + model = [ + nn.Linear(10, 10), + reused, + nn.Linear(10, 10), + nn.ReLU(), + reused, + reused2, + nn.ReLU(), + reused, + reused2, + nn.ReLU(), + ] + # 0 1 2 3 1 5 6 1 5 9 + balance = [4, 2, 1] + + check_partitions( + model, + balance, + expected_order=[[0], [1], [2, 3], [1], [5], [6], [1], [5], [9]], + expected_ranks=[0, 0, 0, 0, 1, 1, 0, 1, 2], + ) diff --git a/tests/nn/pipe_process/test_rpc.py b/tests/nn/pipe_process/test_rpc.py new file mode 100644 index 000000000..dfc1d9a68 --- /dev/null +++ b/tests/nn/pipe_process/test_rpc.py @@ -0,0 +1,278 @@ +import copy +import os + +import pytest +import torch +from torch import nn +from torch.distributed import rpc + +from fairscale.nn.model_parallel.initialize import get_model_parallel_group, get_pipeline_parallel_group +from fairscale.nn.pipe import PipeRPCWrapper +from tests.nn.model_parallel.commons import get_worker_map, torch_spawn + + +def init_rpc(): + os.environ["MASTER_PORT"] = "10639" + init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}" + rpc.init_rpc( + f"Test{torch.distributed.get_rank()}", + rank=torch.distributed.get_rank(), + world_size=torch.distributed.get_world_size(), + backend=rpc.BackendType.TENSORPIPE, + rpc_backend_options=rpc.TensorPipeRpcBackendOptions(init_method=init_method), + ) + + +@torch_spawn([2]) +@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" not in os.environ, reason="mpi required") +def basic_rpc(): + init_rpc() + if torch.distributed.get_rank() != 0: + rpc.shutdown() + torch.distributed.barrier() + return + + model = [nn.Linear(10, 10), nn.ReLU()] + pipe = PipeRPCWrapper(model, [1, 1], input_device=torch.cuda.current_device(), worker_map=get_worker_map()) + + pipe.foreach_worker(register_optimizer, include_self=True) + + inputs = torch.rand(10).cuda() + output = pipe(inputs) + loss = output.mean() + loss.backward() + + pipe.foreach_worker(step_optimizer, include_self=True) + + pipe.eval() + + rpc.shutdown() + torch.distributed.barrier() + + +def register_optimizer(ctx, model): + if len(list(model.parameters())) > 0: + model.optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) + else: + model.optimizer = None + + +def step_optimizer(ctx, model): + if model.optimizer: + model.optimizer.step() + + +def check_pipe_against_reference(balance, model_constructor, checkpoint="except_last"): + model = model_constructor() + reference_model = model_constructor() + for src, dst in zip(model, reference_model): + dst.load_state_dict(copy.deepcopy(src.state_dict())) + + reference_model = nn.Sequential(*reference_model).cuda() + + pipe = PipeRPCWrapper( + model, balance, input_device=torch.cuda.current_device(), worker_map=get_worker_map(), checkpoint=checkpoint + ) + + pipe.foreach_worker(register_optimizer, include_self=True) + register_optimizer(None, reference_model) + + inputs = torch.rand(10).cuda() + target = torch.rand(10).cuda() + cloned = inputs.clone() + output = pipe(inputs) + print(f"out pipe on {torch.distributed.get_rank()}") + mp_group = get_model_parallel_group() + if mp_group.size() > 1: + pass # torch.distributed.barrier(mp_group) + ref_out = reference_model(inputs) + + print(f"out on {torch.distributed.get_rank()}") + print(f"{ref_out}, {output}, {inputs}, {cloned}") + assert torch.equal(ref_out.cpu(), output.cpu()) + + for out in output, ref_out: + if mp_group.size() > 1: + pass # torch.distributed.barrier(mp_group) + try: + target = target.to(out.device) + loss = nn.MSELoss()(out, target) + loss.backward() + except Exception as e: + print(f"loss failed {e}") + raise e + + print(f"{torch.distributed.get_rank()}: optimizer") + pipe.foreach_worker(step_optimizer, include_self=True) + print(f"{torch.distributed.get_rank()}: optimizer2") + step_optimizer(None, reference_model.cuda()) + print(f"{torch.distributed.get_rank()}: eval") + + pipe.eval() + reference_model.eval() + print(f"{torch.distributed.get_rank()}: pipe2") + + final_output = pipe(inputs) + print(f"{torch.distributed.get_rank()}: ref2") + if mp_group.size() > 1: + pass # torch.distributed.barrier(mp_group) + try: + final_ref = reference_model(inputs.cuda()) + except Exception as e: + print(f"ref got {e}") + raise e + + assert torch.equal(final_output.cpu(), final_ref.cpu()) + + +@torch_spawn([3]) +@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" not in os.environ, reason="mpi required") +def rpc_optimizer(): + + init_rpc() + if torch.distributed.get_rank() != 0: + rpc.shutdown() + torch.distributed.barrier() + return + + def model_with_reuse(): + reused_1 = nn.Linear(10, 10) + return [reused_1, nn.ReLU(), reused_1, nn.ReLU(), reused_1, nn.ReLU()] + + print(f"easy") + check_pipe_against_reference( + [2, 2, 2], lambda: [nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10), nn.ReLU()], + ) + print(f"hard") + check_pipe_against_reference([2, 1, 1], model_with_reuse) + + rpc.shutdown() + torch.distributed.barrier() + + +@torch_spawn([6]) +@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" not in os.environ, reason="mpi required") +def rpc_megatron_reuse(): + + from fairscale.nn.model_parallel import layers + from fairscale.nn.model_parallel.initialize import destroy_model_parallel, initialize_model_parallel + + def make_model_simple(): + return [ + layers.ColumnParallelLinear(10, 10), + nn.ReLU(), + layers.RowParallelLinear(10, 10), + nn.ReLU(), + layers.ColumnParallelLinear(10, 10), + nn.ReLU(), + layers.RowParallelLinear(10, 10), + nn.ReLU(), + nn.Linear(10, 10), + nn.ReLU(), + ] + + def make_model_with_reuse(): + column = layers.ColumnParallelLinear(10, 10) + row = layers.RowParallelLinear(10, 10) + return [ + column, + nn.ReLU(), + row, + nn.ReLU(), + column, + nn.ReLU(), + row, + nn.ReLU(), + nn.Linear(10, 10), + nn.ReLU(), + ] + + destroy_model_parallel() + torch.distributed.destroy_process_group() + torch.distributed.init_process_group("gloo", rank=int(os.environ["RANK"]), world_size=int(os.environ["WORLD_SIZE"])) + initialize_model_parallel(2, 3, model_parallel_backend="nccl", pipeline_backend="mpi") + + init_rpc() + if get_pipeline_parallel_group().rank() != 0: + rpc.shutdown() + torch.distributed.barrier() + return + + check_pipe_against_reference([4, 4, 2], make_model_simple, "always") + print(f"{torch.distributed.get_rank()} simple returned!") + check_pipe_against_reference([4, 2, 2], make_model_with_reuse) + print(f"{torch.distributed.get_rank()} returned!") + + rpc.shutdown() + torch.distributed.barrier() + + +@torch_spawn([3]) +@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" not in os.environ, reason="mpi required") +def rpc_deadlock(): + reused = nn.Linear(10, 10) + if False: + reused2 = nn.Linear(10, 10) + model = [ + nn.Linear(10, 10), + nn.ReLU(), + nn.Linear(10, 10), + reused2, + nn.ReLU(), + reused, + nn.ReLU(), + reused, + reused2, + nn.ReLU(), + reused, + nn.ReLU(), + ] + balance = [2, 3, 4] + else: + model = [ + nn.Linear(10, 10), + nn.ReLU(), + nn.Linear(10, 10), + nn.ReLU(), + reused, + nn.ReLU(), + reused, + nn.ReLU(), + reused, + nn.ReLU(), + ] + balance = [2, 2, 4] + + init_rpc() + + if torch.distributed.get_rank() != 0: + rpc.shutdown() + torch.distributed.barrier() + return + + pipe = PipeRPCWrapper(model, balance, worker_map=get_worker_map()) + + inputs = torch.rand(10).cuda() + target = torch.rand(10).cuda() + output = pipe(inputs) + nn.MSELoss()(output, target).backward() + output = pipe(inputs) + nn.MSELoss()(output, target).backward() + rpc.shutdown() + torch.distributed.barrier() + + +@torch_spawn([2]) +@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="no mpi") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") +def construct_only_rank_zero(): + model = [nn.Linear(10, 10), nn.ReLU()] + if torch.distributed.get_rank() == 0: + PipeRPCWrapper(model, [1, 1], worker_map=get_worker_map()) + rpc.shutdown() + else: + # Must enter rpc loop to complte PipeRPCWrapper constructor above + rpc.shutdown() + + with pytest.raises(AssertionError): + PipeRPCWrapper(model, [1, 1], worker_map=get_worker_map()) diff --git a/tests/nn/pipe_process/test_transparency.py b/tests/nn/pipe_process/test_transparency.py index ad11ac4bb..a1e509924 100644 --- a/tests/nn/pipe_process/test_transparency.py +++ b/tests/nn/pipe_process/test_transparency.py @@ -27,7 +27,8 @@ @torch_spawn([2]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") -def simple_linears(): +@pytest.mark.parametrize("pipeline_style", [Pipe.MultiProcess, Pipe.AsyncSchedule]) +def simple_linears(pipeline_style): def sum_grad(parameters): return sum([p.grad.sum() for p in parameters if p.grad is not None]) @@ -54,19 +55,19 @@ def zero_grad(parameters): zero_grad(model.parameters()) # With Pipe - model = Pipe(model, [2, 2], style=Pipe.MultiProcess, worker_map=get_worker_map(), chunks=4) + model = Pipe(model, [2, 2], style=pipeline_style, worker_map=get_worker_map(), chunks=4) outputs = model(inputs) if model.group.rank() == 1: loss = outputs.mean() loss.backward() - grad_with_pipe = sum_grad(model.pipeline.partitions[0].parameters()) + grad_with_pipe = sum_grad(model.pipeline.mp_partitions[0].module.parameters()) # Both grads should be identical. assert torch.allclose(grad_with_pipe, grad_without_pipe[1]) else: model.back_helper(outputs) - grad_with_pipe = sum_grad(model.pipeline.partitions[0].parameters()) + grad_with_pipe = sum_grad(model.pipeline.mp_partitions[0].module.parameters()) # Both grads should be identical. assert torch.allclose(grad_with_pipe, grad_without_pipe[0]) From 699a15a8d7889b4d4c0b2e30454804f554650976 Mon Sep 17 00:00:00 2001 From: Tom Birch Date: Tue, 20 Oct 2020 19:47:29 -0700 Subject: [PATCH 3/8] fixes per review comments --- benchmarks/pipe.py | 31 +++++--------------------- tests/nn/model_parallel/commons.py | 23 ++++++++++--------- tests/nn/model_parallel/test_layers.py | 20 ++++++----------- tests/nn/moe/test_moe_layer.py | 3 +-- 4 files changed, 26 insertions(+), 51 deletions(-) diff --git a/benchmarks/pipe.py b/benchmarks/pipe.py index 21b9b1439..6d3c469ef 100644 --- a/benchmarks/pipe.py +++ b/benchmarks/pipe.py @@ -1,6 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. import argparse +import logging import math import os import time @@ -16,7 +17,6 @@ import torchtext from torchtext.data.utils import get_tokenizer -# from deepspeed.pipe import PipelineModule from fairscale.nn import Pipe from fairscale.nn.model_parallel import initialize_model_parallel from fairscale.nn.model_parallel.initialize import get_data_parallel_group, get_pipeline_parallel_group @@ -25,12 +25,6 @@ from fairscale.optim.oss import OSS from tests.nn.model_parallel.commons import dist_init, get_worker_map -try: - import torch_ucc # noqa: F401 -except ImportError as e: - print(f"can't import torch_ucc: {e}") - pass - try: from fairscale.optim import Adam # type: ignore @@ -291,15 +285,15 @@ def train(lm_dataloader, model, criterion, optimizer, vocab_size, args): if model.group: total = torch.Tensor([num_params]).cuda() torch.distributed.all_reduce(total, group=model.group) - print( + logging.info( f"training model, #prams = {num_params}, group: {model.group.rank()}, grank:" f" {torch.distributed.get_rank()}, sizes {model.group.size()}" ) torch.distributed.barrier() if model.group.rank() == 0: - print(f"total #prams = {total.item()}") + logging.info(f"total #prams = {total.item()}") else: - print(f"training model, #prams = {num_params}") + logging.info(f"training model, #prams = {num_params}") vocab_size = 10000 # FIXME total_loss = 0.0 start_time = time.time() @@ -326,9 +320,6 @@ def get_last_device(model): pipe_group = model.group - if pipe_group is None or pipe_group.rank() == 0: - print(f">> Init DDP") - if args.ddp_zero: model = DDP( model, @@ -337,9 +328,6 @@ def get_last_device(model): find_unused_parameters=False, ) - if pipe_group is None or pipe_group.rank() == 0: - print(f"<< Init DDP") - if pipe_group and pipe_group.rank() != 0 and pipe_group.rank() != (pipe_group.size() - 1): thing = {"input": torch.zeros(args.batch_size)} @@ -354,7 +342,6 @@ def __len__(self): for i, batch in enumerate(lm_dataloader): bi = batch["input"] - # print(f"batch size: {torch.numel(bi)}, {bi.size()}, {bi.device}") if args.max_batch and i > args.max_batch: break optimizer.zero_grad() @@ -365,18 +352,12 @@ def __len__(self): else: output = model(batch["input"]) except Exception as e: - print(f"exception while training on rank {torch.distributed.get_rank()}") raise RuntimeError(f"training failed on {torch.distributed.get_rank()}") from e if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1: - if True: - target = batch["target"].to(get_last_device(model)) - output = output.to(target.device) - else: - target = batch["target"].cpu() - output = output.cpu() + target = batch["target"].to(get_last_device(model)) + output = output.to(target.device) - print(f"output size is {output.size()}") loss = criterion(output.view(-1, vocab_size), target.view(-1)) if args.ddp_zero: ddp_group = get_data_parallel_group() diff --git a/tests/nn/model_parallel/commons.py b/tests/nn/model_parallel/commons.py index f4bc34250..11428771c 100644 --- a/tests/nn/model_parallel/commons.py +++ b/tests/nn/model_parallel/commons.py @@ -36,12 +36,6 @@ from fairscale.nn.model_parallel import initialize_model_parallel from fairscale.nn.model_parallel.random import model_parallel_cuda_manual_seed -try: - import torch_ucc # noqa: F401 -except ImportError as e: - print(f"can't import torch_ucc: {e}") - pass - class IdentityLayer(torch.nn.Module): def __init__(self, size, scale=1.0): @@ -107,16 +101,19 @@ def spawn_for_all_world_sizes(test_func, world_sizes=get_world_sizes(), args=[]) mp.spawn(test_func, args=(world_size, *args), nprocs=world_size, join=True) -def helper(rank, world_size, func, args, error_queue): +def worker_process(rank, world_size, func, args, error_queue): + """Main function for unit tests launced with torch_spawn""" + dist_init(rank, world_size) kwargs = {} if "OMPI_COMM_WORLD_RANK" not in os.environ: kwargs["pipeline_backend"] = "gloo" - print(f"init glooo backend") initialize_model_parallel(1, world_size, **kwargs) try: func(*args) except BaseException as e: + # If the function raises 'Skipped', this indicates pytest.skip(), so + # forward it to parent so we can call pytest.skip() there if e.__class__.__name__ == "Skipped": error_queue.put(str(e)) return @@ -127,7 +124,9 @@ def torch_spawn(world_sizes=None): if world_sizes is None: world_sizes = get_world_sizes() - def fixer(func): + def prepare_test(func): + """Function called with the test function as the argument. Generates a + replacement which serves as the actual test function.""" name = func.__name__ parameters = inspect.signature(func).parameters @@ -160,15 +159,17 @@ def replacement(*args, **kwargs): else: pytest.skip(f"requested world size doesn't match current world size") else: - spawn_for_all_world_sizes(helper, world_sizes, (func, args, error_queue)) + spawn_for_all_world_sizes(worker_process, world_sizes, (func, args, error_queue)) if not error_queue.empty(): msg = error_queue.get() pytest.skip(msg) + # Register a function with the same name, prefixed with "test_" in the + # calling module, so it will be picked up by pytest caller_module = inspect.getmodule(inspect.currentframe().f_back) setattr(caller_module, f"test_{name}", replacement) return func - return fixer + return prepare_test diff --git a/tests/nn/model_parallel/test_layers.py b/tests/nn/model_parallel/test_layers.py index 314321a7d..835d731ee 100644 --- a/tests/nn/model_parallel/test_layers.py +++ b/tests/nn/model_parallel/test_layers.py @@ -307,10 +307,6 @@ def run_test_row_parallel_linear(rank, model_parallel_size): print(" >> passed the test :-)") -def a_barrier(): - torch.distributed.barrier() - - def run_test_pipe(rank, world_size, skip_dist_init=False): pipe_world_size = 2 @@ -417,17 +413,17 @@ def forward_model(model_, target, step=False): reference_output = forward_model(reference, target) error = reference_output.sub(output).max() - a_barrier() + torch.distributed.barrier() assert error < 1.0e-6 output = forward_model(model, target) error = reference_output.sub(output).max() - a_barrier() + torch.distributed.barrier() assert error < 1.0e-6 output = forward_model(model, target) error = reference_output.sub(output).max() - a_barrier() + torch.distributed.barrier() assert error < 1.0e-6 check_weights(model, reference, "before") @@ -456,7 +452,7 @@ def forward_model(model_, target, step=False): chunks=chunk_size, pipelined_backward=True, ).cuda() - a_barrier() + torch.distributed.barrier() pipe_rank = torch.distributed.get_rank(group=mpu.get_pipeline_parallel_group()) print(f"pipe rank is {pipe_rank}") if pipe_rank == 0: @@ -479,7 +475,6 @@ def forward_model(model_, target, step=False): else: check_weights(model, reference, "pre-pipe", index=0) - # pipe_mode.eval() pipe_output = pipe_model(identity()) print(f"exited pipe for {rank}") forward_model(reference, target, step=True) @@ -487,8 +482,7 @@ def forward_model(model_, target, step=False): print(f"pipe_output {rank} = {pipe_output}") print(f"reference_output {rank} = {reference_output}") - if style == Pipe.MultiProcess: - a_barrier() + torch.distributed.barrier() if torch.distributed.get_rank(mpu.get_pipeline_parallel_group()) == 1: error = reference_output.sub(pipe_output.cuda()).max() @@ -533,7 +527,7 @@ def forward_model(model_, target, step=False): print(f"waiting for barrier on slave") pipe_model.zero_grad() - a_barrier() + torch.distributed.barrier() pipe_model.eval() pipe_output = pipe_model(identity()) @@ -542,7 +536,7 @@ def forward_model(model_, target, step=False): error = updated_ref_output.sub(pipe_output.cuda()).max() print(f"outputs are ref:\n{updated_ref_output}\npipe:\n{pipe_output}") assert error < 1.0e-6 - a_barrier() + torch.distributed.barrier() print(f"finished waiting for barrier on, pid={os.getpid()}") diff --git a/tests/nn/moe/test_moe_layer.py b/tests/nn/moe/test_moe_layer.py index 77b4b1c3d..4dc95b4cd 100644 --- a/tests/nn/moe/test_moe_layer.py +++ b/tests/nn/moe/test_moe_layer.py @@ -34,8 +34,7 @@ def setup_module(module): def teardown_module(module): - if "OMPI_COMM_WORLD_SIZE" not in os.environ: - torch.distributed.destroy_process_group() + torch.distributed.destroy_process_group() @pytest.mark.parametrize("device", devices) From cedb218e6524086d5a578f316a986b037ba271c8 Mon Sep 17 00:00:00 2001 From: Tom Birch Date: Wed, 21 Oct 2020 13:38:43 -0700 Subject: [PATCH 4/8] refactor message --- fairscale/nn/pipe/async_schedule.py | 2 +- fairscale/nn/pipe/messages.py | 48 +++++++----- fairscale/nn/pipe/pipeline.py | 26 ++----- fairscale/nn/pipe/rpc.py | 111 +++++++++++----------------- fairscale/nn/pipe/types.py | 17 ++++- tests/nn/model_parallel/commons.py | 6 +- 6 files changed, 99 insertions(+), 111 deletions(-) diff --git a/fairscale/nn/pipe/async_schedule.py b/fairscale/nn/pipe/async_schedule.py index 5d3e72869..94557735d 100644 --- a/fairscale/nn/pipe/async_schedule.py +++ b/fairscale/nn/pipe/async_schedule.py @@ -133,7 +133,7 @@ def backward(ctx, *grad: Tensor,) -> Tuple[Optional[Tensor], ...]: ranks = get_pipeline_parallel_ranks() this_rank = torch.distributed.get_rank() dprint(f"AsyncRecvOperator back {this_rank} {len(grad)}, {ctx.args}") - # Note that dst/source are swaped coz in backward pass, maybe abstract + # Note that dst/source are swaped due to backward pass, maybe abstract # this out? body = AsyncMessageBody( AsyncMessageType.Gradients, ctx.index, source=ctx.args.dest, dest=ctx.args.source, order=ctx.args.order - 1 diff --git a/fairscale/nn/pipe/messages.py b/fairscale/nn/pipe/messages.py index b9fa63969..4d803a507 100644 --- a/fairscale/nn/pipe/messages.py +++ b/fairscale/nn/pipe/messages.py @@ -53,7 +53,7 @@ def rpc_push_queue(message: PipeMessage) -> None: globals()["MessageQueues"][message.queue_name].put(message) -def send_message(config: TransportConfig, message: PipeMessage, sync: bool = False) -> None: +def send_message(config: TransportConfig, message: PipeMessage, sync: bool = False, skip_header: bool = False) -> None: if config.use_rpc: message.tensors = tuple(t.cpu() for t in message.tensors) assert config.worker_map @@ -65,12 +65,13 @@ def send_message(config: TransportConfig, message: PipeMessage, sync: bool = Fal else: tensors = message.tensors message.tensors = tuple() - message.tensor_shapes = [t.size() for t in tensors] - message.tensor_dtypes = [t.dtype for t in tensors] torch.cuda.current_stream().synchronize() - torch.distributed.send( - pyobject_to_tensor(message), message.dest, tag=message.queue_name, group=get_pipeline_parallel_group() - ) + if not skip_header: + message.tensor_shapes = [t.size() for t in tensors] + message.tensor_dtypes = [t.dtype for t in tensors] + torch.distributed.send( + pyobject_to_tensor(message), message.dest, tag=message.queue_name, group=get_pipeline_parallel_group() + ) for index, t in enumerate(tensors): if t.device.type == "cpu": t = t.cuda() @@ -79,6 +80,26 @@ def send_message(config: TransportConfig, message: PipeMessage, sync: bool = Fal ) +def recv_message_tensors(input_device: InputDevice, config: TransportConfig, message: PipeMessage) -> PipeMessage: + if config.use_rpc: + # Tensors already contained within message + message.tensors = to_input_device(message.tensors, input_device) + return message + else: + torch.cuda.current_stream().synchronize() + + message_tensors = [] + for index, (shape, dtype) in enumerate(zip(message.tensor_shapes, message.tensor_dtypes)): + t = torch.empty(*shape, dtype=dtype, device=input_device) + torch.distributed.recv(t, message.src, tag=message.tag + index, group=get_pipeline_parallel_group()) + message_tensors.append(t) + + message.tensors = tuple(message_tensors) + + torch.cuda.current_stream().synchronize() + return message + + def recv_message( config: TransportConfig, queue_name: int, *, nowait: bool = False, input_device: InputDevice = None ) -> PipeMessage: @@ -88,8 +109,7 @@ def recv_message( result = queue.get_nowait() else: result = queue.get() - result.tensors = to_input_device(result.tensors, input_device) - return result + return recv_message_tensors(input_device, config, result) else: # FIXME(handle nowait) if nowait: @@ -101,17 +121,7 @@ def recv_message( torch.cuda.current_stream().synchronize() message = tensor_to_pyobject(tensor.cpu()) - message_tensors = [] - for index, (shape, dtype) in enumerate(zip(message.tensor_shapes, message.tensor_dtypes)): - t = torch.empty(*shape, dtype=dtype, device=input_device) - torch.distributed.recv(t, message.src, tag=message.tag + index, group=get_pipeline_parallel_group()) - message_tensors.append(t) - - message.tensors = tuple(message_tensors) - # print(f"<<< recv:{torch.distributed.get_rank()}") - - torch.cuda.current_stream().synchronize() - return message + return recv_message_tensors(input_device, config, message) def get_out_of_order(config: TransportConfig, queue_name: int, index: int, *, input_device: InputDevice) -> Tensors: diff --git a/fairscale/nn/pipe/pipeline.py b/fairscale/nn/pipe/pipeline.py index b53c84b6b..cc18be600 100644 --- a/fairscale/nn/pipe/pipeline.py +++ b/fairscale/nn/pipe/pipeline.py @@ -17,6 +17,7 @@ # limitations under the License. """The pipeline parallelism of Pipe.""" +import logging import os from queue import Empty as QueueEmpty from queue import Queue @@ -56,19 +57,9 @@ __all__: List[str] = [] - ExcInfo = Tuple[Type[BaseException], BaseException, TracebackType] -debug_fd = None - - -def dprint(s: str) -> None: - print(f"{time.monotonic()}: {s}") - # debug_fd.write(f"{time.monotonic()}: {s}\n") - # debug_fd.flush() - - class SendOperator(torch.autograd.Function): """Send activations to the next pipeline stage""" @@ -309,8 +300,6 @@ def run(self, training: bool, batches: List[Batch], event: Optional[Event]) -> N elif self.style is PipelineStyle.AsyncSchedule: assert self.group rank = self.group.rank() - dprint(f">>> start barrier") - # torch.distributed.barrier(group=self.group) event_loop = AsyncEventLoop( self.mp_partitions, self.group, @@ -319,19 +308,18 @@ def run(self, training: bool, batches: List[Batch], event: Optional[Event]) -> N self.input_device, self.checkpoint_stop, ) - dprint(f"<<< start barrier") if rank == 0 and not self.final_stage: - dprint(f"{torch.distributed.get_rank()}: entered event head") + logging.debug(f"{torch.distributed.get_rank()}: entered event head") event_loop.event_loop_head(batches, skip_trackers, event) - dprint(f"{torch.distributed.get_rank()}: exited event head") + logging.debug(f"{torch.distributed.get_rank()}: exited event head") elif self.final_stage: - dprint(f"{torch.distributed.get_rank()}: entered event tail") + logging.debug(f"{torch.distributed.get_rank()}: entered event tail") event_loop.event_loop_tail(batches, skip_trackers) - dprint(f"{torch.distributed.get_rank()}: exited event tail") + logging.debug(f"{torch.distributed.get_rank()}: exited event tail") else: - dprint(f"{torch.distributed.get_rank()}: entered event loop") + logging.debug(f"{torch.distributed.get_rank()}: entered event loop") event_loop.event_loop(len(batches), skip_trackers) - dprint(f"{torch.distributed.get_rank()}: exited event loop") + logging.debug(f"{torch.distributed.get_rank()}: exited event loop") self.callcount += 1 diff --git a/fairscale/nn/pipe/rpc.py b/fairscale/nn/pipe/rpc.py index bfc16cfe5..9fdd9ae89 100644 --- a/fairscale/nn/pipe/rpc.py +++ b/fairscale/nn/pipe/rpc.py @@ -10,7 +10,8 @@ from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group from . import Pipe -from .types import TensorOrTensors +from .messages import recv_message_tensors, send_message +from .types import EVENT_LOOP_QUEUE, PipeMessage, TensorOrTensors, TransportConfig DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024 @@ -70,7 +71,6 @@ def model_forward(training: bool, shape: torch.Size, dtype: torch.dtype) -> Opti dprint(f"mf: train stage {model.group.rank()}, {os.getpid()}") model.train(training) result = model(tensor) - torch.cuda.current_stream().synchronize() if model.final_stage: globals()["PipeResult"] = result return (get_shapes(result), get_dtype(result)) @@ -84,7 +84,7 @@ def model_forward(training: bool, shape: torch.Size, dtype: torch.dtype) -> Opti return None -def send_result(training: bool) -> None: +def send_result(training: bool, message: PipeMessage, grads_message: PipeMessage) -> None: dprint(f"send result {training}") group = get_pipeline_parallel_group() set_device_based_on_group(group) @@ -97,76 +97,53 @@ def send_result(training: bool) -> None: result = [result] dest = _get_global_rank(group, 0) + print( f"ho har {torch.distributed.get_rank()} " + str([_get_global_rank(group, i) for i in range(group.size())]) ) - torch.cuda.current_stream().synchronize() - for r in result: - dprint(f">>> send {torch.distributed.get_rank()}, {dest} {r.shape}, {r.dtype}, {r.device}") - if "Gloo" in group.__class__.__name__: - r = r.cpu() - torch.distributed.send(r.contiguous(), dest, group=group) - dprint(f"<<< send {torch.distributed.get_rank()}, {dest}") - torch.cuda.current_stream().synchronize() + message.tensors = tuple(result) + config = TransportConfig(False, None) + send_message(config, message, sync=False, skip_header=True) if training: - grads = [] - for r in result: - g = torch.empty(r.shape).cuda() - dprint(f">>> recv grads {g.shape}") - torch.cuda.current_stream().synchronize() - torch.distributed.recv(g, dest, group=group) - torch.cuda.current_stream().synchronize() - if "Gloo" in group.__class__.__name__: - g = g.cuda() - dprint(f"<<< recv grads {g.shape}") - grads.append(g) + grads_message.tensor_shapes = [r.shape for r in result] + grads_message.tensor_dtypes = [r.dtype for r in result] + input_device = torch.device("cuda", torch.cuda.current_device()) + transport_config = TransportConfig(False, None) + grads_message = recv_message_tensors(input_device, transport_config, grads_message) with model.lock: print(f" >>> autograd-backward tail") - torch.autograd.backward(result, tuple(grads), retain_graph=True) + torch.autograd.backward(result, grads_message.tensors, retain_graph=True) print(f" <<< autograd-backward tail") - torch.cuda.synchronize() except Exception as e: print(f"got {e}") -def recv_result(shapes: SizeOrSizes, dtypes: DtypeOrDtypes) -> TensorOrTensors: +def recv_result(shapes: SizeOrSizes, dtypes: DtypeOrDtypes, message: PipeMessage) -> TensorOrTensors: group = get_pipeline_parallel_group() set_device_based_on_group(group) src = torch.distributed.distributed_c10d._get_global_rank(group, group.size() - 1) dprint(f"recv_result... {src}, {torch.cuda.current_device()}") + input_device = torch.device("cuda", torch.cuda.current_device()) + transport_config = TransportConfig(False, None) + if isinstance(shapes, torch.Size): shape = cast(torch.Size, shapes) dtype = cast(torch.dtype, dtypes) - t = torch.empty(shape, dtype=dtype).cuda() - dprint(f">>> recv {torch.distributed.get_rank()}, {src} {t.shape}, {t.dtype}") - torch.cuda.current_stream().synchronize() - torch.distributed.recv(t, src, group=group) - torch.cuda.current_stream().synchronize() - if "Gloo" in group.__class__.__name__: - t = t.cuda() - dprint(f"<<< recv {torch.distributed.get_rank()}, {src}") - dprint(f"recvd solo") - return t + message.tensor_shapes = [shape] + message.tensor_dtypes = [dtype] + message = recv_message_tensors(input_device, transport_config, message) + return message.tensors[0] else: - result = [] - torch.cuda.current_stream().synchronize() shapes = cast(List[torch.Size], shapes) dtypes = cast(List[torch.dtype], dtypes) - for s, d in zip(shapes, dtypes): - t = torch.empty(s, dtype=d).cuda() - dprint(f">>> recv {torch.distributed.get_rank()}, {src} {t.shape}, {t.dtype}") - torch.distributed.recv(t, src, group=group) - if "Gloo" in group.__class__.__name__: - t = t.cuda() - dprint(f"<<< recv {torch.distributed.get_rank()}, {src}") - dprint(f"recvd multi / {len(shapes)}") - result.append(t) - torch.cuda.current_stream().synchronize() - return tuple(result) + message.tensor_shapes = shapes + message.tensor_dtypes = dtypes + message = recv_message_tensors(input_device, transport_config, message) + return message.tensors def get_global_ranks_from_group(group: ProcessGroup) -> List[int]: @@ -179,35 +156,26 @@ def run_model(model: Pipe, tensor: TensorOrTensors, event: Event, lock: Lock) -> print(f">> run_model thread {t}") assert model.group set_device_based_on_group(model.group) - torch.cuda.current_stream().synchronize() - torch.cuda.synchronize() model(tensor, event=event) - torch.cuda.synchronize() - torch.cuda.current_stream().synchronize() print(f"<< run_model thread {t}") class PipeBackRedirect(torch.autograd.Function): @staticmethod # type: ignore - def forward(ctx, inputs, dest, event): + def forward(ctx, inputs, dest, event, message): ctx.dest = dest ctx.event = event + ctx.message = message return inputs @staticmethod # type: ignore def backward(ctx, *grad): dprint(f">>> back hook yay") - group = get_pipeline_parallel_group() - torch.cuda.current_stream().synchronize() - for g in grad: - dprint(f">>> back send {g.shape}") - if "Gloo" in group.__class__.__name__: - g = g.cpu() - torch.distributed.send(g, ctx.dest, group=group) - dprint(f"<<< back send") - torch.cuda.current_stream().synchronize() + config = TransportConfig(False, None) + ctx.message.tensors = tuple(grad) + send_message(config, ctx.message, sync=False, skip_header=True) ctx.event.set() dprint(f"<<< back hook yay") return (None, None, None, None) @@ -219,10 +187,8 @@ def callback_with_model(callback: Callable, ctx: Any) -> None: global PipeModel - torch.cuda.current_stream().synchronize() with PipeModel.lock: callback(ctx, PipeModel) - torch.cuda.current_stream().synchronize() class PipeRPCWrapper(nn.Module): @@ -245,7 +211,6 @@ def __init__(self, *args: Any, **kwargs: Any): self.model = Pipe(*args, **kwargs) self.worker_map = kwargs["worker_map"] - print(f"calling rpc {args}, {kwargs}") futures = [ # FIXME get global rank rpc.rpc_async(self.get_rpc_name(rank), register_remote_model, args=(args, kwargs)) @@ -271,6 +236,11 @@ def forward(self, tensor: TensorOrTensors) -> TensorOrTensors: # type: ignore shape = get_shapes(tensor) dtype = get_dtype(tensor) + if isinstance(tensor, torch.Tensor): + num_tensors = 1 + else: + num_tensors = len(tensor) + futures = [ rpc.rpc_async(self.get_rpc_name(rank), model_forward, args=(self.model.training, shape, dtype)) for rank in range(1, self.group.size()) @@ -288,10 +258,15 @@ def forward(self, tensor: TensorOrTensors) -> TensorOrTensors: # type: ignore dprint("forward after wait recv") dest_rank = self.group.size() - 1 dest = self.get_rpc_name(dest_rank) + dest_global_rank = _get_global_rank(self.group, dest_rank) + src_global_rank = torch.distributed.get_rank() dprint(f"async to {dest}") - rpc.rpc_async(dest, send_result, args=(self.model.training,)) + queue = EVENT_LOOP_QUEUE + message = PipeMessage(dest_global_rank, src_global_rank, queue_name=queue, tensor_count=num_tensors) + grads_message = PipeMessage(src_global_rank, dest_global_rank, queue_name=queue, tensor_count=num_tensors) + rpc.rpc_async(dest, send_result, args=(self.model.training, message, grads_message)) dprint(">>> recv_result") - result = recv_result(shape, dtype) + result = recv_result(shape, dtype, message) dprint("<<< recv_result") # event.set() dprint("not set event") @@ -302,7 +277,7 @@ def forward(self, tensor: TensorOrTensors) -> TensorOrTensors: # type: ignore for r in result: r.requires_grad_() - applied = PipeBackRedirect.apply(result, _get_global_rank(self.group, dest_rank), event) + applied = PipeBackRedirect.apply(result, _get_global_rank(self.group, dest_rank), event, grads_message) except Exception as e: dprint(f"failed got {e}") dprint("return applied") diff --git a/fairscale/nn/pipe/types.py b/fairscale/nn/pipe/types.py index ae5d09d92..d6859e345 100644 --- a/fairscale/nn/pipe/types.py +++ b/fairscale/nn/pipe/types.py @@ -51,15 +51,26 @@ class PipeMessage: tensor_dtypes: List[torch.dtype] tag: int = 0 - def __init__(self, src: int, dest: int, queue_name: int, args: Any, tensors: Tensors): + def __init__( + self, + src: int, + dest: int, + queue_name: int, + args: Any = None, + tensors: Optional[Tensors] = None, + tensor_count: int = 0, + ): self.src = src self.dest = dest self.queue_name = queue_name self.args = args - self.tensors = tensors + self.tensors = tensors or tuple() self.tensor_shapes = [] self.tensor_dtypes = [] global MessageGeneration self.tag = MessageGeneration - MessageGeneration += len(tensors) + if tensors is None: + MessageGeneration += tensor_count + else: + MessageGeneration += len(self.tensors) diff --git a/tests/nn/model_parallel/commons.py b/tests/nn/model_parallel/commons.py index 11428771c..19f70e31f 100644 --- a/tests/nn/model_parallel/commons.py +++ b/tests/nn/model_parallel/commons.py @@ -155,7 +155,11 @@ def replacement(*args, **kwargs): initialize_model_parallel(1, world_size) torch.cuda.set_device(torch.distributed.get_rank() % torch.cuda.device_count()) if world_size in world_sizes: - func(*args) + try: + func(*args) + except BaseException as e: + print(f"got exception {e} from test") + raise e else: pytest.skip(f"requested world size doesn't match current world size") else: From 2d075912a3969e4423989bb185a8acfbece39f00 Mon Sep 17 00:00:00 2001 From: Tom Birch Date: Wed, 21 Oct 2020 14:43:04 -0700 Subject: [PATCH 5/8] cleanup --- fairscale/nn/pipe/async_schedule.py | 55 ++++---------- fairscale/nn/pipe/messages.py | 24 +++++-- fairscale/nn/pipe/pipeline.py | 5 +- fairscale/nn/pipe/rpc.py | 107 +++++++++++++--------------- fairscale/nn/pipe/types.py | 1 + tests/nn/model_parallel/commons.py | 3 + 6 files changed, 88 insertions(+), 107 deletions(-) diff --git a/fairscale/nn/pipe/async_schedule.py b/fairscale/nn/pipe/async_schedule.py index 94557735d..93bd5e894 100644 --- a/fairscale/nn/pipe/async_schedule.py +++ b/fairscale/nn/pipe/async_schedule.py @@ -9,7 +9,14 @@ from fairscale.nn.model_parallel import get_pipeline_parallel_group, get_pipeline_parallel_ranks -from .messages import MESSAGE_TENSOR_SIZE, MessageQueues, send_message, tensor_to_pyobject, to_input_device +from .messages import ( + MESSAGE_TENSOR_SIZE, + MessageQueues, + recv_message_tensors, + send_message, + tensor_to_pyobject, + to_input_device, +) from .microbatch import Batch from .skip.tracker import SkipTrackerThroughPotals from .types import EVENT_LOOP_QUEUE, InputDevice, PipelineStyle, PipeMessage, Tensors, TransportConfig @@ -81,29 +88,6 @@ def backward(ctx, grad): return ctx.grad_from_pipeline -def recv_async_tensors( - rank: int, input_device: InputDevice, config: TransportConfig, message: PipeMessage -) -> PipeMessage: - if config.use_rpc: - # Tensors already contained within message - message.tensors = to_input_device(message.tensors, input_device) - dprint(f"recv_async_tensors {torch.distributed.get_rank()}, {len(message.tensors)}") - return message - else: - torch.cuda.current_stream().synchronize() - - message_tensors = [] - for index, (shape, dtype) in enumerate(zip(message.tensor_shapes, message.tensor_dtypes)): - t = torch.empty(*shape, dtype=dtype, device=input_device) - torch.distributed.recv(t, message.src, tag=message.tag + index, group=get_pipeline_parallel_group()) - message_tensors.append(t) - - message.tensors = tuple(message_tensors) - - torch.cuda.current_stream().synchronize() - return message - - class AsyncRecvOperator(torch.autograd.Function): """Receive activations to the previous pipeline stage""" @@ -116,7 +100,7 @@ def forward( ctx.config = config ctx.index = message.args.microbatch_index - result = recv_async_tensors(dst_rank, input_device, config, message) + result = recv_message_tensors(config, message) ctx.args = result.args @@ -239,25 +223,12 @@ def async_send_inner( return result def async_grad_inner(self, message: PipeMessage, activations: Activations, invocation: Invocation) -> None: - args: AsyncMessageBody = message.args - if self.transport_config.use_rpc: - recvd_grads = message - else: - recvd_grads = recv_async_tensors( - torch.distributed.get_rank(), self.input_device, self.transport_config, message - ) - - # FIXME tom + recvd_grads = recv_message_tensors(self.transport_config, message) - batch: Batch = activations[invocation.this.index][invocation.order][args.microbatch_index] + batch: Batch = activations[invocation.this.index][invocation.order][message.args.microbatch_index] - try: - batch.tensor.grad_fn.grad_from_pipeline = tuple(recvd_grads.tensors) # type: ignore - batch.tensor.backward(retain_graph=True) - return - except Exception as e: - print(f"hackity fail {e}") - raise e + batch.tensor.grad_fn.grad_from_pipeline = tuple(recvd_grads.tensors) # type: ignore + batch.tensor.backward(retain_graph=True) def process_batch_forward( self, diff --git a/fairscale/nn/pipe/messages.py b/fairscale/nn/pipe/messages.py index 4d803a507..d11b334e3 100644 --- a/fairscale/nn/pipe/messages.py +++ b/fairscale/nn/pipe/messages.py @@ -80,17 +80,31 @@ def send_message(config: TransportConfig, message: PipeMessage, sync: bool = Fal ) -def recv_message_tensors(input_device: InputDevice, config: TransportConfig, message: PipeMessage) -> PipeMessage: +def recv_message_header(transport_config: TransportConfig, input_device: InputDevice, queue_name: int) -> PipeMessage: + if transport_config.use_rpc: + queue = MessageQueues[queue_name] + result = queue.get() + result.tensors = to_input_device(result.tensors, input_device) + return result + else: + tensor = torch.empty(MESSAGE_TENSOR_SIZE, dtype=torch.uint8, device=input_device) + torch.cuda.current_stream().synchronize() + torch.distributed.recv(tensor, src=None, tag=queue_name, group=get_pipeline_parallel_group()) + torch.cuda.current_stream().synchronize() + return tensor_to_pyobject(tensor.cpu()) + + +def recv_message_tensors(config: TransportConfig, message: PipeMessage) -> PipeMessage: if config.use_rpc: # Tensors already contained within message - message.tensors = to_input_device(message.tensors, input_device) + message.tensors = to_input_device(message.tensors, config.input_device) return message else: torch.cuda.current_stream().synchronize() message_tensors = [] for index, (shape, dtype) in enumerate(zip(message.tensor_shapes, message.tensor_dtypes)): - t = torch.empty(*shape, dtype=dtype, device=input_device) + t = torch.empty(*shape, dtype=dtype, device=config.input_device) torch.distributed.recv(t, message.src, tag=message.tag + index, group=get_pipeline_parallel_group()) message_tensors.append(t) @@ -109,7 +123,7 @@ def recv_message( result = queue.get_nowait() else: result = queue.get() - return recv_message_tensors(input_device, config, result) + return recv_message_tensors(config, result) else: # FIXME(handle nowait) if nowait: @@ -121,7 +135,7 @@ def recv_message( torch.cuda.current_stream().synchronize() message = tensor_to_pyobject(tensor.cpu()) - return recv_message_tensors(input_device, config, message) + return recv_message_tensors(config, message) def get_out_of_order(config: TransportConfig, queue_name: int, index: int, *, input_device: InputDevice) -> Tensors: diff --git a/fairscale/nn/pipe/pipeline.py b/fairscale/nn/pipe/pipeline.py index cc18be600..936b481ed 100644 --- a/fairscale/nn/pipe/pipeline.py +++ b/fairscale/nn/pipe/pipeline.py @@ -22,7 +22,6 @@ from queue import Empty as QueueEmpty from queue import Queue from threading import Event -import time from types import TracebackType from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Type, Union, cast @@ -242,7 +241,9 @@ def __init__( self.group = group self.training: bool self.transport_config = TransportConfig( - use_rpc=("OMPI_COMM_WORLD_RANK" not in os.environ) or ("FORCE_RPC" in os.environ), worker_map=worker_map + use_rpc=("OMPI_COMM_WORLD_RANK" not in os.environ) or ("FORCE_RPC" in os.environ), + worker_map=worker_map, + input_device=input_device, ) self.input_device = input_device diff --git a/fairscale/nn/pipe/rpc.py b/fairscale/nn/pipe/rpc.py index 9fdd9ae89..a4408ef77 100644 --- a/fairscale/nn/pipe/rpc.py +++ b/fairscale/nn/pipe/rpc.py @@ -1,4 +1,4 @@ -import os +import logging from threading import Event, Lock, Thread from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast @@ -11,12 +11,13 @@ from . import Pipe from .messages import recv_message_tensors, send_message -from .types import EVENT_LOOP_QUEUE, PipeMessage, TensorOrTensors, TransportConfig +from .types import EVENT_LOOP_QUEUE, PipeMessage, TensorOrTensors DEFAULT_MAX_SOURCE_POSITIONS = 1024 DEFAULT_MAX_TARGET_POSITIONS = 1024 PipeModel: Pipe +PipeResult: TensorOrTensors SizeOrSizes = Union[torch.Size, List[torch.Size]] @@ -40,7 +41,8 @@ def register_remote_model(args: List[Any], kwargs: Dict[str, Any]) -> None: kwargs["input_device"] = torch.device("cuda", torch.cuda.current_device()) model = Pipe(*args, **kwargs) model.cuda() - globals()["PipeModel"] = model + global PipeModel + PipeModel = model def get_shapes(tensor: TensorOrTensors) -> SizeOrSizes: @@ -58,28 +60,22 @@ def get_dtype(tensor: TensorOrTensors) -> DtypeOrDtypes: def model_forward(training: bool, shape: torch.Size, dtype: torch.dtype) -> Optional[Tuple[SizeOrSizes, DtypeOrDtypes]]: - try: - dprint(f"mf: train stage {torch.distributed.get_rank()}") - if isinstance(shape, torch.Size): - tensor = torch.empty(shape, dtype=dtype) - else: - tensor = tuple([torch.empty(s, dtype=d) for s, d in zip(shape, dtype)]) - - model = globals()["PipeModel"] - set_device_based_on_group(model.group) + if isinstance(shape, torch.Size): + tensor = torch.empty(shape, dtype=dtype) + else: + tensor = tuple([torch.empty(s, dtype=d) for s, d in zip(shape, dtype)]) - dprint(f"mf: train stage {model.group.rank()}, {os.getpid()}") - model.train(training) - result = model(tensor) - if model.final_stage: - globals()["PipeResult"] = result - return (get_shapes(result), get_dtype(result)) - except Exception as e: - print(f"failboat {e} {type(e)}") - import traceback + global PipeModel + model = PipeModel + assert model.group + set_device_based_on_group(model.group) - print(f"format {traceback.format_exc()}") - raise e + model.train(training) + result = model(tensor) + if model.final_stage: + global PipeResult + PipeResult = result + return (get_shapes(result), get_dtype(result)) return None @@ -90,11 +86,12 @@ def send_result(training: bool, message: PipeMessage, grads_message: PipeMessage set_device_based_on_group(group) try: dprint(f"send result {torch.distributed.get_rank()}, {torch.cuda.current_device()}") - result = globals()["PipeResult"] - model = globals()["PipeModel"] + result = PipeResult + global PipeModel + model = PipeModel if isinstance(result, torch.Tensor): - result = [result] + result = tuple([result]) dest = _get_global_rank(group, 0) @@ -102,15 +99,14 @@ def send_result(training: bool, message: PipeMessage, grads_message: PipeMessage f"ho har {torch.distributed.get_rank()} " + str([_get_global_rank(group, i) for i in range(group.size())]) ) message.tensors = tuple(result) - config = TransportConfig(False, None) + assert model.pipeline + config = model.pipeline.transport_config send_message(config, message, sync=False, skip_header=True) if training: grads_message.tensor_shapes = [r.shape for r in result] grads_message.tensor_dtypes = [r.dtype for r in result] - input_device = torch.device("cuda", torch.cuda.current_device()) - transport_config = TransportConfig(False, None) - grads_message = recv_message_tensors(input_device, transport_config, grads_message) + grads_message = recv_message_tensors(config, grads_message) with model.lock: print(f" >>> autograd-backward tail") @@ -121,28 +117,28 @@ def send_result(training: bool, message: PipeMessage, grads_message: PipeMessage print(f"got {e}") -def recv_result(shapes: SizeOrSizes, dtypes: DtypeOrDtypes, message: PipeMessage) -> TensorOrTensors: +def recv_result(model: Pipe, shapes: SizeOrSizes, dtypes: DtypeOrDtypes, message: PipeMessage) -> TensorOrTensors: group = get_pipeline_parallel_group() set_device_based_on_group(group) src = torch.distributed.distributed_c10d._get_global_rank(group, group.size() - 1) dprint(f"recv_result... {src}, {torch.cuda.current_device()}") - input_device = torch.device("cuda", torch.cuda.current_device()) - transport_config = TransportConfig(False, None) + assert model.pipeline + transport_config = model.pipeline.transport_config if isinstance(shapes, torch.Size): shape = cast(torch.Size, shapes) dtype = cast(torch.dtype, dtypes) message.tensor_shapes = [shape] message.tensor_dtypes = [dtype] - message = recv_message_tensors(input_device, transport_config, message) + message = recv_message_tensors(transport_config, message) return message.tensors[0] else: shapes = cast(List[torch.Size], shapes) dtypes = cast(List[torch.dtype], dtypes) message.tensor_shapes = shapes message.tensor_dtypes = dtypes - message = recv_message_tensors(input_device, transport_config, message) + message = recv_message_tensors(transport_config, message) return message.tensors @@ -163,22 +159,23 @@ def run_model(model: Pipe, tensor: TensorOrTensors, event: Event, lock: Lock) -> class PipeBackRedirect(torch.autograd.Function): @staticmethod # type: ignore - def forward(ctx, inputs, dest, event, message): + def forward(ctx, inputs, dest, event, message, transport_config): ctx.dest = dest ctx.event = event ctx.message = message + ctx.transport_config = transport_config return inputs @staticmethod # type: ignore def backward(ctx, *grad): - dprint(f">>> back hook yay") - config = TransportConfig(False, None) + logging.debug(f">>> PipeBackRedirect.backward") + global PipeModel ctx.message.tensors = tuple(grad) - send_message(config, ctx.message, sync=False, skip_header=True) + send_message(ctx.transport_config, ctx.message, sync=False, skip_header=True) ctx.event.set() - dprint(f"<<< back hook yay") - return (None, None, None, None) + logging.debug(f"<<< PipeBackRedirect.backward") + return (None, None, None, None, None) def callback_with_model(callback: Callable, ctx: Any) -> None: @@ -186,7 +183,6 @@ def callback_with_model(callback: Callable, ctx: Any) -> None: set_device_based_on_group(group) global PipeModel - with PipeModel.lock: callback(ctx, PipeModel) @@ -253,7 +249,6 @@ def forward(self, tensor: TensorOrTensors) -> TensorOrTensors: # type: ignore t = Thread(target=run_model, args=(self.model, tensor, event, self.lock)) t.start() - dprint("forward before wait recv") shape, dtype = futures[-1].wait() dprint("forward after wait recv") dest_rank = self.group.size() - 1 @@ -266,22 +261,18 @@ def forward(self, tensor: TensorOrTensors) -> TensorOrTensors: # type: ignore grads_message = PipeMessage(src_global_rank, dest_global_rank, queue_name=queue, tensor_count=num_tensors) rpc.rpc_async(dest, send_result, args=(self.model.training, message, grads_message)) dprint(">>> recv_result") - result = recv_result(shape, dtype, message) + result = recv_result(self.model, shape, dtype, message) dprint("<<< recv_result") - # event.set() - dprint("not set event") - try: - if isinstance(result, torch.Tensor): - result.requires_grad_() - else: - for r in result: - r.requires_grad_() - - applied = PipeBackRedirect.apply(result, _get_global_rank(self.group, dest_rank), event, grads_message) - except Exception as e: - dprint(f"failed got {e}") - dprint("return applied") - return applied + if isinstance(result, torch.Tensor): + result.requires_grad_() + else: + for r in result: + r.requires_grad_() + + assert self.model.pipeline + return PipeBackRedirect.apply( + result, dest_global_rank, event, grads_message, self.model.pipeline.transport_config + ) @property def final_stage(self) -> bool: diff --git a/fairscale/nn/pipe/types.py b/fairscale/nn/pipe/types.py index d6859e345..b418df9ba 100644 --- a/fairscale/nn/pipe/types.py +++ b/fairscale/nn/pipe/types.py @@ -38,6 +38,7 @@ class PipelineStyle(Enum): class TransportConfig: use_rpc: bool worker_map: Optional[Dict[int, str]] + input_device: InputDevice @dataclass(init=False) diff --git a/tests/nn/model_parallel/commons.py b/tests/nn/model_parallel/commons.py index 19f70e31f..26da89753 100644 --- a/tests/nn/model_parallel/commons.py +++ b/tests/nn/model_parallel/commons.py @@ -159,6 +159,9 @@ def replacement(*args, **kwargs): func(*args) except BaseException as e: print(f"got exception {e} from test") + import traceback + + print(f"{traceback.format_exc()}") raise e else: pytest.skip(f"requested world size doesn't match current world size") From 151263ec3e17af1c288cd546d699b405c1529d59 Mon Sep 17 00:00:00 2001 From: Tom Birch Date: Thu, 22 Oct 2020 10:22:56 -0700 Subject: [PATCH 6/8] more review comments --- benchmarks/pipe.py | 12 +- fairscale/nn/model_parallel/initialize.py | 2 - fairscale/nn/model_parallel/mappings.py | 10 - fairscale/nn/pipe/async_schedule.py | 261 +++++++++------------- fairscale/nn/pipe/messages.py | 196 ++++++++-------- fairscale/nn/pipe/pipe.py | 13 -- fairscale/nn/pipe/pipeline.py | 70 ++---- fairscale/nn/pipe/rpc.py | 113 ++++------ fairscale/nn/pipe/types.py | 14 +- fairscale/utils/object.py | 29 +++ tests/nn/pipe_process/test_pipe.py | 4 +- tests/nn/pipe_process/test_rpc.py | 39 +--- 12 files changed, 307 insertions(+), 456 deletions(-) create mode 100644 fairscale/utils/object.py diff --git a/benchmarks/pipe.py b/benchmarks/pipe.py index 6d3c469ef..4dcf00868 100644 --- a/benchmarks/pipe.py +++ b/benchmarks/pipe.py @@ -530,13 +530,10 @@ def bench_single_process(args): blob = make_model_and_data(args, None, new_data=new_data) model = blob["model"] - if args.deepspeed: - p = PipelineModule(layers=model, num_stages=min(num_devices, 4)) - else: - balance = generate_balance(min(num_devices, 4), len(model)) - p = pipe.Pipe( - model, balance, chunks=args.chunks, pipelined_backward=args.pipelined_backward, checkpoint=args.checkpoint - ) + balance = generate_balance(min(num_devices, 4), len(model)) + p = pipe.Pipe( + model, balance, chunks=args.chunks, pipelined_backward=args.pipelined_backward, checkpoint=args.checkpoint + ) del model del blob["model"] @@ -665,7 +662,6 @@ def bench_mpi(args): parser.add_argument("--socket-name", type=str, default=None, help="socket ifname for gloo/tp") parser.add_argument("--num-decoder-layers", type=int, default=10, help="Number of decoder layers in the model") parser.add_argument("--ddp-zero", action="store_true", default=False, help="enable ddp") -parser.add_argument("--deepspeed", action="store_true", default=False, help="use eepspeed instead of fairscale pipe") parser.add_argument( "--lazy-construction", action="store_true", default=False, help="Number of decoder layers in the model" ) diff --git a/fairscale/nn/model_parallel/initialize.py b/fairscale/nn/model_parallel/initialize.py index fb488fa1f..c42a61dcc 100644 --- a/fairscale/nn/model_parallel/initialize.py +++ b/fairscale/nn/model_parallel/initialize.py @@ -64,8 +64,6 @@ def initialize_model_parallel( with a total of 16 GPUs, rank 0 to 7 belong to the first box and ranks 8 to 15 belong to the second box. """ - # if torch.distributed.get_rank() == 0: - # print("> initializing model parallel with size {}".format(model_parallel_size_)) # Get world size and rank. Ensure some consistencies. assert torch.distributed.is_initialized() world_size = torch.distributed.get_world_size() diff --git a/fairscale/nn/model_parallel/mappings.py b/fairscale/nn/model_parallel/mappings.py index c91c01cf5..78d0961c5 100644 --- a/fairscale/nn/model_parallel/mappings.py +++ b/fairscale/nn/model_parallel/mappings.py @@ -39,9 +39,7 @@ def _reduce(ctx: Any, input_: torch.Tensor) -> torch.Tensor: return input_ # All-reduce. - print(f">> doing all_reduce on {torch.distributed.get_rank()}") torch.distributed.all_reduce(input_, group=group) - print(f"<< doing all_reduce on {torch.distributed.get_rank()}") return input_ @@ -94,12 +92,10 @@ class _CopyToModelParallelRegion(torch.autograd.Function): @staticmethod def forward(ctx, input_): # type: ignore - print(f"{torch.distributed.get_rank()}: _CopyToModelParallelRegion Forward") return input_ @staticmethod def backward(ctx, grad_output): # type: ignore - print(f"{torch.distributed.get_rank()}: _CopyToModelParallelRegion Backward") return _reduce(None, grad_output) @@ -108,12 +104,10 @@ class _ReduceFromModelParallelRegion(torch.autograd.Function): @staticmethod def forward(ctx, input_): # type: ignore - print(f"{torch.distributed.get_rank()}: _ReduceFromModelParallelRegion Forward") return _reduce(ctx, input_) @staticmethod def backward(ctx, grad_output): # type: ignore - print(f"{torch.distributed.get_rank()}: _ReduceFromModelParallelRegion Backward") return grad_output @@ -122,12 +116,10 @@ class _ScatterToModelParallelRegion(torch.autograd.Function): @staticmethod def forward(ctx, input_): # type: ignore - print(f"{torch.distributed.get_rank()}: _ScatterToModelParallelRegion Forward") return _split(input_) @staticmethod def backward(ctx, grad_output): # type: ignore - print(f"{torch.distributed.get_rank()}: _ScatterToModelParallelRegion Backward") return _gather(grad_output) @@ -136,12 +128,10 @@ class _GatherFromModelParallelRegion(torch.autograd.Function): @staticmethod def forward(ctx, input_): # type: ignore - print(f"{torch.distributed.get_rank()}: _GatherFromModelParallelRegion Forward") return _gather(input_) @staticmethod def backward(ctx, grad_output): # type: ignore - print(f"{torch.distributed.get_rank()}: _GatherFromModelParallelRegion Backward") return _split(grad_output) diff --git a/fairscale/nn/pipe/async_schedule.py b/fairscale/nn/pipe/async_schedule.py index 93bd5e894..d999d1664 100644 --- a/fairscale/nn/pipe/async_schedule.py +++ b/fairscale/nn/pipe/async_schedule.py @@ -1,3 +1,9 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from collections import OrderedDict from enum import Enum, auto from threading import Event from typing import Dict, Iterable, List, Optional, Tuple @@ -7,25 +13,12 @@ from torch import Tensor, nn from torch.distributed import ProcessGroup -from fairscale.nn.model_parallel import get_pipeline_parallel_group, get_pipeline_parallel_ranks +from fairscale.nn.model_parallel import get_pipeline_parallel_ranks -from .messages import ( - MESSAGE_TENSOR_SIZE, - MessageQueues, - recv_message_tensors, - send_message, - tensor_to_pyobject, - to_input_device, -) +from .messages import Transport from .microbatch import Batch from .skip.tracker import SkipTrackerThroughPotals -from .types import EVENT_LOOP_QUEUE, InputDevice, PipelineStyle, PipeMessage, Tensors, TransportConfig - -Activations = Dict[int, Dict[int, Dict[int, Batch]]] - - -def dprint(x: str) -> None: - pass +from .types import EVENT_LOOP_QUEUE, PipelineStyle, PipeMessage, Tensors @dataclass(frozen=True) @@ -45,6 +38,10 @@ class Invocation: dest: Optional[Location] +Activations = Dict[int, Dict[int, Dict[int, Batch]]] +Invocations = Dict[int, Invocation] + + class ModuleWrapper: def __init__(self, module: nn.Sequential, location: Location, invocations: Optional[List[Invocation]] = None): self.module: nn.Sequential = module @@ -75,7 +72,11 @@ class AsyncMessageBody: order: int -class Hackity(torch.autograd.Function): +class AutogradWithoutActivations(torch.autograd.Function): + """A helper class to add another edge in the autograd graph which allows us + to delete the potentially large activations and still perform a backward + pass. Returns return a phony tensor which is connected to the graph.""" + @staticmethod # type: ignore def forward(ctx, *x): @@ -93,14 +94,11 @@ class AsyncRecvOperator(torch.autograd.Function): @staticmethod # type: ignore - def forward( - ctx, dst_rank: int, phony: Tensor, input_device, config: TransportConfig, message: PipeMessage - ) -> Tensors: - assert dst_rank == torch.distributed.get_rank() - ctx.config = config + def forward(ctx, phony: Tensor, transport: Transport, message: PipeMessage) -> Tensors: + ctx.transport = transport ctx.index = message.args.microbatch_index - result = recv_message_tensors(config, message) + result = transport.recv_message_tensors(message) ctx.args = result.args @@ -116,89 +114,67 @@ def maybe_requires_grad(t: Tensor) -> Tensor: def backward(ctx, *grad: Tensor,) -> Tuple[Optional[Tensor], ...]: ranks = get_pipeline_parallel_ranks() this_rank = torch.distributed.get_rank() - dprint(f"AsyncRecvOperator back {this_rank} {len(grad)}, {ctx.args}") - # Note that dst/source are swaped due to backward pass, maybe abstract - # this out? body = AsyncMessageBody( AsyncMessageType.Gradients, ctx.index, source=ctx.args.dest, dest=ctx.args.source, order=ctx.args.order - 1 ) - dprint(f"AsyncRecvOperator 2 back {this_rank} {len(grad)}") - send_message( - ctx.config, + ctx.transport.send_message( PipeMessage( this_rank, ranks[ctx.args.source.stage], queue_name=EVENT_LOOP_QUEUE, args=body, tensors=tuple(grad), ), sync=True, ) - dprint(f"AsyncRecvOperator 3 back {this_rank} {len(grad)}") return (None, None, None, None, None) -def recv_async_header(transport_config: TransportConfig, input_device: InputDevice) -> PipeMessage: - if transport_config.use_rpc: - queue = MessageQueues[EVENT_LOOP_QUEUE] - result = queue.get() - result.tensors = to_input_device(result.tensors, input_device) - return result - else: - dprint(f"cactus") - tensor = torch.empty(MESSAGE_TENSOR_SIZE, dtype=torch.uint8, device=input_device) - torch.cuda.current_stream().synchronize() - torch.distributed.recv(tensor, src=None, tag=EVENT_LOOP_QUEUE, group=get_pipeline_parallel_group()) - torch.cuda.current_stream().synchronize() - dprint(f"cactus2") - return tensor_to_pyobject(tensor.cpu()) - - class AsyncEventLoop: def __init__( self, partitions: List[ModuleWrapper], group: ProcessGroup, - transport_config: TransportConfig, + transport: Transport, training: bool, - input_device: InputDevice, checkpoint_stop: int, ): self.training = training - self.input_device = input_device self.checkpoint_stop = checkpoint_stop - self.transport_config = transport_config + self.transport = transport self.group = group self.partitions: List[ModuleWrapper] = partitions - def send_async_message( - self, src_rank: int, dst_rank: int, input: List[Tensor], index: int, invocation: Invocation - ) -> None: - assert src_rank == torch.distributed.get_rank() + def send_async_message(self, dst_rank: int, result: Batch, invocation: Invocation) -> Batch: + """Send batch to dst_rank, and use AutogradWithoutActivations to delete + the activations since we no longer need them""" + assert invocation.dest + src_rank = torch.distributed.get_rank() body = AsyncMessageBody( - AsyncMessageType.Activations, index, invocation.this, invocation.dest, invocation.order + 1 + AsyncMessageType.Activations, result.index, invocation.this, invocation.dest, invocation.order + 1 ) - dprint(f">>> send batch {src_rank} {dst_rank} {len(input)} {invocation.order}") - send_message( - self.transport_config, - PipeMessage(src_rank, dst_rank, queue_name=EVENT_LOOP_QUEUE, args=body, tensors=tuple(input)), + self.transport.send_message( + PipeMessage(src_rank, dst_rank, queue_name=EVENT_LOOP_QUEUE, args=body, tensors=tuple([*result])), sync=True, ) - dprint(f"<<< send batch {src_rank} {dst_rank} {len(input)} {invocation.order}") - def async_send_inner( + phony = AutogradWithoutActivations.apply(*result) + return Batch(phony, result.index) + + def run_invocation( self, batch: Batch, partition: ModuleWrapper, - index: int, skip_trackers: List[SkipTrackerThroughPotals], invocation: Invocation, ) -> Batch: + """Actually run the forward pass for a given module, and send the result + to the next stage in the pipeline if needed.""" assert self.group from .pipeline import create_task task = create_task( PipelineStyle.AsyncSchedule, self.checkpoint_stop, - index, + batch.index, self.group.rank(), batch, partition.module, @@ -210,91 +186,80 @@ def async_send_inner( if invocation.dest and invocation.dest.stage != invocation.this.stage: ranks = get_pipeline_parallel_ranks() - this_rank = torch.distributed.get_rank() - - # self.send_skip_tensors(this_rank, ranks, batch, i, skip_trackers) - dprint(f"sending to next stage from {this_rank}...{invocation}, {index}") - self.send_async_message(this_rank, ranks[invocation.dest.stage], [*result], index, invocation) - z = Hackity.apply(*result) - result = Batch(z, result.index) - dprint(f"empty yay!") - else: - dprint(f"not sending to next stage...{invocation.this}, {invocation.dest}") + dst_rank = ranks[invocation.dest.stage] + result = self.send_async_message(dst_rank, result, invocation) return result def async_grad_inner(self, message: PipeMessage, activations: Activations, invocation: Invocation) -> None: - recvd_grads = recv_message_tensors(self.transport_config, message) + recvd_grads = self.transport.recv_message_tensors(message) batch: Batch = activations[invocation.this.index][invocation.order][message.args.microbatch_index] + # All batches saved in `activations` are generated by AutogradWithoutActivations, + # so we store the gradients in `grad_from_pipeline` so it will be used + # during the backward pass batch.tensor.grad_fn.grad_from_pipeline = tuple(recvd_grads.tensors) # type: ignore batch.tensor.backward(retain_graph=True) - def process_batch_forward( + def run_invocations_on_batch( self, batch: Batch, - i: int, - invocations: List[Invocation], + invocations: Invocations, order: int, skip_trackers: List[SkipTrackerThroughPotals], activations: Activations, ) -> Tuple[int, int]: + """Run invocations on the batch until we hit one that receives its input + from a different stage (i.e. another process)""" + invocations_handled = 0 last_order = 0 - for invocation in invocations: + for invocation in invocations.values(): if invocation.order < order: continue pi = invocation.this.index partition = self.partitions[pi] - dprint(f"{self.group.rank()}: pbb {invocation}, {order}, {self.group.rank()}") if invocation.order == order: - dprint(f"{self.group.rank()}: assigning {pi}, {invocation.order}, {i}") invocations_handled += 1 last_order = invocation.order - activations[pi][invocation.order][i] = self.async_send_inner( - batch, partition, i, skip_trackers, invocation + activations[pi][invocation.order][batch.index] = self.run_invocation( + batch, partition, skip_trackers, invocation ) elif invocation.source and invocation.source.stage == self.group.rank(): - dprint( - f"{self.group.rank()}: reading {invocation}, {invocation.source.index}, {invocation.order-1}, {i}" - ) invocations_handled += 1 last_order = invocation.order - batch = activations[invocation.source.index][invocation.order - 1][i] - dprint(f"{self.group.rank()}: assigning {pi}, {invocation.order}, {i}") - activations[pi][invocation.order][i] = self.async_send_inner( - batch, partition, i, skip_trackers, invocation + batch = activations[invocation.source.index][invocation.order - 1][batch.index] + activations[pi][invocation.order][batch.index] = self.run_invocation( + batch, partition, skip_trackers, invocation ) - del activations[invocation.source.index][invocation.order - 1][i] + del activations[invocation.source.index][invocation.order - 1][batch.index] elif invocation.source and invocation.source.stage != self.group.rank(): break - dprint(f"pbb {self.group.rank()} {invocations_handled}") return (invocations_handled, last_order) def event_loop_head( self, batches: List[Batch], skip_trackers: List[SkipTrackerThroughPotals], event: Optional[Event] ) -> None: + """The event loop for the "head", which first performs the forward pass + on any applicable layers for this stage, and then enters the common + `event_loop_inner`""" - invocations, activations = self.get_sorted_invocations_and_activations() + invocations, activations = self.get_invocations_and_activations() expected_invocations = len(invocations) * len(batches) actual_invocations = 0 count_per_order = dict() - dprint(f"head loop start {torch.distributed.get_rank()}") - for i, batch in enumerate(batches): - dprint(f"head loop iter {torch.distributed.get_rank()}, {i}") - inv_count, last_order = self.process_batch_forward(batch, i, invocations, 0, skip_trackers, activations) + for batch in batches: + inv_count, last_order = self.run_invocations_on_batch(batch, invocations, 0, skip_trackers, activations) actual_invocations += inv_count count_per_order[last_order] = inv_count - dprint(f"head wat {actual_invocations}, {expected_invocations}") if actual_invocations < expected_invocations or self.training: - dprint(f"head extra {actual_invocations}, {expected_invocations}") self.event_loop_inner( expected_invocations, skip_trackers, @@ -305,51 +270,47 @@ def event_loop_head( event=event, ) - # if self.pipeline.training: - # for _ in range(len(batches)): - # message = self.recv_async_header() - # args: AsyncMessageBody = message.args - # assert args.message_type is AsyncMessageType.Gradients - # self.async_grad_inner(message, activations) + def get_batch_from_message(self, message: PipeMessage) -> Batch: + microbatch_index = message.args.microbatch_index + phony = torch.empty(0, device=self.transport.input_device, requires_grad=True) + result = AsyncRecvOperator.apply(phony, self.transport, message) + if len(result) == 1: + batch = Batch(result[0], microbatch_index) + else: + batch = Batch(result, microbatch_index) + return batch def event_loop_tail(self, batches: List[Batch], skip_trackers: List[SkipTrackerThroughPotals]) -> None: + """The event loop for the "tail", or final stage which only processes + activations and then returns to the caller so that the loss can be + calculated. This also handles the first/only stage for the special + case of a 1-stage pipeline.""" + assert self.group - invocations, activations = self.get_sorted_invocations_and_activations() + invocations, activations = self.get_invocations_and_activations() expected_invocations = len(invocations) * len(batches) actual_invocations = 0 rank = self.group.rank() count_per_order = dict() - for i, batch in enumerate(batches): + for batch in batches: if rank == 0: - batch_index = i order = 0 else: - message = recv_async_header(self.transport_config, self.input_device) + message = self.transport.recv_message_header(EVENT_LOOP_QUEUE) args: AsyncMessageBody = message.args - phony = torch.empty(0, device=self.input_device, requires_grad=True) - result = AsyncRecvOperator.apply( - torch.distributed.get_rank(), phony, self.input_device, self.transport_config, message, - ) - if len(result) == 1: - batch = Batch(result[0], args.microbatch_index) - else: - batch = Batch(result, args.microbatch_index) - batch_index = args.microbatch_index + batch = self.get_batch_from_message(message) order = args.order - inv_count, last_order = self.process_batch_forward( - batch, batch_index, invocations, order, skip_trackers, activations - ) + inv_count, last_order = self.run_invocations_on_batch(batch, invocations, order, skip_trackers, activations) actual_invocations += inv_count count_per_order[last_order] = inv_count if actual_invocations < expected_invocations: expected_gradients = 0 # (len(invocations) - 1) * len(batches) - dprint(f"tail expect {expected_invocations}, {len(invocations)}, {len(batches)}") self.event_loop_inner( expected_invocations, @@ -361,31 +322,33 @@ def event_loop_tail(self, batches: List[Batch], skip_trackers: List[SkipTrackerT ignore_gradients=True, ) - for index, batch in activations[len(self.partitions) - 1][invocations[-1].order].items(): + _, last_invocation = invocations.popitem() + + for index, batch in activations[len(self.partitions) - 1][last_invocation.order].items(): batches[index] = batch - def get_sorted_invocations_and_activations(self) -> Tuple[List[Invocation], Activations]: + def get_invocations_and_activations(self) -> Tuple[Invocations, Activations]: activations: Activations = dict() - invocations: List[Invocation] = [] + invocations: Invocations = OrderedDict() for pi, partition in enumerate(self.partitions): activations[pi] = dict() for invocation in partition.invocations: activations[pi][invocation.order] = dict() - invocations.append(invocation) + invocations[invocation.order] = invocation - invocations.sort(key=lambda inv: inv.order) + invocations = OrderedDict(sorted(invocations.items(), key=lambda entry: entry[0])) return (invocations, activations) def event_loop(self, num_microbatch: int, skip_trackers: List[SkipTrackerThroughPotals]) -> None: + """The event loop for the "middle", i.e. neither the head nor the tail""" assert self.group - invocations, activations = self.get_sorted_invocations_and_activations() + invocations, activations = self.get_invocations_and_activations() expected_invocations = len(invocations) * num_microbatch - dprint(f"event_loop {expected_invocations}, {num_microbatch}, {len(invocations)}") self.event_loop_inner(expected_invocations, skip_trackers, activations, invocations, dict()) def event_loop_inner( @@ -393,13 +356,16 @@ def event_loop_inner( expected_invocations: int, skip_trackers: List[SkipTrackerThroughPotals], activations: Activations, - invocations: List[Invocation], + invocations: Invocations, count_per_order: Dict[int, int], *, already_received: int = 0, ignore_gradients: bool = False, event: Optional[Event] = None, ) -> None: + """The common event loop shared by all stages. This processses + activations for the forward pass, and if `self.training` is true, + handles gradients as well for the backward pass.""" num_activations = already_received if self.training and not ignore_gradients: @@ -408,23 +374,13 @@ def event_loop_inner( num_gradients = expected_invocations while num_activations < expected_invocations or num_gradients < expected_invocations: - dprint( - f">> recv_async_header {self.group.rank()}, {torch.distributed.get_rank()} {expected_invocations}," - f" {num_activations}, {num_gradients}, {ignore_gradients}" - ) if num_activations == expected_invocations and num_gradients == 0 and event is not None: - print(f">>> wait on event") event.wait() - print(f"<<< wait on event") - message = recv_async_header(self.transport_config, self.input_device) - dprint(f"<< recv_async_header {torch.distributed.get_rank()}") + message = self.transport.recv_message_header(EVENT_LOOP_QUEUE) args: AsyncMessageBody = message.args - filtered = [inv for inv in invocations if inv.order == args.order] - if len(filtered) == 0: - dprint(f"no invocation on {self.group.rank()} for {args.order}, {invocations}") - invocation = filtered[0] + invocation = invocations[args.order] # FIXME(tom) for combining pipeline with megatron, I currently don't # control the order of received activations or gradients, so it is @@ -432,32 +388,15 @@ def event_loop_inner( # a different order of activations w.r.t. the sending stage, which # would result in incorrect values being used for the all_gather if args.message_type is AsyncMessageType.Activations: + batch = self.get_batch_from_message(message) - phony = torch.empty(0, device=self.input_device, requires_grad=True) - result = AsyncRecvOperator.apply( - torch.distributed.get_rank(), phony, self.input_device, self.transport_config, message, - ) - - dprint( - f"got batch {torch.distributed.get_rank()}|{self.group.rank()} i:{args.microbatch_index}" - f" len:{len(result)}, {invocation}" - ) - - if len(result) == 1: - batch = Batch(result[0], args.microbatch_index) - else: - batch = Batch(result, args.microbatch_index) - - dprint(f"calling pbb? {self.group.rank()}, {expected_invocations}, {num_activations}, {num_gradients}") - inv_count, last_order = self.process_batch_forward( - batch, args.microbatch_index, invocations, args.order, skip_trackers, activations + inv_count, last_order = self.run_invocations_on_batch( + batch, invocations, args.order, skip_trackers, activations ) count_per_order[last_order] = inv_count num_activations += inv_count assert num_activations <= expected_invocations elif args.message_type is AsyncMessageType.Gradients: - dprint(f">> try {self.group.rank()}, {invocation.order}, {count_per_order}, {num_gradients}") num_gradients += count_per_order[invocation.order] self.async_grad_inner(message, activations, invocation) - dprint(f"<< try {self.group.rank()}, {invocation.order}, {count_per_order}, {num_gradients}") diff --git a/fairscale/nn/pipe/messages.py b/fairscale/nn/pipe/messages.py index d11b334e3..0613274d7 100644 --- a/fairscale/nn/pipe/messages.py +++ b/fairscale/nn/pipe/messages.py @@ -1,47 +1,26 @@ -import pickle +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC from queue import Empty as QueueEmpty from queue import Queue -from typing import Any, List +from typing import Dict, List, Optional -import numpy as np +from dataclasses import dataclass import torch -from torch import Tensor from fairscale.nn.model_parallel import get_pipeline_parallel_group +from fairscale.utils.object import pyobject_to_tensor, tensor_to_pyobject -from .types import MESSAGE_GENERATION_START, InputDevice, PipeMessage, Tensors, TransportConfig +from .types import MESSAGE_GENERATION_START, InputDevice, PipeMessage, Tensors -# FIXME Why is 256 ok for training but not for tests? -MESSAGE_TENSOR_SIZE = 1024 # 256 +MESSAGE_TENSOR_SIZE = 1024 MessageQueues: List[Queue] = [Queue() for _ in range(MESSAGE_GENERATION_START)] -def pyobject_to_tensor(obj: Any) -> Tensor: - pickled = pickle.dumps(obj) - nparray = np.frombuffer(pickled, dtype=np.uint8).copy() - nparray.setflags(write=True) - result = torch.from_numpy(nparray) - delta = MESSAGE_TENSOR_SIZE - len(result) - if delta < 0: - raise ValueError( - f"message too big to send, increase MESSAGE_TENSOR_SIZE? - {len(result)} > {MESSAGE_TENSOR_SIZE}" - ) - elif delta > 0: - result = torch.cat((result, torch.zeros(delta, dtype=torch.uint8))) - - return result.cuda() - - -def tensor_to_pyobject(tensor: Tensor) -> Any: - try: - nparray = tensor.numpy() - return pickle.loads(nparray.tobytes()) - except Exception as e: - print(f"pickle fail {e}") - raise e - - def to_input_device(tensors: Tensors, input_device: InputDevice) -> Tensors: if input_device is None: return tensors @@ -53,16 +32,81 @@ def rpc_push_queue(message: PipeMessage) -> None: globals()["MessageQueues"][message.queue_name].put(message) -def send_message(config: TransportConfig, message: PipeMessage, sync: bool = False, skip_header: bool = False) -> None: - if config.use_rpc: +@dataclass(frozen=True) +class Transport(ABC): + worker_map: Optional[Dict[int, str]] + input_device: InputDevice + + def recv_message(self, queue_name: int, *, nowait: bool = False) -> PipeMessage: + message = self.recv_message_header(queue_name, nowait) + return self.recv_message_tensors(message) + + def recv_message_header(self, queue_name: int, nowait: bool = False) -> PipeMessage: + ... + + def recv_message_tensors(self, message: PipeMessage) -> PipeMessage: + ... + + def send_message(self, message: PipeMessage, sync: bool = False, skip_header: bool = False) -> None: + ... + + def get_out_of_order(self, queue_name: int, index: int) -> Tensors: + ... + + +def MakeTransport(use_rpc: bool, worker_map: Optional[Dict[int, str]], input_device: InputDevice) -> Transport: + if use_rpc: + if worker_map is None: + raise ValueError("'RpcTransport' requires 'worker_map' to be set") + return RpcTransport(worker_map, input_device) + else: + return SendRecvTransport(worker_map, input_device) + + +class RpcTransport(Transport): + def send_message(self, message: PipeMessage, sync: bool = False, skip_header: bool = False) -> None: message.tensors = tuple(t.cpu() for t in message.tensors) - assert config.worker_map - name = config.worker_map[message.dest] + assert self.worker_map + name = self.worker_map[message.dest] if sync: torch.distributed.rpc.rpc_sync(name, rpc_push_queue, args=(message,)) else: torch.distributed.rpc.rpc_async(name, rpc_push_queue, args=(message,)) - else: + + def recv_message_header(self, queue_name: int, nowait: bool = False) -> PipeMessage: + queue = MessageQueues[queue_name] + if nowait: + result = queue.get_nowait() + else: + result = queue.get() + result.tensors = to_input_device(result.tensors, self.input_device) + return result + + def recv_message_tensors(self, message: PipeMessage) -> PipeMessage: + # Tensors already contained within message + message.tensors = to_input_device(message.tensors, self.input_device) + return message + + def get_out_of_order(self, queue_name: int, index: int) -> Tensors: + """Receive a message with a known microbatch index, and handle out-of-order + messages by placing them back on the queue""" + + queue = globals()["MessageQueues"][queue_name] + out_of_order: List[PipeMessage] = [] + while True: + message = self.recv_message(queue_name) + got_index = message.args + value = message.tensors + if got_index == index: + for b in out_of_order: + queue.put(b) + return value + else: + out_of_order.append(message) + + +class SendRecvTransport(Transport): + def send_message(self, message: PipeMessage, sync: bool = False, skip_header: bool = False) -> None: tensors = message.tensors message.tensors = tuple() torch.cuda.current_stream().synchronize() @@ -70,7 +114,10 @@ def send_message(config: TransportConfig, message: PipeMessage, sync: bool = Fal message.tensor_shapes = [t.size() for t in tensors] message.tensor_dtypes = [t.dtype for t in tensors] torch.distributed.send( - pyobject_to_tensor(message), message.dest, tag=message.queue_name, group=get_pipeline_parallel_group() + pyobject_to_tensor(message, MESSAGE_TENSOR_SIZE).cuda(), + message.dest, + tag=message.queue_name, + group=get_pipeline_parallel_group(), ) for index, t in enumerate(tensors): if t.device.type == "cpu": @@ -79,32 +126,22 @@ def send_message(config: TransportConfig, message: PipeMessage, sync: bool = Fal t.contiguous(), message.dest, tag=message.tag + index, group=get_pipeline_parallel_group() ) - -def recv_message_header(transport_config: TransportConfig, input_device: InputDevice, queue_name: int) -> PipeMessage: - if transport_config.use_rpc: - queue = MessageQueues[queue_name] - result = queue.get() - result.tensors = to_input_device(result.tensors, input_device) - return result - else: - tensor = torch.empty(MESSAGE_TENSOR_SIZE, dtype=torch.uint8, device=input_device) + def recv_message_header(self, queue_name: int, nowait: bool = False) -> PipeMessage: + # FIXME(handle nowait) + if nowait: + raise QueueEmpty + tensor = torch.empty(MESSAGE_TENSOR_SIZE, dtype=torch.uint8, device=self.input_device) torch.cuda.current_stream().synchronize() torch.distributed.recv(tensor, src=None, tag=queue_name, group=get_pipeline_parallel_group()) torch.cuda.current_stream().synchronize() - return tensor_to_pyobject(tensor.cpu()) - + return tensor_to_pyobject(tensor) -def recv_message_tensors(config: TransportConfig, message: PipeMessage) -> PipeMessage: - if config.use_rpc: - # Tensors already contained within message - message.tensors = to_input_device(message.tensors, config.input_device) - return message - else: + def recv_message_tensors(self, message: PipeMessage) -> PipeMessage: torch.cuda.current_stream().synchronize() message_tensors = [] for index, (shape, dtype) in enumerate(zip(message.tensor_shapes, message.tensor_dtypes)): - t = torch.empty(*shape, dtype=dtype, device=config.input_device) + t = torch.empty(*shape, dtype=dtype, device=self.input_device) torch.distributed.recv(t, message.src, tag=message.tag + index, group=get_pipeline_parallel_group()) message_tensors.append(t) @@ -113,49 +150,10 @@ def recv_message_tensors(config: TransportConfig, message: PipeMessage) -> PipeM torch.cuda.current_stream().synchronize() return message + def get_out_of_order(self, queue_name: int, index: int) -> Tensors: + """Receive a message with a known microbatch index, and handle out-of-order + messages by placing them back on the queue""" -def recv_message( - config: TransportConfig, queue_name: int, *, nowait: bool = False, input_device: InputDevice = None -) -> PipeMessage: - if config.use_rpc: - queue = globals()["MessageQueues"][queue_name] - if nowait: - result = queue.get_nowait() - else: - result = queue.get() - return recv_message_tensors(config, result) - else: - # FIXME(handle nowait) - if nowait: - raise QueueEmpty - - torch.cuda.current_stream().synchronize() - tensor = torch.empty(MESSAGE_TENSOR_SIZE, dtype=torch.uint8, device=input_device) - torch.distributed.recv(tensor, src=None, tag=queue_name, group=get_pipeline_parallel_group()) - torch.cuda.current_stream().synchronize() - message = tensor_to_pyobject(tensor.cpu()) - - return recv_message_tensors(config, message) - - -def get_out_of_order(config: TransportConfig, queue_name: int, index: int, *, input_device: InputDevice) -> Tensors: - """Receive a message with a known microbatch index, and handle out-of-order - messages by placing them back on the queue""" - - if config.use_rpc: - queue = globals()["MessageQueues"][queue_name] - out_of_order: List[PipeMessage] = [] - while True: - message = recv_message(config, queue_name, input_device=input_device) - got_index = message.args - value = message.tensors - if got_index == index: - for b in out_of_order: - queue.put(b) - return value - else: - out_of_order.append(message) - else: - message = recv_message(config, queue_name, input_device=input_device) + message = self.recv_message(queue_name) assert message.args == index return message.tensors diff --git a/fairscale/nn/pipe/pipe.py b/fairscale/nn/pipe/pipe.py index 8d9034c12..d543c64fc 100644 --- a/fairscale/nn/pipe/pipe.py +++ b/fairscale/nn/pipe/pipe.py @@ -133,10 +133,8 @@ def check_balance(module: Any, balance: Iterable[int], filter_unique: bool = Fal if filter_unique: module_len = len(set(map(id, module))) - print(f"unique layer {module_len}, {balance}") else: module_len = len(module) - print(f"non-unique layer {module_len}, {balance}") if module_len != sum(balance): raise BalanceError( @@ -256,27 +254,16 @@ def append_module(mod: "OrderedDict[str, nn.Module]") -> None: result.append(wrapper) - # print(f"partitions = {partitions}") - return result j = 0 - # print(f"{torch.distributed.get_rank()}: assigned = {assigned}") - - # print(f"yay {list(map(id, module))}") - # print(f"first_index = {first_index}") - # print(f"duplicates = {duplicates}") - # print(f"duplicates2 = {duplicates2}") - # print(f"locations = {locations}") - # for name, layer in iterate_module(module): layers[name] = layer if len(layers) == balance[j]: if j == group.rank(): for key in layers: - print(f"key = {type(key)}-{key}") layers[key] = maybe_realize(layers[key]) if not isinstance(module, nn.Sequential): for layer in layers.values(): diff --git a/fairscale/nn/pipe/pipeline.py b/fairscale/nn/pipe/pipeline.py index 936b481ed..2c50de019 100644 --- a/fairscale/nn/pipe/pipeline.py +++ b/fairscale/nn/pipe/pipeline.py @@ -35,7 +35,7 @@ from .checkpoint import Checkpointing from .copy import Copy, Wait from .dependency import fork, join -from .messages import get_out_of_order, recv_message, send_message +from .messages import MakeTransport, Transport from .microbatch import Batch from .skip import Namespace from .skip.layout import SkipLayout @@ -50,7 +50,6 @@ Schedule, TensorOrTensors, Tensors, - TransportConfig, ) from .worker import Task, create_workers, join_workers @@ -64,11 +63,10 @@ class SendOperator(torch.autograd.Function): @staticmethod # type: ignore - def forward(ctx, src_rank, dst_rank, config: TransportConfig, input: List[Tensor], index: int) -> Tensors: + def forward(ctx, src_rank, dst_rank, transport: Transport, input: List[Tensor], index: int) -> Tensors: assert src_rank == torch.distributed.get_rank() - send_message( - config, + transport.send_message( PipeMessage(src_rank, dst_rank, queue_name=ACTIVATIONS_GRADS_QUEUE, args=index, tensors=tuple(input)), ) return () @@ -84,12 +82,12 @@ class RecvOperator(torch.autograd.Function): @staticmethod # type: ignore - def forward(ctx, dst_rank: int, tensor: Tensor, input_device, config: TransportConfig, index: int) -> Tensors: + def forward(ctx, dst_rank: int, tensor: Tensor, input_device, transport: Transport, index: int) -> Tensors: assert dst_rank == torch.distributed.get_rank() - ctx.config = config + ctx.transport = transport ctx.index = index - result = get_out_of_order(config, ACTIVATIONS_GRADS_QUEUE, index, input_device=input_device) + result = transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, index) def maybe_requires_grad(t: Tensor) -> Tensor: if t.dtype.is_floating_point: @@ -103,8 +101,7 @@ def maybe_requires_grad(t: Tensor) -> Tensor: def backward(ctx, *grad: Tensor,) -> Tuple[Optional[Tensor], ...]: ranks = get_pipeline_parallel_ranks() this_rank = torch.distributed.get_rank() - send_message( - ctx.config, + ctx.transport.send_message( PipeMessage( this_rank, ranks[ranks.index(this_rank) - 1], @@ -240,12 +237,12 @@ def __init__( self.style = style self.group = group self.training: bool - self.transport_config = TransportConfig( - use_rpc=("OMPI_COMM_WORLD_RANK" not in os.environ) or ("FORCE_RPC" in os.environ), - worker_map=worker_map, - input_device=input_device, - ) - + if style in [PipelineStyle.MultiProcess, PipelineStyle.AsyncSchedule]: + self.transport = MakeTransport( + use_rpc=("OMPI_COMM_WORLD_RANK" not in os.environ) or ("FORCE_RPC" in os.environ), + worker_map=worker_map, + input_device=input_device, + ) self.input_device = input_device self.all_at_once = False self.callcount = 0 @@ -254,13 +251,6 @@ def __init__( assert self.devices is not None (self.in_queues, self.out_queues) = create_workers(self.devices) - if ( - self.style in [PipelineStyle.MultiProcess, PipelineStyle.AsyncSchedule] - and self.transport_config.worker_map is None - and self.transport_config.use_rpc is True - ): - raise ValueError("'PipelineStyle.MultiProcess' requires 'worker_map' to be set") - @property def checkpoint_stop(self) -> int: # Disable checkpointing if in eval mode. @@ -302,12 +292,7 @@ def run(self, training: bool, batches: List[Batch], event: Optional[Event]) -> N assert self.group rank = self.group.rank() event_loop = AsyncEventLoop( - self.mp_partitions, - self.group, - self.transport_config, - self.training, - self.input_device, - self.checkpoint_stop, + self.mp_partitions, self.group, self.transport, self.training, self.checkpoint_stop, ) if rank == 0 and not self.final_stage: logging.debug(f"{torch.distributed.get_rank()}: entered event head") @@ -357,7 +342,7 @@ def get_batch_from_previous_stage( ) -> Batch: phony = torch.empty(0, device=self.input_device, requires_grad=True) - result = RecvOperator.apply(torch.distributed.get_rank(), phony, self.input_device, self.transport_config, i) + result = RecvOperator.apply(torch.distributed.get_rank(), phony, self.input_device, self.transport, i) if len(result) == 1: batch = Batch(result[0], i) else: @@ -379,8 +364,7 @@ def send_skip_tensors( else: tensors = tuple() - send_message( - self.transport_config, + self.transport.send_message( PipeMessage( this_rank, ranks[next_j], queue_name=SKIP_TENSOR_QUEUE, args=(i, ns, name, life), tensors=tensors, ), @@ -390,9 +374,7 @@ def send_skip_tensors( def recv_skip_tensors(self, skip_trackers: List[SkipTrackerThroughPotals], batches: List[Batch]) -> None: while True: try: - message = recv_message( - self.transport_config, SKIP_TENSOR_QUEUE, nowait=True, input_device=self.input_device - ) + message = self.transport.recv_message(SKIP_TENSOR_QUEUE, nowait=True) (si, ns, name, life) = message.args value: Optional[TensorOrTensors] = message.tensors @@ -422,7 +404,7 @@ def execute_task(self, task: Task, i: int, skip_trackers: List[SkipTrackerThroug this_rank = torch.distributed.get_rank() self.send_skip_tensors(this_rank, ranks, batch, i, skip_trackers) - SendOperator.apply(this_rank, ranks[ranks.index(this_rank) + 1], self.transport_config, [*batch], i) + SendOperator.apply(this_rank, ranks[ranks.index(this_rank) + 1], self.transport, [*batch], i) for portal in skip_trackers[i].portals.values(): portal.pipeline = self @@ -554,14 +536,12 @@ def send_portal_grad(self, ns_name: Tuple[Namespace, str], index: int, grad: Ten if isinstance(grad, Tensor): grad = tuple([grad]) - send_message( - self.transport_config, - PipeMessage(ranks[src], dst_rank, queue_name=PORTAL_QUEUE, args=(ns_name, index), tensors=grad), - sync=True, + self.transport.send_message( + PipeMessage(ranks[src], dst_rank, queue_name=PORTAL_QUEUE, args=(ns_name, index), tensors=grad), sync=True, ) def recv_portal_grad(self, expected_ns_name: Tuple[Namespace, str], expected_index: int) -> Tensor: - message = recv_message(self.transport_config, PORTAL_QUEUE, input_device=self.input_device) + message = self.transport.recv_message(PORTAL_QUEUE) (ns_name, index) = message.args grad = message.tensors @@ -584,9 +564,7 @@ def back_helper(self, output: List[Batch]) -> None: grads = [] for i, batch in enumerate(o): rank = torch.distributed.get_rank() - found = get_out_of_order( - self.transport_config, ACTIVATIONS_GRADS_QUEUE, i, input_device=self.input_device - ) + found = self.transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, i) assert len(found) == 1 grads.append(found[0]) tensors = tuple(x.tensor_or_tensors for x in o) # type: ignore @@ -597,9 +575,7 @@ def back_helper(self, output: List[Batch]) -> None: else: rank = torch.distributed.get_rank() for batch in o: - found = get_out_of_order( - self.transport_config, ACTIVATIONS_GRADS_QUEUE, batch.index, input_device=self.input_device - ) + found = self.transport.get_out_of_order(ACTIVATIONS_GRADS_QUEUE, batch.index) if batch.atomic: tensors = tuple([batch.tensor]) else: diff --git a/fairscale/nn/pipe/rpc.py b/fairscale/nn/pipe/rpc.py index a4408ef77..58bc1e7e9 100644 --- a/fairscale/nn/pipe/rpc.py +++ b/fairscale/nn/pipe/rpc.py @@ -1,4 +1,8 @@ -import logging +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + from threading import Event, Lock, Thread from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast @@ -10,7 +14,6 @@ from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group from . import Pipe -from .messages import recv_message_tensors, send_message from .types import EVENT_LOOP_QUEUE, PipeMessage, TensorOrTensors DEFAULT_MAX_SOURCE_POSITIONS = 1024 @@ -24,10 +27,6 @@ DtypeOrDtypes = Union[torch.dtype, List[torch.dtype]] -def dprint(s: str) -> None: - print(str(torch.distributed.get_rank()) + ": " + s) - - def set_device_based_on_group(group: ProcessGroup) -> None: # torch.cuda.set_device(group.rank() % torch.cuda.device_count()) torch.cuda.set_device(torch.distributed.get_rank() % torch.cuda.device_count()) @@ -36,7 +35,6 @@ def set_device_based_on_group(group: ProcessGroup) -> None: def register_remote_model(args: List[Any], kwargs: Dict[str, Any]) -> None: group = get_pipeline_parallel_group() # FIXME(tom) handle dynamic group set_device_based_on_group(group) - dprint(f"model registered {torch.cuda.current_device()}") kwargs["group"] = group kwargs["input_device"] = torch.device("cuda", torch.cuda.current_device()) model = Pipe(*args, **kwargs) @@ -65,7 +63,6 @@ def model_forward(training: bool, shape: torch.Size, dtype: torch.dtype) -> Opti else: tensor = tuple([torch.empty(s, dtype=d) for s, d in zip(shape, dtype)]) - global PipeModel model = PipeModel assert model.group set_device_based_on_group(model.group) @@ -80,101 +77,76 @@ def model_forward(training: bool, shape: torch.Size, dtype: torch.dtype) -> Opti return None -def send_result(training: bool, message: PipeMessage, grads_message: PipeMessage) -> None: - dprint(f"send result {training}") +def send_result_and_do_backwards(training: bool, message: PipeMessage, grads_message: PipeMessage) -> None: group = get_pipeline_parallel_group() set_device_based_on_group(group) - try: - dprint(f"send result {torch.distributed.get_rank()}, {torch.cuda.current_device()}") - result = PipeResult - global PipeModel - model = PipeModel - - if isinstance(result, torch.Tensor): - result = tuple([result]) - - dest = _get_global_rank(group, 0) + result = PipeResult + model = PipeModel - print( - f"ho har {torch.distributed.get_rank()} " + str([_get_global_rank(group, i) for i in range(group.size())]) - ) - message.tensors = tuple(result) - assert model.pipeline - config = model.pipeline.transport_config - send_message(config, message, sync=False, skip_header=True) + if isinstance(result, torch.Tensor): + result = tuple([result]) - if training: - grads_message.tensor_shapes = [r.shape for r in result] - grads_message.tensor_dtypes = [r.dtype for r in result] - grads_message = recv_message_tensors(config, grads_message) + message.tensors = tuple(result) + assert model.pipeline + transport = model.pipeline.transport + transport.send_message(message, sync=False, skip_header=True) - with model.lock: - print(f" >>> autograd-backward tail") - torch.autograd.backward(result, grads_message.tensors, retain_graph=True) - print(f" <<< autograd-backward tail") + if training: + grads_message.tensor_shapes = [r.shape for r in result] + grads_message.tensor_dtypes = [r.dtype for r in result] + grads_message = transport.recv_message_tensors(grads_message) - except Exception as e: - print(f"got {e}") + with model.lock: + torch.autograd.backward(result, grads_message.tensors, retain_graph=True) def recv_result(model: Pipe, shapes: SizeOrSizes, dtypes: DtypeOrDtypes, message: PipeMessage) -> TensorOrTensors: group = get_pipeline_parallel_group() set_device_based_on_group(group) - src = torch.distributed.distributed_c10d._get_global_rank(group, group.size() - 1) - dprint(f"recv_result... {src}, {torch.cuda.current_device()}") assert model.pipeline - transport_config = model.pipeline.transport_config + transport = model.pipeline.transport if isinstance(shapes, torch.Size): - shape = cast(torch.Size, shapes) - dtype = cast(torch.dtype, dtypes) - message.tensor_shapes = [shape] - message.tensor_dtypes = [dtype] - message = recv_message_tensors(transport_config, message) + message.tensor_shapes = [cast(torch.Size, shapes)] + message.tensor_dtypes = [cast(torch.dtype, dtypes)] + message = transport.recv_message_tensors(message) return message.tensors[0] else: - shapes = cast(List[torch.Size], shapes) - dtypes = cast(List[torch.dtype], dtypes) - message.tensor_shapes = shapes - message.tensor_dtypes = dtypes - message = recv_message_tensors(transport_config, message) + message.tensor_shapes = cast(List[torch.Size], shapes) + message.tensor_dtypes = cast(List[torch.dtype], dtypes) + message = transport.recv_message_tensors(message) return message.tensors def get_global_ranks_from_group(group: ProcessGroup) -> List[int]: - return [torch.distributed.distributed_c10d._get_global_rank(group, r) for r in range(group.size())] + return [_get_global_rank(group, r) for r in range(group.size())] -def run_model(model: Pipe, tensor: TensorOrTensors, event: Event, lock: Lock) -> None: +def model_forward_first_stage(model: Pipe, tensor: TensorOrTensors, event: Event, lock: Lock) -> None: t = model.training with lock: - print(f">> run_model thread {t}") assert model.group set_device_based_on_group(model.group) model(tensor, event=event) - print(f"<< run_model thread {t}") class PipeBackRedirect(torch.autograd.Function): @staticmethod # type: ignore - def forward(ctx, inputs, dest, event, message, transport_config): + def forward(ctx, inputs, dest, event, message, transport): ctx.dest = dest ctx.event = event ctx.message = message - ctx.transport_config = transport_config + ctx.transport = transport return inputs @staticmethod # type: ignore def backward(ctx, *grad): - logging.debug(f">>> PipeBackRedirect.backward") - global PipeModel ctx.message.tensors = tuple(grad) - send_message(ctx.transport_config, ctx.message, sync=False, skip_header=True) + ctx.transport.send_message(ctx.message, sync=False, skip_header=True) ctx.event.set() - logging.debug(f"<<< PipeBackRedirect.backward") return (None, None, None, None, None) @@ -182,7 +154,6 @@ def callback_with_model(callback: Callable, ctx: Any) -> None: group = get_pipeline_parallel_group() # FIXME(tom) handle dynamic group set_device_based_on_group(group) - global PipeModel with PipeModel.lock: callback(ctx, PipeModel) @@ -246,23 +217,21 @@ def forward(self, tensor: TensorOrTensors) -> TensorOrTensors: # type: ignore return self.model(tensor) else: event = Event() - t = Thread(target=run_model, args=(self.model, tensor, event, self.lock)) + t = Thread(target=model_forward_first_stage, args=(self.model, tensor, event, self.lock)) t.start() shape, dtype = futures[-1].wait() - dprint("forward after wait recv") dest_rank = self.group.size() - 1 dest = self.get_rpc_name(dest_rank) dest_global_rank = _get_global_rank(self.group, dest_rank) src_global_rank = torch.distributed.get_rank() - dprint(f"async to {dest}") queue = EVENT_LOOP_QUEUE - message = PipeMessage(dest_global_rank, src_global_rank, queue_name=queue, tensor_count=num_tensors) - grads_message = PipeMessage(src_global_rank, dest_global_rank, queue_name=queue, tensor_count=num_tensors) - rpc.rpc_async(dest, send_result, args=(self.model.training, message, grads_message)) - dprint(">>> recv_result") - result = recv_result(self.model, shape, dtype, message) - dprint("<<< recv_result") + + activations = PipeMessage(dest_global_rank, src_global_rank, queue_name=queue, tensor_count=num_tensors) + grads = PipeMessage(src_global_rank, dest_global_rank, queue_name=queue, tensor_count=num_tensors) + + rpc.rpc_async(dest, send_result_and_do_backwards, args=(self.model.training, activations, grads)) + result = recv_result(self.model, shape, dtype, activations) if isinstance(result, torch.Tensor): result.requires_grad_() else: @@ -270,9 +239,7 @@ def forward(self, tensor: TensorOrTensors) -> TensorOrTensors: # type: ignore r.requires_grad_() assert self.model.pipeline - return PipeBackRedirect.apply( - result, dest_global_rank, event, grads_message, self.model.pipeline.transport_config - ) + return PipeBackRedirect.apply(result, dest_global_rank, event, grads, self.model.pipeline.transport) @property def final_stage(self) -> bool: diff --git a/fairscale/nn/pipe/types.py b/fairscale/nn/pipe/types.py index b418df9ba..eec479748 100644 --- a/fairscale/nn/pipe/types.py +++ b/fairscale/nn/pipe/types.py @@ -1,5 +1,10 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + from enum import Enum, auto -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union from dataclasses import dataclass import torch @@ -34,13 +39,6 @@ class PipelineStyle(Enum): AsyncSchedule = auto() -@dataclass(frozen=True) -class TransportConfig: - use_rpc: bool - worker_map: Optional[Dict[int, str]] - input_device: InputDevice - - @dataclass(init=False) class PipeMessage: src: int diff --git a/fairscale/utils/object.py b/fairscale/utils/object.py new file mode 100644 index 000000000..fbde70e1c --- /dev/null +++ b/fairscale/utils/object.py @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import pickle +from typing import Any + +import torch + + +def pyobject_to_tensor(obj: Any, fixed_buffer_size: int = 0) -> torch.Tensor: + pickled = pickle.dumps(obj) + result: torch.Tensor = torch.ByteTensor(bytearray(pickled)) + if fixed_buffer_size: + delta = fixed_buffer_size - len(result) + if delta < 0: + raise ValueError( + f"message too big to send, increase `fixed_buffer_size`? - {len(result)} > {fixed_buffer_size}" + ) + elif delta > 0: + result = torch.cat((result, torch.zeros(delta, dtype=torch.uint8))) + + return result + + +def tensor_to_pyobject(tensor: torch.Tensor) -> Any: + nparray = tensor.cpu().numpy() + return pickle.loads(nparray.tobytes()) diff --git a/tests/nn/pipe_process/test_pipe.py b/tests/nn/pipe_process/test_pipe.py index 24cfceb22..e42da9fef 100644 --- a/tests/nn/pipe_process/test_pipe.py +++ b/tests/nn/pipe_process/test_pipe.py @@ -817,7 +817,7 @@ def forward(self, x): def missing_worker_map(pipeline_style): model = nn.Sequential(nn.ReLU(), nn.ReLU()) - with pytest.raises(ValueError, match="'PipelineStyle.MultiProcess' requires 'worker_map' to be set"): + with pytest.raises(ValueError, match="'RpcTransport' requires 'worker_map' to be set"): Pipe(model, [1, 1], style=pipeline_style) @@ -937,8 +937,8 @@ def reuse_lazy(): def test_instantiate_partition(): - from fairscale.nn.pipe.pipe import instantiate_partition from fairscale.nn.pipe.async_schedule import Location + from fairscale.nn.pipe.pipe import instantiate_partition class FakeGroup: def __init__(self, rank, size): diff --git a/tests/nn/pipe_process/test_rpc.py b/tests/nn/pipe_process/test_rpc.py index dfc1d9a68..e9cc1be27 100644 --- a/tests/nn/pipe_process/test_rpc.py +++ b/tests/nn/pipe_process/test_rpc.py @@ -6,7 +6,7 @@ from torch import nn from torch.distributed import rpc -from fairscale.nn.model_parallel.initialize import get_model_parallel_group, get_pipeline_parallel_group +from fairscale.nn.model_parallel.initialize import get_pipeline_parallel_group from fairscale.nn.pipe import PipeRPCWrapper from tests.nn.model_parallel.commons import get_worker_map, torch_spawn @@ -81,46 +81,23 @@ def check_pipe_against_reference(balance, model_constructor, checkpoint="except_ target = torch.rand(10).cuda() cloned = inputs.clone() output = pipe(inputs) - print(f"out pipe on {torch.distributed.get_rank()}") - mp_group = get_model_parallel_group() - if mp_group.size() > 1: - pass # torch.distributed.barrier(mp_group) ref_out = reference_model(inputs) - print(f"out on {torch.distributed.get_rank()}") - print(f"{ref_out}, {output}, {inputs}, {cloned}") assert torch.equal(ref_out.cpu(), output.cpu()) for out in output, ref_out: - if mp_group.size() > 1: - pass # torch.distributed.barrier(mp_group) - try: - target = target.to(out.device) - loss = nn.MSELoss()(out, target) - loss.backward() - except Exception as e: - print(f"loss failed {e}") - raise e - - print(f"{torch.distributed.get_rank()}: optimizer") + target = target.to(out.device) + loss = nn.MSELoss()(out, target) + loss.backward() + pipe.foreach_worker(step_optimizer, include_self=True) - print(f"{torch.distributed.get_rank()}: optimizer2") step_optimizer(None, reference_model.cuda()) - print(f"{torch.distributed.get_rank()}: eval") pipe.eval() reference_model.eval() - print(f"{torch.distributed.get_rank()}: pipe2") final_output = pipe(inputs) - print(f"{torch.distributed.get_rank()}: ref2") - if mp_group.size() > 1: - pass # torch.distributed.barrier(mp_group) - try: - final_ref = reference_model(inputs.cuda()) - except Exception as e: - print(f"ref got {e}") - raise e + final_ref = reference_model(inputs.cuda()) assert torch.equal(final_output.cpu(), final_ref.cpu()) @@ -139,11 +116,9 @@ def model_with_reuse(): reused_1 = nn.Linear(10, 10) return [reused_1, nn.ReLU(), reused_1, nn.ReLU(), reused_1, nn.ReLU()] - print(f"easy") check_pipe_against_reference( [2, 2, 2], lambda: [nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 10), nn.ReLU()], ) - print(f"hard") check_pipe_against_reference([2, 1, 1], model_with_reuse) rpc.shutdown() @@ -199,9 +174,7 @@ def make_model_with_reuse(): return check_pipe_against_reference([4, 4, 2], make_model_simple, "always") - print(f"{torch.distributed.get_rank()} simple returned!") check_pipe_against_reference([4, 2, 2], make_model_with_reuse) - print(f"{torch.distributed.get_rank()} returned!") rpc.shutdown() torch.distributed.barrier() From e240d90f6e7f4350f875d41e738fa71aeba320ef Mon Sep 17 00:00:00 2001 From: Tom Birch Date: Mon, 26 Oct 2020 15:01:15 -0700 Subject: [PATCH 7/8] more cleanup/comments --- fairscale/nn/pipe/async_schedule.py | 67 +++++++- fairscale/nn/pipe/rpc.py | 252 ++++++++++++++++------------ tests/nn/pipe_process/test_rpc.py | 71 ++++---- 3 files changed, 245 insertions(+), 145 deletions(-) diff --git a/fairscale/nn/pipe/async_schedule.py b/fairscale/nn/pipe/async_schedule.py index d999d1664..52448f6f2 100644 --- a/fairscale/nn/pipe/async_schedule.py +++ b/fairscale/nn/pipe/async_schedule.py @@ -42,6 +42,14 @@ class Invocation: Invocations = Dict[int, Invocation] +@dataclass(frozen=True) +class TailBackwardContext: + activations: Activations + invocations: Invocations + count_per_order: Dict[int, int] + expected_gradients: int + + class ModuleWrapper: def __init__(self, module: nn.Sequential, location: Location, invocations: Optional[List[Invocation]] = None): self.module: nn.Sequential = module @@ -123,6 +131,20 @@ def backward(ctx, *grad: Tensor,) -> Tuple[Optional[Tensor], ...]: ), sync=True, ) + + tail_ctx = getattr(ctx, "tail_ctx", None) + if tail_ctx: + expected_gradients = tail_ctx.expected_gradients + while expected_gradients > 0: + message = ctx.transport.recv_message_header(EVENT_LOOP_QUEUE) + + args: AsyncMessageBody = message.args + assert args.message_type is AsyncMessageType.Gradients + + invocation = tail_ctx.invocations[args.order] + expected_gradients -= tail_ctx.count_per_order[invocation.order] + AsyncEventLoop.perform_backward_for_invocation(ctx.transport, message, tail_ctx.activations, invocation) + return (None, None, None, None, None) @@ -190,8 +212,14 @@ def run_invocation( result = self.send_async_message(dst_rank, result, invocation) return result - def async_grad_inner(self, message: PipeMessage, activations: Activations, invocation: Invocation) -> None: - recvd_grads = self.transport.recv_message_tensors(message) + @staticmethod + def perform_backward_for_invocation( + transport: Transport, message: PipeMessage, activations: Activations, invocation: Invocation + ) -> None: + """Perform the backward pass by looking up the appropriate `Batch` and + then calling `backward` on the tensor""" + + recvd_grads = transport.recv_message_tensors(message) batch: Batch = activations[invocation.this.index][invocation.order][message.args.microbatch_index] @@ -271,6 +299,9 @@ def event_loop_head( ) def get_batch_from_message(self, message: PipeMessage) -> Batch: + """Get the tensor(s) wrapped in a `Batch` from a `PipeMessage`, applying + AsyncRecvOperator so we can intercept the backward pass""" + microbatch_index = message.args.microbatch_index phony = torch.empty(0, device=self.transport.input_device, requires_grad=True) result = AsyncRecvOperator.apply(phony, self.transport, message) @@ -309,6 +340,11 @@ def event_loop_tail(self, batches: List[Batch], skip_trackers: List[SkipTrackerT actual_invocations += inv_count count_per_order[last_order] = inv_count + if invocations[last_order].dest is None: + self.prepare_tail_backward( + batch, activations, invocations, count_per_order, len(invocations) - inv_count + ) + if actual_invocations < expected_invocations: expected_gradients = 0 # (len(invocations) - 1) * len(batches) @@ -320,6 +356,7 @@ def event_loop_tail(self, batches: List[Batch], skip_trackers: List[SkipTrackerT count_per_order, already_received=actual_invocations, ignore_gradients=True, + tail=True, ) _, last_invocation = invocations.popitem() @@ -362,10 +399,11 @@ def event_loop_inner( already_received: int = 0, ignore_gradients: bool = False, event: Optional[Event] = None, + tail: bool = False, ) -> None: """The common event loop shared by all stages. This processses activations for the forward pass, and if `self.training` is true, - handles gradients as well for the backward pass.""" + processes gradients for the backward pass.""" num_activations = already_received if self.training and not ignore_gradients: @@ -375,6 +413,9 @@ def event_loop_inner( while num_activations < expected_invocations or num_gradients < expected_invocations: if num_activations == expected_invocations and num_gradients == 0 and event is not None: + # We are ready to do the backward pass, but must wait for + # PipeRPCWrapper to signal that it is safe to proceed, otherwise + # deadlock event.wait() message = self.transport.recv_message_header(EVENT_LOOP_QUEUE) @@ -395,8 +436,26 @@ def event_loop_inner( ) count_per_order[last_order] = inv_count num_activations += inv_count + if tail and invocations[last_order].dest is None: + self.prepare_tail_backward( + batch, activations, invocations, count_per_order, len(invocations) - inv_count + ) + assert num_activations <= expected_invocations elif args.message_type is AsyncMessageType.Gradients: num_gradients += count_per_order[invocation.order] - self.async_grad_inner(message, activations, invocation) + self.perform_backward_for_invocation(self.transport, message, activations, invocation) + + @staticmethod + def prepare_tail_backward( + batch: Batch, + activations: Activations, + invocations: Invocations, + count_per_order: Dict[int, int], + expected_gradients: int, + ) -> None: + if expected_gradients > 0: + grad_fn = next(b.grad_fn for b in batch if b.requires_grad) + assert grad_fn + grad_fn.tail_ctx = TailBackwardContext(activations, invocations, count_per_order, expected_gradients) diff --git a/fairscale/nn/pipe/rpc.py b/fairscale/nn/pipe/rpc.py index 58bc1e7e9..dbea851ab 100644 --- a/fairscale/nn/pipe/rpc.py +++ b/fairscale/nn/pipe/rpc.py @@ -32,17 +32,6 @@ def set_device_based_on_group(group: ProcessGroup) -> None: torch.cuda.set_device(torch.distributed.get_rank() % torch.cuda.device_count()) -def register_remote_model(args: List[Any], kwargs: Dict[str, Any]) -> None: - group = get_pipeline_parallel_group() # FIXME(tom) handle dynamic group - set_device_based_on_group(group) - kwargs["group"] = group - kwargs["input_device"] = torch.device("cuda", torch.cuda.current_device()) - model = Pipe(*args, **kwargs) - model.cuda() - global PipeModel - PipeModel = model - - def get_shapes(tensor: TensorOrTensors) -> SizeOrSizes: if isinstance(tensor, torch.Tensor): return tensor.shape @@ -57,88 +46,19 @@ def get_dtype(tensor: TensorOrTensors) -> DtypeOrDtypes: return [t.dtype for t in tensor] -def model_forward(training: bool, shape: torch.Size, dtype: torch.dtype) -> Optional[Tuple[SizeOrSizes, DtypeOrDtypes]]: - if isinstance(shape, torch.Size): - tensor = torch.empty(shape, dtype=dtype) - else: - tensor = tuple([torch.empty(s, dtype=d) for s, d in zip(shape, dtype)]) - - model = PipeModel - assert model.group - set_device_based_on_group(model.group) - - model.train(training) - result = model(tensor) - if model.final_stage: - global PipeResult - PipeResult = result - return (get_shapes(result), get_dtype(result)) - - return None - - -def send_result_and_do_backwards(training: bool, message: PipeMessage, grads_message: PipeMessage) -> None: - group = get_pipeline_parallel_group() - set_device_based_on_group(group) - result = PipeResult - model = PipeModel - - if isinstance(result, torch.Tensor): - result = tuple([result]) - - message.tensors = tuple(result) - assert model.pipeline - transport = model.pipeline.transport - transport.send_message(message, sync=False, skip_header=True) - - if training: - grads_message.tensor_shapes = [r.shape for r in result] - grads_message.tensor_dtypes = [r.dtype for r in result] - grads_message = transport.recv_message_tensors(grads_message) - - with model.lock: - torch.autograd.backward(result, grads_message.tensors, retain_graph=True) - - -def recv_result(model: Pipe, shapes: SizeOrSizes, dtypes: DtypeOrDtypes, message: PipeMessage) -> TensorOrTensors: - group = get_pipeline_parallel_group() - set_device_based_on_group(group) - - assert model.pipeline - transport = model.pipeline.transport - - if isinstance(shapes, torch.Size): - message.tensor_shapes = [cast(torch.Size, shapes)] - message.tensor_dtypes = [cast(torch.dtype, dtypes)] - message = transport.recv_message_tensors(message) - return message.tensors[0] - else: - message.tensor_shapes = cast(List[torch.Size], shapes) - message.tensor_dtypes = cast(List[torch.dtype], dtypes) - message = transport.recv_message_tensors(message) - return message.tensors - - def get_global_ranks_from_group(group: ProcessGroup) -> List[int]: return [_get_global_rank(group, r) for r in range(group.size())] -def model_forward_first_stage(model: Pipe, tensor: TensorOrTensors, event: Event, lock: Lock) -> None: - t = model.training - with lock: - assert model.group - set_device_based_on_group(model.group) - model(tensor, event=event) - - class PipeBackRedirect(torch.autograd.Function): @staticmethod # type: ignore - def forward(ctx, inputs, dest, event, message, transport): + def forward(ctx, inputs, dest, event, message, transport, futures): ctx.dest = dest ctx.event = event ctx.message = message ctx.transport = transport + ctx.futures = futures return inputs @staticmethod @@ -147,18 +67,30 @@ def backward(ctx, *grad): ctx.message.tensors = tuple(grad) ctx.transport.send_message(ctx.message, sync=False, skip_header=True) ctx.event.set() - return (None, None, None, None, None) + # torch.futures.wait_all(ctx.futures) + return (None, None, None, None, None, None) -def callback_with_model(callback: Callable, ctx: Any) -> None: - group = get_pipeline_parallel_group() # FIXME(tom) handle dynamic group - set_device_based_on_group(group) +def callback_with_model(callback: Callable[[Any, Pipe], None], ctx: Any) -> None: + try: + group = get_pipeline_parallel_group() # FIXME(tom) handle dynamic group + set_device_based_on_group(group) - with PipeModel.lock: - callback(ctx, PipeModel) + with PipeModel.lock: + callback(ctx, PipeModel) + except Exception as e: + print(f"callback_with_model got {e}") class PipeRPCWrapper(nn.Module): + """A wrapper for Pipe to control the entire pipeline from a single process. + Typical usecase would have rank 0 construct `PipeRPCWrapper` and run the + training loop as normal, and all other ranks would call + `torch.distributed.rpc.shutdown()` + + To run code on each worker, e.g. to run the optimizer, use `foreach_worker` + """ + def __init__(self, *args: Any, **kwargs: Any): super().__init__() self.group = cast(ProcessGroup, kwargs.get("group")) or get_pipeline_parallel_group() @@ -178,23 +110,35 @@ def __init__(self, *args: Any, **kwargs: Any): self.model = Pipe(*args, **kwargs) self.worker_map = kwargs["worker_map"] - futures = [ - # FIXME get global rank - rpc.rpc_async(self.get_rpc_name(rank), register_remote_model, args=(args, kwargs)) - for rank in range(1, self.group.size()) - ] - futures = [f.wait() for f in futures] + self._foreach_worker(self._register_remote_model, args=(args, kwargs)) self.model.cuda() - def get_rpc_name(self, rank: int) -> str: + def _get_rpc_name(self, rank: int) -> str: return self.worker_map[_get_global_rank(self.group, rank)] - def foreach_worker(self, callback: Callable, ctx: Any = None, *, include_self: bool = False) -> None: - futures = [ - rpc.rpc_async(self.get_rpc_name(rank), callback_with_model, args=(callback, ctx)) - for rank in range(1, self.group.size()) - ] + def _foreach_worker(self, callback: Callable, args: Any = None) -> None: + futures = [rpc.rpc_async(self._get_rpc_name(rank), callback, args=args) for rank in range(1, self.group.size())] futures = [f.wait() for f in futures] + + def foreach_worker( + self, callback: Callable[[Any, Pipe], None], ctx: Any = None, *, include_self: bool = False + ) -> None: + """Call `callback` on each worker with the `ctx` and model local to that + worker. e.g. + def register_optimizer(ctx, model): + args, kwargs = ctx + model.optimizer = torch.optim.SGD(model.parameters(), *args, **kwargs) + + pipe_model = PipeRPCWrapper( ... ) + + pipe_model.foreach_worker( + register_optimizer, + ([], {"lr" : 0.01, "momentum" : 0.9}) + ) + """ + + self._foreach_worker(callback_with_model, args=(callback, ctx)) + if include_self: with self.model.lock: callback(ctx, self.model) @@ -209,7 +153,7 @@ def forward(self, tensor: TensorOrTensors) -> TensorOrTensors: # type: ignore num_tensors = len(tensor) futures = [ - rpc.rpc_async(self.get_rpc_name(rank), model_forward, args=(self.model.training, shape, dtype)) + rpc.rpc_async(self._get_rpc_name(rank), self._model_forward, args=(self.model.training, shape, dtype)) for rank in range(1, self.group.size()) ] @@ -217,12 +161,12 @@ def forward(self, tensor: TensorOrTensors) -> TensorOrTensors: # type: ignore return self.model(tensor) else: event = Event() - t = Thread(target=model_forward_first_stage, args=(self.model, tensor, event, self.lock)) + t = Thread(target=self._model_forward_first_stage, args=(tensor, event)) t.start() - shape, dtype = futures[-1].wait() + shape, dtype = futures.pop().wait() dest_rank = self.group.size() - 1 - dest = self.get_rpc_name(dest_rank) + dest = self._get_rpc_name(dest_rank) dest_global_rank = _get_global_rank(self.group, dest_rank) src_global_rank = torch.distributed.get_rank() queue = EVENT_LOOP_QUEUE @@ -230,8 +174,12 @@ def forward(self, tensor: TensorOrTensors) -> TensorOrTensors: # type: ignore activations = PipeMessage(dest_global_rank, src_global_rank, queue_name=queue, tensor_count=num_tensors) grads = PipeMessage(src_global_rank, dest_global_rank, queue_name=queue, tensor_count=num_tensors) - rpc.rpc_async(dest, send_result_and_do_backwards, args=(self.model.training, activations, grads)) - result = recv_result(self.model, shape, dtype, activations) + back_fut = rpc.rpc_async( + dest, self._send_result_and_do_backwards, args=(self.model.training, activations, grads) + ) + futures.append(back_fut) + + result = self._recv_result(self.model, shape, dtype, activations) if isinstance(result, torch.Tensor): result.requires_grad_() else: @@ -239,8 +187,98 @@ def forward(self, tensor: TensorOrTensors) -> TensorOrTensors: # type: ignore r.requires_grad_() assert self.model.pipeline - return PipeBackRedirect.apply(result, dest_global_rank, event, grads, self.model.pipeline.transport) + return PipeBackRedirect.apply( + result, dest_global_rank, event, grads, self.model.pipeline.transport, futures + ) @property def final_stage(self) -> bool: return self.model.final_stage + + @staticmethod + def _recv_result(model: Pipe, shapes: SizeOrSizes, dtypes: DtypeOrDtypes, message: PipeMessage) -> TensorOrTensors: + group = get_pipeline_parallel_group() + set_device_based_on_group(group) + + assert model.pipeline + transport = model.pipeline.transport + + if isinstance(shapes, torch.Size): + message.tensor_shapes = [cast(torch.Size, shapes)] + message.tensor_dtypes = [cast(torch.dtype, dtypes)] + message = transport.recv_message_tensors(message) + return message.tensors[0] + else: + message.tensor_shapes = cast(List[torch.Size], shapes) + message.tensor_dtypes = cast(List[torch.dtype], dtypes) + message = transport.recv_message_tensors(message) + return message.tensors + + @staticmethod + def _send_result_and_do_backwards(training: bool, message: PipeMessage, grads_message: PipeMessage) -> None: + group = get_pipeline_parallel_group() + set_device_based_on_group(group) + result = PipeResult + model = PipeModel + + if isinstance(result, torch.Tensor): + result = tuple([result]) + + message.tensors = tuple(result) + assert model.pipeline + transport = model.pipeline.transport + transport.send_message(message, sync=False, skip_header=True) + + if training: + grads_message.tensor_shapes = [r.shape for r in result] + grads_message.tensor_dtypes = [r.dtype for r in result] + grads_message = transport.recv_message_tensors(grads_message) + + with model.lock: + torch.autograd.backward(result, grads_message.tensors, retain_graph=True) + + @staticmethod + def _register_remote_model(args: List[Any], kwargs: Dict[str, Any]) -> None: + group = get_pipeline_parallel_group() # FIXME(tom) handle dynamic group + set_device_based_on_group(group) + kwargs["group"] = group + kwargs["input_device"] = torch.device("cuda", torch.cuda.current_device()) + model = Pipe(*args, **kwargs) + model.cuda() + global PipeModel + PipeModel = model + + @staticmethod + def _model_forward( + training: bool, shape: torch.Size, dtype: torch.dtype + ) -> Optional[Tuple[SizeOrSizes, DtypeOrDtypes]]: + try: + if isinstance(shape, torch.Size): + tensor = torch.empty(shape, dtype=dtype) + else: + tensor = tuple([torch.empty(s, dtype=d) for s, d in zip(shape, dtype)]) + + model = PipeModel + assert model.group + set_device_based_on_group(model.group) + + model.train(training) + result = model(tensor) + if model.final_stage: + global PipeResult + PipeResult = result + return (get_shapes(result), get_dtype(result)) + + return None + except Exception as e: + print(f"_model_forward got {e}") + raise e + + def _model_forward_first_stage(self, tensor: TensorOrTensors, event: Event) -> None: + try: + assert self.model.group + set_device_based_on_group(self.model.group) + self.model(tensor, event=event) + except Exception as e: + print(f"_model_forward got {e}") + raise e diff --git a/tests/nn/pipe_process/test_rpc.py b/tests/nn/pipe_process/test_rpc.py index e9cc1be27..0f135dd7c 100644 --- a/tests/nn/pipe_process/test_rpc.py +++ b/tests/nn/pipe_process/test_rpc.py @@ -62,7 +62,7 @@ def step_optimizer(ctx, model): model.optimizer.step() -def check_pipe_against_reference(balance, model_constructor, checkpoint="except_last"): +def check_pipe_against_reference(balance, model_constructor, checkpoint="except_last", custom_inputs=None): model = model_constructor() reference_model = model_constructor() for src, dst in zip(model, reference_model): @@ -71,7 +71,7 @@ def check_pipe_against_reference(balance, model_constructor, checkpoint="except_ reference_model = nn.Sequential(*reference_model).cuda() pipe = PipeRPCWrapper( - model, balance, input_device=torch.cuda.current_device(), worker_map=get_worker_map(), checkpoint=checkpoint + model, balance, input_device=torch.cuda.current_device(), worker_map=get_worker_map(), checkpoint=checkpoint, ) pipe.foreach_worker(register_optimizer, include_self=True) @@ -182,39 +182,30 @@ def make_model_with_reuse(): @torch_spawn([3]) @pytest.mark.skipif("OMPI_COMM_WORLD_RANK" not in os.environ, reason="mpi required") -def rpc_deadlock(): +def rpc_reuse_in_final_stage(): + + # 'reused' and 'reused2' are located on stage 2, so the backward pass for + # the final stage will need to first send gradients to stage 2, then receive + # gradients from stage 2. This tests custom logic to handle reuse of layers + # in the final stage of the pipeline. + reused = nn.Linear(10, 10) - if False: - reused2 = nn.Linear(10, 10) - model = [ - nn.Linear(10, 10), - nn.ReLU(), - nn.Linear(10, 10), - reused2, - nn.ReLU(), - reused, - nn.ReLU(), - reused, - reused2, - nn.ReLU(), - reused, - nn.ReLU(), - ] - balance = [2, 3, 4] - else: - model = [ - nn.Linear(10, 10), - nn.ReLU(), - nn.Linear(10, 10), - nn.ReLU(), - reused, - nn.ReLU(), - reused, - nn.ReLU(), - reused, - nn.ReLU(), - ] - balance = [2, 2, 4] + reused2 = nn.Linear(10, 10) + model = [ + nn.Linear(10, 10), + nn.ReLU(), + nn.Linear(10, 10), + reused2, + nn.ReLU(), + reused, + nn.ReLU(), + reused, + reused2, + nn.ReLU(), + reused, + nn.ReLU(), + ] + balance = [2, 3, 4] init_rpc() @@ -235,6 +226,18 @@ def rpc_deadlock(): torch.distributed.barrier() +@torch_spawn([3]) +@pytest.mark.skipif("OMPI_COMM_WORLD_RANK" not in os.environ, reason="mpi required") +def rpc_multiple_tensors(): + class FuseTwo(nn.Module): + def forward(self, left, right): + return left + right + + class SplitTwo(nn.Module): + def forward(self, inputs): + return (inputs, 2 * inputs) + + @torch_spawn([2]) @pytest.mark.skipif("OMPI_COMM_WORLD_RANK" in os.environ, reason="no mpi") @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required") From 94bb946fb6cf1148fa8703154af9a72bc9ac4dc8 Mon Sep 17 00:00:00 2001 From: Tom Birch Date: Mon, 9 Nov 2020 14:36:27 -0800 Subject: [PATCH 8/8] Add mp and rpc pipe tutorials --- examples/tutorial_pipe_multiprocess.py | 62 +++++++++++++++++++++ examples/tutorial_pipe_rpc.py | 76 ++++++++++++++++++++++++++ fairscale/nn/__init__.py | 4 +- pyproject.toml | 2 +- 4 files changed, 141 insertions(+), 3 deletions(-) create mode 100644 examples/tutorial_pipe_multiprocess.py create mode 100644 examples/tutorial_pipe_rpc.py diff --git a/examples/tutorial_pipe_multiprocess.py b/examples/tutorial_pipe_multiprocess.py new file mode 100644 index 000000000..f57f1ec92 --- /dev/null +++ b/examples/tutorial_pipe_multiprocess.py @@ -0,0 +1,62 @@ +import os + +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim + +import fairscale +from fairscale.nn.model_parallel import initialize_model_parallel + + +def run(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "10638" + torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size) + os.environ["MASTER_PORT"] = "10639" + torch.distributed.rpc.init_rpc(f"worker{rank}", rank=rank, world_size=world_size) + initialize_model_parallel(1, world_size) + + model = nn.Sequential(torch.nn.Linear(10, 10), torch.nn.ReLU(), torch.nn.Linear(10, 5)) + target = torch.randint(0, 2, size=(20, 1)).squeeze() + data = torch.randn(20, 10) + loss_fn = F.nll_loss + + device = torch.device("cuda", rank) + + model = fairscale.nn.Pipe( + model, + balance=[2, 1], + style=fairscale.nn.Pipe.MultiProcess, + worker_map={0: "worker0", 1: "worker1"}, # Needed to convert ranks to RPC worker names + input_device=device, + ).to(device) + + # define optimizer and loss function + optimizer = optim.SGD(model.parameters(), lr=0.001) + + # zero the parameter gradients + optimizer.zero_grad() + + # outputs and target need to be on the same device + # forward step + outputs = model(data.to(device)) + # compute loss + if rank == 1: + loss = loss_fn(outputs.to(device), target.to(device)) + + # backward + optimize + loss.backward() + optimizer.step() + else: + model.back_helper(outputs) + + print(f"Finished Training Step on {rank}") + + del model + + +if __name__ == "__main__": + world_size = 2 + mp.spawn(run, args=(world_size,), nprocs=world_size, join=True) diff --git a/examples/tutorial_pipe_rpc.py b/examples/tutorial_pipe_rpc.py new file mode 100644 index 000000000..fc22725cb --- /dev/null +++ b/examples/tutorial_pipe_rpc.py @@ -0,0 +1,76 @@ +# run with: +# mpirun -np 2 --host localhost:2 -x PYTHONPATH=$PWD python # examples/tutorial_pipe_rpc.py + +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torch_pg + +import fairscale +from fairscale.nn.model_parallel import initialize_model_parallel + + +def register_optimizer(ctx, model): + # Set the optimizer as an attribute on the model so we can access it later + model.optimizer = optim.SGD(model.parameters(), **ctx) + # zero the parameter gradients + model.optimizer.zero_grad() + + +def run_optimizer(ctx, model): + model.optimizer.step() + + +def run(rank, world_size): + torch_pg.init_mpi() + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "10638" + torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size) + os.environ["MASTER_PORT"] = "10639" + torch.distributed.rpc.init_rpc(f"worker{rank}", rank=rank, world_size=world_size) + initialize_model_parallel(1, world_size, pipeline_backend="mpi") + + if rank == 1: + # For RPC, all ranks other than 0 just need to call rpc.shutdown() + torch.distributed.rpc.shutdown() + return + + model = nn.Sequential(torch.nn.Linear(10, 10), torch.nn.ReLU(), torch.nn.Linear(10, 5)) + target = torch.randint(0, 2, size=(20, 1)).squeeze() + data = torch.randn(20, 10) + loss_fn = F.nll_loss + + device = torch.device("cuda", rank) + + model = fairscale.nn.PipeRPCWrapper( + model, + balance=[2, 1], + worker_map={0: "worker0", 1: "worker1"}, # Needed to convert ranks to RPC worker names + input_device=device, + ).to(device) + + # We can't directly access the model on each worker, so we need to call + # foreach_worker with a callback to setup the optimizer + model.foreach_worker(register_optimizer, {"lr": 0.001}, include_self=True) + + outputs = model(data.to(device)) + loss = loss_fn(outputs.to(device), target.to(device)) + loss.backward() + + # Same as earlier, use foreach_worker to step the optimizer on each rank + model.foreach_worker(run_optimizer, include_self=True) + + print(f"Finished Training Step on {rank}") + + torch.distributed.rpc.shutdown() + + del model + + +if __name__ == "__main__": + rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) + world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) + run(rank, world_size) diff --git a/fairscale/nn/__init__.py b/fairscale/nn/__init__.py index 690e26fc5..45690576f 100644 --- a/fairscale/nn/__init__.py +++ b/fairscale/nn/__init__.py @@ -4,6 +4,6 @@ # LICENSE file in the root directory of this source tree. from .moe import MOELayer, Top2Gate -from .pipe import LazyModule, Pipe +from .pipe import LazyModule, Pipe, PipeRPCWrapper -__all__ = ["Pipe", "Top2Gate", "LazyModule"] +__all__ = ["Pipe", "PipeRPCWrapper", "Top2Gate", "LazyModule"] diff --git a/pyproject.toml b/pyproject.toml index 8e9afb654..8010c95e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,4 +28,4 @@ use_parentheses = true skip_glob = ["build/*", "stubs/*"] # Don't split "import" and "from". force_sort_within_sections = true -known_third_party = ["benchmark_dataset", "dataclasses", "numpy", "packaging", "pytest", "recommonmark", "setuptools", "torch", "torchtext", "torchvision"] +known_third_party = ["benchmark_dataset", "dataclasses", "numpy", "packaging", "pytest", "recommonmark", "setuptools", "torch", "torch_pg", "torchtext", "torchvision"]