Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DeepSpeedExamples
20 changes: 14 additions & 6 deletions deepspeed/runtime/activation_checkpointing/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from deepspeed.runtime.config import DeepSpeedConfig
from deepspeed.utils import logger
from deepspeed.runtime.utils import copy_to_device, move_to_device, see_memory_usage
from deepspeed.runtime.utils import copy_to_device, move_to_device, see_memory_usage, bwc_tensor_model_parallel_rank
from deepspeed.utils.timer import SynchronizedWallClockTimer as Timers

# DeepSpeed Checkpointing Enabled or Disabled
Expand Down Expand Up @@ -213,9 +213,12 @@ def model_parallel_cuda_manual_seed(seed):
model parallel regions.
"""
global mpu

tp_rank = bwc_tensor_model_parallel_rank(mpu)

# 2718 is just for fun and any POSITIVE value will work.
offset = seed + 2718
model_parallel_seed = offset + mpu.get_model_parallel_rank()
model_parallel_seed = offset + tp_rank
# Data parallel gets the original sedd.
data_parallel_seed = seed

Expand All @@ -225,7 +228,7 @@ def model_parallel_cuda_manual_seed(seed):
'model parallel rank {}, and data parallel rank {} with '
'model parallel seed: {} and data parallel seed: {}'.format(
torch.distributed.get_rank(),
mpu.get_model_parallel_rank(),
tp_rank,
mpu.get_data_parallel_rank(),
model_parallel_seed,
data_parallel_seed),
Expand Down Expand Up @@ -515,9 +518,14 @@ def save_args_for_backward(*all_args):
global data_offsets, size_offsets
if mp_rank is None:
if mpu is not None:
mp_rank = mpu.get_model_parallel_rank()
mp_size = mpu.get_model_parallel_world_size()
mp_group = mpu.get_model_parallel_group()
if hasattr(mpu, 'get_tensor_model_parallel_rank'):
mp_rank = mpu.get_tensor_model_parallel_rank()
mp_size = mpu.get_tensor_model_parallel_world_size()
mp_group = mpu.get_tensor_model_parallel_group()
else:
mp_rank = mpu.get_model_parallel_rank()
mp_size = mpu.get_model_parallel_world_size()
mp_group = mpu.get_model_parallel_group()
else:
mp_rank = 0
mp_size = 1
Expand Down
89 changes: 63 additions & 26 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from deepspeed.runtime.csr_tensor import CSRTensor
import deepspeed.runtime.lr_schedules as lr_schedules
import deepspeed.utils.groups as groups
from deepspeed.runtime.utils import get_grad_norm
from deepspeed.utils import logger, log_dist, init_distributed
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
from deepspeed.utils.debug import debug_extract_module_and_param_names
Expand Down Expand Up @@ -140,6 +141,8 @@ def __init__(self,
self.dist_backend = "nccl"
self.has_moe_layers = False
self.num_experts = None
self._step_applied = False
self._global_grad_norm = None

# for debug purposes - can then debug print: debug_get_module_name(module)
debug_extract_module_and_param_names(model)
Expand Down Expand Up @@ -259,6 +262,40 @@ def get_batch_info(self):
"""
return self.train_batch_size, self.train_micro_batch_size_per_gpu, self.gradient_accumulation_steps

def set_train_batch_size(self, train_batch_size):
"""Adjust the global batch size by increasing or decreasing the number of
micro-batches (i.e., gradient accumulation steps). The size of each micro-batch
(i.e., ``train_micro_batch_size_per_gpu``) is not changed.
Args:
train_batch_size (int): The new global batch size for training.
Raises:
ValueError: if ``train_batch_size`` is not divisible by the
configured micro-batch size and data parallelism.
"""
if train_batch_size % (self.train_micro_batch_size_per_gpu() *
self.dp_world_size) != 0:
#print(f'{train_batch_size=} {self.train_micro_batch_size_per_gpu()=} {self.dp_world_size=}')
raise ValueError(
f'Train batch size must be divisible by micro-batch data parallelism')
new_gas = train_batch_size // (self.train_micro_batch_size_per_gpu() *
self.dp_world_size)
# overwrite config
self._config.train_batch_size = train_batch_size
self._config.gradient_accumulation_steps = new_gas

