Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pipe rpc #156

Merged
merged 8 commits into from
Nov 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 104 additions & 21 deletions benchmarks/pipe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

import argparse
import logging
import math
import os
import time
Expand All @@ -11,14 +12,17 @@
from torch.distributed import rpc
import torch.multiprocessing as mp
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
import torchtext
from torchtext.data.utils import get_tokenizer

from fairscale.nn import Pipe
from fairscale.nn.model_parallel import initialize_model_parallel
from fairscale.nn.pipe import pipe
from fairscale.nn.model_parallel.initialize import get_data_parallel_group, get_pipeline_parallel_group
from fairscale.nn.pipe import LazyModule, pipe
from fairscale.optim import GradScaler
from fairscale.optim.oss import OSS
from tests.nn.model_parallel.commons import dist_init, get_worker_map

try:
Expand Down Expand Up @@ -164,13 +168,13 @@ def make_model(args, device, ntokens):

if args.lazy_construction:
layers = [
lambda: EmbeddingLayer(ntokens, ninp, initrange),
lambda: PositionalEncodingLayer(ninp, dropout),
LazyModule(lambda: EmbeddingLayer(ntokens, ninp, initrange)),
LazyModule(lambda: PositionalEncodingLayer(ninp, dropout)),
]
for _ in range(ndecoder):
layers.append(lambda: TransformerDecoderLayer(ninp, nhead, nhid, dropout))
layers.append(LazyModule(lambda: TransformerDecoderLayer(ninp, nhead, nhid, dropout)))

layers.append(lambda: LinearLayer(ninp, ntokens, initrange))
layers.append(LazyModule(lambda: LinearLayer(ninp, ntokens, initrange)))
model = layers
else:
model = TransformerLMSequntial(ntokens, ninp, nhead, nhid, dropout, initrange, ndecoder).to(device)
Expand All @@ -179,7 +183,10 @@ def make_model(args, device, ntokens):
lr = 0.01 # learning rate

def make_adam(model):
return Adam(model.parameters(), lr=lr)
if args.ddp_zero:
return OSS(params=model.parameters(), optim=Adam, group=get_data_parallel_group(), lr=lr)
else:
return Adam(model.parameters(), lr=lr)

optimizer = make_adam
scaler = GradScaler()
Expand Down Expand Up @@ -276,9 +283,17 @@ def train(lm_dataloader, model, criterion, optimizer, vocab_size, args):

num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters()))
if model.group:
print(f"training model, #prams = {num_params}, group: {model.group.rank()}, sizes {model.group.size()}")
total = torch.Tensor([num_params]).cuda()
torch.distributed.all_reduce(total, group=model.group)
logging.info(
f"training model, #prams = {num_params}, group: {model.group.rank()}, grank:"
f" {torch.distributed.get_rank()}, sizes {model.group.size()}"
)
torch.distributed.barrier()
if model.group.rank() == 0:
logging.info(f"total #prams = {total.item()}")
else:
print(f"training model, #prams = {num_params}")
logging.info(f"training model, #prams = {num_params}")
vocab_size = 10000 # FIXME
total_loss = 0.0
start_time = time.time()
Expand All @@ -287,37 +302,81 @@ def train(lm_dataloader, model, criterion, optimizer, vocab_size, args):
optimizer = optimizer(model)

def get_first_device(model):
if isinstance(model, DDP):
model = model.module

if model.devices:
return model.devices[0]
else:
return torch.cuda.current_device()

def get_last_device(model):
if isinstance(model, DDP):
model = model.module
if model.devices:
return model.devices[-1]
else:
return torch.cuda.current_device()

pipe_group = model.group

if args.ddp_zero:
model = DDP(
model,
device_ids=[torch.cuda.current_device()],
process_group=get_data_parallel_group(),
find_unused_parameters=False,
)

if pipe_group and pipe_group.rank() != 0 and pipe_group.rank() != (pipe_group.size() - 1):
thing = {"input": torch.zeros(args.batch_size)}

class FakeDataset:
def __getitem__(self, index):
return thing

def __len__(self):
return len(lm_dataloader)

lm_dataloader = FakeDataset()

for i, batch in enumerate(lm_dataloader):
bi = batch["input"]
if args.max_batch and i > args.max_batch:
break
optimizer.zero_grad()
output = model(batch["input"].to(get_first_device(model)))

if model.group is None or model.group.rank() == model.group.size() - 1:
try:
if (pipe_group is None or pipe_group.rank() == 0) and not args.ddp_zero:
tmp = batch["input"].to(get_first_device(model))
output = model(tmp)
else:
output = model(batch["input"])
except Exception as e:
raise RuntimeError(f"training failed on {torch.distributed.get_rank()}") from e

if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
target = batch["target"].to(get_last_device(model))
output = output.to(target.device)

loss = criterion(output.view(-1, vocab_size), target.view(-1))
if args.ddp_zero:
ddp_group = get_data_parallel_group()
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM, group=ddp_group)
loss /= ddp_group.size()
loss.backward()
del target
else:
model.back_helper(output)
if args.ddp_zero:
model.module.back_helper(output)
else:
model.back_helper(output)

del output

torch.nn.utils.clip_grad_value_(model.parameters(), 0.05)
optimizer.step()

if model.group is None or model.group.rank() == model.group.size() - 1:
if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1:
total_loss += loss.item()
log_interval = 1
word_counter += batch["ntokens"]
Expand Down Expand Up @@ -406,6 +465,17 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion,
print("No regression detected")


def generate_balance_weighted(num_devices, num_layers, fraction=0.5):
balance = []
layers_assigned = 0
average_count = num_layers / num_devices
last_layers = int(average_count * fraction)

