Skip to content

Commit c589284

Browse files
committed
Update docs
1 parent 6110f70 commit c589284

File tree

1 file changed

+37
-7
lines changed

1 file changed

+37
-7
lines changed

docs/code-docs/source/zero3.rst

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -369,13 +369,13 @@ These routines can be used in a training loop as shown in the following snippet.
369369
from deepspeed.utils import safe_get_full_fp32_param, safe_get_full_grad, safe_get_full_optimizer_state
370370
for n, lp in model.named_parameters():
371371
# 1. Access the full states
372-
# 1) gradient lookup
372+
# 1.1) gradient lookup
373373
# For zero1 and zero2, gradient lookup must be called after `backward` and before `step`
374374
# For zero3, gradient lookup must be called after `backward`
375375
hp_grad = safe_get_full_grad(lp)
376376
377377
378-
# 2) fp32 and optim states can probably be called anywhere in the training loop, but will be updated after `step`
378+
# 1.2) fp32 and optim states can probably be called anywhere in the training loop, but will be updated after `step`
379379
hp = safe_get_full_fp32_param(lp)
380380
exp_avg = safe_get_full_optimizer_state(lp, "exp_avg")
381381
exp_avg_sq = safe_get_full_optimizer_state(lp, "exp_avg_sq")
@@ -396,34 +396,39 @@ These routines can be used in a training loop as shown in the following snippet.
396396
Modifying Partitioned States
397397
----------------------------
398398

399-
Sometimes, a user may want to modify parameters or optimizer states outside of the regular training loop. This is currently difficult in ZeRO training because of partitioning. To overcome that, DeepSpeed provides the following routines for modifying the fp32 master parameters and the fp32 optimizer states.
399+
Sometimes, a user may want to modify parameters, gradients, or optimizer states outside of the regular training loop. This is currently difficult in ZeRO training because of partitioning. To overcome that, DeepSpeed provides the following routines for modifying the fp32 master parameters and the fp32 optimizer states.
400400

401401
.. autofunction:: deepspeed.utils.safe_set_full_fp32_param
402402

403403
.. autofunction:: deepspeed.utils.safe_set_full_optimizer_state
404404

405+
.. autofunction:: deepspeed.utils.safe_set_full_grad
406+
405407
.. autofunction:: deepspeed.utils.safe_set_local_fp32_param
406408

409+
.. autofunction:: deepspeed.utils.safe_set_local_grad
410+
407411
.. autofunction:: deepspeed.utils.safe_set_local_optimizer_state
408412

409-
These routines can be used at any point after initialization of the DeepSpeed engine (i.e., ``deepspeed.initialize()``) as shown in the following snippet.
413+
The routines for modifying parameters and optimizer states can be used at any point after initialization of the DeepSpeed engine (i.e., ``deepspeed.initialize()``) as shown in the following snippet.
410414

411415
.. code-block:: python
412416
413417
[...]
418+
from deepspeed.runtime.zero.utils import is_zero_param
414419
from deepspeed.utils import safe_set_full_fp32_param, safe_set_full_optimizer_state
415420
from deepspeed.utils import safe_set_local_fp32_param, safe_set_local_optimizer_state
416421
# Here is an example to zero all the fp32 parameters and optimizer states.
417422
for n, lp in model.named_parameters():
418-
# 1. For zero stage 1 or 2, set the full fp32 and their full optim states
419-
zero_tensor = torch.zeros_like(lp)
423+
# 1. For zero stage 1, 2, or 3 set the full fp32 and their full optim states
424+
zero_tensor = torch.zeros(lp.ds_shape) if is_zero_param(lp) else torch.zeros(lp.shape)
420425
421426
safe_set_full_fp32_param(lp, zero_tensor)
422427
safe_get_full_optimizer_state(lp, zero_tensor, "exp_avg")
423428
safe_get_full_optimizer_state(lp, zero_tensor, "exp_avg_sq")
424429
425430
# 2. For zero stage 3, each process sets its local fp32 parameters and their local optimizer states individually
426-
zero_tensor_local = torch.zeros_like(lp.ds_tensor.shape)
431+
zero_tensor_local = torch.zeros(lp.ds_tensor.shape)
427432
428433
safe_set_local_fp32_param(lp, zero_tensor_local)
429434
safe_set_local_optimizer_state(lp, zero_tensor_local, "exp_avg")
@@ -432,6 +437,31 @@ These routines can be used at any point after initialization of the DeepSpeed en
432437
[...]
433438
434439
440+
The routines for modifying gradients can be used after ``backward`` but before ``step`` as shown in the following snippet.
441+
442+
.. code-block:: python
443+
444+
backward(loss)
445+
[...]
446+
from deepspeed.runtime.zero.utils import is_zero_param
447+
from deepspeed.utils import safe_set_full_grad, safe_set_local_grad
448+
# Here is an example of how to zero all the gradients.
449+
for n, lp in model.named_parameters():
450+
# 1. For zero stage 1, 2, or 3 set the full gradient.
451+
zero_tensor = torch.zeros(lp.ds_shape) if is_zero_param(lp) else torch.zeros(lp.shape)
452+
453+
safe_set_full_grad(lp, zero_tensor)
454+
455+
# 2. For zero stage 3, each process sets its local gradient partition.
456+
zero_tensor_local = torch.zeros_like(lp.ds_tensor.shape)
457+
458+
safe_set_local_grad(lp, zero_tensor_local)
459+
460+
[...]
461+
optimizer.step()
462+
463+
464+
435465
GPU Memory Management
436466
---------------------
437467

0 commit comments

Comments
 (0)