-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tuning BLOOMZ 176B #194
Comments
Hello, could you provide more information on setup, details like are you using PEFT with INT8 training or using DeeSpeed with CPU offload... Also details like what is the GPU utilisation? |
I'm using DeepSpeed, and offload device is set to none. # creating model
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.float16)
peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)
model = get_peft_model(model, peft_config)
model = model.half()
model.print_trainable_parameters() accelerate config: compute_environment: LOCAL_MACHINE
deepspeed_config:
deepspeed_config_file: ds_config.json
deepspeed_multinode_launcher: standard
zero3_init_flag: true
distributed_type: DEEPSPEED
downcast_bf16: 'no'
dynamo_backend: 'NO'
fsdp_config: {}
machine_rank: 0
main_process_ip: ***
main_process_port: 18049
main_training_function: main
megatron_lm_config: {}
num_machines: 8
num_processes: 64
rdzv_backend: static
same_network: true
use_cpu: false ds_config.json: {
"fp16": {
"enabled": true,
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 12,
"loss_scale_window": 500,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": false
},
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "none"
},
"offload_param": {
"device": "none"
},
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_bucket_size": 205520896,
"stage3_prefetch_bucket_size": 184968807,
"stage3_param_persistence_threshold": 143360,
"sub_group_size": 1e+6,
"stage3_max_live_parameters": 1e+6,
"stage3_max_reuse_distance": 1e+6,
"stage3_gather_16bit_weights_on_model_save": true
},
"steps_per_print": 2000,
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 1,
"wall_clock_breakdown": false
} log file:
nvidia-smi shows
|
Hello @alex-ht, gently pinging @stas as they have a lot of experience with training models at such large scale. Possible hypothesis:
|
Thanks for your information, @pacman100 ! |
Should I worry about this warning? |
ping @stas00
|
Hello @alex-ht , Stas and I have discussed this internally and the above suggestions were from him. To put more context based on the discussion we had:
So, measuring the n/w speed and checking for DL bottlenecks might help in this case. |
If there were doubts if having the base model frozen might be the cause, Stas explained the following: => GPU's high utilization directly correlates to matrix sizes - the larger the matrices the more efficient the compute will be - so typically largers bs+seqlen will lead to better gpu compute. Make sure to enable grad checkpointing and raise the BS. |
FYI, with some tweaking BLOOM-176B can be LoRA fine-tuned on 8xA100-40G: tloen/alpaca-lora#130 (comment) |
Hello @zsc, that is super cool! Thank you for sharing 🤗 |
😄 An update, I'm trying to push PiPPy people to have support for PEFT, so that we can enjoy true pipeline parallelism that will really exploit the multi-gpu. pytorch/PiPPy#773 |
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
I have same worry. have you solve this problems without reducing batch_size? |
Hi, I'm wondering if I could use peft to finetune 176B BLOOMZ?
I am experiencing poor GPU utilization efficiency with 64 V100s.
The text was updated successfully, but these errors were encountered: