Skip to content

Comments

ZeRO 2+3 memory estimators#965

Merged
jeffra merged 12 commits intodeepspeedai:masterfrom
stas00:mem-estimation-utils
Jun 23, 2021
Merged

ZeRO 2+3 memory estimators#965
jeffra merged 12 commits intodeepspeedai:masterfrom
stas00:mem-estimation-utils

Conversation

@stas00
Copy link
Collaborator

@stas00 stas00 commented Apr 16, 2021

With @samyam and @tjruwase's help I have been working on having utils to estimate how much cpu+gpu ram is needed for a given model on a given setup.

This PR adds memory estimators for ZeRO 2+3 params, optim states and gradients for a given model and hardware setup:

  • estimate_zero3_model_states_mem_needs_all_live - requires an actual model object
  • estimate_zero3_model_states_mem_needs_all_cold- requires total_params and largest_layer_params
  • estimate_zero2_model_states_mem_needs_all_live - requires an actual model object
  • estimate_zero2_model_states_mem_needs_all_cold- requires total_params and largest_layer_params
  • a detailed description of most of the memory usage details

ZeRO-3

Let's try a 3B model with just 1 node with 8 gpus, using live model:

USE_TF=0 PYTHONPATH=. python -c 'from transformers import AutoModel; \
from deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_live; \
model = AutoModel.from_pretrained("t5-3b"); \
estimate_zero3_model_states_mem_needs_all_live(model, num_gpus_per_node=8, num_nodes=1)'     
                                       
Estimated memory needed for params, optim states and gradients for a:
HW: Setup with 1 node, 8 GPUs per node.
SW: Model with 2851M total params, 32M largest layer params.
  per CPU  |  per GPU |   Options
   71.71GB |   0.12GB | cpu_offload=1, cpu_offload_params=1, zero_init=1
  127.48GB |   0.12GB | cpu_offload=1, cpu_offload_params=1, zero_init=0
   63.74GB |   0.79GB | cpu_offload=1, cpu_offload_params=0, zero_init=1
  127.48GB |   0.79GB | cpu_offload=1, cpu_offload_params=0, zero_init=0
    1.47GB |   6.10GB | cpu_offload=0, cpu_offload_params=0, zero_init=1
  127.48GB |   6.10GB | cpu_offload=0, cpu_offload_params=0, zero_init=0

Now, w/o the actual model, which requires us to know total_params and largest_layer_params, but we got those from the run above, so future estimators are now much faster as we don't need to load the model.

PYTHONPATH=. python -c 'from deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_cold; \
estimate_zero3_model_states_mem_needs_all_cold(total_params=2851e6, largest_layer_params=32e6, num_gpus_per_node=8, num_nodes=1)'

Estimated memory needed for params, optim states and gradients for a:
HW: Setup with 1 node, 8 GPUs per node.
SW: Model with 2851M total params, 32M largest layer params.
  per CPU  |  per GPU |   Options
   71.69GB |   0.12GB | cpu_offload=1, cpu_offload_params=1, zero_init=1
  127.45GB |   0.12GB | cpu_offload=1, cpu_offload_params=1, zero_init=0
   63.72GB |   0.78GB | cpu_offload=1, cpu_offload_params=0, zero_init=1
  127.45GB |   0.78GB | cpu_offload=1, cpu_offload_params=0, zero_init=0
    1.43GB |   6.09GB | cpu_offload=0, cpu_offload_params=0, zero_init=1
  127.45GB |   6.09GB | cpu_offload=0, cpu_offload_params=0, zero_init=0

a slight difference due to rounding - the actual live model has a few more params

Let's try a 3B model on 8 nodes with 8 gpus each (cold)

PYTHONPATH=. python -c 'from deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_cold; \
estimate_zero3_model_states_mem_needs_all_cold(total_params=2851e6, largest_layer_params=32e6, num_gpus_per_node=8, num_nodes=8)'

Estimated memory needed for params, optim states and gradients for a:
HW: Setup with 8 nodes, 8 GPUs per node.
SW: Model with 2851M total params, 32M largest layer params.
  per CPU  |  per GPU |   Options
    8.96GB |   0.12GB | cpu_offload=1, cpu_offload_params=1, zero_init=1
  127.45GB |   0.12GB | cpu_offload=1, cpu_offload_params=1, zero_init=0
    7.97GB |   0.20GB | cpu_offload=1, cpu_offload_params=0, zero_init=1
  127.45GB |   0.20GB | cpu_offload=1, cpu_offload_params=0, zero_init=0
    1.43GB |   0.87GB | cpu_offload=0, cpu_offload_params=0, zero_init=1
  127.45GB |   0.87GB | cpu_offload=0, cpu_offload_params=0, zero_init=0