balance = generate_balance(num_devices - 1, num_layers - last_layers)
balance.append(last_layers)
return balance


def generate_balance(num_devices, num_layers):
balance = []
layers_assigned = 0
Expand Down Expand Up @@ -460,7 +530,7 @@ def bench_single_process(args):
blob = make_model_and_data(args, None, new_data=new_data)
model = blob["model"]

balance = generate_balance(min(num_devices, 8), len(model))
balance = generate_balance(min(num_devices, 4), len(model))
p = pipe.Pipe(
model, balance, chunks=args.chunks, pipelined_backward=args.pipelined_backward, checkpoint=args.checkpoint
)
Expand All @@ -480,16 +550,17 @@ def run_mp_worker(args, available_workers):
blob = make_model_and_data(args, None, new_data=new_data)
model = blob["model"]

balance = generate_balance(min(available_workers, 8), len(model))
balance = generate_balance_weighted(get_pipeline_parallel_group().size(), len(model), 0.8)
p = pipe.Pipe(
model,
balance,
style=Pipe.MultiProcess,
style=Pipe.AsyncSchedule,
chunks=args.chunks,
worker_map=get_worker_map(),
input_device=torch.cuda.current_device(),
pipelined_backward=args.pipelined_backward,
checkpoint=args.checkpoint,
# loss_fn=blob["criterion"],
).cuda()

if args.all_at_once and p.pipeline:
Expand Down Expand Up @@ -537,18 +608,24 @@ def bench_multi_process(args, all_at_once=False):

def bench_mpi(args):
guess_rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
os.environ["UCX_NET_DEVICES"] = best_device_map[guess_rank]
world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"])
os.environ["UCX_NET_DEVICES"] = best_device_map[local_rank]

torch.distributed.init_process_group(backend="mpi")
os.environ["MASTER_ADDR"] = args.host
os.environ["MASTER_PORT"] = "10639"
os.environ["MASTER_PORT"] = "10638"
if args.socket_name:
os.environ["GLOO_SOCKET_IFNAME"] = args.socket_name
os.environ["TP_SOCKET_IFNAME"] = args.socket_name

torch.distributed.init_process_group(backend="gloo", rank=guess_rank, world_size=world_size)

os.environ["MASTER_ADDR"] = args.host
os.environ["MASTER_PORT"] = "10639"
init_method = f"tcp://{os.environ['MASTER_ADDR']}:{os.environ['MASTER_PORT']}"
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
torch.cuda.set_device(rank % torch.cuda.device_count())
torch.cuda.set_device(local_rank % torch.cuda.device_count())

rpc.init_rpc(
f"Test{rank}",
Expand All @@ -558,7 +635,12 @@ def bench_mpi(args):
rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(rpc_timeout=20, init_method=init_method),
)

initialize_model_parallel(1, world_size)
backends = {"model_parallel_backend": "nccl", "pipeline_backend": "mpi", "ddp_backend": "nccl"}

if args.ddp_zero:
initialize_model_parallel(1, 4, **backends)
else:
initialize_model_parallel(1, world_size, **backends)
init_random_seed(0)

run_mp_worker(args, world_size)
Expand All @@ -579,6 +661,7 @@ def bench_mpi(args):
parser.add_argument("--max-batch", type=int, default=4, help="Max number of batches")
parser.add_argument("--socket-name", type=str, default=None, help="socket ifname for gloo/tp")
parser.add_argument("--num-decoder-layers", type=int, default=10, help="Number of decoder layers in the model")
parser.add_argument("--ddp-zero", action="store_true", default=False, help="enable ddp")
parser.add_argument(
"--lazy-construction", action="store_true", default=False, help="Number of decoder layers in the model"
)
Expand Down
5 changes: 4 additions & 1 deletion docs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ BUILDDIR = build
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

.PHONY: help Makefile
setup:
pip install -r requirements.txt

.PHONY: help Makefile setup

# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
Expand Down
62 changes: 62 additions & 0 deletions examples/tutorial_pipe_multiprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os

import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import fairscale
from fairscale.nn.model_parallel import initialize_model_parallel


def run(rank, world_size):
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "10638"
torch.distributed.init_process_group(backend="nccl", rank=rank, world_size=world_size)
os.environ["MASTER_PORT"] = "10639"
torch.distributed.rpc.init_rpc(f"worker{rank}", rank=rank, world_size=world_size)
initialize_model_parallel(1, world_size)

model = nn.Sequential(torch.nn.Linear(10, 10), torch.nn.ReLU(), torch.nn.Linear(10, 5))
target = torch.randint(0, 2, size=(20, 1)).squeeze()
data = torch.randn(20, 10)
loss_fn = F.nll_loss

device = torch.device("cuda", rank)

model = fairscale.nn.Pipe(
model,
balance=[2, 1],
style=fairscale.nn.Pipe.MultiProcess,
worker_map={0: "worker0", 1: "worker1"}, # Needed to convert ranks to RPC worker names
input_device=device,
).to(device)

# define optimizer and loss function
optimizer = optim.SGD(model.parameters(), lr=0.001)

# zero the parameter gradients
optimizer.zero_grad()

# outputs and target need to be on the same device
# forward step
outputs = model(data.to(device))
# compute loss
if rank == 1:
loss = loss_fn(outputs.to(device), target.to(device))

# backward + optimize
loss.backward()
optimizer.step()
else:
model.back_helper(outputs)

print(f"Finished Training Step on {rank}")

del model


if __name__ == "__main__":
world_size = 2
mp.spawn(run, args=(world_size,), nprocs=world_size, join=True)
Loading