Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor Parallel #153

Merged
merged 32 commits into from
Aug 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
0d4ea37
fix is_first_layer
Aug 22, 2023
3063afb
tensor parallel
Aug 23, 2023
bdc1ed9
Merge branch 'fix_first_layer' into tensor_parallel
Aug 23, 2023
8648f5b
rm unused code
Aug 23, 2023
763b408
refactor nccl group; remove partition_modules in pipe_layer.py
Aug 24, 2023
4c50567
fix by review comment
Aug 24, 2023
825139c
fix topology
Aug 24, 2023
4ff0f41
fix topology
Aug 24, 2023
a5d7ba6
fix
Aug 24, 2023
2951d70
use ParallelEmbedding
Aug 24, 2023
39319e1
overlap parallel linear backward
Aug 24, 2023
df3fd8f
add tp_comm_stream
Aug 24, 2023
99efba3
fix tp
Achazwl Aug 24, 2023
85dd5ab
Merge branch 'tensor_parallel' into tp
Achazwl Aug 24, 2023
76abcb4
Merge pull request #1 from Achazwl/tp
Aug 24, 2023
f1b4fd7
fix load_state_dict
Aug 25, 2023
677a316
test parallel linear
Aug 25, 2023
743253e
mv zero_level to CheckpointBlock
Aug 25, 2023
4e8c462
merge dev
Aug 25, 2023
604ddfe
fix overlap
Aug 25, 2023
0aee817
gather once in atten
Aug 25, 2023
bd0bad0
fix sub grad_input in parallel linear
Aug 25, 2023
50cdcaf
Merge branch 'dev' into tensor_parallel
zkh2016 Aug 26, 2023
15460b6
fix gather_output
Aug 26, 2023
0e0e05c
Merge branch 'tensor_parallel' of https://github.com/zkh2016/BMTrain …
Aug 26, 2023
b44a62e
fix train.py
Aug 26, 2023
100cd55
fused q,k,v
Aug 26, 2023
fa09468
fix row parallel linear
Aug 26, 2023
37bc403
fix cross entropy
Aug 26, 2023
15c2c48
Update setup.py
zkh2016 Aug 28, 2023
42663c8
overlap send communication in pipeline
Aug 28, 2023
207912b
Merge branch 'tensor_parallel' of https://github.com/zkh2016/BMTrain …
Aug 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 29 additions & 15 deletions bmtrain/block_layer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, Iterable, Iterator, Union, List

from .utils import round_up
from .utils import (round_up, tp_split_tensor)
from .global_var import config
import torch
from . import nccl
Expand Down Expand Up @@ -94,7 +94,8 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_lev
"total": 0,
"storage_type": storage_type,
"requires_grad": param.requires_grad,
"group": param.group
"group": param.group,
"zero_comm" : param._zero_comm
}

param_shape = param._original_shape
Expand All @@ -108,11 +109,14 @@ def __init__(self, inner_module : torch.nn.Module, use_checkpoint=True, zero_lev
offsets = {}
# intialize storage buffers
for kw, val in self._storage_info.items():
val["world_size"] = config["world_size"]
comm = val['zero_comm']
world_size = nccl.commCount(comm)
rank = nccl.commRank(comm)
val["world_size"] = world_size
partition_size = round_up(val["total"], val["world_size"]) // val["world_size"]
val["partition_size"] = partition_size
val["begin"] = config['rank'] * partition_size
val["end"] = (config['rank'] + 1) * partition_size
val["begin"] = rank * partition_size
val["end"] = (rank+1) * partition_size
offsets[kw] = 0


Expand Down Expand Up @@ -302,13 +306,18 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
if key in state_dict:
# load here
input_param = state_dict[key]
param = it['parameter']
tp_mode = param._tp_mode
if input_param.__class__.__name__ == "DistributedTensorWrapper":
input_param = input_param.broadcast()
if input_param.shape != it["shape"]:

verify_shape = torch.Size(it["shape"] if not tp_mode else param._tp_original_shape)
if input_param.shape != verify_shape:
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'
.format(key, input_param.shape, it["shape"]))
.format(key, input_param.shape, verify_shape))
continue

param_st = it["offset"]
param_end = it["offset"] + it["size"]
kw_name = it["kw_name"]
Expand All @@ -322,16 +331,22 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
continue