Let's try a different setup with just 1 node with 1 gpu:

PYTHONPATH=. python -c 'from deepspeed.runtime.zero.stage3 import estimate_zero3_model_states_mem_needs_all_cold; \
estimate_zero3_model_states_mem_needs_all_cold(total_params=2851e6, largest_layer_params=32e6, num_gpus_per_node=1, num_nodes=1)'

Estimated memory needed for params, optim states and gradients for a:
HW: Setup with 1 node, 1 GPU per node.
SW: Model with 2851M total params, 32M largest layer params.
  per CPU  |  per GPU |   Options
   71.69GB |   0.12GB | cpu_offload=1, cpu_offload_params=1, zero_init=1
   71.69GB |   0.12GB | cpu_offload=1, cpu_offload_params=1, zero_init=0
   63.72GB |   5.43GB | cpu_offload=1, cpu_offload_params=0, zero_init=1
   63.72GB |   5.43GB | cpu_offload=1, cpu_offload_params=0, zero_init=0
    0.18GB |  47.91GB | cpu_offload=0, cpu_offload_params=0, zero_init=1
   15.93GB |  47.91GB | cpu_offload=0, cpu_offload_params=0, zero_init=0

ZeRO-2

Live:

USE_TF=0 PYTHONPATH=. python -c 'from transformers import AutoModel; \
from deepspeed.runtime.zero.stage2 import estimate_zero2_model_states_mem_needs_all_live; \
model = AutoModel.from_pretrained("t5-3b"); \
estimate_zero2_model_states_mem_needs_all_live(model, num_gpus_per_node=8, num_nodes=1)'
Estimated memory needed for params, optim states and gradients for a:
HW: Setup with 1 node, 8 GPUs per node.
SW: Model with 2851M total params.
  per CPU  |  per GPU |   Options
  127.48GB |   5.31GB | cpu_offload=1
  127.48GB |  15.93GB | cpu_offload=0

Cold:

PYTHONPATH=. python -c 'from deepspeed.runtime.zero.stage2 import estimate_zero2_model_states_mem_needs_all_cold; \
estimate_zero2_model_states_mem_needs_all_cold(total_params=2851e6, num_gpus_per_node=8, num_nodes=1)'
Estimated memory needed for params, optim states and gradients for a:
HW: Setup with 1 node, 8 GPUs per node.
SW: Model with 2851M total params.
  per CPU  |  per GPU |   Options
  127.45GB |   5.31GB | cpu_offload=1
  127.45GB |  15.93GB | cpu_offload=0

@samyam

@samyam
Copy link
Contributor

samyam commented Apr 16, 2021

@stas00 The numbers look slightly smaller than what I would expect for a 3b parameter model but the code looks fine to me. I am guessing the model is not exactly 3b here but slightly smaller?

@stas00
Copy link
Collaborator Author

stas00 commented Apr 16, 2021

@stas00 The numbers look slightly smaller than what I would expect for a 3b parameter model but the code looks fine to me. I am guessing the model is not exactly 3b here but slightly smaller?

It's 2851M - i will make the script dump the exact number of params.

@stas00
Copy link
Collaborator Author

stas00 commented Apr 16, 2021

I'm also thinking that I should probably match the launcher API and ask for num_gpus and num_nodes, rather than total_gpus, does it make sense?

@stas00
Copy link
Collaborator Author

stas00 commented Apr 17, 2021

  • switched to num_nodes
  • made two versions:
    1. one live (real model obj is required)
    2. one cold (user needs to pass # params) - the second way is easier for planning.
  • added docs

@stas00 stas00 changed the title [wip] zero memory estimators ZeRO 2+3 memory estimators Jun 4, 2021
@stas00
Copy link
Collaborator Author

stas00 commented Jun 4, 2021

Added zero 2 estimators and full in-depth docs

@jeffra jeffra merged commit 0c1802c into deepspeedai:master Jun 23, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants