Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] [pipeline parallel] t5 - experiment #9765

Closed
wants to merge 14 commits into from
2 changes: 2 additions & 0 deletions examples/legacy/seq2seq/finetune_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def main():
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
# Set the verbosity to info of the Transformers logger (on main process only):
transformers.utils.logging.enable_default_handler()
transformers.utils.logging.enable_explicit_format()
if is_main_process(training_args.local_rank):
transformers.utils.logging.set_verbosity_info()
logger.info("Training/evaluation parameters %s", training_args)
Expand Down
149 changes: 148 additions & 1 deletion src/transformers/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,153 @@ def rewrite_logs(d):
return new_d


def init_deepspeed(trainer, num_training_steps):

import torch


# Model parallel group that the current rank belongs to.
_MODEL_PARALLEL_GROUP = None
_MODEL_PARALLEL_GROUP_DEVICE_IDS = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
_DATA_PARALLEL_GROUP_DEVICE_IDS = None

# adjusted from Megatron-LM/mpu/
class MPU:
def initialize_model_parallel(self, model_parallel_size_):
"""
Initialize model data parallel groups.

Arguments:
model_parallel_size: number of GPUs used to parallelize model.
**Important**: not the total number of gpus!

Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
use 2 GPUs to parallelize the model. The present function will
create 4 model parallel groups and 2 data parallel groups as:
4 model parallel groups:
[g0, g1], [g2, g3], [g4, g5], [g6, g7]
2 data parallel groups:
[g0, g2, g4, g6], [g1, g3, g5, g7]

Note that for efficiency, the caller should make sure adjacent ranks
are on the same DGX box. For example if we are using 2 DGX-1 boxes
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
ranks 8 to 15 belong to the second box.

Let's say we have a total of 4 GPUs denoted by g0 ... g3 and we
use 2 GPUs to parallelize the model. The present function will
create 2 model parallel groups and 2 data parallel groups as:
2 model parallel groups:
[g0, g1], [g2, g3]
2 data parallel groups:
[g0, g2], [g1, g3]

"""

def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, "{} is not divisible by {}".format(numerator, denominator)

if torch.distributed.get_rank() == 0:
print("> initializing model parallel with size {}".format(model_parallel_size_))
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
model_parallel_size = min(model_parallel_size_, world_size)
ensure_divisibility(world_size, model_parallel_size)
rank = torch.distributed.get_rank()

print(f"MP size: {model_parallel_size}")
print(f"world_size: {world_size}")
print(f"rank: {rank}")

# Build the data parallel groups.
global _DATA_PARALLEL_GROUP
global _DATA_PARALLEL_GROUP_DEVICE_IDS
assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized"
for i in range(model_parallel_size):
ranks = range(i, world_size, model_parallel_size)
group = torch.distributed.new_group(ranks)
if i == (rank % model_parallel_size):
#print(f"DP ranks: {list(ranks)}")
_DATA_PARALLEL_GROUP = group
_DATA_PARALLEL_GROUP_DEVICE_IDS = list(ranks)