def get_global_grad_norm(self) -> float:
"""Return the 2-norm of all gradients. If there is model parallelism,
the norm will be global.

The computed norm will be cached and reused until the next step() pass.
.. note::
In the presence of model parallelism, this is a collective call
and acts as a barrier among ``mpu.get_model_parallel_group()``.
Returns:
float: norm
"""
return self._global_grad_norm

def checkpoint_tag_validation_enabled(self):
return self._config.checkpoint_tag_validation_enabled

Expand Down Expand Up @@ -1146,6 +1183,18 @@ def is_iterable_style_dataset(obj):
def dataloader_drop_last(self):
return self._config.dataloader_drop_last

def was_step_applied(self) -> bool:
"""Returns True if the latest ``step()`` produced in parameter updates.

Note that a ``False`` return is not an error condition. Steps are frequently
no-ops, such as between gradient accumulation boundaries or when overflows
occur.

Returns:
bool: Whether the latest ``step()`` modified model parameters.
"""
return self._step_applied

def deepspeed_io(self,
dataset,
batch_size=None,
Expand Down Expand Up @@ -1432,6 +1481,9 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}):
mpu=self.mpu)
self.optimizer.step()

if hasattr(self.optimizer, '_global_grad_norm'):
self._global_grad_norm = self.optimizer._global_grad_norm

# Quantize the updated parameter if there is no overflow
if self.quantizer:
self.quantizer.quantize(
Expand All @@ -1454,12 +1506,19 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}):
overflow = False
if hasattr(self.optimizer, 'overflow'):
overflow = self.optimizer.overflow
self._step_applied = not overflow

if overflow:
self.skipped_steps += 1
else:
if self.lr_scheduler is not None:
self.lr_scheduler.step(**(lr_kwargs or {}))
try:
self.lr_scheduler.step(**(lr_kwargs or {}))
except TypeError:
# XXX Hack to work with Megatron 2.0 and DeepSpeed pipelines.
# We don't currently have a way to specify lr_kwargs from
# pipe_engine.train_batch()
self.lr_scheduler.step(increment=self.train_batch_size())

if report_progress and (self.global_steps + 1) % self.steps_per_print() == 0:
self._report_progress(self.global_steps + 1)
Expand All @@ -1479,6 +1538,8 @@ def step(self, lr_kwargs=None):
"init in order to use step"
report_progress = self.global_rank == 0 if self.global_rank else True

self._step_applied = False # assume False, will flip to True

# Update the model when we reach gradient accumulation boundaries
if self.is_gradient_accumulation_boundary():
self.gas_boundary_ctr += 1
Expand Down Expand Up @@ -2413,7 +2474,7 @@ def _copy_recovery_script(self, save_path):
script = "zero_to_fp32.py"
src = os.path.join(base_dir, "utils", script)
dst = os.path.join(save_path, script)
logger.info(f"creating recovery script {dst}")
#logger.info(f"creating recovery script {dst}")
copyfile(src, dst)
# make executable
os.chmod(dst, os.stat(dst).st_mode | stat.S_IEXEC)
Expand Down Expand Up @@ -2530,27 +2591,3 @@ def save_fp16_model(self, save_dir, save_filename="pytorch_model.bin"):
os.makedirs(save_dir, exist_ok=True)
logger.info(f"Saving model weights to {path}")
torch.save(state_dict, path)

def set_train_batch_size(self, train_batch_size):
"""Adjust the global batch size by increasing or decreasing the size of
each micro-batch (i.e., ``train_micro_batch_size_per_gpu``). The number of
micro-batches (i.e., gradient accumulation steps) is not changed.
Args:
train_batch_size (int): The new global batch size for training.
Raises:
ValueError: if ``train_batch_size`` is not divisible by the
configured gradient_accumulation_steps and data parallelism.
"""

if train_batch_size % (self.gradient_accumulation_steps() *
self.dp_world_size) != 0:
raise ValueError(
f'Train batch size must be divisible by gradient_accumulation_steps * data parallelism'
)

new_micro_bsz = train_batch_size // (self.gradient_accumulation_steps() *
self.dp_world_size)

# overwrite config
self._config.train_batch_size = train_batch_size
self._config.train_micro_batch_size_per_gpu = new_micro_bsz
20 changes: 11 additions & 9 deletions deepspeed/runtime/fp16/fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import math
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

