Skip to content

Commit

Permalink
Single-process control via PipeRPCWrapper (#156)
Browse files Browse the repository at this point in the history
Adds support for:
* Reused layers (e.g. for weight sharing)
* Lazily-constructed layers
* Single-process control via PipeRPCWrapper
* PipelineStyle.AsyncScheudle, which lays the foundation for asynchronous pipeline work by introducing an event loop for each rank/worker to process either activations or gradients as they arrive

Also added examples for multi-process and PipeRPCWrapper
  • Loading branch information
froody authored Nov 10, 2020
1 parent 543d569 commit 5d4f50f
Show file tree
Hide file tree
Showing 38 changed files with 2,358 additions and 568 deletions.
125 changes: 104 additions & 21 deletions benchmarks/pipe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

import argparse
import logging
import math
import os
import time
Expand All @@ -11,14 +12,17 @@
from torch.distributed import rpc
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
import torchtext
from torchtext.data.utils import get_tokenizer

from fairscale.nn import Pipe
from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.pipe import pipe
from fairscale.nn.model_parallel.initialize import get_data_parallel_group, get_pipeline_parallel_group
from fairscale.nn.pipe import LazyModule, pipe
from fairscale.optim import GradScaler
from fairscale.optim.oss import OSS
from tests.nn.model_parallel.commons import dist_init, get_worker_map

try:
Expand Down Expand Up @@ -164,13 +168,13 @@ def make_model(args, device, ntokens):

if args.lazy_construction:
layers = [
lambda: EmbeddingLayer(ntokens, ninp, initrange),
lambda: PositionalEncodingLayer(ninp, dropout),
LazyModule(lambda: EmbeddingLayer(ntokens, ninp, initrange)),
LazyModule(lambda: PositionalEncodingLayer(ninp, dropout)),
]
for _ in range(ndecoder):
layers.append(lambda: TransformerDecoderLayer(ninp, nhead, nhid, dropout))
layers.append(LazyModule(lambda: TransformerDecoderLayer(ninp, nhead, nhid, dropout)))

layers.append(lambda: LinearLayer(ninp, ntokens, initrange))
layers.append(LazyModule(lambda: LinearLayer(ninp, ntokens, initrange)))
model = layers
else:
model = TransformerLMSequntial(ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device)
Expand All @@ -179,7 +183,10 @@ def make_model(args, device, ntokens):
lr = 0.01 # learning rate

def make_adam(model):
return Adam(model.parameters(), lr=lr)
if args.ddp_zero:
return OSS(params=model.parameters(), optim=Adam, group=get_data_parallel_group(), lr=lr)
else:
return Adam(model.parameters(), lr=lr)

optimizer = make_adam
scaler = GradScaler()
Expand Down Expand Up @@ -276,9 +283,17 @@ def train(lm_dataloader, model, criterion, optimizer, vocab_size, args):

num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters()))
if model.group:
print(f"training model, #prams = {num_params}, group: {model.group.rank()}, sizes {model.group.size()}")
total = torch.Tensor([num_params]).cuda()
torch.distributed.all_reduce(total, group=model.group)
logging.info(
f"training model, #prams = {num_params}, group: {model.group.rank()}, grank:"
f" {torch.distributed.get_rank()}, sizes {model.group.size()}"
)
torch.distributed.barrier()
if model.group.rank() == 0:
logging.info(f"total #prams = {total.item()}")
else:
print(f"training model, #prams = {num_params}")
logging.info(f"training model, #prams = {num_params}")
vocab_size = 10000 # FIXME
total_loss = 0.0
start_time = time.time()
Expand All @@ -287,37 +302,81 @@ def train(lm_dataloader, model, criterion, optimizer, vocab_size, args):
optimizer = optimizer(model)

def get_first_device(model):
if isinstance(model, DDP):
model = model.module

if model.devices:
return model.devices[0]
else:
return torch.cuda.current_device()

def get_last_device(model):
if isinstance(model, DDP):
model = model.module
if model.devices:
return model.devices[-1]
else:
return torch.cuda.current_device()

pipe_group = model.group

if args.ddp_zero:
model = DDP(
model,
device_ids=[torch.cuda.current_device()],
process_group=get_data_parallel_group(),
find_unused_parameters=False,
)

if pipe_group and pipe_group.rank() != 0 and pipe_group.rank() != (pipe_group.size() - 1):
thing = {"input": torch.zeros(args.batch_size)}

class FakeDataset:
def __getitem__(self, index):
return thing

def __len__(self):
return len(lm_dataloader)

lm_dataloader = FakeDataset()

for i, batch in enumerate(lm_dataloader):
bi = batch["input"]
if args.max_batch and i > args.max_batch:
break
optimizer.zero_grad()
output = model(batch["input"].to(get_first_device(model)))

if model.group is None or model.group.rank() == model.group.size() - 1:
try:
if (pipe_group is None or pipe_group.rank() == 0) and not args.ddp_zero:
tmp = batch["input"].to(get_first_device(model))
output = model(tmp)
else:
output = model(batch["input"])
except Exception as e:
raise RuntimeError(f"training failed on {torch.distributed.get_rank()}") from e

if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
target = batch["target"].to(get_last_device(model))
output = output.to(target.device)

loss = criterion(output.view(-1, vocab_size), target.view(-1))
if args.ddp_zero:
ddp_group = get_data_parallel_group()
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM, group=ddp_group)
loss /= ddp_group.size()
loss.backward()
del target
else:
model.back_helper(output)
if args.ddp_zero:
model.module.back_helper(output)
else:
model.back_helper(output)

del output

torch.nn.utils.clip_grad_value_(model.parameters(), 0.05)
optimizer.step()