# copy to buffer
assert input_param.numel() == it["size"]
verify_size = verify_shape.numel()
assert input_param.numel() == verify_size

contiguous_param = input_param.to(it["parameter"].dtype).cuda().contiguous()

tp_split_dim = param._tp_split_dim
if tp_mode and tp_split_dim >= 0:
contiguous_param = tp_split_tensor(contiguous_param, tp_split_dim)

offset_st = max(storage_st - param_st, 0)
offset_end = min(storage_end - param_st, contiguous_param.numel())
assert offset_st < offset_end

to_offset_st = offset_st + param_st - storage_st
to_offset_end = offset_end + param_st - storage_st

# copy to buffer
# PyTorch 1.11 changed the API of storage.__getitem__
d_dtype = self._storage_params[kw_name].dtype
Expand Down Expand Up @@ -398,7 +413,7 @@ def init_parameters(self):
param = it["parameter"]
if isinstance(param, DistributedParameter) and param._init_method is not None:
# initialzie here
tmp_tensor = torch.empty(it["shape"], device=param.device, dtype=param.dtype)
tmp_tensor = torch.empty(param._tp_original_shape, device=param.device, dtype=param.dtype)
param._init_method(tmp_tensor)
param_st = it["offset"]
param_end = it["offset"] + it["size"]
Expand All @@ -412,16 +427,15 @@ def init_parameters(self):
if param_end <= storage_st:
continue

if param._tp_mode and param._tp_split_dim >= 0:
tmp_tensor = tp_split_tensor(tmp_tensor, param._tp_split_dim)
# copy to buffer
assert tmp_tensor.is_contiguous() and it["size"] == tmp_tensor.numel()

offset_st = max(storage_st - param_st, 0)
offset_end = min(storage_end - param_st, tmp_tensor.numel())
offset_st = max(storage_st - param_st, 0)
offset_end = min(storage_end - param_st, tmp_tensor.numel())
assert offset_st < offset_end

to_offset_st = offset_st + param_st - storage_st
to_offset_end = offset_end + param_st - storage_st