from deepspeed.runtime.utils import get_grad_norm, CheckOverflow, get_weight_norm
from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
from deepspeed.utils import groups, logger, log_dist
import torch.distributed as dist
Expand Down Expand Up @@ -47,6 +47,8 @@ def __init__(self,
self.fp16_groups_flat = []
self.fp32_groups_flat = []

self._global_grad_norm = 0.

# loop to deal with groups
for i, param_group in enumerate(self.optimizer.param_groups):
# push this group to list before modify
Expand Down Expand Up @@ -163,8 +165,11 @@ def step_fused_adam(self, closure=None):
"scale: {}, reducing to {}".format(prev_scale,
self.cur_scale))
return self.overflow

self._global_grad_norm = get_global_norm(norm_list=norm_groups)

combined_scale = self.unscale_and_clip_grads(grads_groups_flat,
norm_groups,
self._global_grad_norm,
apply_scale=False)
# norm is in fact norm*cur_scale
self.optimizer.step(grads=[[g] for g in grads_groups_flat],
Expand Down Expand Up @@ -268,8 +273,10 @@ def step(self, closure=None):

self.stop_timers([COMPUTE_NORM])

self._global_grad_norm = get_global_norm(norm_list=[all_groups_norm])

self.start_timers([UNSCALE_AND_CLIP])
self.unscale_and_clip_grads(grads_groups_flat, [all_groups_norm])
self.unscale_and_clip_grads(grads_groups_flat, self._global_grad_norm)
self.stop_timers([UNSCALE_AND_CLIP])

self.start_timers([BASIC_STEP])
Expand All @@ -294,12 +301,7 @@ def step(self, closure=None):

return self.overflow

def unscale_and_clip_grads(self, grad_groups_flat, norm_groups, apply_scale=True):
total_norm = 0.0
for norm in norm_groups:
total_norm += norm**2.0
total_norm = math.sqrt(total_norm)

def unscale_and_clip_grads(self, grad_groups_flat, total_norm, apply_scale=True):
# compute combined scale factor for this group
combined_scale = self.cur_scale
if self.clip_grad > 0.:
Expand Down
17 changes: 8 additions & 9 deletions deepspeed/runtime/fp16/unfused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
import math

from deepspeed.runtime.utils import get_grad_norm, CheckOverflow, get_weight_norm
from deepspeed.runtime.utils import get_global_norm, get_grad_norm, CheckOverflow, get_weight_norm
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
from deepspeed.utils import logger

Expand All @@ -33,6 +33,7 @@ def __init__(self,
fused_lamb_legacy=False):

self.fused_lamb_legacy = fused_lamb_legacy
self._global_grad_norm = 0.

if torch.distributed.get_rank() == 0:
logger.info(f'Fused Lamb Legacy : {self.fused_lamb_legacy} ')
Expand Down Expand Up @@ -163,7 +164,9 @@ def step_fused_lamb(self, closure=None):
self.cur_scale))
return self.overflow

combined_scale = self.unscale_and_clip_grads(norm_groups, apply_scale=False)
self._global_grad_norm = get_global_norm(norm_list=norm_groups)
combined_scale = self.unscale_and_clip_grads(self._global_grad_norm,
apply_scale=False)
self.optimizer.step(grads=grads_groups,
output_params=self.fp16_groups,
scale=combined_scale)
Expand Down Expand Up @@ -216,7 +219,8 @@ def step(self, closure=None):
else:
fp32_param.grad = fp16_param.grad.to(fp32_param.dtype)

self.unscale_and_clip_grads(norm_groups)
self._global_grad_norm = get_global_norm(norm_list=norm_groups)
self.unscale_and_clip_grads(self._global_grad_norm)

self.optimizer.step()

Expand All @@ -231,12 +235,7 @@ def step(self, closure=None):

return self.overflow

def unscale_and_clip_grads(self, norm_groups, apply_scale=True):
total_norm = 0.0
for norm in norm_groups:
total_norm += norm**2.0
total_norm = math.sqrt(total_norm)

def unscale_and_clip_grads(self, total_norm, apply_scale=True):
# compute combined scale factor for this group
combined_scale = self.cur_scale
if self.clip_grad > 0.:
Expand Down
Loading