You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -396,34 +396,39 @@ These routines can be used in a training loop as shown in the following snippet.
396
396
Modifying Partitioned States
397
397
----------------------------
398
398
399
-
Sometimes, a user may want to modify parameters or optimizer states outside of the regular training loop. This is currently difficult in ZeRO training because of partitioning. To overcome that, DeepSpeed provides the following routines for modifying the fp32 master parameters and the fp32 optimizer states.
399
+
Sometimes, a user may want to modify parameters, gradients, or optimizer states outside of the regular training loop. This is currently difficult in ZeRO training because of partitioning. To overcome that, DeepSpeed provides the following routines for modifying the fp32 master parameters and the fp32 optimizer states.
These routines can be used at any point after initialization of the DeepSpeed engine (i.e., ``deepspeed.initialize()``) as shown in the following snippet.
413
+
The routines for modifying parameters and optimizer states can be used at any point after initialization of the DeepSpeed engine (i.e., ``deepspeed.initialize()``) as shown in the following snippet.
410
414
411
415
.. code-block:: python
412
416
413
417
[...]
418
+
from deepspeed.runtime.zero.utils import is_zero_param
414
419
from deepspeed.utils import safe_set_full_fp32_param, safe_set_full_optimizer_state
415
420
from deepspeed.utils import safe_set_local_fp32_param, safe_set_local_optimizer_state
416
421
# Here is an example to zero all the fp32 parameters and optimizer states.
417
422
for n, lp in model.named_parameters():
418
-
# 1. For zero stage 1 or 2, set the full fp32 and their full optim states
419
-
zero_tensor = torch.zeros_like(lp)
423
+
# 1. For zero stage 1, 2, or 3 set the full fp32 and their full optim states
424
+
zero_tensor = torch.zeros(lp.ds_shape) if is_zero_param(lp) else torch.zeros(lp.shape)
0 commit comments