Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,7 @@ def _configure_zero_optimizer(self, optimizer):
zero_stage = self.zero_optimization_stage()
log_dist('Creating fp16 ZeRO stage {} optimizer'.format(zero_stage), ranks=[0])
assert not self.allreduce_always_fp32(), "ZeRO does not support 'fp32_allreduce': true"
timers = self.timers if self.wall_clock_breakdown() else None

if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
assert self.zero_reduce_scatter(), 'Stage 1 only supports reduce scatter mode'
Expand All @@ -740,7 +741,7 @@ def _configure_zero_optimizer(self, optimizer):
elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS:
optimizer = FP16_DeepSpeedZeroOptimizer(
optimizer,
timers=self.timers,
timers=timers,
static_loss_scale=self.loss_scale(),
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=self.dynamic_loss_scale_args(),
Expand All @@ -762,7 +763,7 @@ def _configure_zero_optimizer(self, optimizer):
optimizer = FP16_DeepSpeedZeroOptimizer_Stage3(
self.module,
optimizer,
timers=self.timers,
timers=timers,
static_loss_scale=self.loss_scale(),
dynamic_loss_scale=self.dynamic_loss_scale(),
dynamic_loss_args=self.dynamic_loss_scale_args(),
Expand Down
52 changes: 34 additions & 18 deletions deepspeed/runtime/zero/stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1326,6 +1326,26 @@ def reset_cpu_buffers(self):
self.norm_for_param_grads = {}
self.local_overflow = False

def log_timers(self, timer_names):
if self.timers is None:
return

self.timers.log(names=list(timer_names))

def start_timers(self, timer_names):
if self.timers is None:
return

for name in timer_names:
self.timers(name).start()

def stop_timers(self, timer_names):
if self.timers is None:
return

for name in timer_names:
self.timers(name).stop()

def step(self, closure=None):
"""
Not supporting closure.
Expand All @@ -1340,7 +1360,10 @@ def step(self, closure=None):
# First compute norm for all group so we know if there is overflow
self.check_overflow()

timers = self.timers
OPTIMIZER_ALLGATHER = 'optimizer_allgather'
OPTIMIZER_GRADIENTS = 'optimizer_gradients'
OPTIMIZER_STEP = 'optimizer_step'
timer_names = [OPTIMIZER_ALLGATHER, OPTIMIZER_GRADIENTS, OPTIMIZER_STEP]

prev_scale = self.loss_scale
self._update_scale(self.overflow)
Expand All @@ -1359,15 +1382,11 @@ def step(self, closure=None):
"reducing to {}".format(dist.get_rank(),
prev_scale,
self.loss_scale))
timers('optimizer_gradients').start()
timers('optimizer_gradients').stop()
timers('optimizer_step').start()
timers('optimizer_step').stop()
timers('optimizer_allgather').start()
timers('optimizer_allgather').stop()
self.start_timers(timer_names)
self.stop_timers(timer_names)
return

timers('optimizer_gradients').start()
self.start_timers([OPTIMIZER_GRADIENTS])
norm_groups = []
single_partition_grad_groups = []
skip = False
Expand Down Expand Up @@ -1409,10 +1428,9 @@ def step(self, closure=None):
single_partition_grad_groups.append(single_grad_partition)

self.unscale_and_clip_grads(single_partition_grad_groups, norm_groups)
timers('optimizer_gradients').stop()
self.stop_timers([OPTIMIZER_GRADIENTS])

#torch.set_num_threads(12)
timers('optimizer_step').start()
self.start_timers([OPTIMIZER_STEP])
if self.deepspeed_adam_offload:
from deepspeed.ops.adam import DeepSpeedCPUAdam
if type(self.optimizer) == DeepSpeedCPUAdam:
Expand All @@ -1436,12 +1454,12 @@ def step(self, closure=None):
for fp16_partitions, fp32_partition in zip(self.parallel_partitioned_fp16_groups, self.single_partition_of_fp32_groups):
fp16_partitions[partition_id].data.copy_(fp32_partition.data)

timers('optimizer_step').stop()
self.stop_timers([OPTIMIZER_STEP])

if self.cpu_offload:
self.reset_cpu_buffers()

timers('optimizer_allgather').start()
self.start_timers([OPTIMIZER_ALLGATHER])
#gather the updated weights from everyone
for group_id, partitioned_params in enumerate(self.parallel_partitioned_fp16_groups):

Expand Down Expand Up @@ -1474,7 +1492,7 @@ def step(self, closure=None):
dist.all_gather(shard_list,
shard_list[partition_id],
group=self.dp_process_group)
timers('optimizer_allgather').stop()
self.stop_timers([OPTIMIZER_ALLGATHER])

# TODO: we probably don't need this? just to be safe
for i in range(len(norm_groups)):
Expand All @@ -1483,11 +1501,9 @@ def step(self, closure=None):
for p, q in zip(self.fp16_groups[i], updated_params):
p.data = q.data

timers.log(
names=['optimizer_gradients',
'optimizer_step',
'optimizer_allgather'])
self.log_timers(timer_names)
see_memory_usage('After zero_optimizer step')

return

def unscale_and_clip_grads(self, grad_groups_flat, norm_groups):
Expand Down
42 changes: 25 additions & 17 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def __init__(self,
gradient_accumulation_steps=1,
elastic_checkpoint=False):

see_memory_usage("Stage 3 intialize begining", force=True)
see_memory_usage("Stage 3 intialize beginning", force=True)

if dist.get_rank() == 0:
logger.info(f"Reduce bucket size {reduce_bucket_size}")
Expand Down Expand Up @@ -628,15 +628,15 @@ def __init__(self,
self.device = torch.cuda.current_device() if not self.cpu_offload else 'cpu'
############################################################################

see_memory_usage("Before Partitioned Parameter Coordinator", force=True)
see_memory_usage("Before Partitioned Parameter Coordinator", force=False)

fetch_stream = torch.cuda.Stream() if self.overlap_comm else None
self.param_coordinator = PartitionedParameterCoordinator(
comm_stream=fetch_stream,
max_reuse_distance_in_numel=int(max_reuse_distance),
max_available_parameters_in_numel=int(max_live_parameters))

see_memory_usage("After Partitioned Parameter Coordinator", force=True)
see_memory_usage("After Partitioned Parameter Coordinator", force=False)

#self.param_coordinator = PartitionedParameterCoordinator(comm_stream=torch.cuda.Stream())
#-------------Stage 3 Setup-------------------#
Expand Down Expand Up @@ -711,20 +711,20 @@ def __init__(self,

self.sub_group_to_group_id = {}

see_memory_usage("Before creating fp16 partitions", force=True)
see_memory_usage("Before creating fp16 partitions", force=False)
#self._create_fp16_partitions()
self._create_fp16_partitions_with_defragmentation()
num_fp16_subgroups = len(self.fp16_partitioned_groups_flat)
see_memory_usage(f"After creating fp16 partitions: {num_fp16_subgroups}",
force=True)
force=False)

see_memory_usage("Before creating fp32 partitions", force=True)
see_memory_usage("Before creating fp32 partitions", force=False)
self._create_fp32_partitions()
see_memory_usage("After creating fp32 partitions", force=True)
see_memory_usage("After creating fp32 partitions", force=False)

see_memory_usage("Before initializing optimizer states", force=True)
see_memory_usage("Before initializing optimizer states", force=False)
self.initialize_optimizer_states()
see_memory_usage("After initializing optimizer states", force=True)
see_memory_usage("After initializing optimizer states", force=False)

if dist.get_rank() == 0:
logger.info(f"optimizer state initialized")
Expand Down Expand Up @@ -767,11 +767,11 @@ def __init__(self,
#Largest partitioned param
largest_partitioned_param_numel = self._get_largest_partitioned_numel()

see_memory_usage(f"Before Set Grad positions", force=True)
see_memory_usage(f"Before Set Grad positions", force=False)

self.grad_position = {}
self.set_grad_positions()
see_memory_usage(f"Before CPU Offload initialization", force=True)
see_memory_usage(f"Before CPU Offload initialization", force=False)

self.grads_in_partition = None

Expand All @@ -785,7 +785,7 @@ def __init__(self,
self.temp_grad_gpu_buffer = torch.zeros(
largest_partitioned_param_numel,
device=torch.cuda.current_device()).half()
see_memory_usage(f"After CPU Offload initialization", force=True)
see_memory_usage(f"After CPU Offload initialization", force=False)

# stores if a partition has been reduced in this step
self.is_partition_reduced = {}
Expand Down Expand Up @@ -1614,7 +1614,7 @@ def partition_previous_reduced_grads(self):

see_memory_usage(
f"group {i} before creating {total_size} reduced gradients into partition",
force=True)
force=False)
if self.cpu_offload_use_pin_memory:
self.grads_in_partition.append(
torch.zeros(int(total_size),
Expand All @@ -1627,7 +1627,7 @@ def partition_previous_reduced_grads(self):
device=self.device))
see_memory_usage(
f"group {i} after creating {total_size} reduced gradients into partition",
force=True)
force=False)

for param in self.previous_reduced_grads:

Expand Down Expand Up @@ -2044,13 +2044,22 @@ def reset_cpu_buffers(self):
self.local_overflow = False

def log_timers(self, timer_names):
if self.timers is None:
return

self.timers.log(names=list(timer_names))

def start_timers(self, timer_names):
if self.timers is None:
return

for name in timer_names:
self.timers(name).start()

def stop_timers(self, timer_names):
if self.timers is None:
return

for name in timer_names:
self.timers(name).stop()

Expand Down Expand Up @@ -2210,7 +2219,7 @@ def old_step(self, closure=None):

see_memory_usage('After zero_optimizer step', force=False)
print_rank_0(f"------------------Finishing Step-----------------------",
force=True)
force=False)
return

def _pre_step(self):
Expand Down Expand Up @@ -2327,7 +2336,7 @@ def _post_step(self, timer_names=set()):

self.log_timers(timer_names)

see_memory_usage('After zero_optimizer step', force=True)
see_memory_usage('After zero_optimizer step', force=False)
print_rank_0(f"------------------Finishing Step-----------------------")

def step(self, closure=None):
Expand All @@ -2342,7 +2351,6 @@ def step(self, closure=None):

norm_groups = self._get_norm_groups()

timers = self.timers
timer_names = set()

timer_names.add('optimizer_step')
Expand Down