Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions deepspeed/runtime/zero/stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2000,3 +2000,106 @@ def _handle_overflow(cpu_sum, x, i):
logger.info(
f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}"
)


def estimate_zero2_model_states_mem_needs(total_params,
num_gpus_per_node=1,
num_nodes=1,
cpu_offload=True,
additional_buffer_factor=1.5):

total_gpus = num_nodes * num_gpus_per_node

if cpu_offload:
gpu_mem = 2 * total_params
cpu_mem = total_params * max(4 * total_gpus, 16) * additional_buffer_factor
else:
gpu_mem = 4 * total_params + int(16 * total_params / total_gpus)
cpu_mem = total_params * 4 * num_gpus_per_node * additional_buffer_factor

return int(cpu_mem), int(gpu_mem)


def model_to_params(model):
# shared params calculated only once
total_params = sum(
dict((p.data_ptr(),
p.numel()) for p in model.parameters()).values())
return total_params


def estimate_zero2_model_states_mem_needs_all_live(model,
num_gpus_per_node=1,
num_nodes=1,
additional_buffer_factor=1.5):
"""
Print out estimates on memory usage requirements for ZeRO 2 params, optim states and gradients
for a given ``model`` and hardware setup.

If you have an actual model object, use this function and everything will be derived
automatically.

If it's a hypothetical model, use ``estimate_zero2_model_states_mem_needs_all_cold`` where you have to pass
the ``total_params`` explicitly.

Args:
- ``model``: ``nn.Module`` object
- ``num_gpus_per_node``: how many gpus per node (defaults to 1)
- ``num_nodes``: how many nodes (defaults to 1),
- ``additional_buffer_factor``: estimation factor (defaults to 1.5):

"""

total_params = model_to_params(model)

estimate_zero2_model_states_mem_needs_all_cold(
total_params=total_params,
num_gpus_per_node=num_gpus_per_node,
num_nodes=num_nodes,
additional_buffer_factor=additional_buffer_factor)


def estimate_zero2_model_states_mem_needs_all_cold(total_params,
num_gpus_per_node=1,
num_nodes=1,
additional_buffer_factor=1.5):
"""
Print out estimates on memory usage requirements for ZeRO 2 params, optim states and gradients
for a given ``model`` and hardware setup.

If it's a hypothetical model, use this function where you have to pass
the ``total_params`` and ``largest_layer_params`` explicitly.

If you have an actual model object, use ``estimate_zero2_model_states_mem_needs_all_live`` and everything
will be derived automatically.

Args:
- ``total_params``: total model params
- ``num_gpus_per_node``: how many gpus per node (defaults to 1)
- ``num_nodes``: how many nodes (defaults to 1),
- ``additional_buffer_factor``: estimation factor (defaults to 1.5):

"""
def format_options(cpu_offload):
enabled = []
enabled.append(f"cpu_offload={1 if cpu_offload else 0}")
return ", ".join(enabled)

nodes_str = "nodes" if num_nodes > 1 else "node"
gpus_str = "GPUs" if num_gpus_per_node > 1 else "GPU"
print(
"Estimated memory needed for params, optim states and gradients for a:\n"
f"HW: Setup with {num_nodes} {nodes_str}, {num_gpus_per_node} {gpus_str} per node.\n"
f"SW: Model with {int(total_params/1e6)}M total params.")
print(" per CPU | per GPU | Options")
for cpu_offload in [True, False]:
cpu_mem, gpu_mem = estimate_zero2_model_states_mem_needs(
total_params=total_params,
num_gpus_per_node=num_gpus_per_node,
num_nodes=num_nodes,
cpu_offload=cpu_offload,
additional_buffer_factor=additional_buffer_factor
)

options_str = format_options(cpu_offload=cpu_offload)
print(f" {cpu_mem/2**30:7.2f}GB | {gpu_mem/2**30:6.2f}GB | {options_str}")
152 changes: 152 additions & 0 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3258,3 +3258,155 @@ def _handle_overflow(cpu_sum, x, i):
logger.info(
f"rank {rank} detected overflow {cpu_sum} in tensor {i}:{t_i} shape {x.shape}"
)


def estimate_zero3_model_states_mem_needs(total_params,
largest_layer_params,
num_gpus_per_node=1,
num_nodes=1,
cpu_offload=True,
cpu_offload_params=True,
zero_init=True,
additional_buffer_factor=1.5):

total_gpus = num_nodes * num_gpus_per_node
gpus_factor = 1 / num_nodes
largest_layer_memory = (4 * largest_layer_params)

if cpu_offload:
if cpu_offload_params:
gpu_mem = largest_layer_memory

if zero_init:
cpu_mem = total_params * 18 * gpus_factor * additional_buffer_factor
else:
cpu_mem = total_params * max(4 * num_gpus_per_node,
18 * gpus_factor) * additional_buffer_factor

else:
gpu_mem = largest_layer_memory + int(2 * total_params / total_gpus)