# Build the model parallel groups.
global _MODEL_PARALLEL_GROUP
global _MODEL_PARALLEL_GROUP_DEVICE_IDS
assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized"
for i in range(world_size // model_parallel_size):
ranks = range(i * model_parallel_size, (i + 1) * model_parallel_size)
group = torch.distributed.new_group(ranks)
if i == (rank // model_parallel_size):
#print(f"MP ranks: {list(ranks)}")
_MODEL_PARALLEL_GROUP = group
_MODEL_PARALLEL_GROUP_DEVICE_IDS = list(ranks)

def model_parallel_is_initialized(self):
"""Check if model and data parallel groups are initialized."""
if _MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None:
return False
return True

def get_model_parallel_group_device_ids(self):
"""Get the model parallel device ids of the group the caller rank belongs to."""
assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized"
return _MODEL_PARALLEL_GROUP_DEVICE_IDS

def get_model_parallel_group(self):
"""Get the model parallel group the caller rank belongs to."""
assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized"
return _MODEL_PARALLEL_GROUP

def get_data_parallel_group_device_ids(self):
"""Get the data parallel device ids of the group the caller rank belongs to."""
assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized"
return _DATA_PARALLEL_GROUP_DEVICE_IDS

def get_data_parallel_group(self):
"""Get the data parallel group the caller rank belongs to."""
assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized"
return _DATA_PARALLEL_GROUP

def get_model_parallel_world_size(self):
"""Return world size for the model parallel group."""
return torch.distributed.get_world_size(group=self.get_model_parallel_group())

def get_model_parallel_rank(self):
"""Return my rank for the model parallel group."""
return torch.distributed.get_rank(group=self.get_model_parallel_group())

def get_model_parallel_src_rank(self):
"""Calculate the global rank corresponding to a local rank zero
in the model parallel group."""
global_rank = torch.distributed.get_rank()
local_world_size = get_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size

def get_data_parallel_world_size(self):
"""Return world size for the data parallel group."""
return torch.distributed.get_world_size(group=self.get_data_parallel_group())

def get_data_parallel_rank(self):
"""Return my rank for the data parallel group."""
return torch.distributed.get_rank(group=self.get_data_parallel_group())

def destroy_model_parallel(self):
"""Set the groups to none."""
global _MODEL_PARALLEL_GROUP
global _MODEL_PARALLEL_GROUP_DEVICE_IDS
global _DATA_PARALLEL_GROUP
global _DATA_PARALLEL_GROUP_DEVICE_IDS
_MODEL_PARALLEL_GROUP = None
_MODEL_PARALLEL_GROUP_DEVICE_IDS = None
_DATA_PARALLEL_GROUP = None
_DATA_PARALLEL_GROUP_DEVICE_IDS = None


def init_deepspeed(trainer, num_training_steps, mpu):
"""
Init DeepSpeed, after converting any relevant Trainer's args into DeepSpeed configuration

Expand Down Expand Up @@ -415,6 +561,7 @@ def init_deepspeed(trainer, num_training_steps):
model=model,
model_parameters=model_parameters,
config_params=config,
mpu = mpu,
)

return model, optimizer, lr_scheduler
Expand Down
67 changes: 67 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1793,3 +1793,70 @@ def forward(self, hidden_states):
return torch.cat(output_chunks, dim=chunk_dim)

return forward_fn(*input_tensors)


def recursive_to(device, item):
"""
Switch any tensors found in `item` to `device`. Currently can handle a single tensor, or any of the nested list,
tuple and dict structures.
"""

if torch.is_tensor(item):
return item.to(device)

elif isinstance(item, list):
for i, x in enumerate(item):
item[i] = recursive_to(device, x)
return item

elif isinstance(item, tuple):
return tuple(recursive_to(device, list(item)))

elif isinstance(item, dict):
for k, v in item.items():
item[k] = recursive_to(device, v)
return item

else:
return item


# tnone = torch.tensor([float('nan')]*batch_size)
def pipe_none_or_empty_to_torch(x, batch_size, device):
tnone = torch.tensor([-100] * batch_size).to(device)
tempty = torch.empty(0).to(device)
if x is None:
return tnone.to(device)
if x == ():
return tempty.to(device)
return x


def pipe_torch_to_none_or_empty(x, batch_size, device):
tnone = torch.tensor([-100] * batch_size).to(device)
# tempty = torch.empty(0).to(device)
# if torch.is_tensor(x):
# print(x.shape, x)
# else:
# print(x)
if torch.is_tensor(x) and x.shape[0] == batch_size:
if not x.numel():
return ()
# print(x.numel(), batch_size, x, tnone)
if x.shape == tnone.shape and all(x == tnone):
return None
return x


def pipe_encode_all(input, batch_size, device):
input = list(input)
for i, x in enumerate(input):
input[i] = pipe_none_or_empty_to_torch(x, batch_size, device)
return tuple(input)


def pipe_decode_all(input, batch_size, device):
input = list(input)
for i, x in enumerate(input):
input[i] = pipe_torch_to_none_or_empty(x, batch_size, device)
return tuple(input)
Loading