Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tuning BLOOMZ 176B #194

Closed
alex-ht opened this issue Mar 20, 2023 · 13 comments
Closed

tuning BLOOMZ 176B #194

alex-ht opened this issue Mar 20, 2023 · 13 comments

Comments

@alex-ht
Copy link

alex-ht commented Mar 20, 2023

Hi, I'm wondering if I could use peft to finetune 176B BLOOMZ?
I am experiencing poor GPU utilization efficiency with 64 V100s.

@pacman100
Copy link
Contributor

Hello, could you provide more information on setup, details like are you using PEFT with INT8 training or using DeeSpeed with CPU offload...

Also details like what is the GPU utilisation?

@alex-ht
Copy link
Author

alex-ht commented Mar 20, 2023

I'm using DeepSpeed, and offload device is set to none.
I modified examples/causal_language_modeling/peft_lora_clm_accelerate_ds_zero3_offload.py to load xP3 dataset, and load model into fp16:

# creating model
    model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16)
    peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)
    model = get_peft_model(model, peft_config)
    model = model.half()
    model.print_trainable_parameters()

accelerate config:

compute_environment: LOCAL_MACHINE
deepspeed_config:
  deepspeed_config_file: ds_config.json
  deepspeed_multinode_launcher: standard
  zero3_init_flag: true
distributed_type: DEEPSPEED
downcast_bf16: 'no'
dynamo_backend: 'NO'
fsdp_config: {}
machine_rank: 0
main_process_ip: ***
main_process_port: 18049
main_training_function: main
megatron_lm_config: {}
num_machines: 8
num_processes: 64
rdzv_backend: static
same_network: true
use_cpu: false

ds_config.json:

{
   "fp16": {
      "enabled": true,
      "auto_cast": false,
      "loss_scale": 0,
      "initial_scale_power": 12,
      "loss_scale_window": 500,
      "hysteresis": 2,
      "min_loss_scale": 1
   },
   "bf16": {
      "enabled": false
   },
   "zero_optimization": {
      "stage": 3,
      "offload_optimizer": {
         "device": "none"
      },
      "offload_param": {
         "device": "none"
      },
      "overlap_comm": true,
      "contiguous_gradients": true,
      "reduce_bucket_size": 205520896,
      "stage3_prefetch_bucket_size": 184968807,
      "stage3_param_persistence_threshold": 143360,
      "sub_group_size": 1e+6,
      "stage3_max_live_parameters": 1e+6,
      "stage3_max_reuse_distance": 1e+6,
      "stage3_gather_16bit_weights_on_model_save": true
   },
   "steps_per_print": 2000,
   "train_batch_size": "auto",
   "train_micro_batch_size_per_gpu": 1,
   "gradient_accumulation_steps": 1,
   "wall_clock_breakdown": false
}

log file:


