Skip to content

Commit

Permalink
support tutel.checkpoint.* for issue #177 (#181)
Browse files Browse the repository at this point in the history
  • Loading branch information
ghostplant committed Aug 9, 2022
1 parent bffee7c commit a2242e7
Show file tree
Hide file tree
Showing 12 changed files with 206 additions and 54 deletions.
48 changes: 26 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,30 @@ How to setup Tutel MoE for Pytorch and [run examples](tutel/examples), or [enabl
```

How to import Tutel-optimized MoE in Pytorch:
#### How to convert checkpoint files that adapt to different distributed world sizes:
```
# Firstly, using 2 GPUs to train a model with 16 global experts (each GPU holds 8 local experts), saving checkpoint files in the end:
mpiexec -bind-to none -host localhost -x LOCAL_SIZE=2 python3 -m tutel.launcher.run -m tutel.examples.helloworld --num_local_experts=8 --checkpoint=./states/{rank}-of-{size}.ckpt --device=cuda
# Secondly, convert the checkpoint files (based on 2 GPUs) into a single checkpoint file containing all parameters:
python3 -m tutel.checkpoint.gather --inputs=./states/{rank}-of-{size}.ckpt --input_size=2 --output ./model-synthetis.ckpt
# Optionally, you can test the synthetis checkpoint using single CPU device, note that there will be 16 experts locally:
python3 -m tutel.examples.helloworld --num_local_experts=16 --checkpoint=./model-synthetis.ckpt --device=cpu --eval
# Next, convert the synthetis checkpoint file that adapts to distributed training using 8 GPUs:
python3 -m tutel.checkpoint.scatter --input=./model-synthetis.ckpt --output_size=8 --outputs=./adapted-for-8-gpus/{rank}-of-{size}.ckpt
# Then, using generated checkpoint files to train/eval using 8 GPUs, note that there will be 2 local experts this time:
mpiexec -bind-to none -host localhost -x LOCAL_SIZE=8 python3 -m tutel.launcher.run -m tutel.examples.helloworld --num_local_experts=2 --checkpoint=./adapted-for-8-gpus/{rank}-of-{size}.ckpt --device=cuda
# Similarly, the convertion tool also supports X global experts adapting to Y GPUs, where Y % X == 0, making num_local_experts to be -Y / X.
python3 -m tutel.checkpoint.scatter --input=./model-synthetis.ckpt --output_size=32 --outputs=./adapted-for-32-gpus/{rank}-of-{size}.ckpt
mpiexec -bind-to none -host localhost -x LOCAL_SIZE=32 python3 -m tutel.launcher.run -m tutel.examples.helloworld --num_local_experts=-2 --checkpoint=./adapted-for-32-gpus/{rank}-of-{size}.ckpt --device=cuda
```

#### How to import Tutel-optimized MoE in Pytorch:
```
# Input Example:
import torch
Expand Down Expand Up @@ -104,7 +127,7 @@ y = moe_layer(x)
print(y)
```

Usage of MOELayer:
#### Usage of MOELayer:
```
* Usage of MOELayer Args:
Expand All @@ -130,7 +153,7 @@ Usage of MOELayer:
activation_fn : the custom-defined activation function between two linear layers (used for type == 'ffn' only)
```

For Deepspeed MoE Acceleration (Deepspeed MoE Top-1 Gate has integrated Tutel acceleration):
#### For Deepspeed MoE Acceleration (Deepspeed MoE Top-1 Gate has integrated Tutel acceleration):
```sh
# Without Tutel optimization:
python3 -m tutel.examples.helloworld_deepspeed --top=1
Expand All @@ -139,25 +162,6 @@ python3 -m tutel.examples.helloworld_deepspeed --top=1
python3 -m tutel.examples.helloworld_deepspeed --top=1 --use_tutel
```


### Single-GPU Throughput (batches/sec) with default settings on NVIDIA A100 (40GB):
| batch-size | helloworld (top2) | helloworld_ddp (top2) | helloworld_deepspeed (top2) |
| :--------: | :--------: | :------------: | :------------------: |
| 8 | 672.75 | 672.24 | 188.27 |
| 16 | 715.86 | 714.95 | 115.43 |
| 24 | 725.95 | 725.04 | 81.02 |
| 32 | 729.02 | 729.02 | OOM |
| 64 | 687.92 | 686.31 | OOM |
| 128 | 619.75 | 619.03 | OOM |
| 256 | 577.08 | 577.49 | OOM |

How to reproduce these results:
```shell
$ python3 -m tutel.examples.helloworld --batch_size=<batch_size>
$ python3 -m tutel.examples.helloworld_ddp --batch_size=<batch_size>
$ python3 -m tutel.examples.helloworld_deepspeed --batch_size=<batch_size>
```

## Reference
You can consult this [paper](https://arxiv.org/pdf/2206.03382.pdf) below to get to know more technical details about Tutel:
```
Expand Down
6 changes: 0 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,6 @@ def install(use_cuda, use_nccl):
ext_libs += ['nccl']
ext_args += ['-DUSE_NCCL']

for folder in ('build', 'dist',):
try:
shutil.rmtree(os.path.join(root_path, folder))
except:
pass

setup(
name='tutel',
version='0.1',
Expand Down
3 changes: 3 additions & 0 deletions tutel/checkpoint/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

60 changes: 60 additions & 0 deletions tutel/checkpoint/gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
import argparse
import torch
import re

from tutel.system import apply_rank_size_from_pattern

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--input_size', type=int, required=True)
parser.add_argument('--inputs', type=str, required=True)
parser.add_argument('--output', type=str, required=True)
args = parser.parse_args()
args.size = args.input_size

mutate_size, expert_dict = {}, {}

input_file = apply_rank_size_from_pattern(args.inputs, rank=0, size=args.size)
state_dict = torch.load(input_file, map_location=torch.device('cpu'))
for k in state_dict:
if k.endswith('._num_global_experts'):
entry = k[:k.rindex('.')] + '.experts.'
mutate_size[entry] = int(state_dict[k])

if not mutate_size:
raise Exception('No any Tutel MoE layer is found, as the provided checkpoint may be in legacy format. You need to reload this legacy checkpoint by corresponding application, re-checkpoint model\'s state_dict and get the latest format.')

for rank in range(args.size):
input_file = apply_rank_size_from_pattern(args.inputs, rank=rank, size=args.size)
state_dict = torch.load(input_file, map_location=torch.device('cpu'))
for k in state_dict:
for e in mutate_size:
if k.startswith(e):
expert_dict[k] = expert_dict.get(k, [mutate_size[e],]) + [state_dict[k],]

expert_dict = [(i, k, expert_dict[k]) for i, k in enumerate(expert_dict)]
for i, k, v in expert_dict:
num_global_experts, pieces = v[0], v[1:]
if num_global_experts % args.size == 0:
expert_dict[i] = torch.concat(pieces, dim=0).contiguous().clone()
assert expert_dict[i].size(0) == num_global_experts, "Unexpected group size of expert with num_global_experts: %d v.s. %d. Maybe you set a wrong --size value." % (expert_dict[i].size(0), num_global_experts)
elif args.size % num_global_experts == 0:
expert_dict[i] = torch.concat(pieces, dim=0).contiguous()
expert_dict[i] = expert_dict[i].view([num_global_experts, -1] + list(expert_dict[i].shape)[2:]).clone()
else:
raise Exception(f'Neither of "global_experts({num_global_experts}) / args.size({args.size})" nor "args.size({args.size}) / global_experts({num_global_experts})" is evenly divisible.')
expert_dict[i] = (k, expert_dict[i])

expert_dict = dict(expert_dict)
for k in state_dict:
if k in expert_dict:
state_dict[k] = expert_dict[k]
torch.save(state_dict, args.output)

if __name__ == "__main__":
main()

61 changes: 61 additions & 0 deletions tutel/checkpoint/scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import os
import argparse
import torch
import re

from tutel.system import apply_rank_size_from_pattern

def main():
parser = argparse.ArgumentParser()
parser.add_argument('--output_size', type=int, required=True)
parser.add_argument('--input', type=str, required=True)
parser.add_argument('--outputs', type=str, required=True)
args = parser.parse_args()
args.size = args.output_size

state_dict = torch.load(args.input, map_location=torch.device('cpu'))
mutate_size, expert_dict = {}, {}

for k in state_dict:
if k.endswith('._num_global_experts'):
entry = k[:k.rindex('.')] + '.experts.'
mutate_size[entry] = int(state_dict[k])

if not mutate_size:
raise Exception('No any Tutel MoE layer is found, as the provided checkpoint may be in legacy format. You need to reload this legacy checkpoint by corresponding application, re-checkpoint model\'s state_dict and get the latest format.')

for k in state_dict:
for e in mutate_size:
if k.startswith(e):
state = state_dict[k]
shape = state.shape
if shape[0] % args.size == 0:
state = state.view([args.size, shape[0] // args.size] + list(shape)[1:])
elif args.size % shape[0] == 0:
divisor = args.size // shape[0]
for i in range(1, len(shape)):
if shape[i] <= 1:
continue
assert shape[i] % divisor == 0, f"The second non-squeezable dimension is to be sliced to {divisor} pieces from an parameter of shape {shape}, which isn't divisible evenly."
state = state.view([args.size,] + list(shape)[1:i] + [shape[i] // divisor,] + list(shape)[i+1:])
else:
raise Exception(f'Neither of "global_experts({int(shape[0])}) / args.size({args.size})" nor "args.size({args.size}) / global_experts({int(shape[0])})" is evenly divisible.')
expert_dict[k] = state

for rank in range(args.size):
generate_dict = dict()
for k in state_dict:
if k not in expert_dict:
generate_dict[k] = state_dict[k]
else:
generate_dict[k] = expert_dict[k][rank, :].contiguous().clone()

output_file = apply_rank_size_from_pattern(args.outputs, rank=rank, size=args.size)
torch.save(generate_dict, output_file)

if __name__ == "__main__":
main()

10 changes: 5 additions & 5 deletions tutel/examples/helloworld.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
parser.add_argument('--allreduce_degree', type=int, default=1)
parser.add_argument('--num_steps', type=int, default=100)
parser.add_argument('--parallel_type', type=str, default='auto')
parser.add_argument('--save_load_checkpoint', default=False, action='store_true')
parser.add_argument('--checkpoint_path', type=str, default='')
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--use_2dh', default=False, action='store_true')
parser.add_argument('--eval', default=False, action='store_true')
Expand Down Expand Up @@ -89,12 +89,12 @@ def forward(self, input):
model = ExampleModel().to(device)
dist_print(model)

if args.save_load_checkpoint:
checkpoint_path = './distributed-hellworld-%d-in-%d.ckpt' % (parallel_env.global_rank, parallel_env.global_size)
if args.checkpoint_path:
checkpoint_path = system.apply_rank_size_from_pattern(args.checkpoint_path, rank=parallel_env.global_rank, size=parallel_env.global_size)
if os.path.exists(checkpoint_path):
model.load_state_dict(torch.load(checkpoint_path))
else:
print('Checkpoint not loaded: file `%s` is not found' % checkpoint_path)
print('Checkpoint not loaded: file `%s` is not found. Will train the model from start.' % checkpoint_path)

optimizer = torch.optim.SGD(model.parameters(), lr=1e-5)

Expand Down Expand Up @@ -145,5 +145,5 @@ def forward(self, input):
average_time /= 10
dist_print('\n[Summary] Average synchronized step_time = %s sec.' % average_time)

if args.save_load_checkpoint:
if args.checkpoint_path:
torch.save(model.state_dict(), checkpoint_path)
10 changes: 0 additions & 10 deletions tutel/examples/helloworld_ddp_tutel.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
parser.add_argument('--allreduce_degree', type=int, default=1)
parser.add_argument('--num_steps', type=int, default=100)
parser.add_argument('--parallel_type', type=str, default='auto')
parser.add_argument('--save_load_checkpoint', default=False, action='store_true')
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
parser.add_argument('--use_2dh', default=False, action='store_true')
parser.add_argument('--eval', default=False, action='store_true')
Expand Down Expand Up @@ -88,12 +87,6 @@ def forward(self, input):
model = ExampleModel().to(device)
dist_print(model)

if args.save_load_checkpoint:
checkpoint_path = './distributed-hellworld-%d-in-%d.ckpt' % (parallel_env.global_rank, parallel_env.global_size)
if os.path.exists(checkpoint_path):
model.load_state_dict(torch.load(checkpoint_path))
else:
print('Checkpoint not loaded: file `%s` is not found' % checkpoint_path)

optimizer = net.TutelDistributedOptimizer(model.parameters(), group=None, average_shared=True).warp_local(torch.optim.SGD, lr=1e-5)

Expand Down Expand Up @@ -134,6 +127,3 @@ def forward(self, input):

average_time /= 10
dist_print('\n[Summary] Average synchronized step_time = %s sec.' % average_time)

if args.save_load_checkpoint:
torch.save(model.state_dict(), checkpoint_path)
14 changes: 7 additions & 7 deletions tutel/experts/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ def update(self, ctx):

fc1_weight = torch.empty(1, local_experts, hidden_size, model_dim)
fc2_weight = torch.empty(1, local_experts, hidden_size, self.output_dim)
fc1_bias = torch.empty(1, local_experts, 1, hidden_size)
fc2_bias = torch.empty(1, local_experts, 1, (self.output_dim + ctx.sharded_count - 1) // ctx.sharded_count)
fc1_bias = torch.empty(1, local_experts, hidden_size)
fc2_bias = torch.empty(1, local_experts, (self.output_dim + ctx.sharded_count - 1) // ctx.sharded_count)

for i in range(local_experts):
fc1 = torch.nn.Linear(model_dim, hidden_size)
fc2 = torch.nn.Linear(hidden_size, self.output_dim)
fc1_weight[0, i, :, :], fc1_bias[0, i, :, :] = fc1.weight, fc1.bias
fc2_weight[0, i, :, :], fc2_bias[0, i, :, :] = fc2.weight.t(), fc2.bias[:fc2_bias.size(-1)]
fc1_weight[0, i, :, :], fc1_bias[0, i, :] = fc1.weight, fc1.bias
fc2_weight[0, i, :, :], fc2_bias[0, i, :] = fc2.weight.t(), fc2.bias[:fc2_bias.size(-1)]

self.register_parameter(name='batched_fc1_w', param=torch.nn.Parameter(fc1_weight.squeeze(0)))
self.register_parameter(name='batched_fc2_w', param=torch.nn.Parameter(fc2_weight.squeeze(0)))
Expand All @@ -54,8 +54,8 @@ def forward(self, x, ctx):

batched_fc1_w = self.batched_fc1_w
batched_fc2_w = self.batched_fc2_w
batched_fc1_bias = self.batched_fc1_bias
batched_fc2_bias = self.batched_fc2_bias
batched_fc1_bias = self.batched_fc1_bias.unsqueeze(1)
batched_fc2_bias = self.batched_fc2_bias.unsqueeze(1)

if ctx.ffn_zero_group is not None:
if not ctx.use_model_parallel:
Expand All @@ -64,7 +64,7 @@ def forward(self, x, ctx):
batched_fc1_bias = zero_gather(batched_fc1_bias, group=ctx.ffn_zero_group).view(1, 1, -1)

batched_fc2_bias = zero_gather(batched_fc2_bias, group=ctx.ffn_zero_group)
batched_fc2_bias = batched_fc2_bias.view(self.batched_fc2_bias.size(0), self.batched_fc2_bias.size(1), -1)
batched_fc2_bias = batched_fc2_bias.view(self.batched_fc2_bias.size(0), 1, -1)
if batched_fc2_bias.size(-1) != self.output_dim:
batched_fc2_bias = batched_fc2_bias[:, :, :self.output_dim]

Expand Down
4 changes: 2 additions & 2 deletions tutel/impls/communicate.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def simple_reduce_scatter(input, group=None, op=torch.distributed.ReduceOp.SUM):
input = input.contiguous()
assert input.size(0) % world_size == 0, "Cannot evenly devide dim length %s into %s slices" % (input.size(0), world_size)
if not input.is_cuda:
return simple_split(simple_all_reduce(input, group, op=op))
return simple_split(simple_all_reduce(input, group, op=op), group=group)
chunks = list(input.chunk(chunks=world_size, dim=0))
output = torch.empty_like(chunks[0])
dist.reduce_scatter(output=output, input_list=chunks, group=group, op=op)
Expand All @@ -183,7 +183,7 @@ def simple_all_gather(input, group=None):
input = input.contiguous()
output = torch.empty([world_size, input.numel()], device=input.device, dtype=input.dtype)
tensor_list = list(torch.chunk(output, chunks=world_size, dim=0))
dist.all_gather(tensor_list=tensor_list, tensor=input, group=group)
dist.all_gather(tensor_list=tensor_list, tensor=input.view(1, -1), group=group)
return output.view([-1,] + list(input.shape[1:]))

class AllToAllStatus:
Expand Down
27 changes: 26 additions & 1 deletion tutel/impls/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,31 @@ def global_expert_count(num_local_experts, group=None):
assert world_size % -num_local_experts == 0, "Excepting {-num_local_experts} devices to share an expert param, while global device count is {world_size}."
return world_size // -num_local_experts

def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs):
buff_name = prefix + '_num_global_experts'
if buff_name not in state_dict:
logging.warning(f"\033[31mYou are loading a legacy format of checkpoint with at least one Tutel MoE layer inside, which wouldn't support new Tutel feature allowing the number of experts per checkpoint file to mutate.\033[0m")
logging.warning(f"\033[31m The next time you overwrite it with new checkpoint, the recording format will be updated automatically.\033[0m")
logging.warning(f"\033[31m However, the new format won't be compatible with early Tutel versions, unless you force loading it with `model.load_state_dict(.., strict=False)`.\033[0m")
state_dict[buff_name] = self._num_global_experts
else:
state_experts, expect_experts = int(state_dict[buff_name]), self.num_global_experts
assert state_experts == expect_experts, "Failed to load state from checkpoint: the number of global experts mismatch (%s <- %s)" % (expect_experts, state_experts)

for name, param in self.experts.named_parameters():
buff_name = prefix + 'experts.' + name
assert buff_name in state_dict, "Could not find parameter `%s` in state_dict." % buff_name
if state_dict[buff_name].numel() == param.numel():
state_dict[buff_name] = state_dict[buff_name].view(param.shape)
return super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)

def state_dict(self, destination=None, prefix='', keep_vars=False):
return super().state_dict(destination, prefix, keep_vars)

@property
def num_global_experts(self):
return int(self._num_global_experts)

def __init__(
self,
gate_type,
Expand Down Expand Up @@ -71,7 +96,7 @@ def __init__(
self.skip_moe = (int(os.environ.get('SKIP_MOE', '0')) != 0)

self.num_local_experts = experts.pop('count_per_node', 1)
self.num_global_experts = MOELayer.global_expert_count(self.num_local_experts, self.group)
self.register_buffer('_num_global_experts', torch.tensor(MOELayer.global_expert_count(self.num_local_experts, self.group)))

self.world_size = C.get_world_size(self.group)
if self.num_global_experts < self.world_size:
Expand Down
2 changes: 1 addition & 1 deletion tutel/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT license.


# Level-level Ops
# Low-level Ops
from .jit_kernels.gating import fast_cumsum_sub_one
from .impls.fast_dispatch import fast_dispatcher, extract_critical, fast_encode, fast_decode

Expand Down
Loading

0 comments on commit a2242e7

Please sign in to comment.