# copy to buffer
# PyTorch 1.11 changed the API of storage.__getitem__
d_dtype = self._storage_params[kw_name].dtype
Expand Down
9 changes: 3 additions & 6 deletions bmtrain/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,7 @@ def __init__(self, block : 'CheckpointBlock', ctx_dict : dict = None, pipe = Fal
self._param_tensor = {}
self._grad_tensor = {}
self._need_release = False
if pipe:
self.comm = config["zero_comm"]
else:
self.comm = config["comm"]

def enter(self, flag=0, requires_grad=False):
"""
gather parameters
Expand Down Expand Up @@ -74,7 +71,7 @@ def enter(self, flag=0, requires_grad=False):
nccl.allGather(
self.block._storage_params[kw].storage(),
self._param_buffer[kw],
self.comm
val['zero_comm']
)
nccl.groupEnd()

Expand Down Expand Up @@ -144,7 +141,7 @@ def exit(self, flag=0, backward=False):
self._grad_buffer[kw],
local_param.grad.storage(),
"sum",
self.comm
val['zero_comm']
)
nccl.groupEnd()

Expand Down
84 changes: 61 additions & 23 deletions bmtrain/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,22 @@
from . import nccl
from .synchronize import synchronize


def init_distributed(
init_method : str = "env://",
seed : int = 0,
pipe_size: int = -1,
num_micro_batches: int = None,
tp_size : int = 1,
):
"""Initialize distributed training.
This function will initialize the distributed training, set the random seed and global configurations.
It must be called before any other distributed functions.

Args:
seed (int): The random seed.
pipe_size (int) : pipe_size means that all processes will be divided into pipe_size groups
num_micro_batches (int) : means that the input batchs will be divided into num_micro_batches small batches. used in pipeline mode.
tp_size (int) : tp_size means the size of each of tensor parallel group

**init_distributed** reads the following environment variables:

Expand Down Expand Up @@ -70,10 +73,15 @@ def init_distributed(
config["world_size"] = world_size
config["calc_stream"] = torch.cuda.current_stream()
config["load_stream"] = torch.cuda.Stream(priority=-1)
config["tp_comm_stream"] = torch.cuda.Stream(priority=-1)
config["pp_comm_stream"] = torch.cuda.Stream(priority=-1)
config['barrier_stream'] = torch.cuda.Stream()
config["load_event"] = torch.cuda.Event()
config["tp_size"] = tp_size if tp_size > 0 else 1
config["topology"] = topology(config)
config["zero_rank"] = config["topology"].get_group_rank("zero") if pipe_size > 1 else config['rank']
config["zero_rank"] = config['topology'].get_group_rank("zero")
config["tp_rank"] = config['topology'].get_group_rank("tp")
config["tp_zero_rank"] = config['topology'].get_group_rank("tp_zero")
cpus_this_worker = None

all_available_cpus = sorted(list(os.sched_getaffinity(0)))
Expand Down Expand Up @@ -102,21 +110,34 @@ def init_distributed(

unique_id = bytes.fromhex(store.get("BMTRAIN_UNIQUE_ID").decode())
config['comm'] = nccl.commInitRank(unique_id, world_size, rank)
topo = config['topology']

if config['pipe_enabled']:
config["micros"] = num_micro_batches if num_micro_batches else config["pipe_size"]
topo = config['topology']
if topo.stage_id == 0:
unique_id = nccl.getUniqueId()
store.set(f"PIPE_UNIQUE_ID{topo.pipe_idx}", unique_id.hex())
unique_id = bytes.fromhex(store.get(f"PIPE_UNIQUE_ID{topo.pipe_idx}").decode())
config ['pipe_comm'] = nccl.commInitRank(unique_id, pipe_size, topo.stage_id)
if topo.zero_id == 0:
unique_id = nccl.getUniqueId()
store.set(f"ZERO_UNIQUE_ID{topo.zero_idx}", unique_id.hex() )
unique_id = bytes.fromhex(store.get(f"ZERO_UNIQUE_ID{topo.zero_idx}").decode())
config ['zero_comm'] = nccl.commInitRank(unique_id, world_size//pipe_size, topo.zero_id)
else:
config['zero_comm'] = config['comm']

if topo.tp_id == 0:
unique_id = nccl.getUniqueId()
store.set(f"TP_UNIQUE_ID{topo.tp_idx}", unique_id.hex())
unique_id = bytes.fromhex(store.get(f"TP_UNIQUE_ID{topo.tp_idx}").decode())
config['tp_comm'] = nccl.commInitRank(unique_id, tp_size, topo.tp_id)

if topo.tp_zero_id == 0:
unique_id = nccl.getUniqueId()
store.set(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}", unique_id.hex() )
unique_id = bytes.fromhex(store.get(f"TP_ZERO_UNIQUE_ID{topo.tp_zero_idx}").decode())
config['tp_zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size'] * config['tp_size']), topo.tp_zero_id)

if topo.zero_id == 0:
unique_id = nccl.getUniqueId()
store.set(f"ZERO_UNIQUE_ID{topo.zero_idx}", unique_id.hex() )
unique_id = bytes.fromhex(store.get(f"ZERO_UNIQUE_ID{topo.zero_idx}").decode())
config ['zero_comm'] = nccl.commInitRank(unique_id, world_size//(config['pipe_size']), topo.zero_id)

for i in range(world_size):
if i == rank:
print_dict("Initialization", {
Expand All @@ -129,40 +150,57 @@ def init_distributed(
"cpus": cpus_this_worker
})
synchronize()

class topology:
def __init__(self,config):
# pipe_idx is the idx of the pipeline in the group
self.rank = config['rank']
pp_size = config["pipe_size"]
tp_size = config["tp_size"]
world_size = config["world_size"]
assert world_size % pp_size == 0, "The nums of GPUs must be divisible by the pipeline parallel size"

dp_size = world_size // pp_size
topo=torch.tensor(range(dp_size*pp_size),dtype=torch.int,device='cuda')
topo=topo.view(pp_size,dp_size)
self.pp_group=topo.transpose(0,1).reshape(-1,pp_size)
self.dp_group=topo
self.stage_id = (self.pp_group == self.rank).nonzero()[0,-1].item()
assert world_size % (pp_size * tp_size) == 0, "The nums of GPUs must be divisible by the pipeline parallel size * tensor parallel size"

dp_size = world_size // (pp_size * tp_size)
config['tp_zero_size'] = dp_size
config['zero_size'] = world_size // pp_size
topo=torch.tensor(range(dp_size*tp_size*pp_size),dtype=torch.int,device='cuda')
topo=topo.view(pp_size,dp_size*tp_size)
self.stages = config['pipe_size']
self.pipe_idx = (self.pp_group == self.rank).nonzero()[0, 0].item() # x axes
self.zero_id = self.pipe_idx
self.zero_idx = self.stage_id

stage_size = world_size // pp_size
for i in range(world_size):
self.pipe_idx = self.rank % stage_size
self.stage_id = self.rank // stage_size
self.tp_id = self.rank % tp_size
self.tp_idx = self.rank // tp_size
self.zero_idx = self.stage_id
self.zero_id = self.pipe_idx
self.tp_zero_idx = self.stage_id * tp_size + self.tp_id
self.tp_zero_id = self.pipe_idx // tp_size

self.next_rank = self.stage_id+1 if self.stage_id < config['pipe_size'] - 1 else -1
self.prev_rank = self.stage_id-1 if self.stage_id > 0 else -1
self.tails = self.pp_group[self.pipe_idx, self.stage_id:].tolist()
self.heads = self.pp_group[self.pipe_idx, :self.stage_id + 1].tolist()


def get_group_id(self,group_name):
if group_name == "pipe":
return self.pipe_idx
elif group_name == "zero":
return self.zero_idx
elif group_name == "tp_zero":
return self.tp_zero_idx
elif group_name == "tp":
return self.tp_idx

def get_group_rank(self,group_name):
if group_name == "pipe":
return self.stage_id
elif group_name == "zero":
return self.zero_id
elif group_name == "tp_zero":
return self.tp_zero_id
elif group_name == "tp":
return self.tp_id

def is_initialized() -> bool:
return config["initialized"]
Expand Down
20 changes: 15 additions & 5 deletions bmtrain/layer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import torch
from .parameter import DistributedParameter
from .global_var import config
import itertools
from .utils import tp_split_tensor

class DistributedModule(torch.nn.Module):
"""
Expand All @@ -11,7 +13,7 @@ class DistributedModule(torch.nn.Module):
def __getattr__(self, name: str):
ret = super().__getattr__(name)
# gather distributed parameters if not in CheckpointBlock
if isinstance(ret, DistributedParameter) and not ret._in_checkpoint_block:
if isinstance(ret, DistributedParameter) and not ret._in_checkpoint_block:
return ret.gather()
return ret

Expand All @@ -30,8 +32,11 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
"""
for name, param in self._parameters.items():
if param is not None:
if isinstance(param, DistributedParameter) and not param._in_checkpoint_block:
destination[prefix + name] = param.gather().detach().cpu() # sync operation
if isinstance(param, DistributedParameter):#and not param._in_checkpoint_block:
if param._in_checkpoint_block:
destination[prefix + name] = param.tp_gather().detach().cpu() # sync operation
else:
destination[prefix + name] = param.gather_all().detach().cpu() # sync operation
else:
destination[prefix + name] = param if keep_vars else param.detach().cpu()
for name, buf in self._buffers.items():
Expand Down Expand Up @@ -81,6 +86,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
for name, param in local_state.items():
key = prefix + name
if key in state_dict:
tp_mode = param._tp_mode
input_param = state_dict[key]
if input_param.__class__.__name__ == "DistributedTensorWrapper":
input_param = input_param.broadcast()
Expand All @@ -98,13 +104,17 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
'the shape in current model is {}.'
.format(key, input_param.shape, param.shape))
continue
if not is_param_lazy and isinstance(param, DistributedParameter) and input_param.shape != param._original_shape:
verify_shape = torch.Size(param._original_shape if not tp_mode else param._tp_original_shape)
if not is_param_lazy and isinstance(param, DistributedParameter) and input_param.shape != verify_shape:
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, '
'the shape in current model is {}.'
.format(key, input_param.shape, param.shape))
.format(key, input_param.shape, verify_shape))
try:
with torch.no_grad():
if isinstance(param, DistributedParameter):
tp_split_dim = param._tp_split_dim
if tp_mode and tp_split_dim >= 0:
input_param = tp_split_tensor(input_param, tp_split_dim)
param._copy_data(input_param)
else:
param.copy_(input_param)
Expand Down
Loading