Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
db017fd
Minor tweaks to support Megatron 2.4 + DS 3D
Jun 6, 2021
407ff0f
pipe partitioning
Jun 6, 2021
a096d32
re-enable grad buffer partitioning
Jun 11, 2021
9b4093b
Avoid partitioning small activations
tjruwase Jun 11, 2021
182be7b
Merge pull request #4 from ShadenSmith/olruwase/partition_activation
tjruwase Jun 11, 2021
3e948df
send/recv
Jun 13, 2021
b6a2cb3
isend/irecv missing wait
Jun 13, 2021
6bb63b8
turn off async ops
Jun 14, 2021
8097690
Merge branch 'megatron2.4-3d-sendrecv' into megatron2.4-3d
Jun 14, 2021
bd9e953
less verbose load
Jun 26, 2021
081ddb5
Merge branch 'master' into megatron2.4-3d
jeffra Jun 30, 2021
d26c258
added shaden's set_train_batch_size patches, plus formatting
jeffra Jul 13, 2021
9dbfdbd
Adds engine.was_step_applied() (#1251)
Jul 26, 2021
d6945de
Cleaning up tensor/pipe parallel accounting. (#1252)
Jul 26, 2021
f93e22b
Correctness fix PP+ZeRO for gradient accumulation + updates from mast…
jeffra Jul 30, 2021
e9b5dff
dont clear grads in stage 1 code path
jeffra Jul 31, 2021
4b35409
prevent none grads from being reduced
jeffra Jul 31, 2021
bc17042
fix empty grad zero tests
jeffra Aug 2, 2021
6b42882
Use mpu in DeepSpeedConfig() call (#1271)
tjruwase Aug 9, 2021
cce85b8
API for obtaining global gradient norm (#1292)
tjruwase Aug 9, 2021
e65e511
turn excessive noise off (#1293)
stas00 Aug 11, 2021
db2f8a0
[zero] restore fp16 params if no zero ckpts available (#1322)
jeffra Aug 25, 2021
72ce55a
Fix PP checkpoint bloat (#1324)
tjruwase Aug 25, 2021
c7f3bc5
update for cuda-11.4 (#1329)
stas00 Aug 30, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions deepspeed/runtime/activation_checkpointing/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from deepspeed.runtime.config import DeepSpeedConfig
from deepspeed.utils import logger
from deepspeed.runtime.utils import move_to_device, see_memory_usage
from deepspeed.runtime.utils import move_to_device, see_memory_usage, bwc_tensor_model_parallel_rank
from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers

# DeepSpeed Checkpointing Enabled or Disabled
Expand Down Expand Up @@ -213,9 +213,12 @@ def model_parallel_cuda_manual_seed(seed):
model parallel regions.
"""
global mpu

tp_rank = bwc_tensor_model_parallel_rank(mpu)

# 2718 is just for fun and any POSITIVE value will work.
offset = seed + 2718
model_parallel_seed = offset + mpu.get_model_parallel_rank()
model_parallel_seed = offset + tp_rank
# Data parallel gets the original sedd.
data_parallel_seed = seed

Expand All @@ -225,7 +228,7 @@ def model_parallel_cuda_manual_seed(seed):
'model parallel rank {}, and data parallel rank {} with '
'model parallel seed: {} and data parallel seed: {}'.format(
torch.distributed.get_rank(),
mpu.get_model_parallel_rank(),
tp_rank,
mpu.get_data_parallel_rank(),
model_parallel_seed,
data_parallel_seed),
Expand Down Expand Up @@ -384,9 +387,14 @@ def save_args_for_backward(*all_args):
global data_offsets, size_offsets
if mp_rank is None:
if mpu is not None:
mp_rank = mpu.get_model_parallel_rank()
mp_size = mpu.get_model_parallel_world_size()
mp_group = mpu.get_model_parallel_group()
if hasattr(mpu, 'get_tensor_model_parallel_rank'):
mp_rank = mpu.get_tensor_model_parallel_rank()
mp_size = mpu.get_tensor_model_parallel_world_size()
mp_group = mpu.get_tensor_model_parallel_group()
else:
mp_rank = mpu.get_model_parallel_rank()
mp_size = mpu.get_model_parallel_world_size()
mp_group = mpu.get_model_parallel_group()
else:
mp_rank = 0
mp_size = 1
Expand Down
102 changes: 92 additions & 10 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
ZERO_OPTIMIZATION_OPTIMIZER_STATES, ZERO_OPTIMIZATION_GRADIENTS, ZERO_OPTIMIZATION_WEIGHTS
from deepspeed.runtime.csr_tensor import CSRTensor
import deepspeed.runtime.lr_schedules as lr_schedules
from deepspeed.runtime.utils import get_grad_norm
from deepspeed.utils import logger, log_dist, init_distributed
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
from deepspeed.utils.debug import debug_extract_module_and_param_names
Expand Down Expand Up @@ -122,6 +123,8 @@ def __init__(self,
self.block_eigenvalue = None
self.gas_boundary_ctr = 0
self.dist_backend = "nccl"
self._step_applied = False
self._global_grad_norm = None

# for debug purposes - can then debug print: debug_get_module_name(module)
debug_extract_module_and_param_names(model)
Expand Down Expand Up @@ -234,6 +237,40 @@ def get_batch_info(self):
"""
return self.train_batch_size, self.train_micro_batch_size_per_gpu, self.gradient_accumulation_steps

def set_train_batch_size(self, train_batch_size):
"""Adjust the global batch size by increasing or decreasing the number of
micro-batches (i.e., gradient accumulation steps). The size of each micro-batch
(i.e., ``train_micro_batch_size_per_gpu``) is not changed.
Args:
train_batch_size (int): The new global batch size for training.
Raises:
ValueError: if ``train_batch_size`` is not divisible by the
configured micro-batch size and data parallelism.
"""
if train_batch_size % (self.train_micro_batch_size_per_gpu() *
self.dp_world_size) != 0:
#print(f'{train_batch_size=} {self.train_micro_batch_size_per_gpu()=} {self.dp_world_size=}')
raise ValueError(
f'Train batch size must be divisible by micro-batch data parallelism')
new_gas = train_batch_size // (self.train_micro_batch_size_per_gpu() *
self.dp_world_size)
# overwrite config
self._config.train_batch_size = train_batch_size
self._config.gradient_accumulation_steps = new_gas

def get_global_grad_norm(self) -> float:
"""Return the 2-norm of all gradients. If there is model parallelism,
the norm will be global.

The computed norm will be cached and reused until the next step() pass.
.. note::
In the presence of model parallelism, this is a collective call
and acts as a barrier among ``mpu.get_model_parallel_group()``.
Returns:
float: norm
"""
return self._global_grad_norm

def checkpoint_tag_validation_enabled(self):
return self._config.checkpoint_tag_validation_enabled

Expand Down Expand Up @@ -479,6 +516,9 @@ def steps_per_print(self):
def zero_allgather_partitions(self):
return self._config.zero_config.allgather_partitions

def zero_round_robin_gradients(self):
return self._config.zero_config.round_robin_gradients

def dump_state(self):
return self._config.dump_state

Expand Down Expand Up @@ -570,7 +610,7 @@ def _configure_with_arguments(self, args, mpu):
if "OMPI_COMM_WORLD_LOCAL_RANK" in os.environ:
ompi_local_rank = os.environ.get("OMPI_COMM_WORLD_LOCAL_RANK")
local_rank = os.environ.get('LOCAL_RANK', ompi_local_rank)
assert ompi_local_rank == local_rank, f"LOCAL_RANK ({local_rank}) != OMPI_COMM_WORLD_LOCAL_RANK ({mpi_local_rank}), " \
assert ompi_local_rank == local_rank, f"LOCAL_RANK ({local_rank}) != OMPI_COMM_WORLD_LOCAL_RANK ({ompi_local_rank}), " \
"not sure how to proceed as we're seeing conficting local rank info."
os.environ['LOCAL_RANK'] = local_rank

Expand Down Expand Up @@ -904,6 +944,15 @@ def _configure_zero_optimizer(self, optimizer):
gradient_predivide=self.gradient_predivide)
elif zero_stage <= ZERO_OPTIMIZATION_GRADIENTS:
overlap_comm = self.zero_overlap_comm()
contiguous_gradients = self.zero_contiguous_gradients()
round_robin_gradients = self.zero_round_robin_gradients()

# Overlap and contiguous grads are meaningless in stage 1 and are ignored
if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
overlap_comm = False
contiguous_gradients = False
round_robin_gradients = False

if isinstance(self.module, PipelineModule):
if overlap_comm:
logger.warning(
Expand All @@ -918,7 +967,7 @@ def _configure_zero_optimizer(self, optimizer):
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=self.dynamic_loss_scale_args(),
clip_grad=self.gradient_clipping(),
contiguous_gradients=self.zero_contiguous_gradients(),
contiguous_gradients=contiguous_gradients,
reduce_bucket_size=self.zero_reduce_bucket_size(),
allgather_bucket_size=self.zero_allgather_bucket_size(),
dp_process_group=self.data_parallel_group,
Expand All @@ -930,7 +979,8 @@ def _configure_zero_optimizer(self, optimizer):
gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_accumulation_steps=self.gradient_accumulation_steps(),
ignore_unused_parameters=self.zero_ignore_unused_parameters(),
partition_grads=zero_stage == ZERO_OPTIMIZATION_GRADIENTS)
partition_grads=zero_stage == ZERO_OPTIMIZATION_GRADIENTS,
round_robin_gradients=round_robin_gradients)
elif zero_stage == ZERO_OPTIMIZATION_WEIGHTS:
logger.info("Initializing ZeRO Stage 3") if dist.get_rank() == 0 else None
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
Expand Down Expand Up @@ -992,6 +1042,18 @@ def is_iterable_style_dataset(obj):
torch.utils.data.IterableDataset
) # hasattr(obj, "__iter__") should work as well

def was_step_applied(self) -> bool:
"""Returns True if the latest ``step()`` produced in parameter updates.

Note that a ``False`` return is not an error condition. Steps are frequently
no-ops, such as between gradient accumulation boundaries or when overflows
occur.

Returns:
bool: Whether the latest ``step()`` modified model parameters.
"""
return self._step_applied

def deepspeed_io(self,
dataset,
batch_size=None,
Expand Down Expand Up @@ -1129,6 +1191,10 @@ def forward(self, *inputs, **kwargs):
return loss

def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
# Pass (PP) gas boundary flag to optimizer (required for zero)
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary(
)

# ZeRO stage 2 communicates during non gradient accumulation boundaries as well
if self.zero_optimization_partition_gradients():
self.optimizer.overlapping_partition_gradients_reduce_epilogue()
Expand Down Expand Up @@ -1264,6 +1330,9 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}):

self.optimizer.step()

if hasattr(self.optimizer, '_global_grad_norm'):
self._global_grad_norm = self.optimizer._global_grad_norm

# Quantize the updated parameter if there no overflow
if self.quantizer:
self.quantizer.quantize(
Expand All @@ -1286,12 +1355,19 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}):
overflow = False
if hasattr(self.optimizer, 'overflow'):
overflow = self.optimizer.overflow
self._step_applied = not overflow

if overflow:
self.skipped_steps += 1
else:
if self.lr_scheduler is not None:
self.lr_scheduler.step(**(lr_kwargs or {}))
try:
self.lr_scheduler.step(**(lr_kwargs or {}))
except TypeError:
# XXX Hack to work with Megatron 2.0 and DeepSpeed pipelines.
# We don't currently have a way to specify lr_kwargs from
# pipe_engine.train_batch()
self.lr_scheduler.step(increment=self.train_batch_size())

if report_progress and (self.global_steps + 1) % self.steps_per_print() == 0:
self._report_progress(self.global_steps + 1)
Expand All @@ -1311,6 +1387,8 @@ def step(self, lr_kwargs=None):
"init in order to use step"
report_progress = self.global_rank == 0 if self.global_rank else True

self._step_applied = False # assume False, will flip to True

# Update the model when we reach gradient accumulation boundaries
if self.is_gradient_accumulation_boundary():
self.gas_boundary_ctr += 1
Expand Down Expand Up @@ -1674,9 +1752,12 @@ def load_checkpoint(self,
load_lr_scheduler_states=load_lr_scheduler_states)

if self.zero_optimization() and load_path is not None:
self._load_zero_checkpoint(load_dir,
tag,
load_optimizer_states=load_optimizer_states)
success = self._load_zero_checkpoint(
load_dir,
tag,
load_optimizer_states=load_optimizer_states)
if not success:
self.optimizer._restore_from_fp16_weights()

return load_path, client_states

Expand Down Expand Up @@ -1746,7 +1827,7 @@ def _load_checkpoint(self,
def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True):
zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag)
if zero_sd_list is None:
return
return False

self.optimizer.load_state_dict(
state_dict_list=zero_sd_list,
Expand All @@ -1755,6 +1836,7 @@ def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True):
print(
f'loading {len(zero_sd_list)} zero partition checkpoints for rank {self.global_rank}'
)
return True

def _get_mp_rank_zero_checkpoint_names(self, load_dir, tag, mp_rank, dp_world_size):
zero_ckpt_names = []
Expand Down Expand Up @@ -1973,7 +2055,7 @@ def _copy_recovery_script(self, save_path):
script = "zero_to_fp32.py"
src = os.path.join(base_dir, "utils", script)
dst = os.path.join(save_path, script)
logger.info(f"creating recovery script {dst}")
#logger.info(f"creating recovery script {dst}")
copyfile(src, dst)
# make executable
os.chmod(dst, os.stat(dst).st_mode | stat.S_IEXEC)
Expand All @@ -1986,7 +2068,7 @@ def _save_zero_checkpoint(self, save_path, tag):
ds_version=version)
torch.save(zero_sd, zero_checkpoint_name)
self._copy_recovery_script(save_path)
logger.info('zero checkpoint saved {}'.format(zero_checkpoint_name))
#logger.info('zero checkpoint saved {}'.format(zero_checkpoint_name))

