Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import deepspeed.runtime.lr_schedules as lr_schedules
from deepspeed.utils import logger, log_dist, init_distributed
from deepspeed.utils.timer import ThroughputTimer, SynchronizedWallClockTimer
from deepspeed.utils.debug import debug_extract_module_and_param_names
from deepspeed.runtime.progressive_layer_drop import ProgressiveLayerDrop
from deepspeed.runtime.eigenvalue import Eigenvalue

Expand Down Expand Up @@ -122,6 +123,9 @@ def __init__(self,
self.gas_boundary_ctr = 0
self.dist_backend = "nccl"

# for debug purposes - can then debug print: debug_get_module_name(module)
debug_extract_module_and_param_names(model)

# Set config using config_params for backwards compat
if self.config is None and config_params is not None:
self.config = config_params
Expand Down
21 changes: 14 additions & 7 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from ..utils import see_memory_usage
from deepspeed.utils import log_dist, init_distributed
from deepspeed.utils.debug import debug_param2name_id_shape, debug_module2name, debug_param2name, debug_param2name_id_shape_status, printflock, log_rank_file

from ..swap_tensor.partitioned_param_swapper import AsyncPartitionedParameterSwapper, PartitionedParamStatus
from ..config import DeepSpeedConfig
Expand All @@ -27,8 +28,14 @@


def print_rank_0(message, debug=False, force=False):
if torch.distributed.get_rank() == 0 and (debug or force):
rank = torch.distributed.get_rank()
if rank == 0 and (debug or force):
print(message)
# other variations
# - print for all ranks w/o interleaving
# printflock(f"[{rank}] {message}")
# - print to log file per rank
# log_rank_file(rank, message)


def is_zero_param(parameter):
Expand Down Expand Up @@ -481,12 +488,12 @@ def _post_init_method(self, module):
force=False)