if model.group is None or model.group.rank() == model.group.size() - 1:
if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
total_loss += loss.item()
log_interval = 1
word_counter += batch["ntokens"]
Expand Down Expand Up @@ -406,6 +465,17 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion,
print("No regression detected")


def generate_balance_weighted(num_devices, num_layers, fraction=0.5):
balance = []
layers_assigned = 0
average_count = num_layers / num_devices
last_layers = int(average_count * fraction)

balance = generate_balance(num_devices - 1, num_layers - last_layers)
balance.append(last_layers)
return balance


def generate_balance(num_devices, num_layers):
balance = []
layers_assigned = 0
Expand Down Expand Up @@ -460,7 +530,7 @@ def bench_single_process(args):
blob = make_model_and_data(args, None, new_data=new_data)
model = blob["model"]

balance = generate_balance(min(num_devices, 8), len(model))
balance = generate_balance(min(num_devices, 4), len(model))
p = pipe.Pipe(
model, balance, chunks=args.chunks, pipelined_backward=args.pipelined_backward, checkpoint=args.checkpoint
)
Expand All @@ -480,16 +550,17 @@ def run_mp_worker(args, available_workers):
blob = make_model_and_data(args, None, new_data=new_data)
model = blob["model"]

balance = generate_balance(min(available_workers, 8), len(model))
balance = generate_balance_weighted(get_pipeline_parallel_group().size(), len(model), 0.8)
p = pipe.Pipe(
model,
balance,
style=Pipe.MultiProcess,
style=Pipe.AsyncSchedule,
chunks=args.chunks,
worker_map=get_worker_map(),
input_device=torch.cuda.current_device(),
pipelined_backward=args.pipelined_backward,
checkpoint=args.checkpoint,
# loss_fn=blob["criterion"],
).cuda()

if args.all_at_once and p.pipeline:
Expand Down Expand Up @@ -537,18 +608,24 @@ def bench_multi_process(args, all_at_once=False):

def bench_mpi(args):
guess_rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
os.environ["UCX_NET_DEVICES"] = best_device_map[guess_rank]
world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
os.environ["UCX_NET_DEVICES"] = best_device_map[local_rank]

torch.distributed.init_process_group(backend="mpi")
os.environ["MASTER_ADDR"] = args.host
os.environ["MASTER_PORT"] = "10639"
os.environ["MASTER_PORT"] = "10638"
if args.socket_name:
os.environ["GLOO_SOCKET_IFNAME"] = args.socket_name
os.environ["TP_SOCKET_IFNAME"] = args.socket_name

torch.distributed.init_process_group(backend="gloo", rank=guess_rank, world_size=world_size)

os.environ["MASTER_ADDR"] = args.host
os.environ["MASTER_PORT"] = "10639"
init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
torch.cuda.set_device(rank % torch.cuda.device_count())
torch.cuda.set_device(local_rank % torch.cuda.device_count())

rpc.init_rpc(
f"Test{rank}",
Expand All @@ -558,7 +635,12 @@ def bench_mpi(args):
rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(rpc_timeout=20, init_method=init_method),
)

initialize_model_parallel(1, world_size)
backends = {"model_parallel_backend": "nccl", "pipeline_backend": "mpi", "ddp_backend": "nccl"}

if args.ddp_zero:
initialize_model_parallel(1, 4, **backends)
else:
initialize_model_parallel(1, world_size, **backends)
init_random_seed(0)

run_mp_worker(args, world_size)
Expand All @@ -579,6 +661,7 @@ def bench_mpi(args):
parser.add_argument("--max-batch", type=int, default=4, help="Max number of batches")
parser.add_argument("--socket-name", type=str, default=None, help="socket ifname for gloo/tp")
parser.add_argument("--num-decoder-layers", type=int, default=10, help="Number of decoder layers in the model")
parser.add_argument("--ddp-zero", action="store_true", default=False, help="enable ddp")
parser.add_argument(
"--lazy-construction", action="store_true", default=False, help="Number of decoder layers in the model"
)
Expand Down
5 changes: 4 additions & 1 deletion docs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ BUILDDIR = build
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

.PHONY: help Makefile
setup:
pip install -r requirements.txt

.PHONY: help Makefile setup

# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
Expand Down
62 changes: 62 additions & 0 deletions examples/tutorial_pipe_multiprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os

import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import fairscale
from fairscale.nn.model_parallel import initialize_model_parallel


def run(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "10638"
torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size)
os.environ["MASTER_PORT"] = "10639"
torch.distributed.rpc.init_rpc(f"worker{rank}", rank=rank, world_size=world_size)
initialize_model_parallel(1, world_size)

model = nn.Sequential(torch.nn.Linear(10, 10), torch.nn.ReLU(), torch.nn.Linear(10, 5))
target = torch.randint(0, 2, size=(20, 1)).squeeze()
data = torch.randn(20, 10)
loss_fn = F.nll_loss

device = torch.device("cuda", rank)

model = fairscale.nn.Pipe(
model,
balance=[2, 1],
style=fairscale.nn.Pipe.MultiProcess,
worker_map={0: "worker0", 1: "worker1"}, # Needed to convert ranks to RPC worker names
input_device=device,
).to(device)

# define optimizer and loss function
optimizer = optim.SGD(model.parameters(), lr=0.001)

# zero the parameter gradients
optimizer.zero_grad()

# outputs and target need to be on the same device
# forward step
outputs = model(data.to(device))
# compute loss
if rank == 1:
loss = loss_fn(outputs.to(device), target.to(device))

# backward + optimize
loss.backward()
optimizer.step()
else:
model.back_helper(outputs)

print(f"Finished Training Step on {rank}")

del model


if __name__ == "__main__":
world_size = 2
mp.spawn(run, args=(world_size,), nprocs=world_size, join=True)
Loading

0 comments on commit 5d4f50f

Please sign in to comment.