if zero_init:
cpu_mem = total_params * 16 * gpus_factor * additional_buffer_factor
else:
cpu_mem = total_params * max(4 * num_gpus_per_node,
16 * gpus_factor) * additional_buffer_factor
else:
gpu_mem = largest_layer_memory + int(18 * total_params / total_gpus)
if zero_init:
cpu_mem = largest_layer_params * 4 * num_gpus_per_node * additional_buffer_factor
else:
cpu_mem = total_params * 4 * num_gpus_per_node * additional_buffer_factor

return int(cpu_mem), int(gpu_mem), largest_layer_memory


def model_to_params(model):
# shared params calculated only once
total_params = sum(
dict((p.data_ptr(),
p.numel()) for p in model.parameters()).values())

largest_layer_params = 0
for m in model.modules():
# assuming no shared params within a single layer
layer_params = sum(p.numel() for p in m.parameters(recurse=False))
largest_layer_params = max(largest_layer_params, layer_params)

return total_params, largest_layer_params


import math


def estimate_zero3_model_states_mem_needs_all_live(model,
num_gpus_per_node=1,
num_nodes=1,
additional_buffer_factor=1.5):
"""
Print out estimates on memory usage requirements for ZeRO 3 params, optim states and gradients
for a given ``model`` and hardware setup.

If you have an actual model object, use this function and everything will be derived
automatically.

If it's a hypothetical model, use ``estimate_zero3_model_states_mem_needs_all_cold`` where you have to pass
the ``total_params`` and ``largest_layer_params`` explicitly.

Args:
- ``model``: ``nn.Module`` object
- ``num_gpus_per_node``: how many gpus per node (defaults to 1)
- ``num_nodes``: how many nodes (defaults to 1),
- ``additional_buffer_factor``: estimation factor (defaults to 1.5):

"""

total_params, largest_layer_params = model_to_params(model)

estimate_zero3_model_states_mem_needs_all_cold(
total_params=total_params,
largest_layer_params=largest_layer_params,
num_gpus_per_node=num_gpus_per_node,
num_nodes=num_nodes,
additional_buffer_factor=additional_buffer_factor)


def estimate_zero3_model_states_mem_needs_all_cold(total_params,
largest_layer_params,
num_gpus_per_node=1,
num_nodes=1,
additional_buffer_factor=1.5):
"""
Print out estimates on memory usage requirements for ZeRO 3 params, optim states and gradients
for a given ``model`` and hardware setup.

If it's a hypothetical model, use this function where you have to pass
the ``total_params`` and ``largest_layer_params`` explicitly.

If you have an actual model object, use ``estimate_zero3_model_states_mem_needs_all_live`` and everything
will be derived automatically.

Args:
- ``total_params``: total model params
- ``largest_layer_params``: largest layer's params
- ``num_gpus_per_node``: how many gpus per node (defaults to 1)
- ``num_nodes``: how many nodes (defaults to 1),
- ``additional_buffer_factor``: estimation factor (defaults to 1.5):

"""
def format_options(cpu_offload, cpu_offload_params, zero_init):
enabled = []
enabled.append(f"cpu_offload={1 if cpu_offload else 0}")
enabled.append(f"cpu_offload_params={1 if cpu_offload_params else 0}")
enabled.append(f"zero_init={1 if zero_init else 0}")
return ", ".join(enabled)

nodes_str = "nodes" if num_nodes > 1 else "node"
gpus_str = "GPUs" if num_gpus_per_node > 1 else "GPU"
print(
"Estimated memory needed for params, optim states and gradients for a:\n"
f"HW: Setup with {num_nodes} {nodes_str}, {num_gpus_per_node} {gpus_str} per node.\n"
f"SW: Model with {int(total_params/1e6)}M total params, {int(largest_layer_params/1e6)}M largest layer params."
)
print(" per CPU | per GPU | Options")
for cpu_offload in [True, False]:
for cpu_offload_params in [True, False]:
if not cpu_offload and cpu_offload_params:
continue
for zero_init in [True, False]:
cpu_mem, gpu_mem, largest_layer_memory = estimate_zero3_model_states_mem_needs(
total_params=total_params,
largest_layer_params=largest_layer_params,
num_gpus_per_node=num_gpus_per_node,
num_nodes=num_nodes,
cpu_offload=cpu_offload,
cpu_offload_params=cpu_offload_params,
zero_init=zero_init,
additional_buffer_factor=additional_buffer_factor
)

options_str = format_options(cpu_offload=cpu_offload,
cpu_offload_params=cpu_offload_params,
zero_init=zero_init)
print(
f" {cpu_mem/2**30:7.2f}GB | {gpu_mem/2**30:6.2f}GB | {options_str}")
8 changes: 8 additions & 0 deletions docs/code-docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ Flops Profiler

flops-profiler


Memory Usage
------------------
.. toctree::
:maxdepth: 2

memory

Indices and tables
------------------

Expand Down
Loading