def _zero3_consolidated_fp16_state_dict(self):
"""
Expand Down
20 changes: 11 additions & 9 deletions deepspeed/runtime/fp16/fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import math
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

from deepspeed.runtime.utils import get_grad_norm, CheckOverflow, get_weight_norm
from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
from deepspeed.utils import logger, log_dist

Expand Down Expand Up @@ -45,6 +45,8 @@ def __init__(self,
self.fp16_groups_flat = []
self.fp32_groups_flat = []

self._global_grad_norm = 0.

# loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups):
# push this group to list before modify
Expand Down Expand Up @@ -161,8 +163,11 @@ def step_fused_adam(self, closure=None):
"scale: {}, reducing to {}".format(prev_scale,
self.cur_scale))
return self.overflow

self._global_grad_norm = get_global_norm(norm_list=norm_groups)

combined_scale = self.unscale_and_clip_grads(grads_groups_flat,
norm_groups,
self._global_grad_norm,
apply_scale=False)
# norm is in fact norm*cur_scale
self.optimizer.step(grads=[[g] for g in grads_groups_flat],
Expand Down Expand Up @@ -251,8 +256,10 @@ def step(self, closure=None):
all_groups_norm = get_grad_norm(self.fp32_groups_flat, mpu=self.mpu)
self.stop_timers([COMPUTE_NORM])

self._global_grad_norm = get_global_norm(norm_list=[all_groups_norm])

self.start_timers([UNSCALE_AND_CLIP])
self.unscale_and_clip_grads(grads_groups_flat, [all_groups_norm])
self.unscale_and_clip_grads(grads_groups_flat, self._global_grad_norm)
self.stop_timers([UNSCALE_AND_CLIP])

self.start_timers([BASIC_STEP])
Expand All @@ -277,12 +284,7 @@ def step(self, closure=None):

return self.overflow

def unscale_and_clip_grads(self, grad_groups_flat, norm_groups, apply_scale=True):
total_norm = 0.0
for norm in norm_groups:
total_norm += norm**2.0
total_norm = math.sqrt(total_norm)

def unscale_and_clip_grads(self, grad_groups_flat, total_norm, apply_scale=True):
# compute combined scale factor for this group
combined_scale = self.cur_scale
if self.clip_grad > 0.:
Expand Down
17 changes: 8 additions & 9 deletions deepspeed/runtime/fp16/unfused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
import math

from deepspeed.runtime.utils import get_grad_norm, CheckOverflow, get_weight_norm
from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
from deepspeed.utils import logger

Expand All @@ -32,6 +32,7 @@ def __init__(self,
fused_lamb_legacy=False):

self.fused_lamb_legacy = fused_lamb_legacy
self._global_grad_norm = 0.

if torch.distributed.get_rank() == 0:
logger.info(f'Fused Lamb Legacy : {self.fused_lamb_legacy} ')
Expand Down Expand Up @@ -148,7 +149,9 @@ def step_fused_lamb(self, closure=None):
self.cur_scale))
return self.overflow

combined_scale = self.unscale_and_clip_grads(norm_groups, apply_scale=False)
self._global_grad_norm = get_global_norm(norm_list=norm_groups)
combined_scale = self.unscale_and_clip_grads(self._global_grad_norm,
apply_scale=False)
self.optimizer.step(grads=grads_groups,
output_params=self.fp16_groups,
scale=combined_scale)
Expand Down Expand Up @@ -197,7 +200,8 @@ def step(self, closure=None):
else:
fp32_param.grad = fp16_param.grad.to(fp32_param.dtype)

self.unscale_and_clip_grads(norm_groups)
self._global_grad_norm = get_global_norm(norm_list=norm_groups)
self.unscale_and_clip_grads(self._global_grad_norm)

self.optimizer.step()

Expand All @@ -212,12 +216,7 @@ def step(self, closure=None):

return self.overflow

def unscale_and_clip_grads(self, norm_groups, apply_scale=True):
total_norm = 0.0
for norm in norm_groups:
total_norm += norm**2.0
total_norm = math.sqrt(total_norm)

def unscale_and_clip_grads(self, total_norm, apply_scale=True):
# compute combined scale factor for this group
combined_scale = self.cur_scale
if self.clip_grad > 0.:
Expand Down
Loading