Skip to content

Commit

Permalink
add deepspeed examples (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
msftsw authored Nov 14, 2021
1 parent 9a1e076 commit ef8bcc8
Show file tree
Hide file tree
Showing 6 changed files with 169 additions and 18 deletions.
24 changes: 12 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,28 +48,28 @@ Full Examples & Usage:
```
* Single-GPU Test:
$ python3 -m tutel.examples.helloworld
$ python3 -m tutel.examples.helloworld --batch_size=32 # To Test Tutel-optimized MoE + manual distribution
$ python3 -m tutel.examples.helloworld_ddp --batch_size=32 # To Test Tutel-optimized MoE + Pytorch DDP distribution (requires: Pytorch >= 1.8.0)
$ python3 -m tutel.examples.helloworld_megatron --batch_size=32 # To Test Tutel using Megatron Gating (Tensor Parallel on Experts) + manual distribution
$ python3 -m tutel.examples.helloworld_deepspeed --batch_size=32 # To Test Deepspeed MoE + manual distribution
(If full source code exists:)
$ python3 ./tutel/examples/helloworld.py
(If full source code exists, the following also works:)
$ python3 ./tutel/examples/helloworld.py --batch_size=32
..
* Running MoE Hello World Model by torch.distributed.all_reduce:
$ python3 -m torch.distributed.launch --nproc_per_node=2 -m tutel.examples.helloworld
$ python3 -m torch.distributed.launch --nproc_per_node=2 -m tutel.examples.helloworld --batch_size=32
$ python3 -m torch.distributed.launch --nproc_per_node=2 -m tutel.examples.helloworld_ddp --batch_size=32
..
(For New Pytorch:)
$ python3 -m torch.distributed.run --nproc_per_node=2 -m tutel.examples.helloworld
* Running MoE Hello World Model by torch.nn.parallel.DistributedDataParallel (requires torch >= 1.8.0):
$ python3 -m torch.distributed.launch --nproc_per_node=2 -m tutel.examples.helloworld_ddp
(For New Pytorch:)
$ python3 -m torch.distributed.run --nproc_per_node=2 -m tutel.examples.helloworld_ddp
..
* Usage of MOELayer Args:
gate_type : dict-type gate description, e.g. {'type': 'top', 'k': 2, ..}
gate_type : dict-type gate description, e.g. {'type': 'top', 'k': 2, ..}, or {'type': 'megatron'}
model_dim : the number of channels for MOE's input tensor
experts : a dict-type config for builtin expert network, or a torch.nn.Module-type custom expert network
scan_expert_func : allow users to specify a lambda function to iterate each experts param, e.g. `scan_expert_func = lambda name, param: setattr(param, 'expert', True)`
Expand Down
2 changes: 1 addition & 1 deletion tutel/examples/helloworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
parser = argparse.ArgumentParser()

parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--num_tokens', type=int, default=1024)
parser.add_argument('--model_dim', type=int, default=2048)
parser.add_argument('--hidden_size', type=int, default=2048)
Expand Down
2 changes: 1 addition & 1 deletion tutel/examples/helloworld_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
parser = argparse.ArgumentParser()

parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--num_tokens', type=int, default=1024)
parser.add_argument('--model_dim', type=int, default=2048)
parser.add_argument('--hidden_size', type=int, default=2048)
Expand Down
149 changes: 149 additions & 0 deletions tutel/examples/helloworld_deepspeed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
#!/usr/bin/env python3

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
import time
import torch
import torch.optim as optim
import torch.nn.functional as F
import torch.distributed as dist
from torch import nn
import argparse
import deepspeed

import logging
logging.basicConfig(level=logging.INFO)

parser = argparse.ArgumentParser()

parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--num_tokens', type=int, default=1024)
parser.add_argument('--model_dim', type=int, default=2048)
parser.add_argument('--hidden_size', type=int, default=2048)
parser.add_argument('--num_local_experts', type=int, default=2)
parser.add_argument('--dtype', type=str, default='float32')
parser.add_argument('--fp32_gate', default=False, action='store_true')
parser.add_argument('--top', type=int, default=2)
args = parser.parse_args()

if args.local_rank < 0:
args.local_rank = int(os.environ.get('LOCAL_RANK', 0))

torch.cuda.set_device(args.local_rank)

try:
if dist.is_available():
dist.init_process_group('nccl')
dist_rank = dist.get_rank()
dist_world_size = dist.get_world_size()

def dist_print(*args):
if dist_rank == 0:
print(*args)
except:
dist_rank = 0
dist_world_size = 1
dist_print = print

batch_size = args.batch_size
num_tokens = args.num_tokens
model_dim = args.model_dim
hidden_size = args.hidden_size
num_local_experts = args.num_local_experts
top_value = args.top
local_rank = args.local_rank


device = torch.device('cuda', args.local_rank)

if args.dtype == 'float32':
torch.set_default_dtype(torch.float32)
elif args.dtype == 'float16':
torch.set_default_dtype(torch.float16)
elif args.dtype == 'bfloat16':
torch.set_default_dtype(torch.bfloat16)
else:
raise Exception('Unrecognized data type specified: %s' % args.dtype)

deepspeed.init_distributed()
deepspeed.utils.groups.initialize(ep_size=dist_world_size)

class ExpertModel(torch.nn.Module):
def __init__(self, model_dim, hidden_size, activation_fn):
super().__init__()
self.fc1 = torch.nn.Linear(model_dim, hidden_size, bias=True)
self.fc2 = torch.nn.Linear(hidden_size, model_dim, bias=True)
self.activation_fn = activation_fn
def forward(self, x):
x = self.fc1(x)
x = self.activation_fn(x)
x = self.fc2(x)
return x

class ExampleModel(torch.nn.Module):
def __init__(self):
super().__init__()

self._moe_layer = deepspeed.moe.layer.MoE(
hidden_size = hidden_size,
expert = ExpertModel(model_dim, hidden_size, lambda x: F.relu(x)),
num_experts = num_local_experts * dist_world_size,
k = top_value
).to(device)

for name, param in self._moe_layer.named_parameters():
if '.experts.' in name:
setattr(param, 'skip_allreduce', True)

# Distinguish different parameter types: gate, local_experts
local_count = sum([torch.numel(param) for name, param in self._moe_layer.named_parameters() if '.experts.' in name])
shared_count = sum([torch.numel(param) for name, param in self._moe_layer.named_parameters() if '.gate.' in name])
dist_print('[Statistics] param count for MoE local_experts = %s, param count for MoE gate = %s.\n' % (local_count, shared_count))

def forward(self, input):
result, _, _ = self._moe_layer(input)
result = F.log_softmax(torch.sum(result, dim=2), dim=1)
return result

model = ExampleModel()
dist_print(model)

optimizer = torch.optim.SGD(model.parameters(), lr=1e-5)

x = torch.randn([batch_size, num_tokens, model_dim], device=device, requires_grad=True)
y = torch.LongTensor(batch_size).random_(1).to(device)

tuples = (dist_world_size, args.dtype, model_dim, hidden_size, batch_size * num_tokens, num_local_experts, top_value, device)
dist_print('[Benchmark] world_size = %s, dtype = %s, model_dim = %s, hidden_size = %s, samples = %s, num_local_experts = %s, topK = %s, device = `%s`' % tuples)

average_time, num_steps = 0, 100

params_for_all_reduce = [p for p in model.parameters() if not hasattr(p, 'skip_allreduce') and getattr(p, 'requires_grad', False)]

for i in range(num_steps):

torch.cuda.synchronize()
t_start = time.time()
optimizer.zero_grad()

output = model(x)
loss = F.nll_loss(output, y)
loss.backward()
if dist_world_size > 1:
for p in params_for_all_reduce:
p.grad /= dist_world_size
dist.all_reduce(p.grad)
optimizer.step()

torch.cuda.synchronize()
t_stop = time.time()
dist_print('STEP-%s: DONE, loss = %s, step_time = %s sec.' % (i, float(loss.data), t_stop - t_start))

if i + 10 >= num_steps:
average_time += t_stop - t_start

average_time /= 10
dist_print('\n[Summary] Average synchronized step_time = %s sec.' % average_time)
8 changes: 5 additions & 3 deletions tutel/examples/helloworld_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@
parser = argparse.ArgumentParser()

parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--num_tokens', type=int, default=1024)
parser.add_argument('--model_dim', type=int, default=2048)
parser.add_argument('--hidden_size', type=int, default=2048)
parser.add_argument('--num_local_experts', type=int, default=2)
parser.add_argument('--dtype', type=str, default='float32')
parser.add_argument('--l_aux_wt', type=float, default=0.0)
args = parser.parse_args()
Expand Down Expand Up @@ -56,6 +57,7 @@ def dist_print(*args):
num_tokens = args.num_tokens
model_dim = args.model_dim
hidden_size = args.hidden_size
num_local_experts = args.num_local_experts
local_rank = args.local_rank


Expand All @@ -77,7 +79,7 @@ def __init__(self):

self._moe_layer = tutel_moe.moe_layer(
gate_type = {'type': 'megatron'},
experts = {'type': 'ffn', 'hidden_size_per_expert': hidden_size, 'activation_fn': lambda x: F.relu(x)},
experts = {'type': 'ffn', 'hidden_size_per_expert': hidden_size * num_local_experts, 'activation_fn': lambda x: F.relu(x)},
model_dim = model_dim,
scan_expert_func = lambda name, param: setattr(param, 'skip_allreduce', True),
seeds = (1, dist_rank + 1, 1),
Expand All @@ -101,7 +103,7 @@ def forward(self, input):
x = torch.randn([batch_size, num_tokens, model_dim], device=device, requires_grad=True)
y = torch.LongTensor(batch_size).random_(1).to(device)

tuples = (dist_world_size, args.dtype, model_dim, hidden_size, batch_size * num_tokens, 1, device)
tuples = (dist_world_size, args.dtype, model_dim, hidden_size, batch_size * num_tokens, num_local_experts, device)
dist_print('[Benchmark] world_size = %s, dtype = %s, model_dim = %s, hidden_size = %s, samples = %s, num_local_experts = %s, gate = megatron, device = `%s`' % tuples)

average_time, num_steps = 0, 100
Expand Down
2 changes: 1 addition & 1 deletion tutel/impls/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class MOELayer(torch.nn.Module):
"""Tutel optimized MOELayer
Args:
gate_type : dict-type gate description, e.g. {'type': 'top', 'k': 2, ..}
gate_type : dict-type gate description, e.g. {'type': 'top', 'k': 2, ..}, or {'type': 'megatron'}
model_dim : the number of channels for MOE's input tensor
experts : a dict-type config for builtin expert network, or a torch.nn.Module-type custom expert network
scan_expert_func : allow users to specify a lambda function to iterate each experts param, e.g. `scan_expert_func = lambda name, param: setattr(param, 'expert', True)`
Expand Down

0 comments on commit ef8bcc8

Please sign in to comment.