diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 6d8040fe8ec2..5081adb0b021 100755 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1715,6 +1715,8 @@ def _zero3_consolidated_fp16_state_dict(self): Get a full non-partitioned state_dict with fp16 weights on cpu. + Important: this function must be called on all ranks and not just rank 0. + This is similar to nn.Module.state_dict (modelled after _save_to_state_dict), but: 1. consolidates the weights from different partitions on gpu0 diff --git a/docs/_pages/config-json.md b/docs/_pages/config-json.md index 9a9554cbd75f..81c7d5bd62b9 100755 --- a/docs/_pages/config-json.md +++ b/docs/_pages/config-json.md @@ -256,7 +256,8 @@ Enabling and configuring ZeRO memory optimizations "stage3_prefetch_bucket_size" : 5e8, "stage3_param_persistence_threshold" : 1e6, "sub_group_size" : 1e12, - "elastic_checkpoint" : [true|false] + "elastic_checkpoint" : [true|false], + "stage3_gather_fp16_weights_on_model_save": [true|false] } ``` @@ -351,6 +352,11 @@ Enabling and configuring ZeRO memory optimizations | Do not partition parameters smaller than this threshold. Smaller values use less memory, but can greatly increase communication (especially latency-bound messages). | `1e6` | +***stage3_gather_fp16_weights_on_model_save***: [boolean] +| Description | Default | +| -------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------- | +| Consolidate the weights before saving the model by `save_fp16_model()`. Since the weights are partitioned across GPUs, they aren't part of `state_dict`, so this function automatically gather the weights when this option is enabled and then saves the fp16 model weights. | `False` | + ### Logging ***steps\_per\_print***: [integer] diff --git a/docs/_tutorials/advanced-install.md b/docs/_tutorials/advanced-install.md index 18f60e864039..4ff02bf7ec84 100644 --- a/docs/_tutorials/advanced-install.md +++ b/docs/_tutorials/advanced-install.md @@ -73,6 +73,18 @@ DS_BUILD_OPS=1 pip install deepspeed --global-option="build_ext" --global-option This should complete the full build 2-3 times faster. You can adjust `-j` to specify how many cpu-cores are to be used during the build. In the example it is set to 8 cores. +You can also build a binary wheel and install it on multiple machines that have the same type of GPUs and the same software environment (CUDA toolkit, pytorch, python, etc.) + +```bash +DS_BUILD_OPS=1 python setup.py build_ext -j8 bdist_wheel +``` + +This will create a pypi binary wheel under `dist`, e.g., ``dist/deepspeed-0.3.13+8cd046f-cp38-cp38-linux_x86_64.whl`` and then you can install it directly on multiple machines, in our example: + +```bash +pip install dist/deepspeed-0.3.13+8cd046f-cp38-cp38-linux_x86_64.whl +``` + ## Install DeepSpeed from source diff --git a/docs/_tutorials/zero.md b/docs/_tutorials/zero.md index 1e9f97b98a11..8f506d25babe 100644 --- a/docs/_tutorials/zero.md +++ b/docs/_tutorials/zero.md @@ -260,5 +260,43 @@ for more details. self.init_method(self.position_embeddings.weight) ``` +## Extracting weights + +If you need to take the pretrained weights out of Deepspeed here is what you can do for getting fp16 weights: + +- under ZeRO-2 `state_dict` contains the fp16 model weights and these can be saved normally with `torch.save`. +- under ZeRO-3 `state_dict` contains just the placeholders since the model weights are partitioned across multiple GPUs. If you want to get to these weights enable: + +``` + "zero_optimization": { + "stage3_gather_fp16_weights_on_model_save": true + }, +``` +And then save the model using: + +``` + if self.deepspeed: + self.deepspeed.save_fp16_model(output_dir, output_file) +``` + +Because it requires consolidation of the weights on one GPU it can be slow and memory demanding, so only use this feature when needed. + +Note that if `stage3_gather_fp16_weights_on_model_save` is `False`, no weights will be saved (again, because `state_dict` doesn't have them. +You can use this method to save ZeRO-2 weights as well. + +If you'd like to get the fp32 weights, we supply a special script that can do offline consolidation. It requires no configuration files or GPUs. Here is an example of its usage: + +``` +$ cd /path/to/checkpoints_dir +$ ./zero_to_fp32.py global_step1 pytorch_model.bin +Processing zero checkpoint at global_step1 +Detected checkpoint of type zero stage 3, world_size: 2 +Saving fp32 state dict to pytorch_model.bin (total_numel=60506624) +``` + +The `zero_to_fp32.py` gets created automatically when you save a checkpoint. + +Note: currently this script uses 2x memory (general RAM) of the size of the final checkpoint. + Congratulations! You have completed the ZeRO tutorial. diff --git a/docs/code-docs/source/training.rst b/docs/code-docs/source/training.rst index d88d755f39cb..52e124fc3b40 100644 --- a/docs/code-docs/source/training.rst +++ b/docs/code-docs/source/training.rst @@ -31,3 +31,11 @@ Optimizer Step Gradient Accumulation --------------------- .. autofunction:: deepspeed.DeepSpeedEngine.is_gradient_accumulation_boundary + + +Model Saving +------------ +.. autofunction:: deepspeed.DeepSpeedEngine.save_fp16_model + + +Additionally when a DeepSpeed checkpoint is created, a script ``zero_to_fp32.py`` is added there which can be used to reconstruct fp32 master weights into a single pytorch ``state_dict`` file.