+ echo 'accelerate launch          --config_file config.458057.0.yaml     ./peft_lora_bloomz.py         --model-name-or-path bigscience/bloomz         --tokenizer-name-or-path bigscience/tokenizer         --lr 3e-3         --epochs 1         --batch-size 1         --mixed-precision fp16         '
accelerate launch          --config_file config.458057.0.yaml     ./peft_lora_bloomz.py         --model-name-or-path bigscience/bloomz         --tokenizer-name-or-path bigscience/tokenizer         --lr 3e-3         --epochs 1         --batch-size 1         --mixed-precision fp16         
+ srun --jobid 458057 bash -c '$LAUNCHER --config_file config.$SLURM_JOBID.$SLURM_NODEID.yaml $CMD'
Loading cached split indices for dataset at /home/u4005115/alex/peft/examples/bloomz/xP3lora/cache-c49fdf9ee117b070.arrow and /home/u4005115/alex/peft/examples/bloomz/xP3lora/cache-25b57fbede788174.arrow
...
[2023-03-20 20:17:00,146] [INFO] [comm.py:657:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
Distributed environment: DEEPSPEED  Backend: nccl
Num processes: 64
Process index: 56
Local process index: 0
Device: cuda:0
Mixed precision type: fp16
ds_config: {'fp16': {'enabled': True, 'auto_cast': False, 'loss_scale': 0, 'initial_scale_power': 12, 'loss_scale_window': 500, 'hysteresis': 2, 'min_loss_scale': 1}, 'bf16': {'enabled': False}, 'zero_optimization': {'stage': 3, 'offload_optimizer': {'device': 'none'}, 'offload_param': {'device': 'none'}, 'overlap_comm': True, 'contiguous_gradients': True, 'reduce_bucket_size': 205520896, 'stage3_prefetch_bucket_size': 184968807, 'stage3_param_persistence_threshold': 143360, 'sub_group_size': 1000000.0, 'stage3_max_live_parameters': 1000000.0, 'stage3_max_reuse_distance': 1000000.0, 'stage3_gather_16bit_weights_on_model_save': True}, 'steps_per_print': inf, 'train_batch_size': 'auto', 'train_micro_batch_size_per_gpu': 1, 'gradient_accumulation_steps': 1, 'wall_clock_breakdown': False}

...

NCCL version 2.14.3+cuda11.7
[2023-03-20 20:19:17,994] [INFO] [partition_parameters.py:413:__exit__] finished initializing model with 179.84B parameters

Loading checkpoint shards:   0%|          | 0/72 [00:00<?, ?it/s]
/home/u4005115/.conda/envs/alex-peft2/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py:2387: UserWarning: torch.distributed._all_gather_base is a private function and will be deprecated. Please use torch.distributed.all_gather_into_tensor instead.
  warnings.warn(
Loading checkpoint shards: 100%|██████████| 72/72 [07:05<00:00,  4.25s/it]
Loading checkpoint shards: 100%|██████████| 72/72 [07:05<00:00,  5.91s/it]
/home/u4005115/.conda/envs/alex-peft2/lib/python3.8/site-packages/peft/tuners/lora.py:174: UserWarning: fan_in_fan_out is set to True but the target module is not a Conv1D. Setting fan_in_fan_out to False.
  warnings.warn(
...
trainable params: 32112640 || all params: 176279384064 || trainable%: 0.018216900501729222
trainable params: 32112640 || all params: 176279384064 || trainable%: 0.018216900501729222
...
[2023-03-20 20:33:59,066] [INFO] [logging.py:68:log_dist] [Rank 0] DeepSpeed info: version=0.8.0, git-hash=unknown, git-branch=unknown
trainable params: 32112640 || all params: 176279384064 || trainable%: 0.018216900501729222
...
[2023-03-20 20:34:17,417] [INFO] [logging.py:68:log_dist] [Rank 0] DeepSpeed Flops Profiler Enabled: False
[2023-03-20 20:34:17,418] [INFO] [logging.py:68:log_dist] [Rank 0] Removing param_group that has no 'params' in the client Optimizer
[2023-03-20 20:34:17,418] [INFO] [logging.py:68:log_dist] [Rank 0] Using client Optimizer as basic optimizer
[2023-03-20 20:34:17,525] [INFO] [logging.py:68:log_dist] [Rank 0] DeepSpeed Basic Optimizer = AdamW
[2023-03-20 20:34:17,525] [INFO] [utils.py:52:is_zero_supported_optimizer] Checking ZeRO support for optimizer=AdamW type=<class 'torch.optim.adamw.AdamW'>
[2023-03-20 20:34:17,525] [INFO] [logging.py:68:log_dist] [Rank 0] Creating fp16 ZeRO stage 3 optimizer
[2023-03-20 20:34:17,696] [INFO] [utils.py:831:see_memory_usage] Stage 3 initialize beginning
[2023-03-20 20:34:17,696] [INFO] [utils.py:832:see_memory_usage] MA 5.3 GB         Max_MA 18.63 GB         CA 23.56 GB         Max_CA 31 GB 
[2023-03-20 20:34:17,698] [INFO] [utils.py:840:see_memory_usage] CPU Virtual Memory:  used = 51.16 GB, percent = 6.8%
[2023-03-20 20:34:17,705] [INFO] [stage3.py:114:__init__] Reduce bucket size 205520896
[2023-03-20 20:34:17,705] [INFO] [stage3.py:115:__init__] Prefetch bucket size 184968807
Using /home/u4005115/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
Using /home/u4005115/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...Using /home/u4005115/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...

Using /home/u4005115/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
Using /home/u4005115/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
Using /home/u4005115/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...Using /home/u4005115/.cache/torch_extensions/py38_cu117 as PyTorch extensions root...
...
Emitting ninja build file /home/u4005115/.cache/torch_extensions/py38_cu117/utils/build.ninja...
Building extension module utils...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module utils...
...
Time to load utils op: 1.2338910102844238 seconds
...

[2023-03-20 20:34:19,470] [INFO] [utils.py:831:see_memory_usage] DeepSpeedZeRoOffload initialize [begin]
[2023-03-20 20:34:19,471] [INFO] [utils.py:832:see_memory_usage] MA 5.3 GB         Max_MA 5.3 GB         CA 23.56 GB         Max_CA 24 GB 
[2023-03-20 20:34:19,473] [INFO] [utils.py:840:see_memory_usage] CPU Virtual Memory:  used = 50.02 GB, percent = 6.6%
Parameter Offload: Total persistent parameters: 13103104 in 564 params
[2023-03-20 20:34:19,678] [INFO] [utils.py:831:see_memory_usage] DeepSpeedZeRoOffload initialize [end]
[2023-03-20 20:34:19,679] [INFO] [utils.py:832:see_memory_usage] MA 5.24 GB         Max_MA 5.3 GB         CA 23.56 GB         Max_CA 24 GB 
[2023-03-20 20:34:19,681] [INFO] [utils.py:840:see_memory_usage] CPU Virtual Memory:  used = 50.02 GB, percent = 6.6%
[2023-03-20 20:34:19,823] [INFO] [utils.py:831:see_memory_usage] Before creating fp16 partitions
[2023-03-20 20:34:19,824] [INFO] [utils.py:832:see_memory_usage] MA 5.24 GB         Max_MA 5.24 GB         CA 23.56 GB         Max_CA 24 GB 
[2023-03-20 20:34:19,826] [INFO] [utils.py:840:see_memory_usage] CPU Virtual Memory:  used = 50.02 GB, percent = 6.6%
[2023-03-20 20:34:21,186] [INFO] [utils.py:831:see_memory_usage] After creating fp16 partitions: 1
[2023-03-20 20:34:21,186] [INFO] [utils.py:832:see_memory_usage] MA 5.24 GB         Max_MA 5.24 GB         CA 5.5 GB         Max_CA 24 GB 
...
[2023-03-20 20:34:21,742] [INFO] [stage3.py:382:_setup_for_real_optimizer] optimizer state initialized
...
No modifications detected for re-loaded extension module utils, skipping build step...
Loading extension module utils...
Time to load utils op: 0.002360820770263672 seconds
...
Time to load utils op: 0.0030989646911621094 seconds
Loading extension module utils...
DeepSpeedEngine(
  (module): PeftModelForCausalLM(
    (base_model): LoraModel(
      (model): BloomForCausalLM(
        (transformer): BloomModel(
          (word_embeddings): Embedding(250880, 14336)
          (word_embeddings_layernorm): LayerNorm((14336,), eps=1e-05, elementwise_affine=True)
...

          (ln_f): LayerNorm((14336,), eps=1e-05, elementwise_affine=True)
        )
        (lm_head): Linear(in_features=14336, out_features=250880, bias=False)
      )
    )
  )
)
[2023-03-20 20:34:22,190] [INFO] [utils.py:831:see_memory_usage] After initializing ZeRO optimizer
[2023-03-20 20:34:22,190] [INFO] [utils.py:832:see_memory_usage] MA 5.63 GB         Max_MA 5.63 GB         CA 5.9 GB         Max_CA 6 GB 
[2023-03-20 20:34:22,191] [INFO] [utils.py:840:see_memory_usage] CPU Virtual Memory:  used = 49.86 GB, percent = 6.6%
[2023-03-20 20:34:22,191] [INFO] [logging.py:68:log_dist] [Rank 0] DeepSpeed Final Optimizer = AdamW
[2023-03-20 20:34:22,191] [INFO] [logging.py:68:log_dist] [Rank 0] DeepSpeed using client LR scheduler
[2023-03-20 20:34:22,191] [INFO] [logging.py:68:log_dist] [Rank 0] DeepSpeed LR Scheduler = None
[2023-03-20 20:34:22,191] [INFO] [logging.py:68:log_dist] [Rank 0] step=0, skipped=0, lr=[0.003], mom=[(0.9, 0.999)]
[2023-03-20 20:34:22,193] [INFO] [config.py:1008:print] DeepSpeedEngine configuration:
[2023-03-20 20:34:22,194] [INFO] [config.py:1012:print]   activation_checkpointing_config  {
    "partition_activations": false, 
    "contiguous_memory_optimization": false, 
    "cpu_checkpointing": false, 
    "number_checkpoints": null, 
    "synchronize_checkpoint_boundary": false, 
    "profile": false
}
[2023-03-20 20:34:22,194] [INFO] [config.py:1012:print]   aio_config ................... {'block_size': 1048576, 'queue_depth': 8, 'thread_count': 1, 'single_submit': False, 'overlap_events': True}
[2023-03-20 20:34:22,194] [INFO] [config.py:1012:print]   amp_enabled .................. False
[2023-03-20 20:34:22,194] [INFO] [config.py:1012:print]   amp_params ................... False
[2023-03-20 20:34:22,194] [INFO] [config.py:1012:print]   autotuning_config ............ {
    "enabled": false, 
    "start_step": null, 
    "end_step": null, 
    "metric_path": null, 
    "arg_mappings": null, 
    "metric": "throughput", 
    "model_info": null, 
    "results_dir": "autotuning_results", 
    "exps_dir": "autotuning_exps", 
    "overwrite": true, 
    "fast": true, 
    "start_profile_step": 3, 
    "end_profile_step": 5, 
    "tuner_type": "gridsearch", 
    "tuner_early_stopping": 5, 
    "tuner_num_trials": 50, 
    "model_info_path": null, 
    "mp_size": 1, 
    "max_train_batch_size": null, 
    "min_train_batch_size": 1, 
    "max_train_micro_batch_size_per_gpu": 1.024000e+03, 
    "min_train_micro_batch_size_per_gpu": 1, 
    "num_tuning_micro_batch_sizes": 3
}
[2023-03-20 20:34:22,194] [INFO] [config.py:1012:print]   bfloat16_enabled ............. False
[2023-03-20 20:34:22,194] [INFO] [config.py:1012:print]   checkpoint_parallel_write_pipeline  False
[2023-03-20 20:34:22,194] [INFO] [config.py:1012:print]   checkpoint_tag_validation_enabled  True
[2023-03-20 20:34:22,194] [INFO] [config.py:1012:print]   checkpoint_tag_validation_fail  False
[2023-03-20 20:34:22,194] [INFO] [config.py:1012:print]   comms_config ................. <deepspeed.comm.config.DeepSpeedCommsConfig object at 0x2b85a7eb2fd0>
[2023-03-20 20:34:22,194] [INFO] [config.py:1012:print]   communication_data_type ...... None
[2023-03-20 20:34:22,194] [INFO] [config.py:1012:print]   compression_config ........... {'weight_quantization': {'shared_parameters': {'enabled': False, 'quantizer_kernel': False, 'schedule_offset': 0, 'quantize_groups': 1, 'quantize_verbose': False, 'quantization_type': 'symmetric', 'quantize_weight_in_forward': False, 'rounding': 'nearest', 'fp16_mixed_quantize': False, 'quantize_change_ratio': 0.001}, 'different_groups': {}}, 'activation_quantization': {'shared_parameters': {'enabled': False, 'quantization_type': 'symmetric', 'range_calibration': 'dynamic', 'schedule_offset': 1000}, 'different_groups': {}}, 'sparse_pruning': {'shared_parameters': {'enabled': False, 'method': 'l1', 'schedule_offset': 1000}, 'different_groups': {}}, 'row_pruning': {'shared_parameters': {'enabled': False, 'method': 'l1', 'schedule_offset': 1000}, 'different_groups': {}}, 'head_pruning': {'shared_parameters': {'enabled': False, 'method': 'topk', 'schedule_offset': 1000}, 'different_groups': {}}, 'channel_pruning': {'shared_parameters': {'enabled': False, 'method': 'l1', 'schedule_offset': 1000}, 'different_groups': {}}, 'layer_reduction': {'enabled': False}}
[2023-03-20 20:34:22,194] [INFO] [config.py:1012:print]   curriculum_enabled_legacy .... False
[2023-03-20 20:34:22,194] [INFO] [config.py:1012:print]   curriculum_params_legacy ..... False
[2023-03-20 20:34:22,194] [INFO] [config.py:1012:print]   data_efficiency_config ....... {'enabled': False, 'seed': 1234, 'data_sampling': {'enabled': False, 'num_epochs': 1000, 'num_workers': 0, 'curriculum_learning': {'enabled': False}}, 'data_routing': {'enabled': False, 'random_ltd': {'enabled': False, 'layer_token_lr_schedule': {'enabled': False}}}}
[2023-03-20 20:34:22,194] [INFO] [config.py:1012:print]   data_efficiency_enabled ...... False
[2023-03-20 20:34:22,194] [INFO] [config.py:1012:print]   dataloader_drop_last ......... False
[2023-03-20 20:34:22,194] [INFO] [config.py:1012:print]   disable_allgather ............ False
[2023-03-20 20:34:22,194] [INFO] [config.py:1012:print]   dump_state ................... False
[2023-03-20 20:34:22,194] [INFO] [config.py:1012:print]   dynamic_loss_scale_args ...... {'init_scale': 4096, 'scale_window': 500, 'delayed_shift': 2, 'min_scale': 1}
[2023-03-20 20:34:22,194] [INFO] [config.py:1012:print]   eigenvalue_enabled ........... False
[2023-03-20 20:34:22,194] [INFO] [config.py:1012:print]   eigenvalue_gas_boundary_resolution  1
[2023-03-20 20:34:22,194] [INFO] [config.py:1012:print]   eigenvalue_layer_name ........ bert.encoder.layer
[2023-03-20 20:34:22,194] [INFO] [config.py:1012:print]   eigenvalue_layer_num ......... 0
[2023-03-20 20:34:22,194] [INFO] [config.py:1012:print]   eigenvalue_max_iter .......... 100
[2023-03-20 20:34:22,194] [INFO] [config.py:1012:print]   eigenvalue_stability ......... 1e-06
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   eigenvalue_tol ............... 0.01
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   eigenvalue_verbose ........... False
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   elasticity_enabled ........... False
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   flops_profiler_config ........ {
    "enabled": false, 
    "profile_step": 1, 
    "module_depth": -1, 
    "top_modules": 1, 
    "detailed": true, 
    "output_file": null
}
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   fp16_auto_cast ............... False
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   fp16_enabled ................. True
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   fp16_master_weights_and_gradients  False
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   global_rank .................. 0
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   grad_accum_dtype ............. None
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   gradient_accumulation_steps .. 1
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   gradient_clipping ............ 0.0
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   gradient_predivide_factor .... 1.0
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   initial_dynamic_scale ........ 4096
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   load_universal_checkpoint .... False
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   loss_scale ................... 0
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   memory_breakdown ............. False
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   monitor_config ............... <deepspeed.monitor.config.DeepSpeedMonitorConfig object at 0x2b85a7ee4070>
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   nebula_config ................ {
    "enabled": false, 
    "persistent_storage_path": null, 
    "persistent_time_interval": 100, 
    "num_of_version_in_retention": 2, 
    "enable_nebula_load": true, 
    "load_path": null
}
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   optimizer_legacy_fusion ...... False
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   optimizer_name ............... None
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   optimizer_params ............. None
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   pipeline ..................... {'stages': 'auto', 'partition': 'best', 'seed_layers': False, 'activation_checkpoint_interval': 0}
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   pld_enabled .................. False
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   pld_params ................... False
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   prescale_gradients ........... False
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   scheduler_name ............... None
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   scheduler_params ............. None
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   sparse_attention ............. None
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   sparse_gradients_enabled ..... False
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   steps_per_print .............. inf
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   train_batch_size ............. 64
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   train_micro_batch_size_per_gpu  1
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   use_node_local_storage ....... False
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   wall_clock_breakdown ......... False
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   world_size ................... 64
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   zero_allow_untested_optimizer  True
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   zero_config .................. stage=3 contiguous_gradients=True reduce_scatter=True reduce_bucket_size=205520896 allgather_partitions=True allgather_bucket_size=500,000,000 overlap_comm=True load_from_fp32_weights=True elastic_checkpoint=False offload_param=DeepSpeedZeroOffloadParamConfig(device='none', nvme_path=None, buffer_count=5, buffer_size=100,000,000, max_in_cpu=1,000,000,000, pin_memory=False) offload_optimizer=DeepSpeedZeroOffloadOptimizerConfig(device='none', nvme_path=None, buffer_count=4, pin_memory=False, pipeline=False, pipeline_read=False, pipeline_write=False, fast_init=False) sub_group_size=1000000 cpu_offload_param=None cpu_offload_use_pin_memory=None cpu_offload=None prefetch_bucket_size=184968807 param_persistence_threshold=143360 model_persistence_threshold=sys.maxsize max_live_parameters=1000000 max_reuse_distance=1000000 gather_16bit_weights_on_model_save=True stage3_gather_fp16_weights_on_model_save=False ignore_unused_parameters=True legacy_stage1=False round_robin_gradients=False
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   zero_enabled ................. True
[2023-03-20 20:34:22,195] [INFO] [config.py:1012:print]   zero_optimization_stage ...... 3
[2023-03-20 20:34:22,195] [INFO] [config.py:997:print_user_config]   json = {
    "fp16": {
        "enabled": true, 
        "auto_cast": false, 
        "loss_scale": 0, 
        "initial_scale_power": 12, 
        "loss_scale_window": 500, 
        "hysteresis": 2, 
        "min_loss_scale": 1
    }, 
    "bf16": {
        "enabled": false
    }, 
    "zero_optimization": {
        "stage": 3, 
        "offload_optimizer": {
            "device": "none"
        }, 
        "offload_param": {
            "device": "none"
        }, 
        "overlap_comm": true, 
        "contiguous_gradients": true, 
        "reduce_bucket_size": 2.055209e+08, 
        "stage3_prefetch_bucket_size": 1.849688e+08, 
        "stage3_param_persistence_threshold": 1.433600e+05, 
        "sub_group_size": 1.000000e+06, 
        "stage3_max_live_parameters": 1.000000e+06, 
        "stage3_max_reuse_distance": 1.000000e+06, 
        "stage3_gather_16bit_weights_on_model_save": true
    }, 
    "steps_per_print": inf, 
    "train_batch_size": 64, 
    "train_micro_batch_size_per_gpu": 1, 
    "gradient_accumulation_steps": 1, 
    "wall_clock_breakdown": false, 
    "zero_allow_untested_optimizer": true
}
...

/home/u4005115/.conda/envs/alex-peft2/lib/python3.8/site-packages/torch/cuda/memory.py:282: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
/home/u4005115/.conda/envs/alex-peft2/lib/python3.8/site-packages/torch/cuda/memory.py:282: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
...

  0%|          | 0/74939408 [00:00<?, ?it/s]/home/u4005115/.conda/envs/alex-peft2/lib/python3.8/site-packages/torch/cuda/memory.py:282: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
/home/u4005115/.conda/envs/alex-peft2/lib/python3.8/site-packages/torch/cuda/memory.py:282: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(

  0%|          | 0/74939408 [00:00<?, ?it/s]/home/u4005115/.conda/envs/alex-peft2/lib/python3.8/site-packages/torch/cuda/memory.py:282: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(
/home/u4005115/.conda/envs/alex-peft2/lib/python3.8/site-packages/torch/cuda/memory.py:282: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
  warnings.warn(

  0%|          | 0/74939408 [00:00<?, ?it/s]/home/u4005115/.conda/envs/alex-peft2/lib/python3.8/site-packages/torch/cuda/memory.py:282: FutureWarning: torch.cuda.reset_max_memory_allocated now calls torch.cuda.reset_peak_memory_stats, which resets /all/ peak memory stats.
 
/home/u4005115/.conda/envs/alex-peft2/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py:2387: UserWarning: torch.distributed._all_gather_base is a private function and will be deprecated. Please use torch.distributed.all_gather_into_tensor instead.
  warnings.warn(
...
2023-03-20 20:38:19,874] [WARNING] [stage3.py:1939:step] 13 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding torch.cuda.empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time
[2023-03-20 20:41:59,356] [WARNING] [stage3.py:1939:step] 1 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding torch.cuda.empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time

  0%|          | 1/74939408 [03:57<4946389:59:14, 237.62s/it]
  0%|          | 1/74939408 [03:57<4945621:54:07, 237.58s/it]
  0%|          | 1/74939408 [03:57<4945738:05:30, 237.59s/it]
  0%|          | 1/74939408 [03:57<4945036:00:00, 237.55s/it]
  0%|          | 1/74939408 [03:57<4946569:45:30, 237.63s/it]
  0%|          | 1/74939408 [03:57<4946731:32:01, 237.64s/it]
  0%|          | 1/74939408 [03:57<4946408:18:56, 237.62s/it]
  0%|          | 1/74939408 [03:57<4943964:52:18, 237.50s/it]
  0%|          | 2/74939408 [07:37<4724042:37:06, 226.94s/it]
  0%|          | 2/74939408 [07:37<4723997:24:53, 226.94s/it]
  0%|          | 2/74939408 [07:37<4724317:43:23, 226.95s/it]

nvidia-smi shows

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.161.03   Driver Version: 470.161.03   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla V100-SXM2...  On   | 00000000:1B:00.0 Off |                    0 |
| N/A   30C    P0    70W / 300W |  27016MiB / 32510MiB |     99%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2...  On   | 00000000:1C:00.0 Off |                    0 |
| N/A   28C    P0    71W / 300W |  30508MiB / 32510MiB |    100%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   2  Tesla V100-SXM2...  On   | 00000000:3D:00.0 Off |                    0 |
| N/A   28C    P0    67W / 300W |  32426MiB / 32510MiB |     99%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   3  Tesla V100-SXM2...  On   | 00000000:3E:00.0 Off |                    0 |
| N/A   31C    P0    72W / 300W |  26860MiB / 32510MiB |     99%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   4  Tesla V100-SXM2...  On   | 00000000:B1:00.0 Off |                    0 |
| N/A   29C    P0    71W / 300W |  32388MiB / 32510MiB |     99%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   5  Tesla V100-SXM2...  On   | 00000000:B2:00.0 Off |                    0 |
| N/A   30C    P0    72W / 300W |  32392MiB / 32510MiB |     99%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   6  Tesla V100-SXM2...  On   | 00000000:DB:00.0 Off |                    0 |
| N/A   30C    P0    72W / 300W |  26896MiB / 32510MiB |     99%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   7  Tesla V100-SXM2...  On   | 00000000:DC:00.0 Off |                    0 |
| N/A   28C    P0    69W / 300W |  26914MiB / 32510MiB |    100%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A      2273      C   ...nvs/alex-peft2/bin/python    27013MiB |
|    1   N/A  N/A      2274      C   ...nvs/alex-peft2/bin/python    30505MiB |
|    2   N/A  N/A      2275      C   ...nvs/alex-peft2/bin/python    32423MiB |
|    3   N/A  N/A      2276      C   ...nvs/alex-peft2/bin/python    26857MiB |
|    4   N/A  N/A      2278      C   ...nvs/alex-peft2/bin/python    32385MiB |
|    5   N/A  N/A      2279      C   ...nvs/alex-peft2/bin/python    32389MiB |
|    6   N/A  N/A      2281      C   ...nvs/alex-peft2/bin/python    26893MiB |
|    7   N/A  N/A      2283      C   ...nvs/alex-peft2/bin/python    26911MiB |
+-----------------------------------------------------------------------------+

GPU-Util are 100%, but Pwr:Usage always at about 70W, they should be ~300W.

@pacman100
Copy link
Contributor

Hello @alex-ht, gently pinging @stas as they have a lot of experience with training models at such large scale.

Possible hypothesis:

  1. training is not compute bound. e.g. if the network is slow the gpus are just waiting for the data to come in. Please see [math] what network throughput is required to handle ZeRO-3 traffic? microsoft/DeepSpeed#2928 - . Please measure network throughput. https://github.com/stas00/toolbox/blob/master/pytorch/all_reduce_bench.py
  2. Slow DataLoader could be another cause.

@alex-ht
Copy link
Author

alex-ht commented Mar 21, 2023

Thanks for your information, @pacman100 !
We have tried megatron-deepspeed and network seems fast enough. I will check DataLoader later.

@alex-ht
Copy link
Author

alex-ht commented Mar 21, 2023

Should I worry about this warning?
[WARNING] [stage3.py:1939:step] 13 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding torch.cuda.empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time

@alex-ht
Copy link
Author

alex-ht commented Mar 21, 2023

ping @stas00

Hello @alex-ht, gently pinging @stas as they have a lot of experience with training models at such large scale.

Possible hypothesis:

  1. training is not compute bound. e.g. if the network is slow the gpus are just waiting for the data to come in. Please see [math] what network throughput is required to handle ZeRO-3 traffic? microsoft/DeepSpeed#2928 - . Please measure network throughput. https://github.com/stas00/toolbox/blob/master/pytorch/all_reduce_bench.py
  2. Slow DataLoader could be another cause.

@pacman100
Copy link
Contributor

Hello @alex-ht , Stas and I have discussed this internally and the above suggestions were from him. To put more context based on the discussion we had:

  1. Most likely it is very slow network which is a terrible situation for Deepspeed ZeRO-3. which thrives in very fast network environment.e.g. on JeanZay where the network is 50Gbps - which is very very very slow - and that's why we couldn't use ZeRO-3 for BLOOM training and had to use a more complicated but more network efficient Megatron-Deepspeed.
  2. Also make sure that DL is not a bottleneck - slow disc IO and too few workers could be another cause - if it's a cloud the IO could be networked and not local and can be pretty bad. There should be enough dedicated cpu cores/threads for each thread/worker. So at least (1x (main process) + num_workers (DL) ) * gpus if there are no other threads. So say if 2 workers, then at least 3*8=24 cores, but there could be other threads as well - check with py-spy - it'll tell you all threads https://github.com/stas00/toolbox/blob/master/pytorch/torch-distributed-hanging-solutions.md#py-spy.
    if you have less cpu cores/threads then the main process' thread could get pre-empted to allow say a DL thread to run, and now your gpu can't run while preempted. Typically pytorch will warn when this happens. I think it'll even force the number of threads down overriding user's settings to num_workers. If the disc is networked, it'll also compete with whatever other traffic the user shares with other users.

So, measuring the n/w speed and checking for DL bottlenecks might help in this case.

@pacman100
Copy link
Contributor

If there were doubts if having the base model frozen might be the cause, Stas explained the following:

=> GPU's high utilization directly correlates to matrix sizes - the larger the matrices the more efficient the compute will be - so typically largers bs+seqlen will lead to better gpu compute. Make sure to enable grad checkpointing and raise the BS.
if the model is mostly frozen, then the forward+backward will still be the same, but only the step will be a bit faster since there will be less grads to add. So there is little difference compared to unfrozen model compute-wise.
The only real saving is memory (since you don't need to allocate grads+optim states for frozen weights)
Specifically to Deepspeed there will be less network traffic since there are less optim states and grads to sync. But the traffic will be still dominated by Stage-3 traffic.

@zsc
Copy link

zsc commented Apr 3, 2023

FYI, with some tweaking BLOOM-176B can be LoRA fine-tuned on 8xA100-40G: tloen/alpaca-lora#130 (comment)

@pacman100
Copy link
Contributor

Hello @zsc, that is super cool! Thank you for sharing 🤗

@zsc
Copy link

zsc commented Apr 5, 2023

😄 An update, I'm trying to push PiPPy people to have support for PEFT, so that we can enjoy true pipeline parallelism that will really exploit the multi-gpu. pytorch/PiPPy#773

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

@github-actions github-actions bot closed this as completed May 7, 2023
@nicosouth
Copy link

Should I worry about this warning? [WARNING] [stage3.py:1939:step] 13 pytorch allocator cache flushes since last step. this happens when there is high memory pressure and is detrimental to performance. if this is happening frequently consider adjusting settings to reduce memory consumption. If you are unable to make the cache flushes go away consider adding torch.cuda.empty_cache() calls in your training loop to ensure that all ranks flush their caches at the same time

I have same worry. have you solve this problems without reducing batch_size?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants