Skip to content

Commit adec991

Browse files
tohtanaloadamstjruwase
authored
Add API to get devices of offload states (#6586)
This PR adds an API `deepspeed.runtime.zero.offload_states get_state_devices`, which gets devices of offload states as suggested in this [comment](#6011 (comment)). We could lift this up to `deepspeed.utils` but would need to resolve a circular import: User code -> `deepspeed.utils` -> `deepspeed.utils.offload_states` -> `deepspeed.runtime.zero` -> `deepspeed.runtime.zero.partition_parameters` -> `deepspeed.utils` This will require a significant refactoring as long as we have `OffloadStateTypeEnum` in `deepspeed.runtime.zero`. --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
1 parent d7ca3d8 commit adec991

File tree

5 files changed

+110
-16
lines changed

5 files changed

+110
-16
lines changed

deepspeed/runtime/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,28 @@
99
"""
1010

1111
from collections.abc import Iterable
12-
from deepspeed.moe.utils import is_moe_param
1312
import os
1413
import psutil
1514
import gc
1615
from math import sqrt
1716

17+
from numpy import prod
18+
1819
import torch
19-
from deepspeed import comm as dist
20+
from torch.nn import functional as F
2021
try:
2122
from torch._six import inf
2223
except ModuleNotFoundError:
2324
from torch import inf
2425

26+
from deepspeed import comm as dist
27+
from deepspeed.moe.utils import is_moe_param
2528
from deepspeed.utils import groups, logger
2629
from deepspeed.utils.bwc import (bwc_tensor_model_parallel_rank, bwc_pipeline_parallel_world_size,
2730
bwc_pipeline_parallel_group)
2831
from deepspeed.runtime.constants import PIPE_REPLICATED
29-
from numpy import prod
3032
from deepspeed.accelerator import get_accelerator
31-
3233
from deepspeed.module_inject.policy import transpose
33-
from torch.nn import functional as F
3434

3535
torch_memory_reserved = get_accelerator().memory_reserved
3636
torch_max_memory_reserved = get_accelerator().max_memory_reserved
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
from typing import Set
7+
import torch
8+
9+
from deepspeed.accelerator import get_accelerator
10+
from deepspeed.runtime.zero.offload_config import OffloadStateTypeEnum
11+
12+
from deepspeed.utils.tensor_fragment import safe_get_local_fp32_param, safe_get_local_optimizer_state
13+
14+
15+
def _make_offload_state_key(key):
16+
return f"{key}_offload_buffer"
17+
18+
19+
def offload_adam_states(optimizer, device, pin_memory: bool = False, non_blocking: bool = False):
20+
"""Move optimizer states to device. Note that this assumes the state structure of DeepSpeed Adam."""
21+
22+
def move_key(state, key):
23+
offload_buf_key = _make_offload_state_key(key)
24+
if offload_buf_key not in state:
25+
state[offload_buf_key] = torch.empty_like(state[key], device=device)
26+
if pin_memory:
27+
state[offload_buf_key] = get_accelerator().pin_memory(state[offload_buf_key])
28+
state[offload_buf_key].copy_(state[key], non_blocking=non_blocking)
29+
state[key].data = state[offload_buf_key]
30+
31+
for _, state in optimizer.state.items():
32+
if "exp_avg" in state:
33+
move_key(state, "exp_avg")
34+
if "exp_avg_sq" in state:
35+
move_key(state, "exp_avg_sq")
36+
37+
38+
def reload_adam_states(optimizer, device, non_blocking: bool = False):
39+
"""Move optimizer states to device. Note that this assumes the state structure of DeepSpeed Adam."""
40+
41+
def move_back_key(state, key):
42+
state[key].data = state[_make_offload_state_key(key)].to(device, non_blocking=non_blocking)
43+
44+
for _, state in optimizer.state.items():
45+
if "exp_avg" in state:
46+
move_back_key(state, "exp_avg")
47+
if "exp_avg_sq" in state:
48+
move_back_key(state, "exp_avg_sq")
49+
50+
51+
def get_state_devices(model, state: OffloadStateTypeEnum) -> Set[torch.device]:
52+
"""Retrieve the devices of the specified state of the model.
53+
54+
Args:
55+
model (DeepSpeedEngine): The model whose device allocations are to be checked.
56+
state (OffloadStateTypeEnum): The specific state for which the devices should be retrieved.
57+
58+
Returns:
59+
Set[torch.device]: A set of devices of the specified state.
60+
61+
"""
62+
if state == OffloadStateTypeEnum.hp_params:
63+
return set(safe_get_local_fp32_param(p).device for p in model.parameters())
64+
elif state == OffloadStateTypeEnum.lp_params:
65+
return set(p.ds_tensor.device for p in model.parameters())
66+
elif state == OffloadStateTypeEnum.lp_grads:
67+
return {model.optimizer.grad_partitions_flat_buffer.device}
68+
elif state == OffloadStateTypeEnum.optim_states:
69+
return set(safe_get_local_optimizer_state(p, "exp_avg").device for p in model.parameters()) | \
70+
set(safe_get_local_optimizer_state(p, "exp_avg_sq").device for p in model.parameters())
71+
elif state == OffloadStateTypeEnum.contiguous_grad_buffer:
72+
if model.optimizer._DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer == None:
73+
return {}
74+
return {model.optimizer._DeepSpeedZeroOptimizer_Stage3__ipg_bucket_flat_buffer.device}

deepspeed/runtime/zero/stage3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,13 @@
1818
from deepspeed.utils import logger
1919
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
2020
from deepspeed.runtime.comm.coalesced_collectives import reduce_scatter_coalesced, all_to_all_quant_reduce
21-
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item, offload_adam_states, reload_adam_states
21+
from deepspeed.runtime.utils import inf, is_model_parallel_parameter, get_only_unique_item
2222
from deepspeed.runtime.zero.partition_parameters import *
2323
from deepspeed.runtime.zero.config import ZeroStageEnum
2424
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum
2525
from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload
2626
from deepspeed.runtime.zero.utils import apply_to_tensors_only, get_mapping_to_flat_buffer
27+
from deepspeed.runtime.zero.offload_states import offload_adam_states, reload_adam_states
2728
from deepspeed.ops.adam import DeepSpeedCPUAdam
2829
from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus
2930
from deepspeed.runtime.swap_tensor.optimizer_utils import OptimizerSwapper

docs/code-docs/source/zero3.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,3 +509,19 @@ Below is an example code snippet demonstrating how to offload FP32 parameters an
509509
...
510510
# Load states back to device memory
511511
ds_engine.reload_states()
512+
513+
``deepspeed.runtime.zero.offload_states.get_state_devices`` returns devices of the specified state.
514+
515+
.. code-block:: python
516+
517+
def get_state_devices(model, state: OffloadStateTypeEnum) -> Set[torch.device]:
518+
"""Retrieve the devices of the specified state of the model.
519+
520+
Args:
521+
model (DeepSpeedEngine): The model whose device allocations are to be checked.
522+
state (OffloadStateTypeEnum): The specific state for which the devices should be retrieved.
523+
524+
Returns:
525+
Set[torch.device]: A set of devices of the specified state.
526+
527+
"""

tests/unit/runtime/zero/test_offload_states.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,22 @@
1515
import deepspeed
1616
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum
1717
from deepspeed.utils import safe_get_local_fp32_param, safe_get_local_optimizer_state
18+
from deepspeed.runtime.zero.offload_states import get_state_devices
1819

1920

2021
def validate_device(model, device: torch.device, include) -> None:
21-
# Make sure the model parameters are offloaded
22-
if include is None or OffloadStateTypeEnum.hp_params in include:
23-
assert all(safe_get_local_fp32_param(p).device == device for p in model.parameters())
24-
if include is None or OffloadStateTypeEnum.lp_params in include:
25-
assert all(p.ds_tensor.device == device for p in model.parameters())
26-
if include is None or OffloadStateTypeEnum.lp_grads in include:
27-
assert model.optimizer.grad_partitions_flat_buffer.device == device
28-
if include is None or OffloadStateTypeEnum.optim_states in include:
29-
assert all(safe_get_local_optimizer_state(p, "exp_avg").device == device for p in model.parameters())
30-
assert all(safe_get_local_optimizer_state(p, "exp_avg_sq").device == device for p in model.parameters())
22+
23+
def compare_device(state) -> bool:
24+
devices = get_state_devices(model, state)
25+
return len(devices) == 1 and device in devices
26+
27+
for state in OffloadStateTypeEnum:
28+
if include is None or state in include:
29+
if state == OffloadStateTypeEnum.contiguous_grad_buffer and device == torch.device("cpu"):
30+
assert len(get_state_devices(model,
31+
state)) == 0, f"State {state} must be removed after offload_states()"
32+
else:
33+
assert compare_device(state), f"State {state} is not on device {device}"
3134

3235

3336
def run_model(model, config_dict, hidden_dim, dtype, include, pin_memory, non_blocking):

0 commit comments

Comments
 (0)