global param_count
for name, param in module.named_parameters(recurse=False):
for param in module.parameters(recurse=False):
param_count += param.numel()
if not is_zero_param(param):
self._convert_to_deepspeed_param(param)
print_rank_0(
f"Partitioning param with ds id {param.ds_id} and shape {param.data.shape}"
f"Partitioning param {debug_param2name_id_shape(param)} module={debug_module2name(module)}"
)
param.partition()
see_memory_usage(
Expand Down Expand Up @@ -797,23 +804,23 @@ def _allgather_param(self, param, async_op=False, hierarchy=0):
assert tensor_size == aligned_param_size, f'param id {param.ds_id} aligned size {aligned_param_size} does not match tensor size {tensor_size}'

print_rank_0(
f"{'--'* hierarchy}---- Before allocating Allgather param with id {param.ds_id} and status {param.ds_status} Partition Size {partition_size} and data shape {param.ds_shape}"
f"{'--'* hierarchy}---- Before allocating allgather param {debug_param2name_id_shape_status(param)} partition size={partition_size}"
)

see_memory_usage(
f'Before allocate allgather param {param.ds_id} {param.ds_status} {aligned_param_size} {partition_size} {param.ds_shape}',
f'Before allocate allgather param {debug_param2name_id_shape_status(param)} partition_size={partition_size} ',
force=False)
flat_tensor = torch.zeros(aligned_param_size,
dtype=param.dtype,
device=param.device).view(-1)
see_memory_usage(
f'After allocate allgather param {param.ds_id} {param.ds_status} {aligned_param_size} {partition_size} {param.ds_shape}',
f'After allocate allgather param {debug_param2name_id_shape_status(param)} {aligned_param_size} {partition_size} ',
force=False)

torch.cuda.synchronize()

print_rank_0(
f"{'--'* hierarchy}----Allgather param with id {param.ds_id} and status {param.ds_status} Partition Size {partition_size} and data shape {param.ds_shape}"
f"{'--'* hierarchy}----allgather param with {debug_param2name_id_shape_status(param)} partition size={partition_size}"
)
# if not flat_tensor.numel() > 100000:
# replicated_tensor = flat_tensor.narrow(0,
Expand Down
35 changes: 22 additions & 13 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,18 @@
pg_correctness_test = False

FWD_MODULE_STACK = list()
from deepspeed.utils.debug import debug_module2name_id, debug_param2name_id_numel, debug_param2name_id_shape_device, debug_module2name_class, printflock, log_rank_file


def print_rank_0(message, debug=False, force=False):
if torch.distributed.get_rank() == 0 and (debug or force):
logger.info(message)
def print_rank_0(message, debug=False, force=True):
rank = torch.distributed.get_rank()
if rank == 0 and (debug or force):
print(message)
# other variations
# - print for all ranks w/o interleaving
# printflock(f"[{rank}] {message}")
# - print to log file per rank
# log_rank_file(rank, message)


def input(msg):
Expand Down Expand Up @@ -211,7 +218,7 @@ def get_params_to_prefetch(self, sub_module, numel=2000000):
# tracing failed. The sub_module passed at the step_id must match with the sub_module during tracing
if sub_module.id != self.sub_module_trace[self.step_id]:
print_rank_0(
f"Tracing failed. Prefetching is disabled at sub-module: {sub_module.id}"
f"Tracing failed. Prefetching is disabled at sub-module: {debug_module2name_id(sub_module)}"
)
return []

Expand Down Expand Up @@ -390,11 +397,13 @@ def reset_step(self):
def fetch_sub_module(self, sub_module):
partitioned_params = []
params_in_flight = False
#print_rank_0(f"{'--' * self.hierarchy}Fetching params in module {sub_module.__class__.__name__}")
# print_rank_0(f"{'--' * self.hierarchy}Fetching params in module {sub_module.__class__.__name__}")
params_to_fetch = [
param for _,
param in sub_module.named_parameters(recurse=False)
]
# print([n for n,p in sub_module.named_parameters(recurse=False)])

if hasattr(sub_module, 'ds_external_parameters'):
print_rank_0(
f"{'--' * self.hierarchy}--Fetching external parameters {sub_module.ds_external_parameters()}"
Expand All @@ -407,7 +416,7 @@ def fetch_sub_module(self, sub_module):
for param in params_to_fetch:
param.ds_active_sub_modules += 1
print_rank_0(
f"{'--' * self.hierarchy}--Fetching parameters {param.ds_id} with active sub modules {param.ds_active_sub_modules}"
f"{'--' * self.hierarchy}--Fetching parameters {param.ds_id} {param.ds_shape} with active sub modules {param.ds_active_sub_modules}"
)

if param.ds_status == ZeroParamStatus.AVAILABLE:
Expand Down Expand Up @@ -441,14 +450,14 @@ def fetch_sub_module(self, sub_module):
for _, param in sub_module.named_parameters(recurse=False):
param.ds_status = ZeroParamStatus.AVAILABLE
print_rank_0(
f"Param id {param.ds_id}, Shape {param.shape}, device {param.device} norm {param.norm()}",
f"Param {debug_param2name_id_shape_device(param)} norm={param.norm()}",
force=False)
#print_rank_0(f"After fetching (id, shape, device): {[(param.ds_id, param.shape, param.device) for param in sub_module.named_parameters(recurse=False)]}")

def release_sub_module(self, sub_module):
self.hierarchy -= 1
print_rank_0(
f"{'--' * self.hierarchy}Releasing params in module {sub_module.__class__.__name__}"
f"{'--' * self.hierarchy}Releasing params in module {debug_module2name_class(sub_module)}"
)
params_to_release = [
param for _,
Expand All @@ -468,31 +477,31 @@ def release_sub_module(self, sub_module):
if not param.ds_active_sub_modules and not self._keep_for_later(
sub_module) and not param.ds_persist:
print_rank_0(
f"{'--' * self.hierarchy}--Releasing parameters {param.ds_id} with numel {param.numel()} active sub modules {param.ds_active_sub_modules} and keep for later {self._keep_for_later(sub_module)}",
f"{'--' * self.hierarchy}--Releasing parameter {debug_param2name_id_numel(param)} active sub modules {param.ds_active_sub_modules} and keep for later {self._keep_for_later(sub_module)}",
force=False)

# Keeping track of number of elements that are consumed by available parameters
self._decrement_available_parameter_numel(param.ds_numel)
see_memory_usage(
f"Before releasing param {param.ds_id} with numel {param.numel()}",
f"Before releasing param {debug_param2name_id_numel(param)}",
force=False)
param.partition(hierarchy=self.hierarchy)
see_memory_usage(
f"After releasing param {param.ds_id} has numel {param.numel()} ",
f"After releasing param {debug_param2name_id_numel(param)}",
force=False)

param.ds_status = ZeroParamStatus.NOT_AVAILABLE
else:

print_rank_0(
f"{'--' * self.hierarchy}--Did not release parameters {param.ds_id} with numel {param.numel()} with active sub modules {param.ds_active_sub_modules}, keep for later {self._keep_for_later(sub_module)} and persistence {param.ds_persist}",
f"{'--' * self.hierarchy}--Did not release param {debug_param2name_id_numel(param)} with active sub modules {param.ds_active_sub_modules}, keep for later={self._keep_for_later(sub_module)} and persistence={param.ds_persist}",
force=False)

def release_and_reset_parameter(self, param):
param.ds_active_sub_modules = 0
if param.ds_status == ZeroParamStatus.AVAILABLE:
print_rank_0(
f"Releasing unpartitioned {param.ds_id} active sub-modules {param.ds_active_sub_modules} size {param.ds_numel} and persisitence {param.ds_persist}"
f"Releasing unpartitioned param {debug_param2name_id_numel(param)} active sub-modules {param.ds_active_sub_modules} and persisitence {param.ds_persist}"
)
self._decrement_available_parameter_numel(param.ds_numel)
param.partition()
Expand Down
122 changes: 122 additions & 0 deletions deepspeed/utils/debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
""" debug utils """

import fcntl

# for debug purposes map module and param objects to their fully qualified names
module_names = {}
param_names = {}


def debug_extract_module_and_param_names(model):
# extract the fully qualified names as soon as the model is acquired
global module_names
global param_names
# XXX: can probably make a map of param2module and vice-versa
module_names = {module: name for name, module in model.named_modules()}
param_names = {param: name for name, param in model.named_parameters()}


def debug_module2name(module):
if module in module_names:
return module_names[module]
else:
return "unknown"


def debug_module2name_id(module):
return f"name={debug_module2name(module)} id={module.id}"


def debug_module2name_class(module):
return f"name={debug_module2name(module)} {module.__class__.__name__}"


def debug_param2name(param):
if param in param_names:
return param_names[param]
else:
return "unknown"


def debug_param2name_id(param):
return f"name={debug_param2name(param)} id={param.ds_id}"


def debug_param2name_id_shape(param):
return f"name={debug_param2name(param)} id={param.ds_id} shape={param.data.shape}"


def debug_param2name_id_shape_device(param):
return f"name={debug_param2name(param)} id={param.ds_id} shape={param.data.shape} device={param.device}"


def debug_param2name_id_numel(param):
return f"name={debug_param2name(param)} id={param.ds_id} numel={param.numel()}"


def debug_param2name_id_shape_status(param):
return f"name={debug_param2name(param)} id={param.ds_id} shape={param.data.shape} status={param.ds_status}"


def printflock(*msgs):
"""

For printing messages for all concurrent gpus w/o getting interleaved text.

This is useful when debugging issues where multi-gpus don't sync.

1. Enable the force debug in say partitioning and zero3 files
2. Override the usual versions with ::

def print_rank_0(message, debug=False, force=True):
rank = torch.distributed.get_rank()
printflock(f"[{rank}] {message}")
3. run the program and you get both logs non-interleaved

But this makes it very difficult to make sense of the output, so the ``log_rank_file`` helper
function might be more useful, as it's easier to send each log stream into a separate file and
then compare those.

"""

with open(__file__, "r") as fh:
fcntl.flock(fh, fcntl.LOCK_EX)
try:
print(*msgs)
finally:
fcntl.flock(fh, fcntl.LOCK_UN)


fh = None


def log_rank_file(rank, *msgs):
"""
Print to a log file of the given rank

This is useful for debugging hanging in sync processes. Here is a possible workflow:

1. Enable the force debug in say partitioning and zero3 files
2. Override the usual versions of print_rank_0 in those files with ::

def print_rank_0(message, debug=False, force=True):
rank = torch.distributed.get_rank()
log_rank_file(rank, message)

3. run the program
4. fix up the expected differences, e.g. different cuda numbers ::

perl -pi -e 's|cuda:1|cuda:0|' log_rank_*

5. now diff and see where names and ids diverge - you will find where the gpus don't do the same
work (e.g. when some layers get conditionally skipped on one gpu but not all)

diff -u log_rank_0.txt log_rank_1.txt | less

"""
global fh
if fh is None:
fh = open(f"log_rank_{rank}.txt", "w")
for m in msgs:
fh.write(f"{m}\n")
fh.flush()